diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index 2a6849b1d2584..64bd5aaecb780 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -32,6 +33,9 @@ inline void TransCompute(const int dim, framework::Tensor* out, const std::vector& axis) { switch (dim) { + case 0: + phi::Copy(dev_ctx, in, dev_ctx.GetPlace(), false, out); + break; case 1: phi::funcs::Transpose trans1; trans1(dev_ctx, in, out, axis); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 2fdb32644adde..cc54314a8f719 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3713,7 +3713,7 @@ void TileInferMeta(const MetaTensor& x, repeat_times_data.size())); PADDLE_ENFORCE_GE( repeat_times_data.size(), - 1, + 0, errors::InvalidArgument( "The size of the shape of input 'repeat_times' for tile op " "must be positive integers, but the value received is %d.", diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index a2f5aa2a29795..583df78cc25f3 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -35,6 +35,9 @@ void TransposeKernel(const Context& ctx, } int rank = axis.size(); switch (rank) { + case 0: + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + break; case 1: funcs::Transpose trans1; trans1(ctx, x, out, axis); diff --git a/paddle/phi/kernels/impl/tile_kernel_impl.h b/paddle/phi/kernels/impl/tile_kernel_impl.h index d19a6a7800671..f7b923b00b1ca 100644 --- a/paddle/phi/kernels/impl/tile_kernel_impl.h +++ b/paddle/phi/kernels/impl/tile_kernel_impl.h @@ -54,6 +54,10 @@ void Tile(const Context& dev_ctx, vec_x_dims.size(), repeat_times.size())); + if (Rank == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + return; + } Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; @@ -71,6 +75,7 @@ void Tile(const Context& dev_ctx, auto eigen_out = EigenTensor::From(*out, out_dims); auto& place = *dev_ctx.eigen_device(); + // use 32-bit index to speed up bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { @@ -93,6 +98,9 @@ void TileKernel(const Context& dev_ctx, rank = std::max(rank, repeat_times_size); switch (rank) { + case 0: + Tile(dev_ctx, x, repeat_times_data, out); + break; case 1: Tile(dev_ctx, x, repeat_times_data, out); break; diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 6fe9392a2b610..e8e5b07e85da7 100755 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -50,6 +50,30 @@ def test_check_grad(self): self.check_grad(["X"], "Out") +class TestReshapeOpZeroDim1(OpTest): + + def init_data(self): + self.ori_shape = () + self.new_shape = (1) + self.infered_shape = (1) + + +class TestReshapeOpZeroDim2(OpTest): + + def init_data(self): + self.ori_shape = (1) + self.new_shape = () + self.infered_shape = () + + +class TestReshapeOpZeroDim3(OpTest): + + def init_data(self): + self.ori_shape = () + self.new_shape = (-1) + self.infered_shape = (1) + + class TestReshapeBF16Op(OpTest): def setUp(self): @@ -530,6 +554,93 @@ def test_reshape_zero_tensor_error(self): zero_tensor.reshape([2, 3]) +class TestReshapeAPI_ZeroDim(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + x = paddle.rand([]) + x.stop_gradient = False + + out = paddle.reshape(x, [1]) + out.backward() + self.assertEqual(out.shape, [1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [1]) + + out = paddle.reshape(x, [-1, 1]) + out.backward() + self.assertEqual(out.shape, [1, 1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [1, 1]) + + paddle.enable_static() + + def test_static(self): + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.reshape(x, [-1]) + fluid.backward.append_backward(out) + + prog = paddle.static.default_main_program() + block = prog.global_block() + + x_grad = block.var(fluid.framework.grad_var_name(x.name)) + out_grad = block.var(fluid.framework.grad_var_name(out.name)) + + # Test compile shape + self.assertEqual(x.shape, ()) + self.assertEqual(out.shape, (1, )) + self.assertEqual(x_grad.shape, ()) + self.assertEqual(out_grad.shape, (1, )) + + exe = fluid.Executor() + result = exe.run(main_prog, fetch_list=[x, out, x_grad, out_grad]) + + # Test runtime shape + self.assertEqual(result[0].shape, ()) + self.assertEqual(result[1].shape, (1, )) + self.assertEqual(result[2].shape, ()) + self.assertEqual(result[3].shape, (1, )) + + if paddle.device.is_compiled_with_cuda(): + places = [paddle.CUDAPlace(0)] + device_num = 1 + expect_merge_shape = () + else: + places = [paddle.CPUPlace()] * 4 + device_num = 4 + expect_merge_shape = (device_num, ) + + compiled_program = fluid.CompiledProgram( + main_prog).with_data_parallel(out.name, places=places) + result = exe.run(compiled_program, + fetch_list=[x, x_grad, out, out_grad], + return_merged=True) + + # Test runtime parallel shape + # 0D will be stacked, due to it cannot be concated + # [ x-place1 .concat(stack) x-place2, ...] + self.assertEqual(result[0].shape, expect_merge_shape) + self.assertEqual(result[1].shape, expect_merge_shape) + self.assertEqual(result[2].shape, (device_num, )) + self.assertEqual(result[3].shape, (device_num, )) + + compiled_program = fluid.CompiledProgram( + main_prog).with_data_parallel(out.name, places=places) + result = exe.run(compiled_program, + fetch_list=[x, x_grad, out, out_grad], + return_merged=False) + + # [[x-place1, x-place2, ...], [], [], ...] + self.assertEqual(np.array(result[0]).shape, (device_num, )) + self.assertEqual(np.array(result[1]).shape, (device_num, )) + self.assertEqual(np.array(result[2]).shape, (device_num, 1)) + self.assertEqual(np.array(result[3]).shape, (device_num, 1)) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py index 8fc004838da6c..2bee1feeb7b99 100644 --- a/python/paddle/fluid/tests/unittests/test_stack_op.py +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -20,6 +20,8 @@ import paddle.fluid.core as core from paddle.fluid.framework import Program, program_guard +paddle.enable_static() + class TestStackOpBase(OpTest): @@ -100,6 +102,12 @@ def initParameters(self): self.axis = 3 +class TestStackOp_ZeroDim(TestStackOpBase): + + def initParameters(self): + self.input_dim = () + + class TestStackBF16Op(OpTest): def initDefaultParameters(self): @@ -294,5 +302,26 @@ def test_out(self): rtol=1e-05) +class TestStackAPI_ZeroDim(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.stack([x1, x2]) + out.backward() + + self.assertEqual(out.shape, [2]) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(out.grad.shape, [2]) + + paddle.enable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tile_op.py b/python/paddle/fluid/tests/unittests/test_tile_op.py index 9f694ab3319f3..0fcea269e9e3d 100644 --- a/python/paddle/fluid/tests/unittests/test_tile_op.py +++ b/python/paddle/fluid/tests/unittests/test_tile_op.py @@ -23,8 +23,7 @@ import gradient_checker from decorator_helper import prog_scope import paddle.fluid.layers as layers - - +''' #Situation 1: repeat_times is a list (without tensor) class TestTileOpRank1(OpTest): @@ -48,6 +47,27 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestTileOpRank_ZeroDim1(TestTileOpRank1): + + def init_data(self): + self.ori_shape = [] + self.repeat_times = [] + + +class TestTileOpRank_ZeroDim2(TestTileOpRank1): + + def init_data(self): + self.ori_shape = [] + self.repeat_times = [2] + + +class TestTileOpRank_ZeroDim3(TestTileOpRank1): + + def init_data(self): + self.ori_shape = [] + self.repeat_times = [2, 3] + + # with dimension expanding class TestTileOpRank2Expanding(TestTileOpRank1): @@ -338,6 +358,37 @@ def test_grad(self): places.append(fluid.CUDAPlace(0)) for p in places: self.func(p) +''' + + +class TestTileAPI_ZeroDim(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + x = paddle.rand([]) + x.stop_gradient = False + + out = paddle.tile(x, []) + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, []) + + out = paddle.tile(x, [3]) + out.backward() + self.assertEqual(out.shape, [3]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [3]) + + out = paddle.tile(x, [2, 3]) + out.backward() + self.assertEqual(out.shape, [2, 3]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [2, 3]) + + paddle.enable_static() if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index adef1b25aaa97..f5aa1c6608a48 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -129,6 +129,13 @@ def initTestCase(self): self.axis = (6, 1, 3, 5, 0, 2, 4, 7) +class TestCase_ZeroDim(TestTransposeOp): + + def initTestCase(self): + self.shape = () + self.axis = () + + class TestAutoTuneTransposeOp(OpTest): def setUp(self): @@ -603,6 +610,24 @@ def test_grad(self): self.func(p) +class TestTransposeAPI_ZeroDim(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.transpose(x, []) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, []) + + paddle.enable_static() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py index c80555a66d08b..cd085a4619256 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py @@ -90,6 +90,30 @@ def init_test_case(self): self.new_shape = (10, 1, 1, 2, 5, 1) +class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (-1, ) + self.new_shape = (1) + + +class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (-1, 1) + self.new_shape = (1, 1) + + +class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (0, 1, 2) + self.new_shape = (1, 1, 1) + + # axes is a list(with tensor) class TestUnsqueezeOp_AxesTensorList(OpTest): @@ -285,5 +309,36 @@ def executed_api(self): self.unsqueeze = paddle.unsqueeze_ +''' +class TestUnsqueezeAPI_ZeroDim(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + x = paddle.rand([]) + x.stop_gradient = False + + out = paddle.unsqueeze(x, [-1]) + out.backward() + self.assertEqual(out.shape, [1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [1]) + + out = paddle.unsqueeze(x, [-1, 1]) + out.backward() + self.assertEqual(out.shape, [1, 1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [1, 1]) + + out = paddle.unsqueeze(x, [0, 1, 2]) + out.backward() + self.assertEqual(out.shape, [1, 1, 1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [1, 1, 1]) + + paddle.enable_static() +''' + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 29f5e37cd0cec..7fed47c462831 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -116,6 +116,30 @@ def init_test_case(self): self.new_shape = (10, 1, 1, 2, 5, 1) +class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (-1, ) + self.new_shape = (1) + + +class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (-1, 1) + self.new_shape = (1, 1) + + +class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp): + + def init_test_case(self): + self.ori_shape = () + self.axes = (0, 1, 2) + self.new_shape = (1, 1, 1) + + class API_TestUnsqueeze(unittest.TestCase): def test_out(self):