Skip to content

Commit

Permalink
caffe2fluid:upgrade argmax implementtion (#866)
Browse files Browse the repository at this point in the history
  • Loading branch information
walloollaw authored and qingqing01 committed Apr 19, 2018
1 parent 237fe2f commit e7684f0
Show file tree
Hide file tree
Showing 15 changed files with 507 additions and 52 deletions.
68 changes: 48 additions & 20 deletions fluid/image_classification/caffe2fluid/README.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,63 @@
### Caffe2Fluid
This tool is used to convert a Caffe model to Fluid model
This tool is used to convert a Caffe model to a Fluid model

### Howto
### HowTo
1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here:
- Generate pycaffe from caffe.proto
<pre><code>bash ./proto/compile.sh</code></pre>
- Generate pycaffe from caffe.proto
```
bash ./proto/compile.sh
```

- download one from github directly
<pre><code>cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
</code></pre>
- Download one from github directly
```
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
```

2. Convert the Caffe model to Fluid model
- generate fluid code and weight file
<pre><code>python convert.py alexnet.prototxt \
--caffemodel alexnet.caffemodel \
--data-output-path alexnet.npy \
--code-output-path alexnet.py
</code></pre>
- Generate fluid code and weight file
```
python convert.py alexnet.prototxt \
--caffemodel alexnet.caffemodel \
--data-output-path alexnet.npy \
--code-output-path alexnet.py
```

- save weights as fluid model file
<pre><code>python alexnet.py alexnet.npy ./fluid_model
</code></pre>
- Save weights as fluid model file
```
python alexnet.py alexnet.npy ./fluid
```

3. Use the converted model to infer
- see more details in '*examples/imagenet/run.sh*'
- See more details in '*examples/imagenet/run.sh*'

4. compare the inference results with caffe
- see more details in '*examples/imagenet/diff.sh*'
4. Compare the inference results with caffe
- See more details in '*examples/imagenet/diff.sh*'

### How to convert custom layer
1. Implement your custom layer in a file under '*kaffe/custom_layers*', eg: mylayer.py
- Implement ```shape_func(input_shape, [other_caffe_params])``` to calculate the output shape
- Implement ```layer_func(inputs, name, [other_caffe_params])``` to construct a fluid layer
- Register these two functions ```register(kind='MyType', shape=shape_func, layer=layer_func)```
- Notes: more examples can be found in '*kaffe/custom_layers*'

2. Add ```import mylayer``` to '*kaffe/custom_layers/\_\_init__.py*'

3. Prepare your pycaffe as your customized version(same as previous env prepare)
- (option1) replace 'proto/caffe.proto' with your own caffe.proto and compile it
- (option2) change your pycaffe to the customized version

4. Convert the Caffe model to Fluid model

5. Set env $CAFFE2FLUID_CUSTOM_LAYERS to the parent directory of 'custom_layers'
```
export CAFFE2FLUID_CUSTOM_LAYERS=/path/to/caffe2fluid/kaffe
```

6. Use the converted model when loading model in 'xxxnet.py' and 'xxxnet.npy'(no need if model is already in 'fluid/model' and 'fluid/params')

### Tested models
- Lenet
- Lenet:
[model addr](https://github.com/ethereon/caffe-tensorflow/blob/master/examples/mnist)

- ResNets:(ResNet-50, ResNet-101, ResNet-152)
[model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
Expand Down
8 changes: 7 additions & 1 deletion fluid/image_classification/caffe2fluid/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path,
print_stderr('Saving source...')
with open(code_output_path, 'wb') as src_out:
src_out.write(transformer.transform_source())
print_stderr('set env variable before using converted model '\
'if used custom_layers:')
custom_pk_path = os.path.dirname(os.path.abspath(__file__))
custom_pk_path = os.path.join(custom_pk_path, 'kaffe')
print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path))
print_stderr('Done.')
return 0
except KaffeError as err:
fatal_error('Error encountered: {}'.format(err))

return 0
return 1


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
debug = False
print('found a inference model for fluid')
except ValueError as e:
pass
print('try to load model using net file and weight file')
net_weight = model_path
ret = load_model(exe, place, net_file, net_name, net_weight, debug)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import sys
import os
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.fluid as fluid


def test_model(exe, test_program, fetch_list, test_reader, feeder):
Expand All @@ -34,9 +34,6 @@ def evaluate(net_file, model_file):

from lenet import LeNet as MyNet

with_gpu = False
paddle.init(use_gpu=with_gpu)

#1, define network topology
images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
Expand All @@ -45,7 +42,7 @@ def evaluate(net_file, model_file):
prediction = net.layers['prob']
acc = fluid.layers.accuracy(input=prediction, label=label)

place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

Expand Down
104 changes: 104 additions & 0 deletions fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
"""

from .register import get_registered_layers
#custom layer import begins

import axpy
import flatten
import argmax

#custom layer import ends

custom_layers = get_registered_layers()


def set_args(f, params):
""" set args for function 'f' using the parameters in node.layer.parameters
Args:
f (function): a python function object
params (object): a object contains attributes needed by f's arguments
Returns:
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]

kwargs = {}
for arg_name in arg_list:
try:
v = getattr(node.layer.parameters, arg_name, None)
except Exception as e:
v = None

if v is not None:
kwargs[arg_name] = v

return arg_list, kwargs


def has_layer(kind):
""" test whether this layer exists in custom layer
"""
return kind in custom_layers


def compute_output_shape(kind, node):
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)
shape_func = custom_layers[kind]['shape']

parents = node.parents
inputs = [list(p.output_shape) for p in parents]
arg_names, kwargs = set_args(shape_func, node.layer.parameters)

if len(inputs) == 1:
inputs = inputs[0]

return shape_func(inputs, **kwargs)


def make_node(template, kind, node):
""" make a TensorFlowNode for custom layer which means construct
a piece of code to define a layer implemented in 'custom_layers'
Args:
@template (TensorFlowNode): a factory to new a instance of TensorFLowNode
@kind (str): type of custom layer
@node (graph.Node): a layer in the net
Returns:
instance of TensorFlowNode
"""
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)

layer_func = custom_layers[kind]['layer']

#construct arguments needed by custom layer function from node's parameters
arg_names, kwargs = set_args(layer_func, node.layer.parameters)

return template('custom_layer', kind, **kwargs)


def make_custom_layer(kind, inputs, name, *args, **kwargs):
""" execute a custom layer which is implemented by users
Args:
@kind (str): type name of this layer
@inputs (vars): variable list created by fluid
@namme (str): name for this layer
@args (tuple): other positional arguments
@kwargs (dict): other kv arguments
Returns:
output (var): output variable for this layer
"""
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)

layer_func = custom_layers[kind]['layer']
return layer_func(inputs, name, *args, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" a custom layer for 'argmax', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/argmax.html
"""
from .register import register


def import_fluid():
import paddle.fluid as fluid
return fluid


def argmax_shape(input_shape, out_max_val=False, top_k=1, axis=-1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@out_max_val (bool): parameter from caffe's ArgMax layer
@top_k (int): parameter from caffe's ArgMax layer
@axis (int): parameter from caffe's ArgMax layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
input_shape = list(input_shape)

if axis < 0:
axis += len(input_shape)

assert (axis + 1 == len(input_shape)
), 'only can be applied on the last dimension now'

output_shape = input_shape
output_shape[-1] = top_k
if out_max_val is True:
output_shape[-1] *= 2

return output_shape


def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1):
""" build a layer of type 'ArgMax' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@out_max_val (bool): parameter from caffe's ArgMax layer
@top_k (int): parameter from caffe's ArgMax layer
@axis (int): parameter from caffe's ArgMax layer
Returns:
output (variable): output variable for this layer
"""

fluid = import_fluid()

if axis < 0:
axis += len(input.shape)

assert (axis + 1 == len(input_shape)
), 'only can be applied on the last dimension now'

topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
if out_max_val is True:
output = fluid.layers.concate([topk_var, index_var], axis=axis)
else:
output = topk_var
return output


register(kind='ArgMax', shape=argmax_shape, layer=argmax_layer)
51 changes: 51 additions & 0 deletions fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
""" A custom layer for 'axpy' which receives 3 tensors and output 1 tensor.
the function performed is:(the mupltiplication and add are elementewise)
output = inputs[0] * inputs[1] + inputs[2]
"""

from .register import register


def axpy_shape(input_shapes):
""" calculate the output shape of this layer using input shapes
Args:
@input_shapes (list of tuples): a list of input shapes
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'

output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))

return output_shape


def axpy_layer(inputs, name):
""" build a layer of type 'Axpy' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid

assert len(inputs) == 3, "invalid inputs for axpy[%s]" % (name)
alpha = inputs[0]
x = inputs[1]
y = inputs[2]
output = fluid.layers.elementwise_mul(x, alpha, axis=0)
output = fluid.layers.elementwise_add(output, y)

return output


register(kind='Axpy', shape=axpy_shape, layer=axpy_layer)
Loading

0 comments on commit e7684f0

Please sign in to comment.