From 334bffdf25701d37494a2f97050e61384861dd21 Mon Sep 17 00:00:00 2001 From: zhwesky2010 <1183042833@qq.com> Date: Thu, 25 May 2023 11:40:01 +0800 Subject: [PATCH] [Zero-Dim] support ReshapeTransform/nll_loss/matmul support 0D (#53828) --- paddle/fluid/operators/matmul_op.cc | 4 - paddle/fluid/operators/matmul_v2_op.cc | 3 - .../operators/mkldnn/matmul_mkldnn_op.cc | 5 +- paddle/phi/backends/onednn/matmul_utils.h | 10 +- paddle/phi/infermeta/backward.cc | 20 +- paddle/phi/infermeta/binary.cc | 3 - paddle/phi/infermeta/ternary.cc | 6 +- paddle/phi/kernels/impl/matmul_kernel_impl.h | 4 +- paddle/phi/kernels/nll_loss_kernel.cc | 3 +- python/paddle/distribution/transform.py | 7 +- .../test_distribution_transform.py | 4 +- .../fluid/tests/unittests/test_matmul_op.py | 13 +- .../unittests/test_matmul_op_with_head.py | 12 - .../tests/unittests/test_matmul_v2_op.py | 6 - .../tests/unittests/test_zero_dim_tensor.py | 329 ++++++++++++------ python/paddle/tensor/linalg.py | 23 +- test/autograd/test_autograd_dynamic.py | 46 +-- test/autograd/utils.py | 1 + test/mkldnn/test_matmul_v2_mkldnn_op.py | 2 +- test/xpu/test_matmul_op_xpu.py | 13 +- test/xpu/test_matmul_v2_op_xpu.py | 6 - test/xpu/test_zero_dim_tensor_xpu.py | 27 ++ 22 files changed, 309 insertions(+), 238 deletions(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 73c2577caa81c..e1a36fa41894d 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -666,10 +666,6 @@ class MatMulOp : public framework::OperatorWithKernel { dim_out.resize(dim_out.size() - 1); } - if (dim_out.empty()) { - dim_out = {1}; - } - phi::DDim ddim_out = phi::make_ddim(dim_out); context->SetOutputDim("Out", ddim_out); diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 4d46a87ccef22..7e61d1c9c814a 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -91,9 +91,6 @@ void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const { if (!y_broadcasted) { new_dims.push_back(N); } - if (x_broadcasted && y_broadcasted) { - new_dims.push_back(1); - } ctx->SetOutputDim("Out", phi::make_ddim(new_dims)); ctx->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 04338fc866954..44d4ae664e0de 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -334,8 +334,9 @@ void ExecuteMatMulV1(const ExecutionContext &ctx, matmul_p->execute(astream, matmul_args); astream.wait(); - out->set_mem_desc( - dst_memory_p->get_desc().reshape(vectorize(out->dims()))); + auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims()) + : std::vector{1}; + out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims)); } template diff --git a/paddle/phi/backends/onednn/matmul_utils.h b/paddle/phi/backends/onednn/matmul_utils.h index 70ed7e29c9637..7248e64fe60b1 100644 --- a/paddle/phi/backends/onednn/matmul_utils.h +++ b/paddle/phi/backends/onednn/matmul_utils.h @@ -146,8 +146,9 @@ inline void ExecuteMul(const OneDNNContext& dev_ctx, // This kernel is flattening dims so then we need to unflattened version // that should be set in out reshape require plain layout, but // MatmulV2MKLDNNHanlder enforces one so it should work - out->set_mem_desc( - dst_memory_p->get_desc().reshape(vectorize(out->dims()))); + auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims()) + : std::vector{1}; + out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims)); } template @@ -177,8 +178,9 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx, matmul_p->execute(astream, matmul_args); astream.wait(); - out->set_mem_desc( - dst_memory_p->get_desc().reshape(vectorize(out->dims()))); + auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims()) + : std::vector{1}; + out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims)); } } // namespace funcs diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index bd9fae6bd155b..853a8750a2b1a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -822,11 +822,11 @@ void NllLossGradInferMeta(const MetaTensor& x, if (check) { auto batch_size = x_dims[0]; if (x_dims.size() == 2) { - PADDLE_ENFORCE_EQ(dout_dims.size(), - 1, - phi::errors::InvalidArgument( - "The dimensions of Input(Out@Grad) must be 1")); if (reduction == "none") { + PADDLE_ENFORCE_EQ(dout_dims.size(), + 1, + phi::errors::InvalidArgument( + "The dimensions of Input(Out@Grad) must be 1")); PADDLE_ENFORCE_EQ( dout_dims[0], batch_size, @@ -834,10 +834,10 @@ void NllLossGradInferMeta(const MetaTensor& x, "The unreduced size ofInput(Out@Grad) must be the " "same as batch_size.")); } else { - PADDLE_ENFORCE_EQ(dout_dims[0], - 1, + PADDLE_ENFORCE_EQ(dout_dims.size(), + 0, phi::errors::InvalidArgument( - "The reduced size of Input(Out@Grad) must be 1")); + "The dimensions of Input(Out@Grad) must be 0")); } } else if (x_dims.size() == 4) { if (reduction == "none") { @@ -855,10 +855,10 @@ void NllLossGradInferMeta(const MetaTensor& x, "The dimensions of Input(Out@Grad) must be match " "to Input(Label) dimensions.")); } else { - PADDLE_ENFORCE_EQ(dout_dims[0], - 1, + PADDLE_ENFORCE_EQ(dout_dims.size(), + 0, phi::errors::InvalidArgument( - "The reduced size of Input(Out@Grad) must be 1")); + "The dimensions of Input(Out@Grad) must be 0")); } } } diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a1273d5bc2a8d..ae9f941fd663d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2056,9 +2056,6 @@ void MatmulInferMeta(const MetaTensor& x, if (!y_broadcasted) { new_dims.push_back(N); } - if (x_broadcasted && y_broadcasted) { - new_dims.push_back(1); - } auto ddim_out = phi::make_ddim(new_dims); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 646767703154e..61b824271c87d 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -818,7 +818,7 @@ void NllLossRawInferMeta(const MetaTensor& input, if (reduction == "none") { out->set_dims({x_dims[0]}); } else { - out->set_dims({1}); + out->set_dims(phi::make_ddim({})); } } else if (x_dims.size() == 4) { PADDLE_ENFORCE_EQ(label_dims.size(), @@ -841,10 +841,10 @@ void NllLossRawInferMeta(const MetaTensor& input, if (reduction == "none") { out->set_dims({x_dims[0], x_dims[2], x_dims[3]}); } else { - out->set_dims({1}); + out->set_dims(phi::make_ddim({})); } } - total_weight->set_dims({1}); + total_weight->set_dims(phi::make_ddim({})); out->set_dtype(input.dtype()); total_weight->set_dtype(input.dtype()); } diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index acc7affc00e26..94f4ee6a411be 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -126,7 +126,7 @@ void MatMulFunctionImplWithBlas( M, N)); VLOG(3) << "MatMul's case 1"; - Out->Resize({1}); + Out->Resize(phi::make_ddim({})); dev_ctx.template Alloc(Out); blas.GEMM(CblasNoTrans, CblasTrans, @@ -516,7 +516,7 @@ void MatMulFunctionImplWithCublasLt( N)); // MatMul's case 0 => vector * vector - Out->Resize({1}); + Out->Resize(phi::make_ddim({})); dev_ctx.template Alloc(Out); VLOG(3) << "MatMul with blaslt case 1"; blaslt::Run(dev_ctx, diff --git a/paddle/phi/kernels/nll_loss_kernel.cc b/paddle/phi/kernels/nll_loss_kernel.cc index c1a16f175c14b..b9278d152269c 100644 --- a/paddle/phi/kernels/nll_loss_kernel.cc +++ b/paddle/phi/kernels/nll_loss_kernel.cc @@ -24,8 +24,7 @@ void NllLossKernel(const Context& dev_ctx, const std::string& reduction, DenseTensor* out) { DenseTensor total_weight; - total_weight.set_meta( - DenseTensorMeta(phi::CppTypeToDataType::Type(), {1})); + total_weight.set_meta(DenseTensorMeta(phi::CppTypeToDataType::Type(), {})); dev_ctx.template Alloc(total_weight); NllLossRawKernel(dev_ctx, input, diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index f1ee702c15b66..957a39d1ab4e5 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -856,8 +856,8 @@ class ReshapeTransform(Transform): # [[[1., 1., 1.], # [1., 1., 1.]]]) print(reshape_transform.forward_log_det_jacobian(x)) - # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [0.]) + # Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # 0.) """ _type = Type.BIJECTION @@ -945,8 +945,7 @@ def _inverse_shape(self, shape): ) def _forward_log_det_jacobian(self, x): - # TODO(zhouwei): should not set shape to [1], which is [] - shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1] + shape = x.shape[: x.dim() - len(self._in_event_shape)] return paddle.zeros(shape, dtype=x.dtype) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py index 640391b472d7a..38ea037a65d9c 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py @@ -1029,8 +1029,8 @@ def test_zerodim(self, input, expected): self.assertEqual(out.shape, [1, 1]) self.assertEqual(reshape.inverse(out).shape, []) - # self.assertEqual(reshape.forward_log_det_jacobian(x).shape, []) - # self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, []) + self.assertEqual(reshape.forward_log_det_jacobian(x).shape, []) + self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, []) self.assertEqual(reshape.forward_shape(x.shape), (1, 1)) self.assertEqual(reshape.inverse_shape(out.shape), ()) diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index 30085a841de31..c9a480fb0895a 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -77,12 +77,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): Y = np.transpose(Y, tuple(dim)) Out = np.matmul(X, Y) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float32") return Out @@ -167,9 +161,6 @@ def test_out(self): with fluid.program_guard(fluid.Program()): x = paddle.static.data(name="x", shape=[2], dtype="float64") y = paddle.static.data(name='y', shape=[2], dtype='float64') - res = paddle.static.data( - name="output", shape=[1], dtype="float64" - ) result = paddle.mm(x, y) exe = fluid.Executor(fluid.CPUPlace()) data1 = np.random.rand(2) @@ -177,9 +168,7 @@ def test_out(self): np_res = exe.run( feed={'x': data1, 'y': data2}, fetch_list=[result] ) - expected_result = np.matmul( - data1.reshape(1, 2), data2.reshape(2, 1) - ) + expected_result = np.matmul(data1, data2) np.testing.assert_allclose( np_res, diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py b/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py index 37ac37b6a99d9..475ede6fd05d7 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py @@ -102,12 +102,6 @@ def reference_matmul_mul_head( Y = transpose_mat(Y) Out = matmul_head(X, Y, head_number) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float32") return Out @@ -196,12 +190,6 @@ def reference_matmul_mul_head2( Y = transpose_mat(Y) Out = matmul_head2(X, Y, head_number) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float32") return Out diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index a0c41b63b05f2..9463afead2be4 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -45,12 +45,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): Y = np.transpose(Y, tuple(dim)) Out = np.matmul(X, Y) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float64") return Out diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index ca9f369d01cf6..12183b9cad82c 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2557,6 +2557,47 @@ def body(i, x): self.assertEqual(x.grad.shape, []) np.testing.assert_allclose(x.grad, np.array(1.0)) + def test_to_tensor(self): + out1 = paddle.to_tensor(1) + out2 = paddle.to_tensor(2.5) + + out1.retain_grads() + out1.backward() + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1, 1) + self.assertEqual(out2.shape, []) + self.assertEqual(out2, 2.5) + + def test_matmul(self): + # 1) no transpose + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out1 = paddle.matmul(x, y) + out1.retain_grads() + out1.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(y.grad.shape, [10]) + + # 2) transpose x and y + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out2 = paddle.matmul(x, y, True, True) + out2.retain_grads() + out2.backward() + + self.assertEqual(out2.shape, []) + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(y.grad.shape, [10]) + def test_linalg_slogdet(self): # 2-D input x = paddle.randn([3, 3]) @@ -2595,6 +2636,42 @@ def test_multi_dot(self): self.assertEqual(b.grad.shape, [4, 5]) self.assertEqual(c.grad.shape, [5]) + def test_cov(self): + xt = paddle.randn((3, 4)) + xt.stop_gradient = False + xt_1 = paddle.randn((12,)) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.cov(xt) + xt_out.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3, 3]) + self.assertEqual(xt.grad.shape, [3, 4]) + + xt_1_out = paddle.linalg.cov(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [12]) + + def test_det(self): + xt = paddle.randn([3, 3, 3]) + xt.stop_gradient = False + xt_1 = paddle.randn([3, 3]) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.det(xt) + xt.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3]) + self.assertEqual(xt.grad.shape, [3, 3, 3]) + + xt_1_out = paddle.linalg.det(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [3, 3]) + def test_dist(self): x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32") y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32") @@ -2608,16 +2685,6 @@ def test_dist(self): self.assertEqual(x.grad.shape, [2, 2]) self.assertEqual(y.grad.shape, [2, 2]) - def test_trace(self): - x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") - x.stop_gradient = False - out = paddle.trace(x) - out.backward() - - self.assertEqual(out.shape, []) - np.testing.assert_allclose(out, np.array(12)) - self.assertEqual(x.grad.shape, [2, 2]) - def test_linalg_norm(self): # 1D input, p = fro ,axis = None, using reduceInferMeta x_1 = paddle.arange(24, dtype="float32") - 12 @@ -2769,55 +2836,15 @@ def assert_shape(out): self.assertEqual(len(a_cond_fro.shape), 1) self.assertEqual(a.grad.shape, [2, 4, 4]) - def test_cov(self): - xt = paddle.randn((3, 4)) - xt.stop_gradient = False - xt_1 = paddle.randn((12,)) - xt_1.stop_gradient = False - - xt_out = paddle.linalg.cov(xt) - xt_out.retain_grads() - xt_out.backward() - self.assertEqual(xt_out.shape, [3, 3]) - self.assertEqual(xt.grad.shape, [3, 4]) - - xt_1_out = paddle.linalg.cov(xt_1) - xt_1.retain_grads() - xt_1_out.backward() - self.assertEqual(xt_1_out.shape, []) - self.assertEqual(xt_1.grad.shape, [12]) - - def test_det(self): - xt = paddle.randn([3, 3, 3]) - xt.stop_gradient = False - xt_1 = paddle.randn([3, 3]) - xt_1.stop_gradient = False - - xt_out = paddle.linalg.det(xt) - xt.retain_grads() - xt_out.backward() - self.assertEqual(xt_out.shape, [3]) - self.assertEqual(xt.grad.shape, [3, 3, 3]) - - xt_1_out = paddle.linalg.det(xt_1) - xt_1.retain_grads() - xt_1_out.backward() - self.assertEqual(xt_1_out.shape, []) - self.assertEqual(xt_1.grad.shape, [3, 3]) - - def test_to_tensor(self): - out1 = paddle.to_tensor(1) - out2 = paddle.to_tensor(2.5) - - out1.retain_grads() - out1.backward() - out2.retain_grads() - out2.backward() + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + out.backward() - self.assertEqual(out1.shape, []) - self.assertEqual(out1, 1) - self.assertEqual(out2.shape, []) - self.assertEqual(out2, 2.5) + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(12)) + self.assertEqual(x.grad.shape, [2, 2]) class TestSundryAPIStatic(unittest.TestCase): @@ -4859,6 +4886,53 @@ def test_broadcast_tensors(self): self.assertEqual(out1.shape, (2, 3)) self.assertEqual(out2.shape, (2, 3)) + @prog_scope() + def test_to_tensor(self): + out1 = paddle.to_tensor(1) + out2 = paddle.to_tensor(2.5) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out1, out2]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 1) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[1], 2.5) + + @prog_scope() + def test_matmul(self): + # 1) no transpose + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out = paddle.matmul(x, y) + paddle.static.append_backward(out) + + self.assertEqual(out.shape, ()) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (10,)) + self.assertEqual(res[2].shape, (10,)) + + # 2) transpose x and y + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out = paddle.matmul(x, y, True, True) + paddle.static.append_backward(out) + + self.assertEqual(out.shape, ()) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (10,)) + self.assertEqual(res[2].shape, (10,)) + @prog_scope() def test_linalg_slogdet(self): # 2-D input @@ -4903,6 +4977,33 @@ def test_multi_dot(self): self.assertEqual(res[2].shape, (4, 5)) self.assertEqual(res[3].shape, (5,)) + @prog_scope() + def test_cov(self): + xt_1 = paddle.randn((12,)) + xt_1.stop_gradient = False + + out = paddle.linalg.cov(xt_1) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + + res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (12,)) + + @prog_scope() + def test_det(self): + xt_1 = paddle.randn((3, 3)) + xt_1.stop_gradient = False + + out = paddle.linalg.det(xt_1) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (3, 3)) + @prog_scope() def test_dist(self): x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32") @@ -4920,20 +5021,6 @@ def test_dist(self): self.assertEqual(res[1].shape, (2, 2)) np.testing.assert_array_equal(res[0], np.array(2).astype(np.float32)) - @prog_scope() - def test_trace(self): - x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") - x.stop_gradient = False - out = paddle.trace(x) - paddle.static.append_backward(out) - - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out, x.grad_name]) - - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, (2, 2)) - np.testing.assert_allclose(res[0], np.array(12)) - @prog_scope() def test_linalg_norm(self): # 1D input, p = fro ,axis = None, using reduceInferMeta @@ -5128,44 +5215,18 @@ def test_linalg_cond(self): self.assertEqual(res[1].shape, (2, 4, 4)) @prog_scope() - def test_cov(self): - xt_1 = paddle.randn((12,)) - xt_1.stop_gradient = False - - out = paddle.linalg.cov(xt_1) + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) paddle.static.append_backward(out) prog = paddle.static.default_main_program() - - res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, (12,)) - - @prog_scope() - def test_det(self): - xt_1 = paddle.randn((3, 3)) - xt_1.stop_gradient = False - - out = paddle.linalg.det(xt_1) - paddle.static.append_backward(out.sum()) - - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, (3, 3)) - - @prog_scope() - def test_to_tensor(self): - out1 = paddle.to_tensor(1) - out2 = paddle.to_tensor(2.5) - - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out1, out2]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) self.assertEqual(res[0].shape, ()) - self.assertEqual(res[0], 1) - self.assertEqual(res[1].shape, ()) - self.assertEqual(res[1], 2.5) + self.assertEqual(res[1].shape, (2, 2)) + np.testing.assert_allclose(res[0], np.array(12)) # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. @@ -5994,6 +6055,30 @@ def test_l1_loss(self): self.assertEqual(loss.shape, []) self.assertEqual(input.grad.shape, [3, 5]) + def test_nll_loss(self): + input = paddle.rand([5, 3]) + input.stop_gradient = False + log_softmax = paddle.nn.LogSoftmax(axis=1) + log_out = log_softmax(input) + label = paddle.randint(0, 3, [5], "int64") + + loss = paddle.nn.functional.nll_loss(log_out, label) + loss.backward() + + self.assertEqual(loss.shape, []) + self.assertEqual(input.grad.shape, [5, 3]) + + input = paddle.rand([5, 3, 2, 4]) + input.stop_gradient = False + log_softmax = paddle.nn.LogSoftmax(axis=1) + log_out = log_softmax(input) + label = paddle.randint(0, 3, [5, 2, 4], "int64") + loss = paddle.nn.functional.nll_loss(log_out, label) + loss.backward() + + self.assertEqual(loss.shape, []) + self.assertEqual(input.grad.shape, [5, 3, 2, 4]) + class TestLossAPIStatic(unittest.TestCase): def setUp(self): @@ -6060,6 +6145,40 @@ def test_l1_loss(self): self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, (3, 5)) + @prog_scope() + def test_nll_loss(self): + input = paddle.rand([5, 3]) + input.stop_gradient = False + log_softmax = paddle.nn.LogSoftmax(axis=1) + log_out = log_softmax(input) + + label = paddle.randint(0, 3, shape=[5]) + label.stop_gradient = False + + loss = paddle.nn.functional.nll_loss(log_out, label) + paddle.static.append_backward(loss) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[loss, input.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (5, 3)) + + input = paddle.rand([5, 3, 2, 4]) + input.stop_gradient = False + log_softmax = paddle.nn.LogSoftmax(axis=1) + log_out = log_softmax(input) + + label = paddle.randint(0, 3, shape=[5, 2, 4]) + label.stop_gradient = False + + loss = paddle.nn.functional.nll_loss(log_out, label) + paddle.static.append_backward(loss) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[loss, input.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (5, 3, 2, 4)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a0fc7c1cfc27f..3bc6e21169a68 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -28,7 +28,7 @@ from .creation import full from .logic import logical_not from .manipulation import cast -from .math import add, multiply +from .math import _get_reduce_axis, add, multiply __all__ = [] @@ -200,7 +200,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): y = paddle.rand([10]) z = paddle.matmul(x, y) print(z.shape) - # (1,) + # () # matrix * vector x = paddle.rand([10, 5]) @@ -461,12 +461,7 @@ def inf_norm( reduce_out = helper.create_variable_for_type_inference( dtype=helper.input_dtype() ) - - reduce_all = ( - True if axis is None or axis == [] or asvector else False - ) - axis = axis if axis is not None and axis != [] else [0] - + reduce_all, axis = _get_reduce_axis(axis, x) reduce_type = ( 'reduce_max' if porder == np.float64('inf') else 'reduce_min' ) @@ -518,6 +513,7 @@ def p_matrix_norm(input, porder=1.0, axis=axis, keepdim=False, name=None): sum_out = block.create_variable_for_type_inference( dtype=block.input_dtype() ) + reduce_all, axis = _get_reduce_axis(axis, x) block.append_op( type='reduce_sum', inputs={'X': pow_out}, @@ -525,7 +521,7 @@ def p_matrix_norm(input, porder=1.0, axis=axis, keepdim=False, name=None): attrs={ 'dim': axis, 'keep_dim': keepdim, - 'reduce_all': True if axis is None else False, + 'reduce_all': reduce_all, }, ) block.append_op( @@ -836,8 +832,6 @@ def mat_norm(input, porder=1.0, axis=None): if porder == -1 or porder == -np.inf: return _C_ops.min(sum_out, [-1], False) else: - reduce_all = True if axis is None or axis == [] else False - axis = axis if axis is not None and axis != [] else [0] block = LayerHelper('norm', **locals()) abs_out = block.create_variable_for_type_inference( dtype=block.input_dtype() @@ -851,6 +845,8 @@ def mat_norm(input, porder=1.0, axis=None): block.append_op( type='abs', inputs={'X': input}, outputs={'Out': abs_out} ) + + reduce_all, axis = _get_reduce_axis(axis, x) block.append_op( type='reduce_sum', inputs={'X': abs_out}, @@ -896,7 +892,6 @@ def fro_norm(input, porder=2, axis=[-1]): sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False) return _C_ops.pow(sum_out_2, float(1.0 / porder)) else: - reduce_all = True if axis is None or axis == [] else False block = LayerHelper('norm', **locals()) pow_out = block.create_variable_for_type_inference( dtype=block.input_dtype() @@ -916,6 +911,8 @@ def fro_norm(input, porder=2, axis=[-1]): outputs={'Out': pow_out}, attrs={'factor': porder}, ) + + reduce_all, axis = _get_reduce_axis(axis, x) block.append_op( type='reduce_sum', inputs={'X': pow_out}, @@ -962,7 +959,7 @@ def svd_norm(input, porder, axis=[-1]): if porder == -2: return _C_ops.divide(min_out, max_out) else: - reduce_all = True if axis is None or axis == [] else False + reduce_all, axis = _get_reduce_axis(axis, x) block = LayerHelper('norm', **locals()) out = block.create_variable_for_type_inference( dtype=block.input_dtype() diff --git a/test/autograd/test_autograd_dynamic.py b/test/autograd/test_autograd_dynamic.py index 2e6033a6ee937..cd3e54a814d5c 100644 --- a/test/autograd/test_autograd_dynamic.py +++ b/test/autograd/test_autograd_dynamic.py @@ -90,25 +90,16 @@ def test_jacobian(self): ) self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None) if isinstance(self._actual, (tuple, list)): - self._actual = paddle.concat([x[:] for x in self._actual], axis=1) + self._actual = paddle.concat([x[:] for x in self._actual], axis=0) self._expected = self._get_expected() - Index = collections.namedtuple('Index', ('type', 'value')) - indexes = ( - Index('all', (slice(0, None, None), slice(0, None, None))), - Index('row', (0, slice(0, None, None))), - Index('col', (slice(0, None, None), 0)), - Index('multi-row', (slice(0, 2, 1), slice(0, None, None))), + self.assertEqual(self._actual.numpy().dtype, self._expected.dtype) + np.testing.assert_allclose( + self._actual.flatten(), + self._expected.flatten(), + rtol=self._rtol, + atol=self._atol, ) - self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype) - for index in indexes: - np.testing.assert_allclose( - self._actual.__getitem__(index.value), - self._expected.__getitem__(index.value), - rtol=self._rtol, - atol=self._atol, - err_msg=f'Testcase {index.type} index not passed, value is {index.value}', - ) def test_jacobian_attribute_operator(self): xs = ( @@ -121,25 +112,16 @@ def test_jacobian_attribute_operator(self): ) self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None) if isinstance(self._actual, (tuple, list)): - self._actual = paddle.concat([x[:] for x in self._actual], axis=1) + self._actual = paddle.concat([x[:] for x in self._actual], axis=0) self._expected = self._get_expected() - Index = collections.namedtuple('Index', ('type', 'value')) - indexes = ( - Index('all', (slice(0, None, None), slice(0, None, None))), - Index('row', (0, slice(0, None, None))), - Index('col', (slice(0, None, None), 0)), - Index('multi-row', (slice(0, 2, 1), slice(0, None, None))), - ) self.assertEqual(self._actual.numpy().dtype, self._expected.dtype) - for index in indexes: - np.testing.assert_allclose( - self._actual.__getitem__(index.value), - self._expected.__getitem__(index.value), - rtol=self._rtol, - atol=self._atol, - err_msg=f'Testcase {index.type} index not passed, value is {index.value}', - ) + np.testing.assert_allclose( + self._actual.flatten(), + self._expected.flatten(), + rtol=self._rtol, + atol=self._atol, + ) def _get_expected(self): xs = ( diff --git a/test/autograd/utils.py b/test/autograd/utils.py index de1db9f2a19f5..74f14c38b2a59 100644 --- a/test/autograd/utils.py +++ b/test/autograd/utils.py @@ -398,6 +398,7 @@ def concat_row(xs): return src if not isinstance(src[0], typing.Sequence): src = [src] + return concat_row(tuple(concat_col(xs) for xs in src)) diff --git a/test/mkldnn/test_matmul_v2_mkldnn_op.py b/test/mkldnn/test_matmul_v2_mkldnn_op.py index 958a8cca21e51..a7c7fcfad4602 100644 --- a/test/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/test/mkldnn/test_matmul_v2_mkldnn_op.py @@ -46,7 +46,7 @@ def reference_matmul(X, Y, transpose_x=False, transpose_y=False): dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] Y = np.transpose(Y, tuple(dim)) - Out = np.atleast_1d(np.matmul(X, Y)) + Out = np.matmul(X, Y) return Out diff --git a/test/xpu/test_matmul_op_xpu.py b/test/xpu/test_matmul_op_xpu.py index 07cea1b943c91..fb30c2ecbe879 100644 --- a/test/xpu/test_matmul_op_xpu.py +++ b/test/xpu/test_matmul_op_xpu.py @@ -56,12 +56,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): y_dims = Y.shape Y = Y.reshape((y_dims[0] * y_dims[1], y_dims[2])) Out = np.matmul(X, Y) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float32") return Out @@ -141,9 +135,6 @@ def test_out(self): with fluid.program_guard(fluid.Program()): x = paddle.static.data(name="x", shape=[2], dtype=self.in_type) y = paddle.static.data(name='y', shape=[2], dtype=self.in_type) - res = paddle.static.data( - name="output", shape=[1], dtype=self.in_type - ) result = paddle.mm(x, y) exe = fluid.Executor(fluid.XPUPlace(0)) data1 = np.random.rand(2).astype(self.in_type) @@ -151,9 +142,7 @@ def test_out(self): np_res = exe.run( feed={'x': data1, 'y': data2}, fetch_list=[result] ) - expected_result = np.matmul( - data1.reshape(1, 2), data2.reshape(2, 1) - ) + expected_result = np.matmul(data1, data2) np.testing.assert_allclose(np_res, expected_result, atol=1e-3) diff --git a/test/xpu/test_matmul_v2_op_xpu.py b/test/xpu/test_matmul_v2_op_xpu.py index eb10d1462e466..76909c5fd8325 100644 --- a/test/xpu/test_matmul_v2_op_xpu.py +++ b/test/xpu/test_matmul_v2_op_xpu.py @@ -46,12 +46,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] Y = np.transpose(Y, tuple(dim)) Out = np.matmul(X, Y) - if not Out.shape: - # We do not support 0-dimensional Tensors (scalars). So where - # np.matmul outputs a scalar, we must convert to a Tensor of - # shape (1, ) instead. - # Everywhere else, we are compatible with np.matmul. - Out = np.array([Out], dtype="float64") return Out diff --git a/test/xpu/test_zero_dim_tensor_xpu.py b/test/xpu/test_zero_dim_tensor_xpu.py index 7591b3a402f5c..a836e2e7fb58e 100644 --- a/test/xpu/test_zero_dim_tensor_xpu.py +++ b/test/xpu/test_zero_dim_tensor_xpu.py @@ -2256,6 +2256,33 @@ def test_to_tensor(self): self.assertEqual(out2.shape, []) self.assertEqual(out2, 2.5) + def test_matmul(self): + # 1) no transpose + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out1 = paddle.matmul(x, y) + out1.retain_grads() + out1.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(y.grad.shape, [10]) + + # 2) transpose x and y + x = paddle.randn([10]) + x.stop_gradient = False + y = paddle.randn([10]) + y.stop_gradient = False + out2 = paddle.matmul(x, y, True, True) + out2.retain_grads() + out2.backward() + + self.assertEqual(out2.shape, []) + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(y.grad.shape, [10]) + def test_linalg_slogdet(self): # 2-D input x = paddle.randn([3, 3])