From b655946c4e49a5f362bc84f92eebc2b7f3593a85 Mon Sep 17 00:00:00 2001 From: wangxinxin08 Date: Wed, 23 Mar 2022 07:04:51 +0000 Subject: [PATCH] add unittest of elementwise mul, sub and div --- .../inference/test_trt_convert_elementwise.py | 10 ++++++++-- .../ir/inference/test_trt_elementwise_op.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 505060e31a0a2..047a6094ec1e1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -150,7 +150,10 @@ def generate_input(shape): return np.random.random(shape).astype(np.float32) for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in [0, -1]: self.dims = len(shape) dics = [{"axis": axis}] @@ -306,7 +309,10 @@ def generate_input(shape): input1_shape = input1_shape_list[i] for j in range(6): input2_shape = input2_shape_list[j][i] - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in axis_list[j][i]: self.shape1 = input1_shape self.shape2 = input2_shape diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py index f84202df5fb93..b40daba48689b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py @@ -56,5 +56,23 @@ def test_check_output(self): PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) +class TensorRTSubgraphPassElementwiseBroadcastTest1( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest2( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest3( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_div(x=data1, y=data2, axis=0) + + if __name__ == "__main__": unittest.main()