Skip to content

Commit

Permalink
[NNVM] Support argmax/argmin in tensorflow frontend (apache#1514)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergei-mironov authored and tqchen committed Aug 2, 2018
1 parent 71cff3e commit 6d1dc4a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
56 changes: 48 additions & 8 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def _impl(inputs, attr, *args):
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
return _impl

def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].list_output_names()[0]
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl

def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
Expand Down Expand Up @@ -664,6 +678,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'ArgMax' : _argx(_sym.argmax, 'argmax'),
'ArgMin' : _argx(_sym.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
Expand Down Expand Up @@ -879,6 +895,28 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym


def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators


class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -901,7 +939,7 @@ def from_tensorflow(self, graph):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First Const or Placeholder node will be considered as graph input.
-> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
Expand All @@ -910,6 +948,7 @@ def from_tensorflow(self, graph):
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters
----------
Expand All @@ -923,23 +962,25 @@ def from_tensorflow(self, graph):
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes

try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

missing_operators = _parse_import_prerequisites(graph)

if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1

Expand All @@ -954,7 +995,6 @@ def from_tensorflow(self, graph):
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
Expand Down Expand Up @@ -997,7 +1037,7 @@ def from_tensorflow(self, graph):
# Pass the node name too in attr
attr["_node_name"] = node.name

#ToDo: Some of the tensorflow operators maintain internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
Expand Down
31 changes: 31 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,37 @@ def test_forward_sigmoid():

_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))

#######################################################################
# Argmin/Argmax
# -------------

def _test_argx(func, data, **kwargs):

with tf.Graph().as_default():
inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")

# pylint: disable=unused-variable
out = func(inp, name="argx0", **kwargs)
# pylint: enable=unused-variable

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph.as_graph_def(add_shapes=True),
output_node_names=["argx0"])

tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')

np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

sess.close()

def test_argmin_argmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)

#######################################################################
# Variable
Expand Down

0 comments on commit 6d1dc4a

Please sign in to comment.