diff --git a/paddle/phi/kernels/cpu/cast_impl.h b/paddle/phi/kernels/cpu/cast_impl.h index c3b602baaf8ce..93da1b7793487 100644 --- a/paddle/phi/kernels/cpu/cast_impl.h +++ b/paddle/phi/kernels/cpu/cast_impl.h @@ -51,10 +51,11 @@ void CastInplaceKernelImpl(const CPUContext& dev_ctx, const DenseTensor& x, DataType out_dtype, DenseTensor* out) { - auto x_origin = x; - auto* in_begin = x_origin.data(); - auto numel = x_origin.numel(); + auto numel = x.numel(); + auto* in_begin = new InT[numel]; auto* in_end = in_begin + numel; + auto* data_origin = x.data(); + memcpy(in_begin, data_origin, sizeof(InT) * numel); auto* out_begin = dev_ctx.Alloc(out); out->set_type(out_dtype); @@ -65,6 +66,7 @@ void CastInplaceKernelImpl(const CPUContext& dev_ctx, in_end, out_begin, CastOpTransformFunctor()); + delete[] in_begin; } } // namespace phi diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index 69e088f34ba85..fb84c7b5af5b2 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -284,6 +284,21 @@ def test_grad(self): self.func(p) +class TestCastInplaceContinuous(unittest.TestCase): + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0]]) + target = x.cast("uint8") + x.cast_("uint8") + np.testing.assert_array_equal(target.numpy(), x.numpy()) + target = x.cast("float32") + x.cast_("float32") + np.testing.assert_array_equal(target.numpy(), x.numpy()) + + run(paddle.CPUPlace()) + + if __name__ == '__main__': paddle.enable_static() unittest.main()