diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1b904bba7713..b7d668b93837 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -850,6 +850,18 @@ def _impl_v1(cls, inputs, attr, params): shape = shape + attr.pop('extra_shape') return _op.full(inputs[0], shape) +class Sign(OnnxOpConverter): + """ Operator converter for Sign. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.sign(inputs[0]) + +class Equal(Elemwise): + """ Operator converter for Equal. + """ + name = 'equal' + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -964,6 +976,8 @@ def _get_convert_map(opset): 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), 'Shape': Shape.get_converter(opset), + 'Sign': Sign.get_converter(opset), + 'Equal': Equal.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d40996011896..87d38e0d8801 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -962,6 +962,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Sum", x, y, x + y, broadcast=None) verify_binary_ops("Greater", x, y, x > y, broadcast=True) verify_binary_ops("Less", x, y, x < y, broadcast=True) + verify_binary_ops("Equal", x, y, x == y, broadcast=True) def test_single_ops(): in_shape = (1, 2, 3, 3) @@ -1116,6 +1117,15 @@ def test_inception(): # def test_shufflenetv2(): # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) +def test_sign(): + def Sign_x(x): + return np.sign(x) + _test_onnx_op_elementwise((3, 4, 5, 6), + Sign_x, + {}, + 'float32', + 'Sign', + {}) if __name__ == '__main__': test_flatten() @@ -1159,3 +1169,4 @@ def test_inception(): test_resnet() test_inception() test_densenet() + test_sign()