From 0cd1f926909ecbaacbc04bd9321a18856f5a2d89 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 15 Nov 2019 19:38:26 +0800 Subject: [PATCH] [Relay][Frontend][ONNX] operator support: DepthToSpace, SpaceToDepth --- python/tvm/relay/frontend/onnx.py | 72 ++++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 71 ++++++++++++++++++--- 2 files changed, 135 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0c581c96d4e5..3d90d15e1916 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -472,6 +472,76 @@ def _impl_v5(cls, inputs, attr, params): static_shape.asnumpy().astype('int32'))) return out + +class DepthToSpace(OnnxOpConverter): + """ Operator converter for DepthToSpace. + """ + + @classmethod + def _impl_v11(cls, inputs, attr, params): + + block_size = int(attr['blocksize']) + mode = attr.get("mode", "DCR") + + # handle NCHW layout + indata = infer_value_simulated(inputs[0], params) + in_n, in_c, in_h, in_w = indata.shape + + # reshape to proper output + new_c = int(in_c / (block_size * block_size)) + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_c, new_h, new_w) + + if mode == "DCR": + # expand input to larger dimension. + expanded = _op.reshape(inputs[0], + newshape=(in_n, block_size, block_size, new_c, in_h, in_w)) + # reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2)) + + else: # CRD mode + # expand input to larger dimension. + expanded = _op.reshape(inputs[0], + newshape=(in_n, new_c, block_size, block_size, in_h, in_w)) + # reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 1, 4, 2, 5, 3)) + + return AttrCvt(op_name="reshape", + extras={'newshape': newshape}, + ignores=['mode', 'blocksize'])([transposed], attr) + + +class SpaceToDepth(OnnxOpConverter): + """ Operator converter for SpaceToDepth. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + + block_size = int(attr['blocksize']) + + # handle NCHW layout + indata = infer_value_simulated(inputs[0], params) + in_n, in_c, in_h, in_w = indata.shape + + # reshape to proper output + new_c = in_c * (block_size * block_size) + new_h = int(in_h / block_size) + new_w = int(in_w / block_size) + newshape = (in_n, new_c, new_h, new_w) + + # expand input to larger dimension. + expanded = _op.reshape(inputs[0], + newshape=(in_n, in_c, new_h, block_size, new_w, block_size)) + # reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 3, 5, 1, 2, 4)) + + return AttrCvt(op_name="reshape", + extras={'newshape': newshape}, + ignores=['blocksize'])([transposed], attr) + + class Concat(OnnxOpConverter): """ Operator converter for Concat. """ @@ -1121,6 +1191,8 @@ def _get_convert_map(opset): 'Split': Split.get_converter(opset), 'Slice': Slice.get_converter(opset), 'Transpose': AttrCvt('transpose', {'perm': 'axes'}), + 'DepthToSpace': DepthToSpace.get_converter(opset), + 'SpaceToDepth': SpaceToDepth.get_converter(opset), 'Gather': Gather.get_converter(opset), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Unsqueeze': Unsqueeze.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6391a1a9504d..e074bac90f2a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -77,19 +77,19 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output return tvm_output.asnumpy() -def get_caffe2_output(model, x, dtype='float32'): - import caffe2.python.onnx.backend - prepared_backend = caffe2.python.onnx.backend.prepare(model) - W = {model.graph.input[0].name: x.astype(dtype)} - c2_out = prepared_backend.run(W)[0] - return c2_out +def get_onnxruntime_output(model, x, dtype='float32'): + import onnxruntime.backend + rep = onnxruntime.backend.prepare(model, 'CPU') + x = x.astype(dtype) + ort_out = rep.run(x)[0] + return ort_out def verify_onnx_forward_impl(graph_file, data_shape, out_shape): dtype = 'float32' x = np.random.uniform(size=data_shape) model = onnx.load_model(graph_file) - c2_out = get_caffe2_output(model, x, dtype) + c2_out = get_onnxruntime_output(model, x, dtype) for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) @@ -142,6 +142,57 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) +def verify_depth_to_space(inshape, outshape, mode, blockSize): + node = onnx.helper.make_node('DepthToSpace', + inputs=['x'], + outputs=['y'], + blocksize=blockSize) + + graph = helper.make_graph([node], + "depth_to_space_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + + model = helper.make_model(graph, producer_name='depth_to_space_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=inshape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') + onnx_out = get_onnxruntime_output(model, x, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out) + + +def test_depth_to_space(): + # current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument. + # TO-DO, we can add mode arguement to test CRD mode and DCR mode + # in the future when we update to a newer onnx version. + verify_depth_to_space((1, 8, 2, 3), (1, 2, 4, 6), mode="CRD", blockSize=2) + + +def verify_space_to_depth(inshape, outshape, blockSize): + node = onnx.helper.make_node('SpaceToDepth', + inputs=['x'], + outputs=['y'], + blocksize=blockSize) + + graph = helper.make_graph([node], + "space_to_depth_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + + model = helper.make_model(graph, producer_name='space_to_depth_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=inshape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') + onnx_out = get_onnxruntime_output(model, x, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out) + + +def test_space_to_depth(): + verify_space_to_depth((1, 1, 4, 6), (1, 4, 2, 3), 2) + + def test_shape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -1372,7 +1423,7 @@ def check_torch_conversion(model, input_size): onnx_model = onnx.load(file_name) for target, ctx in ctx_list(): input_data = np.random.uniform(size=input_size).astype('int32') - c2_out = get_caffe2_output(onnx_model, input_data) + c2_out = get_onnxruntime_output(onnx_model, input_data) tvm_out = get_tvm_output(onnx_model, input_data, target, ctx) tvm.testing.assert_allclose(c2_out, tvm_out) @@ -1574,6 +1625,7 @@ def test_erf(): z = scipy.special.erf(x) verify_erf(x, z) + def verify_where(condition, x, y, dtype, outdata): node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out']) graph = helper.make_graph([node], @@ -1588,6 +1640,7 @@ def verify_where(condition, x, y, dtype, outdata): tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) + def test_where(): condition = np.array([[1, 0], [1, 1]], dtype=np.bool) x = np.array([[1, 2], [3, 4]], dtype=np.int64) @@ -1704,3 +1757,5 @@ def test_or(): test_erf() test_where() test_or() + test_depth_to_space() + test_space_to_depth()