Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend (
Browse files Browse the repository at this point in the history
…#3760)

* [Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend

* Dummy change to retrigger integration test
  • Loading branch information
soiferj authored and tmoreau89 committed Aug 15, 2019
1 parent 7eb1f35 commit 674feba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,18 @@ def _impl_v1(cls, inputs, attr, params):
shape = shape + attr.pop('extra_shape')
return _op.full(inputs[0], shape)

class Sign(OnnxOpConverter):
""" Operator converter for Sign.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.sign(inputs[0])

class Equal(Elemwise):
""" Operator converter for Equal.
"""
name = 'equal'

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -964,6 +976,8 @@ def _get_convert_map(opset):
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
'Shape': Shape.get_converter(opset),
'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset)
}


Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None):
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
verify_binary_ops("Greater", x, y, x > y, broadcast=True)
verify_binary_ops("Less", x, y, x < y, broadcast=True)
verify_binary_ops("Equal", x, y, x == y, broadcast=True)

def test_single_ops():
in_shape = (1, 2, 3, 3)
Expand Down Expand Up @@ -1116,6 +1117,15 @@ def test_inception():
# def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))

def test_sign():
def Sign_x(x):
return np.sign(x)
_test_onnx_op_elementwise((3, 4, 5, 6),
Sign_x,
{},
'float32',
'Sign',
{})

if __name__ == '__main__':
test_flatten()
Expand Down Expand Up @@ -1159,3 +1169,4 @@ def test_inception():
test_resnet()
test_inception()
test_densenet()
test_sign()

0 comments on commit 674feba

Please sign in to comment.