diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 756022ba663e..48f000e6e101 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -448,11 +448,31 @@ def _impl(inputs, attr, params): def _batch_matmul(): def _impl(inputs, attr, params): + input_x = inputs[0] + input_y = inputs[1] + orig_shape_x = attr['_input_shapes'][input_x] + orig_shape_y = attr['_input_shapes'][input_y] + + # reshape n-dimensional batch matmul into 3d + if len(orig_shape_x) > 3: + outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] + num_outer_elts = np.prod(outer_dims) + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + input_x = _op.reshape(input_x, newshape=new_shape_x) + input_y = _op.reshape(input_y, newshape=new_shape_y) + adj_x = attr['adj_x'] adj_y = attr['adj_y'] - input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0] - input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1] + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y ret = get_relay_op('batch_matmul')(input_x, input_y) + + # reshape result back to n-dimensional + if len(orig_shape_x) > 3: + final_shape = attr['_output_shapes'][0] + ret = _op.reshape(ret, newshape=final_shape) + return ret return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 576e3d9f71df..6c309cdf7292 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -669,6 +669,10 @@ def test_forward_batch_matmul(): _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32') + _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True) + _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False) + _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True) #######################################################################