Skip to content

Commit

Permalink
fix cast bug (#60054)
Browse files Browse the repository at this point in the history
  • Loading branch information
YibinLiu666 committed Dec 19, 2023
1 parent 9ce3a54 commit 1134816
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/cast_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ void CastInplaceKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
auto x_origin = x;
auto* in_begin = x_origin.data<InT>();
auto numel = x_origin.numel();
auto numel = x.numel();
auto* in_begin = new InT[numel];
auto* in_end = in_begin + numel;
auto* data_origin = x.data<InT>();
memcpy(in_begin, data_origin, sizeof(InT) * numel);

auto* out_begin = dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);
Expand All @@ -65,6 +66,7 @@ void CastInplaceKernelImpl(const CPUContext& dev_ctx,
in_end,
out_begin,
CastOpTransformFunctor<InT, OutT>());
delete[] in_begin;
}

} // namespace phi
15 changes: 15 additions & 0 deletions test/legacy_test/test_cast_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,21 @@ def test_grad(self):
self.func(p)


class TestCastInplaceContinuous(unittest.TestCase):
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0]])
target = x.cast("uint8")
x.cast_("uint8")
np.testing.assert_array_equal(target.numpy(), x.numpy())
target = x.cast("float32")
x.cast_("float32")
np.testing.assert_array_equal(target.numpy(), x.numpy())

run(paddle.CPUPlace())


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit 1134816

Please sign in to comment.