Skip to content

Commit

Permalink
[Relay][Frontend][TensorFlow] Support BatchMatMul with input dimensio…
Browse files Browse the repository at this point in the history
…ns larger than 3 (apache#3732)

* Support BatchMatMul with shapes greater than length 3

* Fixes

* Add tests

* Remove dependency on Python3

* Clean up

* Merge with master

* Resolve comments
  • Loading branch information
soiferj authored and wweic committed Sep 6, 2019
1 parent 999506b commit f136528
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
24 changes: 22 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,31 @@ def _impl(inputs, attr, params):

def _batch_matmul():
def _impl(inputs, attr, params):
input_x = inputs[0]
input_y = inputs[1]
orig_shape_x = attr['_input_shapes'][input_x]
orig_shape_y = attr['_input_shapes'][input_y]

# reshape n-dimensional batch matmul into 3d
if len(orig_shape_x) > 3:
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
input_x = _op.reshape(input_x, newshape=new_shape_x)
input_y = _op.reshape(input_y, newshape=new_shape_y)

adj_x = attr['adj_x']
adj_y = attr['adj_y']
input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0]
input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1]
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
ret = get_relay_op('batch_matmul')(input_x, input_y)

# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = attr['_output_shapes'][0]
ret = _op.reshape(ret, newshape=final_shape)

return ret
return _impl

Expand Down
4 changes: 4 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,10 @@ def test_forward_batch_matmul():
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32')
_test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True)
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False)
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)


#######################################################################
Expand Down

0 comments on commit f136528

Please sign in to comment.