diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 5b305325f3d2d..a2ecbf53e5b35 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -226,7 +226,9 @@ def __impl__(self, other_var): # so the calculation result here and the calculation result of numpy are # different after 6 decimal point. If necessary, we can also use float64 here. # torch's behavior here is consistent with ours - if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: + if (op_type == "final_state_divide" or + op_type == "elementwise_div" + ) and self.dtype in _supported_int_dtype_: self = astype(self, 'float32') # here use `scale` replace `elementwise` to get better performance # but only +, -, *, / can use this method @@ -281,7 +283,8 @@ def __impl__(self, other_var): self = other_var other_var = tmp - if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: + if (op_type == "final_state_divide" or op_type == "elementwise_div" + ) and self.dtype in _supported_int_dtype_: self = astype(self, 'float32') other_var = astype(other_var, 'float32') diff --git a/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py index 5f2dfbdd99e16..774d40a17c66d 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py @@ -18,6 +18,7 @@ import numpy as np import paddle +from paddle.fluid.framework import _test_eager_guard # Support types are ref from `paddle.tensor.math` # - Related paddle dtypes: @@ -50,7 +51,7 @@ def check_operation(self, a, b, c, op): self.assertEqual(c_rlt.dtype, c.dtype) self.assertTrue(np.array_equal(c_rlt.numpy(), c.numpy())) - def test_tensor_add_scalar(self): + def func_tensor_add_scalar(self): # tensor(int64) + scalar(int) a = paddle.ones([2, 2, 2], dtype='int64') b = 1 @@ -81,7 +82,12 @@ def test_tensor_add_scalar(self): c = paddle.full([2, 2, 2], 2.5, dtype="float32") self.check_operation(a, b, c, '+') - def test_tensor_sub_scalar(self): + def test_tensor_add_scalar(self): + with _test_eager_guard(): + self.func_tensor_add_scalar() + self.func_tensor_add_scalar() + + def func_tensor_sub_scalar(self): # tensor(int64) - scalar(int) a = paddle.ones([2, 2, 2], dtype='int64') b = 1 @@ -112,7 +118,12 @@ def test_tensor_sub_scalar(self): c = paddle.full([2, 2, 2], 0.5, dtype="float32") self.check_operation(a, b, c, '-') - def test_scalar_sub_tensor(self): + def test_tensor_sub_scalar(self): + with _test_eager_guard(): + self.func_tensor_sub_scalar() + self.func_tensor_sub_scalar() + + def func_scalar_sub_tensor(self): # scalar(int) - tensor(int64) a = 1 b = paddle.ones([2, 2, 2], dtype='int64') @@ -143,7 +154,12 @@ def test_scalar_sub_tensor(self): c = paddle.full([2, 2, 2], -0.5, dtype="float32") self.check_operation(a, b, c, '-') - def test_tensor_mul_tensor(self): + def test_scalar_sub_tensor(self): + with _test_eager_guard(): + self.func_scalar_sub_tensor() + self.func_scalar_sub_tensor() + + def func_tensor_mul_tensor(self): # tensor(int64) * scalar(int) a = paddle.ones([2, 2, 2], dtype='int64') b = 1 @@ -174,7 +190,12 @@ def test_tensor_mul_tensor(self): c = paddle.full([2, 2, 2], 1.5, dtype="float32") self.check_operation(a, b, c, '*') - def test_tensor_div_scalar(self): + def test_tensor_mul_tensor(self): + with _test_eager_guard(): + self.func_tensor_mul_tensor() + self.func_tensor_mul_tensor() + + def func_tensor_div_scalar(self): # tensor(int64) / scalar(int) a = paddle.ones([2, 2, 2], dtype='int64') b = 2 @@ -205,7 +226,12 @@ def test_tensor_div_scalar(self): c = paddle.full([2, 2, 2], 2, dtype="float32") self.check_operation(a, b, c, '/') - def test_scalar_div_tensor(self): + def test_tensor_div_scalar(self): + with _test_eager_guard(): + self.func_tensor_div_scalar() + self.func_tensor_div_scalar() + + def func_scalar_div_tensor(self): # scalar(int) / tensor(int64) a = 1 b = paddle.full([2, 2, 2], 2, dtype='int64') @@ -230,7 +256,12 @@ def test_scalar_div_tensor(self): c = paddle.full([2, 2, 2], 2, dtype="float32") self.check_operation(a, b, c, '/') - def test_tensor_pow_scalar(self): + def test_scalar_div_tensor(self): + with _test_eager_guard(): + self.func_scalar_div_tensor() + self.func_scalar_div_tensor() + + def func_tensor_pow_scalar(self): # tensor(int64) ** scalar(int) a = paddle.full([2, 2, 2], 2, dtype='int64') b = 3 @@ -255,7 +286,12 @@ def test_tensor_pow_scalar(self): c = paddle.full([2, 2, 2], 8, dtype="float32") self.check_operation(a, b, c, '**') - def test_scalar_pow_tensor(self): + def test_tensor_pow_scalar(self): + with _test_eager_guard(): + self.func_tensor_pow_scalar() + self.func_tensor_pow_scalar() + + def func_scalar_pow_tensor(self): # scalar(int) ** tensor(int64) a = 3 b = paddle.full([2, 2, 2], 2, dtype='int64') @@ -280,15 +316,25 @@ def test_scalar_pow_tensor(self): c = paddle.full([2, 2, 2], 9, dtype="float32") self.check_operation(a, b, c, '**') + def test_scalar_pow_tensor(self): + with _test_eager_guard(): + self.func_scalar_pow_tensor() + self.func_scalar_pow_tensor() + ## TODO: floordiv op kernel doesn't support float - def test_tensor_floordiv_scalar(self): + def func_tensor_floordiv_scalar(self): # tensor(int64) // scalar(int) a = paddle.full([2, 2, 2], 3, dtype='int64') b = 2 c = paddle.full([2, 2, 2], 1, dtype="int64") self.check_operation(a, b, c, '//') - def test_tensor_mod_scalar(self): + def test_tensor_floordiv_scalar(self): + with _test_eager_guard(): + self.func_tensor_floordiv_scalar() + self.func_tensor_floordiv_scalar() + + def func_tensor_mod_scalar(self): # tensor(int64) % scalar(int) a = paddle.full([2, 2, 2], 3, dtype='int64') b = 2 @@ -313,6 +359,11 @@ def test_tensor_mod_scalar(self): c = paddle.full([2, 2, 2], 1, dtype="float32") self.check_operation(a, b, c, '%') + def test_tensor_mod_scalar(self): + with _test_eager_guard(): + self.func_tensor_mod_scalar() + self.func_tensor_mod_scalar() + if __name__ == '__main__': unittest.main()