From 5f477882ea429e44af32438cfee5761d5246454f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 24 Aug 2021 18:44:22 +0900 Subject: [PATCH] test only structural equality --- tests/python/relay/test_to_mixed_precision.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index c5b754786de82..8736a201548ff 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -221,12 +221,9 @@ def test_do_not_convert_softmax(): b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) - - mod_params = { - "a": np.random.uniform(-1, 1, size=shape).astype("float32"), - } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_do_not_convert_arange(): @@ -234,10 +231,9 @@ def test_do_not_convert_arange(): dtype = "float32" arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) mod = tvm.IRModule.from_expr(arange) - mod = tvm.relay.transform.InferType()(mod) - - output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_do_not_convert_summation(): @@ -252,14 +248,9 @@ def test_do_not_convert_summation(): ] for op in ops: mod = tvm.IRModule.from_expr(op(a)) - mod = tvm.relay.transform.InferType()(mod) - - mod_params = { - "a": np.random.uniform(-1, 1, size=shape).astype("float32"), - } - - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_green_gray_propagates_simple():