Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Remove the 'try catch' block
Browse files Browse the repository at this point in the history
  • Loading branch information
cchung100m committed Nov 15, 2019
1 parent 2df495a commit c2a51dd
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,10 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32'
x = np.random.uniform(size=data_shape)
model = onnx.load_model(graph_file)
try:
c2_out = get_onnxruntime_output(model, x, dtype)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings
warnings.warn(str(e))
else:
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
c2_out = get_onnxruntime_output(model, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)


def verify_super_resolution_example():
Expand Down Expand Up @@ -197,13 +192,8 @@ def verify_space_to_depth(inshape, outshape, blockSize):
for target, ctx in ctx_list():
x = np.random.uniform(size=inshape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
try:
onnx_out = get_onnxruntime_output(model, x, 'float32')
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings
warnings.warn(str(e))
else:
tvm.testing.assert_allclose(onnx_out, tvm_out)
onnx_out = get_onnxruntime_output(model, x, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out)


def test_space_to_depth():
Expand Down Expand Up @@ -1440,14 +1430,9 @@ def check_torch_conversion(model, input_size):
onnx_model = onnx.load(file_name)
for target, ctx in ctx_list():
input_data = np.random.uniform(size=input_size).astype('int32')
try:
c2_out = get_onnxruntime_output(onnx_model, input_data)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings
warnings.warn(str(e))
else:
tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
tvm.testing.assert_allclose(c2_out, tvm_out)
c2_out = get_onnxruntime_output(onnx_model, input_data)
tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
tvm.testing.assert_allclose(c2_out, tvm_out)


def test_resnet():
Expand Down

0 comments on commit c2a51dd

Please sign in to comment.