diff --git a/fluid/image_classification/caffe2fluid/README.md b/fluid/image_classification/caffe2fluid/README.md index 64f6b9cf90..ef0123b669 100644 --- a/fluid/image_classification/caffe2fluid/README.md +++ b/fluid/image_classification/caffe2fluid/README.md @@ -1,35 +1,63 @@ ### Caffe2Fluid -This tool is used to convert a Caffe model to Fluid model +This tool is used to convert a Caffe model to a Fluid model -### Howto +### HowTo 1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: -- Generate pycaffe from caffe.proto -
bash ./proto/compile.sh
+ - Generate pycaffe from caffe.proto + ``` + bash ./proto/compile.sh + ``` -- download one from github directly -
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
-
+ - Download one from github directly + ``` + cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py + ``` 2. Convert the Caffe model to Fluid model -- generate fluid code and weight file -
python convert.py alexnet.prototxt \
-        --caffemodel alexnet.caffemodel \
-        --data-output-path alexnet.npy \
-        --code-output-path alexnet.py
-
+ - Generate fluid code and weight file + ``` + python convert.py alexnet.prototxt \ + --caffemodel alexnet.caffemodel \ + --data-output-path alexnet.npy \ + --code-output-path alexnet.py + ``` -- save weights as fluid model file -
python alexnet.py alexnet.npy ./fluid_model
-
+ - Save weights as fluid model file + ``` + python alexnet.py alexnet.npy ./fluid + ``` 3. Use the converted model to infer -- see more details in '*examples/imagenet/run.sh*' + - See more details in '*examples/imagenet/run.sh*' -4. compare the inference results with caffe -- see more details in '*examples/imagenet/diff.sh*' +4. Compare the inference results with caffe + - See more details in '*examples/imagenet/diff.sh*' + +### How to convert custom layer +1. Implement your custom layer in a file under '*kaffe/custom_layers*', eg: mylayer.py + - Implement ```shape_func(input_shape, [other_caffe_params])``` to calculate the output shape + - Implement ```layer_func(inputs, name, [other_caffe_params])``` to construct a fluid layer + - Register these two functions ```register(kind='MyType', shape=shape_func, layer=layer_func)``` + - Notes: more examples can be found in '*kaffe/custom_layers*' + +2. Add ```import mylayer``` to '*kaffe/custom_layers/\_\_init__.py*' + +3. Prepare your pycaffe as your customized version(same as previous env prepare) + - (option1) replace 'proto/caffe.proto' with your own caffe.proto and compile it + - (option2) change your pycaffe to the customized version + +4. Convert the Caffe model to Fluid model + +5. Set env $CAFFE2FLUID_CUSTOM_LAYERS to the parent directory of 'custom_layers' + ``` + export CAFFE2FLUID_CUSTOM_LAYERS=/path/to/caffe2fluid/kaffe + ``` + +6. Use the converted model when loading model in 'xxxnet.py' and 'xxxnet.npy'(no need if model is already in 'fluid/model' and 'fluid/params') ### Tested models -- Lenet +- Lenet: +[model addr](https://github.com/ethereon/caffe-tensorflow/blob/master/examples/mnist) - ResNets:(ResNet-50, ResNet-101, ResNet-152) [model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777) diff --git a/fluid/image_classification/caffe2fluid/convert.py b/fluid/image_classification/caffe2fluid/convert.py index 379f1a2636..b0252e3c03 100755 --- a/fluid/image_classification/caffe2fluid/convert.py +++ b/fluid/image_classification/caffe2fluid/convert.py @@ -43,11 +43,17 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path, print_stderr('Saving source...') with open(code_output_path, 'wb') as src_out: src_out.write(transformer.transform_source()) + print_stderr('set env variable before using converted model '\ + 'if used custom_layers:') + custom_pk_path = os.path.dirname(os.path.abspath(__file__)) + custom_pk_path = os.path.join(custom_pk_path, 'kaffe') + print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path)) print_stderr('Done.') + return 0 except KaffeError as err: fatal_error('Error encountered: {}'.format(err)) - return 0 + return 1 def main(): diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py index 099c0abb2e..ef36b6975f 100644 --- a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py @@ -164,7 +164,6 @@ def infer(model_path, imgfile, net_file=None, net_name=None, debug=True): debug = False print('found a inference model for fluid') except ValueError as e: - pass print('try to load model using net file and weight file') net_weight = model_path ret = load_model(exe, place, net_file, net_name, net_weight, debug) diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py index 5c86635d5a..946fa94372 100644 --- a/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py +++ b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py @@ -7,8 +7,8 @@ import sys import os import numpy as np +import paddle.fluid as fluid import paddle.v2 as paddle -import paddle.v2.fluid as fluid def test_model(exe, test_program, fetch_list, test_reader, feeder): @@ -34,9 +34,6 @@ def evaluate(net_file, model_file): from lenet import LeNet as MyNet - with_gpu = False - paddle.init(use_gpu=with_gpu) - #1, define network topology images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') @@ -45,7 +42,7 @@ def evaluate(net_file, model_file): prediction = net.layers['prob'] acc = fluid.layers.accuracy(input=prediction, label=label) - place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace() + place = fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py new file mode 100644 index 0000000000..ba3e53bee8 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py @@ -0,0 +1,104 @@ +""" +""" + +from .register import get_registered_layers +#custom layer import begins + +import axpy +import flatten +import argmax + +#custom layer import ends + +custom_layers = get_registered_layers() + + +def set_args(f, params): + """ set args for function 'f' using the parameters in node.layer.parameters + + Args: + f (function): a python function object + params (object): a object contains attributes needed by f's arguments + + Returns: + arg_names (list): a list of argument names + kwargs (dict): a dict contains needed arguments + """ + argc = f.__code__.co_argcount + arg_list = f.__code__.co_varnames[0:argc] + + kwargs = {} + for arg_name in arg_list: + try: + v = getattr(node.layer.parameters, arg_name, None) + except Exception as e: + v = None + + if v is not None: + kwargs[arg_name] = v + + return arg_list, kwargs + + +def has_layer(kind): + """ test whether this layer exists in custom layer + """ + return kind in custom_layers + + +def compute_output_shape(kind, node): + assert kind in custom_layers, "layer[%s] not exist in custom layers" % ( + kind) + shape_func = custom_layers[kind]['shape'] + + parents = node.parents + inputs = [list(p.output_shape) for p in parents] + arg_names, kwargs = set_args(shape_func, node.layer.parameters) + + if len(inputs) == 1: + inputs = inputs[0] + + return shape_func(inputs, **kwargs) + + +def make_node(template, kind, node): + """ make a TensorFlowNode for custom layer which means construct + a piece of code to define a layer implemented in 'custom_layers' + + Args: + @template (TensorFlowNode): a factory to new a instance of TensorFLowNode + @kind (str): type of custom layer + @node (graph.Node): a layer in the net + + Returns: + instance of TensorFlowNode + """ + assert kind in custom_layers, "layer[%s] not exist in custom layers" % ( + kind) + + layer_func = custom_layers[kind]['layer'] + + #construct arguments needed by custom layer function from node's parameters + arg_names, kwargs = set_args(layer_func, node.layer.parameters) + + return template('custom_layer', kind, **kwargs) + + +def make_custom_layer(kind, inputs, name, *args, **kwargs): + """ execute a custom layer which is implemented by users + + Args: + @kind (str): type name of this layer + @inputs (vars): variable list created by fluid + @namme (str): name for this layer + @args (tuple): other positional arguments + @kwargs (dict): other kv arguments + + Returns: + output (var): output variable for this layer + """ + assert kind in custom_layers, "layer[%s] not exist in custom layers" % ( + kind) + + layer_func = custom_layers[kind]['layer'] + return layer_func(inputs, name, *args, **kwargs) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py new file mode 100644 index 0000000000..17ad683614 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py @@ -0,0 +1,70 @@ +""" a custom layer for 'argmax', maybe we should implement this in standard way. + more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/argmax.html +""" +from .register import register + + +def import_fluid(): + import paddle.fluid as fluid + return fluid + + +def argmax_shape(input_shape, out_max_val=False, top_k=1, axis=-1): + """ calculate the output shape of this layer using input shape + + Args: + @input_shape (list of num): a list of number which represents the input shape + @out_max_val (bool): parameter from caffe's ArgMax layer + @top_k (int): parameter from caffe's ArgMax layer + @axis (int): parameter from caffe's ArgMax layer + + Returns: + @output_shape (list of num): a list of numbers represent the output shape + """ + input_shape = list(input_shape) + + if axis < 0: + axis += len(input_shape) + + assert (axis + 1 == len(input_shape) + ), 'only can be applied on the last dimension now' + + output_shape = input_shape + output_shape[-1] = top_k + if out_max_val is True: + output_shape[-1] *= 2 + + return output_shape + + +def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1): + """ build a layer of type 'ArgMax' using fluid + + Args: + @input (variable): input fluid variable for this layer + @name (str): name for this layer + @out_max_val (bool): parameter from caffe's ArgMax layer + @top_k (int): parameter from caffe's ArgMax layer + @axis (int): parameter from caffe's ArgMax layer + + Returns: + output (variable): output variable for this layer + """ + + fluid = import_fluid() + + if axis < 0: + axis += len(input.shape) + + assert (axis + 1 == len(input_shape) + ), 'only can be applied on the last dimension now' + + topk_var, index_var = fluid.layers.topk(input=input, k=top_k) + if out_max_val is True: + output = fluid.layers.concate([topk_var, index_var], axis=axis) + else: + output = topk_var + return output + + +register(kind='ArgMax', shape=argmax_shape, layer=argmax_layer) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py new file mode 100644 index 0000000000..389bb7996e --- /dev/null +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py @@ -0,0 +1,51 @@ +""" A custom layer for 'axpy' which receives 3 tensors and output 1 tensor. + the function performed is:(the mupltiplication and add are elementewise) + output = inputs[0] * inputs[1] + inputs[2] +""" + +from .register import register + + +def axpy_shape(input_shapes): + """ calculate the output shape of this layer using input shapes + + Args: + @input_shapes (list of tuples): a list of input shapes + + Returns: + @output_shape (list of num): a list of numbers represent the output shape + """ + assert len(input_shapes) == 3, "not valid input shape for axpy layer" + assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims' + + output_shape = input_shapes[1] + assert (input_shapes[2] == output_shape),\ + "shape not consistent for axpy[%s <--> %s]" \ + % (str(output_shape), str(input_shapes[2])) + + return output_shape + + +def axpy_layer(inputs, name): + """ build a layer of type 'Axpy' using fluid + + Args: + @inputs (list of variables): input fluid variables for this layer + @name (str): name for this layer + + Returns: + output (variable): output variable for this layer + """ + import paddle.fluid as fluid + + assert len(inputs) == 3, "invalid inputs for axpy[%s]" % (name) + alpha = inputs[0] + x = inputs[1] + y = inputs[2] + output = fluid.layers.elementwise_mul(x, alpha, axis=0) + output = fluid.layers.elementwise_add(output, y) + + return output + + +register(kind='Axpy', shape=axpy_shape, layer=axpy_layer) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py new file mode 100644 index 0000000000..8f7af4266f --- /dev/null +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py @@ -0,0 +1,73 @@ +""" a custom layer for 'flatten', maybe we should implement this in standard way. + more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/flatten.html +""" +from .register import register + + +def import_fluid(): + import paddle.fluid as fluid + return fluid + + +def flatten_shape(input_shape, axis=1, end_axis=-1): + """ calculate the output shape of this layer using input shape + + Args: + @input_shape (list of num): a list of number which represents the input shape + @axis (int): parameter from caffe's Flatten layer + @end_axis (int): parameter from caffe's Flatten layer + + Returns: + @output_shape (list of num): a list of numbers represent the output shape + """ + + start_axis = axis + end_axis = end_axis + input_shape = list(input_shape) + if start_axis < 0: + start_axis += len(input_shape) + + if end_axis < 0: + end_axis += len(input_shape) + + assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ + % (start_axis, end_axis) + output_shape = input_shape[0:start_axis] + flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis]) + output_shape += [flat_sz] + output_shape += input_shape[end_axis:-1] + + return output_shape + + +def flatten_layer(input, name, axis=1, end_axis=-1): + """ build a layer of type 'Flatten' using fluid + + Args: + @input (variable): input fluid variable for this layer + @name (str): name for this layer + @axis (int): parameter from caffe's Flatten layer + @end_axis (int): parameter from caffe's Flatten layer + + Returns: + output (variable): output variable for this layer + """ + fluid = import_fluid() + + input_shape = list(input.shape) + dims = len(input_shape) + start_axis = axis if axis >= 0 else axis + dims + end_axis = end_axis if end_axis >= 0 else end_axis + dims + + assert start_axis <= end_axis, 'invalid axis or end_axis params' + output_shape = input_shape[0:start_axis] + flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis]) + output_shape += [flat_sz] + output_shape += input_shape[end_axis:-1] + + output = fluid.layers.reshape(input, shape=output_shape, name=name) + + return output + + +register(kind='Flatten', shape=flatten_shape, layer=flatten_layer) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py new file mode 100644 index 0000000000..ae806cd469 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py @@ -0,0 +1,37 @@ +""" this module provides 'register' for registering customized layers +""" + +g_custom_layers = {} + + +def register(kind, shape, layer): + """ register a custom layer or a list of custom layers + + Args: + @kind (str or list): type name of the layer + @shape (function): a function to generate the shape of layer's output + @layer (function): a function to generate the shape of layer's output + + Returns: + None + """ + assert type(shape).__name__ == 'function', 'shape should be a function' + assert type(layer).__name__ == 'function', 'layer should be a function' + + if type(kind) is str: + kind = [kind] + else: + assert type( + kind) is list, 'invalid param "kind" for register, not a list or str' + + for k in kind: + assert type( + k) is str, 'invalid param "kind" for register, not a list of str' + assert k not in g_custom_layers, 'this type[%s] has already been registered' % ( + k) + print('register layer[%s]' % (k)) + g_custom_layers[k] = {'shape': shape, 'layer': layer} + + +def get_registered_layers(): + return g_custom_layers diff --git a/fluid/image_classification/caffe2fluid/kaffe/graph.py b/fluid/image_classification/caffe2fluid/kaffe/graph.py index c6fdada6e7..8f43b86ff0 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/graph.py +++ b/fluid/image_classification/caffe2fluid/kaffe/graph.py @@ -3,7 +3,7 @@ from .caffe import get_caffe_resolver from .errors import KaffeError, print_stderr from .layers import LayerAdapter, LayerType, NodeKind, NodeDispatch -from .shapes import TensorShape +from .shapes import make_tensor class Node(object): @@ -98,7 +98,7 @@ def visit(node): def compute_output_shapes(self): sorted_nodes = self.topologically_sorted() for node in sorted_nodes: - node.output_shape = TensorShape( + node.output_shape = make_tensor( *NodeKind.compute_output_shape(node)) def replaced(self, new_nodes): @@ -111,6 +111,7 @@ def transformed(self, transformers): if graph is None: raise KaffeError('Transformer failed: {}'.format(transformer)) assert isinstance(graph, Graph) + return graph def __contains__(self, key): @@ -237,6 +238,7 @@ def build(self): if (parent_node is None) or (parent_node == node): parent_node = graph.get_node(input_name) node.add_parent(parent_node) + if len(layer.top) > 1: raise KaffeError('Multiple top nodes are not supported.') diff --git a/fluid/image_classification/caffe2fluid/kaffe/layers.py b/fluid/image_classification/caffe2fluid/kaffe/layers.py index f263407ab4..dcdd26040b 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/layers.py +++ b/fluid/image_classification/caffe2fluid/kaffe/layers.py @@ -2,6 +2,7 @@ import numbers from collections import namedtuple +import custom_layers from .shapes import * LAYER_DESCRIPTORS = { @@ -116,6 +117,9 @@ def get_v1_layer_map(): class NodeKind(LayerType): @staticmethod def map_raw_kind(kind): + if custom_layers.has_layer(kind): + return kind + if kind in LAYER_TYPES: return kind @@ -127,6 +131,9 @@ def map_raw_kind(kind): @staticmethod def compute_output_shape(node): + if custom_layers.has_layer(node.kind): + return custom_layers.compute_output_shape(node.kind, node) + try: val = LAYER_DESCRIPTORS[node.kind](node) return val @@ -137,14 +144,13 @@ def compute_output_shape(node): class NodeDispatchError(KaffeError): - pass class NodeDispatch(object): @staticmethod def get_handler_name(node_kind): - if len(node_kind) <= 4: + if len(node_kind) <= 6: # A catch-all for things like ReLU and tanh return node_kind.lower() # Convert from CamelCase to under_scored @@ -152,6 +158,9 @@ def get_handler_name(node_kind): return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() def get_handler(self, node_kind, prefix): + if custom_layers.has_layer(node_kind): + return getattr(self, 'map_custom') + name = self.get_handler_name(node_kind) name = '_'.join((prefix, name)) try: @@ -174,8 +183,10 @@ def parameters(self): try: return getattr(self.layer, name) except AttributeError: + print(dir(self.layer)) raise NodeDispatchError( - 'Caffe parameters not found for layer kind: %s' % (self.kind)) + 'Caffe parameters not found attr[%s] for layer kind[%s]' % + (name, self.kind)) @staticmethod def get_kernel_value(scalar, repeated, idx, default=None): diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py index ac5ecf1d44..ab080cd013 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py @@ -1,5 +1,6 @@ -import math +import sys import os +import math import numpy as np @@ -161,7 +162,8 @@ def relu(self, input, name): output = fluid.layers.relu(x=input) return output - def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): + def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding, + name): # Get the number of channels in the input in_hw = input.shape[2:] k_hw = [k_h, k_w] @@ -173,17 +175,40 @@ def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): pool_size=k_hw, pool_stride=s_hw, pool_padding=padding, - ceil_mode=True, + ceil_mode=ceil_mode, pool_type=pool_type) return output @layer - def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]): - return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding) + def max_pool(self, + input, + k_h, + k_w, + s_h, + s_w, + ceil_mode, + padding=[0, 0], + name=None): + return self.pool('max', input, k_h, k_w, s_h, s_w, ceil_mode, padding, + name) @layer - def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]): - return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding) + def avg_pool(self, + input, + k_h, + k_w, + s_h, + s_w, + ceil_mode, + padding=[0, 0], + name=None): + return self.pool('avg', input, k_h, k_w, s_h, s_w, ceil_mode, padding, + name) + + @layer + def sigmoid(self, input, name): + fluid = import_fluid() + return fluid.layers.sigmoid(input) @layer def lrn(self, input, radius, alpha, beta, name, bias=1.0): @@ -264,3 +289,21 @@ def dropout(self, input, drop_prob, name, is_test=True): output = fluid.layers.dropout( input, dropout_prob=drop_prob, is_test=is_test, name=name) return output + + @layer + def custom_layer(self, inputs, kind, name, *args, **kwargs): + """ make custom layer from the package specified by '$CAFFE2FLUID_CUSTOM_LAYERS' + """ + #fluid = import_fluid() + #import custom package + default = os.path.dirname(os.path.abspath(__file__)) + p = os.environ.get('CAFFE2FLUID_CUSTOM_LAYERS', default) + pk = os.path.join(p, 'custom_layers') + assert os.path.exists(pk) is True, "not found custom_layer package [%s],"\ + "you need to set $CAFFE2FLUID_CUSTOM_LAYERS" % (pk) + + if p not in sys.path: + sys.path.insert(0, p) + + from custom_layers import make_custom_layer + return make_custom_layer(kind, inputs, name, *args, **kwargs) diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py index 20155e992f..7ce77bad21 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py @@ -109,9 +109,17 @@ def map_pooling(self, node): # Stochastic pooling, for instance. raise KaffeError('Unsupported pooling type.') (kernel_params, padding) = self.get_kernel_params(node) + ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True) return TensorFlowNode(pool_op, kernel_params.kernel_h, kernel_params.kernel_w, kernel_params.stride_h, - kernel_params.stride_w, **padding) + kernel_params.stride_w, ceil_mode, **padding) + + def map_sigmoid(self, node): + return TensorFlowNode('sigmoid') + + def map_custom(self, node): + from .. import custom_layers + return custom_layers.make_node(TensorFlowNode, node.kind, node) def map_inner_product(self, node): #TODO: Axis @@ -347,6 +355,7 @@ def load(self, def_path, data_path, phase): # (Caffe's GoogLeNet implementation uses slashes) NodeRenamer(lambda node: node.name.replace('/', '_')) ] + self.graph = graph.transformed(transformers) # Display the graph diff --git a/fluid/image_classification/caffe2fluid/kaffe/shapes.py b/fluid/image_classification/caffe2fluid/kaffe/shapes.py index e8124730c6..a2ce26362b 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/shapes.py +++ b/fluid/image_classification/caffe2fluid/kaffe/shapes.py @@ -3,8 +3,24 @@ from .errors import KaffeError -TensorShape = namedtuple('TensorShape', - ['batch_size', 'channels', 'height', 'width']) +Tensor4DShape = namedtuple('Tensor4DShape', + ['batch_size', 'channels', 'height', 'width']) + +Tensor2DShape = namedtuple('Tensor2DShape', ['batch_size', 'data']) + +ScalarShape = namedtuple('ScalarShape', ['batch_size']) + + +def make_tensor(batch_size, d1=None, d2=None, d3=None): + if d3 is not None: + return Tensor4DShape(batch_size, d1, d2, d3) + elif d1 is not None and d2 is None: + return Tensor2DShape(batch_size, d1) + elif d1 is None and d2 is None and d3 is None: + return ScalarShape(batch_size) + else: + raise NotImplementedError('invalid params for make_tensor %s' \ + % (str((batch_size, d1, d2, d3)))) def get_filter_output_shape(i_h, i_w, params, round_func): @@ -23,7 +39,7 @@ def get_strided_kernel_output_shape(node, round_func): params = node.layer.parameters has_c_o = hasattr(params, 'num_output') c = params.num_output if has_c_o else input_shape.channels - return TensorShape(input_shape.batch_size, c, o_h, o_w) + return make_tensor(input_shape.batch_size, c, o_h, o_w) def shape_not_implemented(node): @@ -36,7 +52,7 @@ def shape_identity(node): def shape_scalar(node): - return TensorShape(1, 1, 1, 1) + return make_tensor(1, 1, 1, 1) def shape_data(node): @@ -59,7 +75,7 @@ def shape_data(node): def shape_mem_data(node): params = node.parameters - return TensorShape(params.batch_size, params.channels, params.height, + return make_tensor(params.batch_size, params.channels, params.height, params.width) @@ -79,10 +95,15 @@ def shape_convolution(node): def shape_pool(node): - return get_strided_kernel_output_shape(node, math.ceil) + ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True) + if ceil_mode is True: + method = math.ceil + else: + method = math.floor + + return get_strided_kernel_output_shape(node, method) def shape_inner_product(node): input_shape = node.get_only_parent().output_shape - return TensorShape(input_shape.batch_size, node.layer.parameters.num_output, - 1, 1) + return make_tensor(input_shape.batch_size, node.layer.parameters.num_output) diff --git a/fluid/image_classification/caffe2fluid/kaffe/transformers.py b/fluid/image_classification/caffe2fluid/kaffe/transformers.py index 9d300ca9c9..8f08149698 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/transformers.py +++ b/fluid/image_classification/caffe2fluid/kaffe/transformers.py @@ -113,7 +113,10 @@ def has_spatial_parent(self, node): try: parent = node.get_only_parent() s = parent.output_shape - return s.height > 1 or s.width > 1 + if len(s) == 4: + return s.height > 1 or s.width > 1 + else: + return False except KaffeError: return False @@ -121,8 +124,8 @@ def map(self, node_kind): try: return self.mapping[node_kind] except KeyError: - raise - #raise KaffeError('Ordering not found for node kind: {}'.format(node_kind)) + raise KaffeError('Ordering not found for node kind: {}'.format( + node_kind)) def __call__(self, graph): for node in graph.nodes: @@ -178,7 +181,8 @@ def __call__(self, graph): continue # Rewrite the fused node's children to its parent. for child in node.children: - child.parents.remove(node) + pos = child.parents.index(node) + child.parents[pos] = parent parent.add_child(child) # Disconnect the fused node from the graph. parent.children.remove(node)