diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 505060e31a0a2..047a6094ec1e1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -150,7 +150,10 @@ def generate_input(shape): return np.random.random(shape).astype(np.float32) for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in [0, -1]: self.dims = len(shape) dics = [{"axis": axis}] @@ -306,7 +309,10 @@ def generate_input(shape): input1_shape = input1_shape_list[i] for j in range(6): input2_shape = input2_shape_list[j][i] - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in axis_list[j][i]: self.shape1 = input1_shape self.shape2 = input2_shape diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py index f84202df5fb93..b40daba48689b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py @@ -56,5 +56,23 @@ def test_check_output(self): PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) +class TensorRTSubgraphPassElementwiseBroadcastTest1( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest2( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest3( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_div(x=data1, y=data2, axis=0) + + if __name__ == "__main__": unittest.main()