From 2c0a131cc449e02a0e2e2a757f59363ee172d597 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 28 Apr 2022 20:14:34 +0800 Subject: [PATCH 01/31] 2022-04-28 --- paddle/fluid/operators/soft_margin_loss_op.cc | 146 ++++++++++++++ .../phi/core/compat/soft_margin_loss_sig.cc | 12 ++ paddle/phi/infermeta/binary.cc | 39 ++++ paddle/phi/infermeta/binary.h | 5 + .../cpu/soft_margin_loss_grad_kernel.cc | 48 +++++ .../kernels/cpu/soft_margin_loss_kernel.cc | 40 ++++ .../gpu/soft_margin_loss_grad_kernel.cu | 59 ++++++ .../kernels/gpu/soft_margin_loss_kernel.cu | 57 ++++++ .../kernels/soft_margin_loss_grad_kernel.h | 28 +++ paddle/phi/kernels/soft_margin_loss_kernel.h | 27 +++ .../tests/unittests/test_soft_margin_loss.py | 187 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/loss.py | 77 ++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 58 ++++++ 16 files changed, 788 insertions(+) create mode 100644 paddle/fluid/operators/soft_margin_loss_op.cc create mode 100644 paddle/phi/core/compat/soft_margin_loss_sig.cc create mode 100644 paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc create mode 100644 paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu create mode 100644 paddle/phi/kernels/soft_margin_loss_grad_kernel.h create mode 100644 paddle/phi/kernels/soft_margin_loss_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_soft_margin_loss.py diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc new file mode 100644 index 0000000000000..f8a921e3311db --- /dev/null +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SoftMarginLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class SoftMarginLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftMarginLossGrad"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "SoftMarginLossGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "SoftMarginLossGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "SoftMarginLossGrad"); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + bool check = true; + if ((!ctx->IsRuntime()) && + (phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(x_dims, labels_dims, + platform::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Label) is [%s].", + x_dims, labels_dims)); + + PADDLE_ENFORCE_EQ(x_dims, dout_dims, + platform::errors::InvalidArgument( + "Input(X) and Input(Out@Grad) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Out@Grad) is [%s].", + x_dims, dout_dims)); + } + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class SoftMarginLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), the input is a tensor of logits" + "computed by the previous operator. "); + AddInput("Label", + "(Tensor, default Tensor), have same shape with input" + "label should between in 0 and 1."); + AddOutput("Out", + "(Tensor, default Tensor), have same shape with" + "input"); + AddComment(R"DOC( +SoftMarginLoss operator. +This measures the element-wise probability error in classification tasks +in which each class is independent. +The logitstic loss is given as follows: + $$loss = log(1+exp(-Label * X))$$ +)DOC"); + } +}; + +template +class SoftMarginLossGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("soft_margin_loss_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Label", this->Input("Label")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(SoftMarginLossInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(SoftMarginLossGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(soft_margin_loss, + SoftMarginLossInferShapeFunctor, + PD_INFER_META(phi::SoftMarginLossInferMeta)); + +REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, + ops::SoftMarginLossOpMaker, + ops::SoftMarginLossGradOpMaker, + ops::SoftMarginLossGradOpMaker, + ops::SoftMarginLossInplaceInferer, SoftMarginLossInferShapeFunctor); +REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp, + ops::SoftMarginLossGradInplaceInferer); diff --git a/paddle/phi/core/compat/soft_margin_loss_sig.cc b/paddle/phi/core/compat/soft_margin_loss_sig.cc new file mode 100644 index 0000000000000..2eb16a7772b92 --- /dev/null +++ b/paddle/phi/core/compat/soft_margin_loss_sig.cc @@ -0,0 +1,12 @@ +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi{ +KernelSignature SoftMarginLossGradOpArgumentMapping(const ArgumentMappingContext& ctx){ +return KernelSignature("soft_margin_loss_grad", + {GradVarName("Out"),"X","Label"}, + {}, + {GradVarName("X")}); + } +}// namespace phi + +PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 2139605fb2048..d8c99423b5c1c 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1757,6 +1757,45 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, out->share_lod(x); } +void SoftMarginLossInferMeta(const MetaTensor& input, + const MetaTensor& label, + MetaTensor* out, + MetaConfig config) { + auto input_dims = input.dims(); + auto label_dims = label.dims(); + + int rank = input_dims.size(); + PADDLE_ENFORCE_EQ(rank, + label_dims.size(), + phi::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same rank." + "But received: the rank of Input(X) is [%d], " + "the rank of Input(Label) is [%d].", + rank, + label_dims.size())); + + bool check = true; + if ((!config.is_runtime) && + (phi::product(input_dims) <= 0 || phi::product(label_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(input_dims, + label_dims, + phi::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Label) is [%s].", + input_dims, + label_dims)); + } + + out->set_dims(input_dims); + out->set_dtype(input.dtype()); + out->share_lod(input); +} + void TakeAlongAxisInferMeta(const MetaTensor& x, const MetaTensor& index, int axis, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 192fa214c905f..a99480e5b31d6 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -253,6 +253,11 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void SoftMarginLossInferMeta(const MetaTensor& input, + const MetaTensor& label, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void TakeAlongAxisInferMeta(const MetaTensor& x, const MetaTensor& index, int axis, diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc new file mode 100644 index 0000000000000..d7c0f64f805e3 --- /dev/null +++ b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" + +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + + +namespace phi { + +template +void SoftMarginLossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { + auto dx_data = dev_ctx.template Alloc(input_grad); + auto dout_data = out_grad.data(); + auto x_data = input.data(); + auto label_data = label.data(); + + int x_numel = input.numel(); + + // dx = dout * (-label * exp(-label * x))/(1 + exp(-label * x )) + for (int i = 0; i < x_numel; ++i) { + dx_data[i] = + dout_data[i] * ((- label_data[i]*std::exp(-label_data[i]*x_data[i] )) / + std::max((static_cast(1) + std::exp(-label_data[i]*x_data[i])), + static_cast(1e-12))); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + soft_margin_loss_grad, CPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc new file mode 100644 index 0000000000000..1c17850cc2789 --- /dev/null +++ b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/soft_margin_loss_kernel.h" + +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SoftMarginLossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { + auto x_data = input.data(); + auto label_data = label.data(); + auto out_data = dev_ctx.template Alloc(out); + auto x_numel = input.numel(); + + // out = ln(1+exp(-label * x)/(x_numel) + for (int64_t i = 0; i < x_numel; ++i) { + out_data[i] =std::log(static_cast(1) + std::exp(-label_data[i]* x_data[i])); + } +} +} // namespace phi +PD_REGISTER_KERNEL( + soft_margin_loss, CPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu new file mode 100644 index 0000000000000..1c54ece47f842 --- /dev/null +++ b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +struct SoftMarginLossGradFunctor { + T one; + T eps; + + HOSTDEVICE inline SoftMarginLossGradFunctor() { + one = static_cast(1.0f); + eps = static_cast(1e-12); + } + + HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { + T term1 = max((one + std::exp(-label * x)), eps); + return (dout * (-label * std::exp(-label * x)) / term1); + } +}; + +template +void SoftMarginLossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { + dev_ctx.template Alloc(input_grad); + std::vector ins = {&input, &label, &out_grad}; + std::vector outs = {input_grad}; + auto functor = SoftMarginLossGradFunctor(); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + soft_margin_loss_grad, GPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu new file mode 100644 index 0000000000000..751ed2697fae7 --- /dev/null +++ b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/soft_margin_loss_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template +struct SoftMarginLossFunctor { + T one; + + HOSTDEVICE inline SoftMarginLossFunctor() { + one = static_cast(1.0f); + } + + HOSTDEVICE inline T operator()(const T x, const T label) const { + T term1 = std::log(one + std::exp(-label * x)); + return term1; + } +}; + +template +void SoftMarginLossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { + dev_ctx.template Alloc(out); + std::vector ins = {&input, &label}; + std::vector outs = {out}; + auto functor = SoftMarginLossFunctor(); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + soft_margin_loss, GPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/soft_margin_loss_grad_kernel.h b/paddle/phi/kernels/soft_margin_loss_grad_kernel.h new file mode 100644 index 0000000000000..6008360b59270 --- /dev/null +++ b/paddle/phi/kernels/soft_margin_loss_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SoftMarginLossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/soft_margin_loss_kernel.h b/paddle/phi/kernels/soft_margin_loss_kernel.h new file mode 100644 index 0000000000000..171add14f328c --- /dev/null +++ b/paddle/phi/kernels/soft_margin_loss_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SoftMarginLossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py new file mode 100644 index 0000000000000..3bbc8ee0dc4be --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -0,0 +1,187 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from op_test import OpTest + + +def test_static_layer(place, + input_np, + label_np, + reduction='mean',): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.static.data( + name='label', shape=label_np.shape, dtype='float64') + sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) + res = sm_loss(input, label) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + return static_result + + +def test_static_functional(place, + input_np, + label_np, + reduction='mean',): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.static.data( + name='label', shape=label_np.shape, dtype='float64') + + res = paddle.nn.functional.soft_margin_loss( + input, label, reduction=reduction) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + return static_result + + +def test_dygraph_layer(place, + input_np, + label_np, + reduction='mean',): + paddle.disable_static() + sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) + dy_res = sm_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np)) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def test_dygraph_functional(place, + input_np, + label_np, + reduction='mean',): + paddle.disable_static() + input = paddle.to_tensor(input_np) + label = paddle.to_tensor(label_np) + + dy_res = paddle.nn.functional.soft_margin_loss( + input, label, reduction=reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_softmarginloss(input_np, label_np, reduction='mean',): + + expected = np.log(1+np.exp(-label_np * input_np)) + # expected = np.mean(expected, axis=-1) + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestSoftMarginLoss(unittest.TestCase): + def test_SoftMarginLoss(self): + input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64) + label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64) + places = ['cpu'] + if paddle.device.is_compiled_with_cuda(): + places.append('gpu') + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + static_result = test_static_layer(place, input_np, label_np, + reduction) + dy_result = test_dygraph_layer(place, input_np, label_np, + reduction) + expected = calc_softmarginloss(input_np, label_np, reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional(place, input_np, + label_np, reduction) + dy_functional = test_dygraph_functional(place, input_np, + label_np, reduction) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_SoftMarginLoss_error(self): + paddle.disable_static() + self.assertRaises( + ValueError, + paddle.nn.loss.SoftMarginLoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.soft_margin_loss, + input=input, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + +def soft_margin_loss(input, label): + return np.log(1+np.exp(-label * input)) + + +class TestSoftMarginLossOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "soft_margin_loss" + input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") + label_np = np.random.randint(0, 2, self.shape).astype("float64") + output_np = soft_margin_loss(input_np, label_np) + + self.inputs = {'X': input_np, 'Label': label_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [10, 10] + + +class TestSoftMarginLossOpCase1(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 4, 5] + + +class TestSoftMarginLossOpCase2(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 20] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index bceee4b964a33..13c27c266c758 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -106,6 +106,7 @@ from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401 +from .layer.loss import SoftMarginLoss from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 @@ -313,4 +314,5 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', + 'SoftMarginLoss', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 68213d831c550..cee1c035693c5 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -89,6 +89,7 @@ from .loss import square_error_cost # noqa: F401 from .loss import ctc_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401 +from .loss import soft_margin_loss from .norm import batch_norm # noqa: F401 from .norm import instance_norm # noqa: F401 from .norm import layer_norm # noqa: F401 @@ -228,4 +229,5 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'soft_margin_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index ca3ac1772829d..e816be705a0e4 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2225,3 +2225,80 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): return paddle.sum(loss, name=name) elif reduction == 'none': return loss + +def soft_margin_loss(input, label,reduction='mean', + name=None): + """ + This op measures the soft margin loss between input predictions ``input`` + and target labels ``label`` . It can be described as: + .. math:: + Out = log(1 + exp((-label * input))) + If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`. + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + .. math:: + Out = MEAN(Out) + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + .. math:: + Out = SUM(Out) + Parameters: + input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], + N is batch_size, `*` means number of additional dimensions. The ``input`` + should always be the output of sigmod. Available dtype is float32, float64. + label (Tensor): The target labels tensor. 2-D tensor with the same shape as + ``input``. The target labels which values should be numbers between 0 and 1. + Available dtype is float32, float64. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + Returns: + output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``input`` , else the shape of output is scalar. + Examples: + .. code-block:: python + import paddle + input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32') + label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32') + output = paddle.nn.functional.soft_margin_loss(input, label) + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in soft_margin_loss should be 'sum', " + "'mean' or 'none', but received %s, which is not allowed." % + reduction) + if _non_static_mode(): + out = _C_ops.soft_margin_loss(input, label) + if reduction == 'sum': + return _C_ops.reduce_sum(out, "reduce_all", True) + elif reduction == 'mean': + return _C_ops.mean(out) + else: + return out + else: + + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'soft_margin_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['float32', 'float64'], 'soft_margin_loss') + + sub_name = name if reduction == 'none' else None + helper = LayerHelper("soft_margin_loss", name=sub_name) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='soft_margin_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) + + if reduction == 'sum': + return paddle.sum(out, name=name) + elif reduction == 'mean': + return paddle.mean(out, name=name) + else: + return out \ No newline at end of file diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 31364f0281c8a..ed2ca98af5a4c 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -78,6 +78,7 @@ from .loss import CTCLoss # noqa: F401 from .loss import SmoothL1Loss # noqa: F401 from .loss import HingeEmbeddingLoss # noqa: F401 +from .loss import SoftMarginLoss from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index d4e059b6dfa49..1b0bae10b6742 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1302,3 +1302,61 @@ def forward(self, input, label): reduction=self.reduction, margin=self.margin, name=self.name) +class SoftMarginLoss(Layer): + r""" + This op measures the soft margin loss between input predictions ``input`` + and target labels ``label`` . It can be described as: + .. math:: + Out = log(1 + exp((-label * input))) + And this operator applies reduce operation on the loss. + If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. + Parameters: + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + Call Parameters: + input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], + N is batch_size, `*` means number of additional dimensions. The ``logit`` + is usually the output of Linear layer. Available dtype is float32, float64. + label (Tensor): The target labels tensor. 2-D tensor with the same shape as + ``logit``. The target labels which values should be numbers between 0 and 1. + Available dtype is float32, float64. + output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``logit`` , else the shape of output is scalar. + Returns: + A callable object of SoftMarginLoss + Examples: + .. code-block:: python + import paddle + input = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") + label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") + soft_margin_loss = paddle.nn.SoftMarginLoss() + output = soft_margin_loss(input, label) + """ + + def __init__(self, + reduction='mean', + name=None): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + + super(SoftMarginLoss, self).__init__() + self.reduction = reduction + self.name = name + + def forward(self, input, label): + out = paddle.nn.functional.soft_margin_loss( + input, + label, + self.reduction, + self.name) + return out From 5cb644cc8a7aaf66c02c07f3355ea58ba900b681 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 28 Apr 2022 23:08:51 +0800 Subject: [PATCH 02/31] 2022-04-28_V2 --- .../phi/core/compat/soft_margin_loss_sig.cc | 12 --------- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 25 +++++++++++++++++++ 2 files changed, 25 insertions(+), 12 deletions(-) delete mode 100644 paddle/phi/core/compat/soft_margin_loss_sig.cc create mode 100644 paddle/phi/ops/compat/soft_margin_loss_sig.cc diff --git a/paddle/phi/core/compat/soft_margin_loss_sig.cc b/paddle/phi/core/compat/soft_margin_loss_sig.cc deleted file mode 100644 index 2eb16a7772b92..0000000000000 --- a/paddle/phi/core/compat/soft_margin_loss_sig.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi{ -KernelSignature SoftMarginLossGradOpArgumentMapping(const ArgumentMappingContext& ctx){ -return KernelSignature("soft_margin_loss_grad", - {GradVarName("Out"),"X","Label"}, - {}, - {GradVarName("X")}); - } -}// namespace phi - -PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc new file mode 100644 index 0000000000000..2aa30bfc37408 --- /dev/null +++ b/paddle/phi/ops/compat/soft_margin_loss_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi{ +KernelSignature SoftMarginLossGradOpArgumentMapping(const ArgumentMappingContext& ctx){ +return KernelSignature("soft_margin_loss_grad", + {"Out@GRAD","X","Label"}, + {}, + {"X@GRAD"}); + } +}// namespace phi + +PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); \ No newline at end of file From eec4f11acac8b2f0cb30f7b5e2006492a991d4d2 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sat, 30 Apr 2022 20:05:33 +0800 Subject: [PATCH 03/31] 2022-04-30 --- paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc | 5 ++--- paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu | 2 +- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc index d7c0f64f805e3..34f3474ecec12 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc @@ -37,9 +37,8 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, // dx = dout * (-label * exp(-label * x))/(1 + exp(-label * x )) for (int i = 0; i < x_numel; ++i) { dx_data[i] = - dout_data[i] * ((- label_data[i]*std::exp(-label_data[i]*x_data[i] )) / - std::max((static_cast(1) + std::exp(-label_data[i]*x_data[i])), - static_cast(1e-12))); + dout_data[i] * ((- label_data[i]* std::exp(-label_data[i]*x_data[i] )) / + (static_cast(1) + std::exp(-label_data[i]*x_data[i])); } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu index 1c54ece47f842..826cc2f6a9e36 100644 --- a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu @@ -35,7 +35,7 @@ struct SoftMarginLossGradFunctor { } HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { - T term1 = max((one + std::exp(-label * x)), eps); + T term1 = (one + std::exp(-label * x)); return (dout * (-label * std::exp(-label * x)) / term1); } }; diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc index 2aa30bfc37408..7179aa4c5fef8 100644 --- a/paddle/phi/ops/compat/soft_margin_loss_sig.cc +++ b/paddle/phi/ops/compat/soft_margin_loss_sig.cc @@ -16,7 +16,7 @@ namespace phi{ KernelSignature SoftMarginLossGradOpArgumentMapping(const ArgumentMappingContext& ctx){ return KernelSignature("soft_margin_loss_grad", - {"Out@GRAD","X","Label"}, + {"X","Label","Out@GRAD"}, {}, {"X@GRAD"}); } From 7099799b7c3d20b18b17122175b0f1b3396a86af Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sat, 30 Apr 2022 22:31:33 +0800 Subject: [PATCH 04/31] 2022-04-30_V2 --- paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc index 34f3474ecec12..91692c4eddb5d 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc @@ -38,7 +38,7 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, for (int i = 0; i < x_numel; ++i) { dx_data[i] = dout_data[i] * ((- label_data[i]* std::exp(-label_data[i]*x_data[i] )) / - (static_cast(1) + std::exp(-label_data[i]*x_data[i])); + (static_cast(1) + std::exp(-label_data[i]*x_data[i]))); } } } // namespace phi From 04a8215483a584ea6b35274f1e7f7a330c4a22bd Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sun, 1 May 2022 21:07:34 +0800 Subject: [PATCH 05/31] 2022-05-01 --- paddle/fluid/operators/soft_margin_loss_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index f8a921e3311db..f5b278d32c3ec 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -124,10 +124,10 @@ class SoftMarginLossGradOpMaker : public framework::SingleGradOpMaker { } }; -DECLARE_INPLACE_OP_INFERER(SoftMarginLossInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(SoftMarginLossGradInplaceInferer, +//DECLARE_INPLACE_OP_INFERER(SoftMarginLossInplaceInferer, {"X", "Out"}); +/*DECLARE_INPLACE_OP_INFERER(SoftMarginLossGradInplaceInferer, {framework::GradVarName("Out"), - framework::GradVarName("X")}); + framework::GradVarName("X")});*/ } // namespace operators } // namespace paddle From 25e62dbbd5e96060c30363d0c97df431555c141f Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 2 May 2022 03:49:22 +0800 Subject: [PATCH 06/31] 2022-05-02 --- paddle/fluid/operators/soft_margin_loss_op.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index f5b278d32c3ec..8820885175e29 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -141,6 +141,7 @@ REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, ops::SoftMarginLossOpMaker, ops::SoftMarginLossGradOpMaker, ops::SoftMarginLossGradOpMaker, - ops::SoftMarginLossInplaceInferer, SoftMarginLossInferShapeFunctor); -REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp, - ops::SoftMarginLossGradInplaceInferer); + //ops::SoftMarginLossInplaceInferer, + SoftMarginLossInferShapeFunctor); +REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp,); + //ops::SoftMarginLossGradInplaceInferer); From bebb9c37c2d8a3b9aca247dffe2d9b067fcd43ca Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 2 May 2022 10:58:46 +0800 Subject: [PATCH 07/31] 2022-05-02_V2 --- paddle/fluid/operators/soft_margin_loss_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index 8820885175e29..0a4d31272c76d 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -143,5 +143,5 @@ REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, ops::SoftMarginLossGradOpMaker, //ops::SoftMarginLossInplaceInferer, SoftMarginLossInferShapeFunctor); -REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp,); +REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp); //ops::SoftMarginLossGradInplaceInferer); From eeee0075f310bd74b075b5f3cd216df8550bd0bc Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 5 May 2022 14:10:34 +0800 Subject: [PATCH 08/31] 2022-05-05_V1 --- python/paddle/nn/functional/loss.py | 15 ++++++--------- python/paddle/nn/layer/loss.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e816be705a0e4..058224a669d0c 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2229,17 +2229,11 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): def soft_margin_loss(input, label,reduction='mean', name=None): """ - This op measures the soft margin loss between input predictions ``input`` + This APIs measures the soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: .. math:: Out = log(1 + exp((-label * input))) - If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`. - If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: - .. math:: - Out = MEAN(Out) - If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: - .. math:: - Out = SUM(Out) + Parameters: input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], N is batch_size, `*` means number of additional dimensions. The ``input`` @@ -2255,11 +2249,14 @@ def soft_margin_loss(input, label,reduction='mean', Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. + Examples: .. code-block:: python + import paddle input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32') label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32') @@ -2301,4 +2298,4 @@ def soft_margin_loss(input, label,reduction='mean', elif reduction == 'mean': return paddle.mean(out, name=name) else: - return out \ No newline at end of file + return out diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 1b0bae10b6742..e2cd14968d641 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1302,16 +1302,15 @@ def forward(self, input, label): reduction=self.reduction, margin=self.margin, name=self.name) + class SoftMarginLoss(Layer): r""" - This op measures the soft margin loss between input predictions ``input`` + Creates a criterion that measures a two-class + soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: .. math:: Out = log(1 + exp((-label * input))) - And this operator applies reduce operation on the loss. - If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. - If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. - If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. + Parameters: reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. @@ -1321,6 +1320,7 @@ class SoftMarginLoss(Layer): Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Call Parameters: input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], N is batch_size, `*` means number of additional dimensions. The ``logit`` @@ -1330,10 +1330,13 @@ class SoftMarginLoss(Layer): Available dtype is float32, float64. output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``logit`` , else the shape of output is scalar. + Returns: A callable object of SoftMarginLoss + Examples: .. code-block:: python + import paddle input = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") From 298c2385191307eb7022aa8a5e7e4c7dc77e44f4 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 6 May 2022 23:03:09 +0800 Subject: [PATCH 09/31] 2022-05-06_V1 --- python/paddle/nn/functional/loss.py | 4 +++- python/paddle/nn/layer/loss.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 058224a669d0c..d093b472feeb5 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2235,6 +2235,7 @@ def soft_margin_loss(input, label,reduction='mean', Out = log(1 + exp((-label * input))) Parameters: + input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], N is batch_size, `*` means number of additional dimensions. The ``input`` should always be the output of sigmod. Available dtype is float32, float64. @@ -2251,7 +2252,8 @@ def soft_margin_loss(input, label,reduction='mean', For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + + Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. Examples: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index e2cd14968d641..0d0e303e6fe76 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1308,30 +1308,33 @@ class SoftMarginLoss(Layer): Creates a criterion that measures a two-class soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: + .. math:: Out = log(1 + exp((-label * input))) Parameters: + reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. Default is ``'mean'``. - name (str, optional): Name for the operation (optional, default is None). - For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Call Parameters: - input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], - N is batch_size, `*` means number of additional dimensions. The ``logit`` + + Input (Tensor): The input tensor. 2-D tensor with shape: [N, *], + N is batch_size, `*` means number of additional dimensions. The ``input`` is usually the output of Linear layer. Available dtype is float32, float64. - label (Tensor): The target labels tensor. 2-D tensor with the same shape as - ``logit``. The target labels which values should be numbers between 0 and 1. + Label (Tensor): The target labels tensor. 2-D tensor with the same shape as + ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float32, float64. - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is - same as ``logit`` , else the shape of output is scalar. + Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``input`` , else the shape of output is scalar. Returns: + A callable object of SoftMarginLoss Examples: From d133b4bf122212cf73159b25105aa0ebbe114a69 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sat, 7 May 2022 22:16:43 +0800 Subject: [PATCH 10/31] 2022-05-07_V1 --- paddle/fluid/operators/soft_margin_loss_op.cc | 6 +----- python/paddle/nn/functional/loss.py | 19 ++++++++++++------- python/paddle/nn/layer/loss.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index 0a4d31272c76d..89c315e00dfd8 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -94,7 +94,7 @@ class SoftMarginLossOpMaker : public framework::OpProtoAndCheckerMaker { "computed by the previous operator. "); AddInput("Label", "(Tensor, default Tensor), have same shape with input" - "label should between in 0 and 1."); + "label should be -1 or 1."); AddOutput("Out", "(Tensor, default Tensor), have same shape with" "input"); @@ -124,10 +124,6 @@ class SoftMarginLossGradOpMaker : public framework::SingleGradOpMaker { } }; -//DECLARE_INPLACE_OP_INFERER(SoftMarginLossInplaceInferer, {"X", "Out"}); -/*DECLARE_INPLACE_OP_INFERER(SoftMarginLossGradInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")});*/ } // namespace operators } // namespace paddle diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d093b472feeb5..80b6e411ec74b 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2226,28 +2226,33 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): elif reduction == 'none': return loss + def soft_margin_loss(input, label,reduction='mean', name=None): """ This APIs measures the soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: + .. math:: Out = log(1 + exp((-label * input))) Parameters: - input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], - N is batch_size, `*` means number of additional dimensions. The ``input`` - should always be the output of sigmod. Available dtype is float32, float64. - label (Tensor): The target labels tensor. 2-D tensor with the same shape as - ``input``. The target labels which values should be numbers between 0 and 1. - Available dtype is float32, float64. + input (Tensor): The input predications tensor with shape: [N, *], + N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf. + Available dtype is float32, float64. + + label (Tensor): The target labels tensor with the same shape as + ``input``. The target labels which values should be numbers -1 or 1. + Available dtype is int32, int64, float32, float64. + reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -2261,7 +2266,7 @@ def soft_margin_loss(input, label,reduction='mean', import paddle input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32') - label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32') + label = paddle.to_tensor([1.0, -1.0, 1.0], 'float32') output = paddle.nn.functional.soft_margin_loss(input, label) """ if reduction not in ['sum', 'mean', 'none']: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 0d0e303e6fe76..023995e49e9f3 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1320,16 +1320,19 @@ class SoftMarginLoss(Layer): If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Call Parameters: - Input (Tensor): The input tensor. 2-D tensor with shape: [N, *], - N is batch_size, `*` means number of additional dimensions. The ``input`` - is usually the output of Linear layer. Available dtype is float32, float64. - Label (Tensor): The target labels tensor. 2-D tensor with the same shape as - ``input``. The target labels which values should be numbers between 0 and 1. + Input (Tensor): The input tensor with shape: [N, *], + N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf Available dtype is float32, float64. + + Label (Tensor): The target labels tensor with the same shape as + ``input``. The target labels which values should be numbers -1 or 1. + Available dtype is int32, int64, float32, float64. + Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. @@ -1342,7 +1345,7 @@ class SoftMarginLoss(Layer): import paddle input = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") - label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") + label = paddle.to_tensor([1, -1, 1], dtype="float32") soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label) """ From 90dd6ed3fc87ec8e80867db441c5b8a427518c0e Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Mon, 9 May 2022 22:20:00 +0800 Subject: [PATCH 11/31] Update loss.py --- python/paddle/nn/layer/loss.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 023995e49e9f3..07c1c224434b8 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1313,7 +1313,11 @@ class SoftMarginLoss(Layer): Out = log(1 + exp((-label * input))) Parameters: - + weight (Tensor, optional): Weight tensor, a manual rescaling weight given + to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise, + it treated as if having all ones. the data type is + float32, float64, Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; @@ -1323,15 +1327,15 @@ class SoftMarginLoss(Layer): name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - Call Parameters: + Shapes: Input (Tensor): The input tensor with shape: [N, *], - N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf - Available dtype is float32, float64. + N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf + Available dtype is float32, float64. Label (Tensor): The target labels tensor with the same shape as - ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is int32, int64, float32, float64. + ``input``. The target labels which values should be numbers -1 or 1. + Available dtype is int32, int64, float32, float64. Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. From 371b955e9f5edc6f78f2dc6c7c2bfcea0f8b1fbb Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sat, 7 May 2022 22:35:10 +0800 Subject: [PATCH 12/31] 2022-05-07_V2 --- python/paddle/nn/layer/loss.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 07c1c224434b8..84025b646d1cb 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1337,9 +1337,6 @@ class SoftMarginLoss(Layer): ``input``. The target labels which values should be numbers -1 or 1. Available dtype is int32, int64, float32, float64. - Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is - same as ``input`` , else the shape of output is scalar. - Returns: A callable object of SoftMarginLoss From 4a85b06e4a160d7f4d93e58b1b4da75080353427 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 13 May 2022 16:52:40 +0800 Subject: [PATCH 13/31] 2022-05-13_V1 --- .../cpu/soft_margin_loss_grad_kernel.cc | 2 +- .../kernels/cpu/soft_margin_loss_kernel.cc | 2 +- .../gpu/soft_margin_loss_grad_kernel.cu | 2 +- .../kernels/gpu/soft_margin_loss_kernel.cu | 2 +- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 2 +- .../tests/unittests/test_soft_margin_loss.py | 12 ++--- python/paddle/nn/functional/loss.py | 46 ++++++++++--------- python/paddle/nn/layer/loss.py | 20 ++++---- 8 files changed, 47 insertions(+), 41 deletions(-) diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc index 91692c4eddb5d..80d09ffa75d22 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc @@ -44,4 +44,4 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - soft_margin_loss_grad, CPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} \ No newline at end of file + soft_margin_loss_grad, CPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc index 1c17850cc2789..181e8ae5911af 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc @@ -37,4 +37,4 @@ void SoftMarginLossKernel(const Context& dev_ctx, } } // namespace phi PD_REGISTER_KERNEL( - soft_margin_loss, CPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} \ No newline at end of file + soft_margin_loss, CPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu index 826cc2f6a9e36..f774b0c5d7598 100644 --- a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu @@ -56,4 +56,4 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - soft_margin_loss_grad, GPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} \ No newline at end of file + soft_margin_loss_grad, GPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu index 751ed2697fae7..d0d087eb87759 100644 --- a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu @@ -54,4 +54,4 @@ void SoftMarginLossKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - soft_margin_loss, GPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} \ No newline at end of file + soft_margin_loss, GPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc index 7179aa4c5fef8..f81557a1a22d4 100644 --- a/paddle/phi/ops/compat/soft_margin_loss_sig.cc +++ b/paddle/phi/ops/compat/soft_margin_loss_sig.cc @@ -22,4 +22,4 @@ return KernelSignature("soft_margin_loss_grad", } }// namespace phi -PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); \ No newline at end of file +PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 3bbc8ee0dc4be..4418cb9e495cd 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -107,8 +107,8 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): - input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64) - label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64) + input_np = np.random.uniform(0.1, 0.8, size=(10, 10)).astype(np.float64) + label_np = np.random.randint(1, size=(10, 10)).astype(np.float64) places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') @@ -138,7 +138,7 @@ def test_SoftMarginLoss_error(self): paddle.nn.loss.SoftMarginLoss, reduction="unsupport reduction") input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') - label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + label = paddle.to_tensor([[-1.0, 1.0]], dtype='float32') self.assertRaises( ValueError, paddle.nn.functional.soft_margin_loss, @@ -157,7 +157,7 @@ def setUp(self): self.init_test_case() self.op_type = "soft_margin_loss" input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") - label_np = np.random.randint(0, 2, self.shape).astype("float64") + label_np = np.random.randint(1, self.shape).astype("float64") output_np = soft_margin_loss(input_np, label_np) self.inputs = {'X': input_np, 'Label': label_np} @@ -170,7 +170,7 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') def init_test_case(self): - self.shape = [10, 10] + self.shape = [5, 5] class TestSoftMarginLossOpCase1(OpTest): @@ -184,4 +184,4 @@ def init_test_cast(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 80b6e411ec74b..be2f372399843 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2268,6 +2268,11 @@ def soft_margin_loss(input, label,reduction='mean', input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32') label = paddle.to_tensor([1.0, -1.0, 1.0], 'float32') output = paddle.nn.functional.soft_margin_loss(input, label) + + shape = (5, 5) + input = np.random.uniform(0, 2, shape).astype('float32') + label = np.random.randint(1, shape).astype('float32') + output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') """ if reduction not in ['sum', 'mean', 'none']: raise ValueError( @@ -2282,27 +2287,26 @@ def soft_margin_loss(input, label,reduction='mean', return _C_ops.mean(out) else: return out - else: - fluid.data_feeder.check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'soft_margin_loss') - fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['float32', 'float64'], 'soft_margin_loss') + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'soft_margin_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['int32','int64','float32', 'float64'], 'soft_margin_loss') - sub_name = name if reduction == 'none' else None - helper = LayerHelper("soft_margin_loss", name=sub_name) - out = helper.create_variable_for_type_inference(dtype=input.dtype) - helper.append_op( - type='soft_margin_loss', - inputs={ - 'X': [input], - 'Label': [label], - }, - outputs={'Out': [out]}) + sub_name = name if reduction == 'none' else None + helper = LayerHelper("soft_margin_loss", name=sub_name) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='soft_margin_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) - if reduction == 'sum': - return paddle.sum(out, name=name) - elif reduction == 'mean': - return paddle.mean(out, name=name) - else: - return out + if reduction == 'sum': + return paddle.sum(out, name=name) + elif reduction == 'mean': + return paddle.mean(out, name=name) + else: + return out diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 84025b646d1cb..368017b64c167 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1313,11 +1313,7 @@ class SoftMarginLoss(Layer): Out = log(1 + exp((-label * input))) Parameters: - weight (Tensor, optional): Weight tensor, a manual rescaling weight given - to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise, - it treated as if having all ones. the data type is - float32, float64, Default is ``'None'``. - + reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; @@ -1330,12 +1326,12 @@ class SoftMarginLoss(Layer): Shapes: Input (Tensor): The input tensor with shape: [N, *], - N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf - Available dtype is float32, float64. + N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf + Available dtype is float32, float64. Label (Tensor): The target labels tensor with the same shape as - ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is int32, int64, float32, float64. + ``input``. The target labels which values should be numbers -1 or 1. + Available dtype is int32, int64, float32, float64. Returns: @@ -1349,6 +1345,12 @@ class SoftMarginLoss(Layer): label = paddle.to_tensor([1, -1, 1], dtype="float32") soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label) + + shape = (5, 5) + input = np.random.uniform(0, 2, shape).astype('float32') + label = np.random.randint(1, shape).astype('float32') + soft_margin_loss = paddle.nn.SoftMarginLoss() + output = soft_margin_loss(input, label,reduction='none') """ def __init__(self, From ec502c6611072c3bd34dd018467e9ac6d36a0174 Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Mon, 16 May 2022 14:09:36 +0800 Subject: [PATCH 14/31] Update test_soft_margin_loss.py --- .../paddle/fluid/tests/unittests/test_soft_margin_loss.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 4418cb9e495cd..c42e1fb914092 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -108,7 +108,8 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): input_np = np.random.uniform(0.1, 0.8, size=(10, 10)).astype(np.float64) - label_np = np.random.randint(1, size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 2, size=(10, 10)).astype(np.float64) + label_np[label_np==0]=-1 places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') @@ -157,7 +158,8 @@ def setUp(self): self.init_test_case() self.op_type = "soft_margin_loss" input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") - label_np = np.random.randint(1, self.shape).astype("float64") + label_np = np.random.randint(0, 2, self.shape).astype("float64") + label_np[label_np==0]=-1 output_np = soft_margin_loss(input_np, label_np) self.inputs = {'X': input_np, 'Label': label_np} @@ -170,7 +172,7 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') def init_test_case(self): - self.shape = [5, 5] + self.shape = [10, 10] class TestSoftMarginLossOpCase1(OpTest): From d593f0c08d567efcaad02a7cb51668e1045f1a4e Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Mon, 16 May 2022 14:12:15 +0800 Subject: [PATCH 15/31] Update loss.py --- python/paddle/nn/functional/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index be2f372399843..d63e9c59bb308 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2259,7 +2259,7 @@ def soft_margin_loss(input, label,reduction='mean', Returns: Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is - same as ``input`` , else the shape of output is scalar. + same as ``input`` , else the shape of output is [1]. Examples: .. code-block:: python From 8e49f6e859ae48dd42b3d1eba40fa0941ee9cfea Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Mon, 16 May 2022 18:25:55 +0800 Subject: [PATCH 16/31] Update loss.py --- python/paddle/nn/functional/loss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d63e9c59bb308..0b0476f012327 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2265,13 +2265,14 @@ def soft_margin_loss(input, label,reduction='mean', .. code-block:: python import paddle - input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32') - label = paddle.to_tensor([1.0, -1.0, 1.0], 'float32') + input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32') + label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32') output = paddle.nn.functional.soft_margin_loss(input, label) shape = (5, 5) input = np.random.uniform(0, 2, shape).astype('float32') - label = np.random.randint(1, shape).astype('float32') + label = np.random.randint(0, 2, shape).astype('float32') + label[label==0]=-1 output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') """ if reduction not in ['sum', 'mean', 'none']: @@ -2293,7 +2294,7 @@ def soft_margin_loss(input, label,reduction='mean', fluid.data_feeder.check_variable_and_dtype( label, 'label', ['int32','int64','float32', 'float64'], 'soft_margin_loss') - sub_name = name if reduction == 'none' else None + sub_name = name helper = LayerHelper("soft_margin_loss", name=sub_name) out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( From 01012016042b7e56115a4c549f01590fa1bc3869 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 16 May 2022 22:45:49 +0800 Subject: [PATCH 17/31] 2022-05-16_V1 --- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 1 + .../tests/unittests/test_soft_margin_loss.py | 38 ++++++++++--------- python/paddle/nn/functional/loss.py | 2 +- python/paddle/nn/layer/loss.py | 7 ++-- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc index f81557a1a22d4..e358795f9f5bc 100644 --- a/paddle/phi/ops/compat/soft_margin_loss_sig.cc +++ b/paddle/phi/ops/compat/soft_margin_loss_sig.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "paddle/phi/core/compat/op_utils.h" namespace phi{ diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index c42e1fb914092..1e046f426676f 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -107,30 +107,32 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): - input_np = np.random.uniform(0.1, 0.8, size=(10, 10)).astype(np.float64) - label_np = np.random.randint(0, 2, size=(10, 10)).astype(np.float64) - label_np[label_np==0]=-1 + input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) + types = ['int32','int64','float32','float64'] places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') reductions = ['sum', 'mean', 'none'] for place in places: for reduction in reductions: - static_result = test_static_layer(place, input_np, label_np, - reduction) - dy_result = test_dygraph_layer(place, input_np, label_np, - reduction) - expected = calc_softmarginloss(input_np, label_np, reduction) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) - static_functional = test_static_functional(place, input_np, - label_np, reduction) - dy_functional = test_dygraph_functional(place, input_np, - label_np, reduction) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + for _type in types: + label_np = np.random.randint(0, 2, size=(5, 5)).astype(_type) + label_np[label_np == 0] = -1 + static_result = test_static_layer(place, input_np, label_np, + reduction) + dy_result = test_dygraph_layer(place, input_np, label_np, + reduction) + expected = calc_softmarginloss(input_np, label_np, reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional(place, input_np, + label_np, reduction) + dy_functional = test_dygraph_functional(place, input_np, + label_np, reduction) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) def test_SoftMarginLoss_error(self): paddle.disable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 0b0476f012327..df625dfe8ffcb 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2272,7 +2272,7 @@ def soft_margin_loss(input, label,reduction='mean', shape = (5, 5) input = np.random.uniform(0, 2, shape).astype('float32') label = np.random.randint(0, 2, shape).astype('float32') - label[label==0]=-1 + label[label==0]=-1 output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') """ if reduction not in ['sum', 'mean', 'none']: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 368017b64c167..2f6b9cc4aa8e0 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1341,14 +1341,15 @@ class SoftMarginLoss(Layer): .. code-block:: python import paddle - input = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") - label = paddle.to_tensor([1, -1, 1], dtype="float32") + input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32') + label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32') soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label) shape = (5, 5) input = np.random.uniform(0, 2, shape).astype('float32') - label = np.random.randint(1, shape).astype('float32') + label = np.random.randint(0, 2, shape).astype('float32') + label[label==0]=-1 soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label,reduction='none') """ From 4c5bd1615f7d4eb8a8a658938ae6da55bd3e27ea Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 19 May 2022 20:14:15 +0800 Subject: [PATCH 18/31] 2022-05-19_V1 --- paddle/fluid/operators/soft_margin_loss_op.cc | 2 - .../tests/unittests/test_soft_margin_loss.py | 44 +++++++++---------- python/paddle/nn/functional/loss.py | 4 +- python/paddle/nn/layer/loss.py | 2 +- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index 89c315e00dfd8..e3a6278572ad8 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -137,7 +137,5 @@ REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, ops::SoftMarginLossOpMaker, ops::SoftMarginLossGradOpMaker, ops::SoftMarginLossGradOpMaker, - //ops::SoftMarginLossInplaceInferer, SoftMarginLossInferShapeFunctor); REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp); - //ops::SoftMarginLossGradInplaceInferer); diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 1e046f426676f..ce7af805c55bb 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -27,9 +27,9 @@ def test_static_layer(place, startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=input_np.shape, dtype='float64') + name='input', shape=input_np.shape, dtype=input_np.dtype) label = paddle.static.data( - name='label', shape=label_np.shape, dtype='float64') + name='label', shape=label_np.shape, dtype=label_np.dtype) sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) res = sm_loss(input, label) exe = paddle.static.Executor(place) @@ -49,9 +49,9 @@ def test_static_functional(place, startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=input_np.shape, dtype='float64') + name='input', shape=input_np.shape, dtype=input_np.dtype) label = paddle.static.data( - name='label', shape=label_np.shape, dtype='float64') + name='label', shape=label_np.shape, dtype=label_np.dtype) res = paddle.nn.functional.soft_margin_loss( input, label, reduction=reduction) @@ -108,31 +108,29 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) - types = ['int32','int64','float32','float64'] + label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.float64) + label_np[label_np == 0] = -1 places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') reductions = ['sum', 'mean', 'none'] for place in places: for reduction in reductions: - for _type in types: - label_np = np.random.randint(0, 2, size=(5, 5)).astype(_type) - label_np[label_np == 0] = -1 - static_result = test_static_layer(place, input_np, label_np, - reduction) - dy_result = test_dygraph_layer(place, input_np, label_np, - reduction) - expected = calc_softmarginloss(input_np, label_np, reduction) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) - static_functional = test_static_functional(place, input_np, - label_np, reduction) - dy_functional = test_dygraph_functional(place, input_np, - label_np, reduction) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + static_result = test_static_layer(place, input_np, label_np, + reduction) + dy_result = test_dygraph_layer(place, input_np, label_np, + reduction) + expected = calc_softmarginloss(input_np, label_np, reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional(place, input_np, + label_np, reduction) + dy_functional = test_dygraph_functional(place, input_np, + label_np, reduction) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) def test_SoftMarginLoss_error(self): paddle.disable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index df625dfe8ffcb..240ecf2565126 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2244,7 +2244,7 @@ def soft_margin_loss(input, label,reduction='mean', label (Tensor): The target labels tensor with the same shape as ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is int32, int64, float32, float64. + Available dtype is float32, float64. reduction (str, optional): Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. @@ -2292,7 +2292,7 @@ def soft_margin_loss(input, label,reduction='mean', fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'soft_margin_loss') fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['int32','int64','float32', 'float64'], 'soft_margin_loss') + label, 'label', ['float32', 'float64'], 'soft_margin_loss') sub_name = name helper = LayerHelper("soft_margin_loss", name=sub_name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 2f6b9cc4aa8e0..508013ca1268a 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1331,7 +1331,7 @@ class SoftMarginLoss(Layer): Label (Tensor): The target labels tensor with the same shape as ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is int32, int64, float32, float64. + Available dtype is float32, float64. Returns: From 84dae9a685fc65e1b7389aefac689adc6a992110 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 20 May 2022 16:03:49 +0800 Subject: [PATCH 19/31] 2022-05-20_V1 --- .../tests/unittests/test_soft_margin_loss.py | 36 ++++++++++--------- python/paddle/nn/functional/loss.py | 7 ++-- python/paddle/nn/layer/loss.py | 2 +- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index ce7af805c55bb..14df63ce93274 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -108,29 +108,31 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) - label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.float64) - label_np[label_np == 0] = -1 + types = [np.int32,np.int64,np.float32,np.float64] places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') reductions = ['sum', 'mean', 'none'] for place in places: for reduction in reductions: - static_result = test_static_layer(place, input_np, label_np, - reduction) - dy_result = test_dygraph_layer(place, input_np, label_np, - reduction) - expected = calc_softmarginloss(input_np, label_np, reduction) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) - static_functional = test_static_functional(place, input_np, - label_np, reduction) - dy_functional = test_dygraph_functional(place, input_np, - label_np, reduction) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + for _type in types: + label_np = np.random.randint(0, 2, size=(5, 5)).astype(_types) + label_np[label_np == 0] = -1 + static_result = test_static_layer(place, input_np, label_np, + reduction) + dy_result = test_dygraph_layer(place, input_np, label_np, + reduction) + expected = calc_softmarginloss(input_np, label_np, reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional(place, input_np, + label_np, reduction) + dy_functional = test_dygraph_functional(place, input_np, + label_np, reduction) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) def test_SoftMarginLoss_error(self): paddle.disable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 94bd7c227a919..85fe40f3a14c8 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2244,7 +2244,7 @@ def soft_margin_loss(input, label,reduction='mean', input (Tensor): The input predications tensor with shape: [N, *], N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf. - Available dtype is float32, float64. + Available dtype is int32, int64, float32, float64. label (Tensor): The target labels tensor with the same shape as ``input``. The target labels which values should be numbers -1 or 1. @@ -2285,6 +2285,7 @@ def soft_margin_loss(input, label,reduction='mean', "'mean' or 'none', but received %s, which is not allowed." % reduction) if _non_static_mode(): + label = _C_ops.cast(label, 'in_dtype', label.dtype, 'out_dtype', input.dtype) out = _C_ops.soft_margin_loss(input, label) if reduction == 'sum': return _C_ops.reduce_sum(out, "reduce_all", True) @@ -2296,7 +2297,9 @@ def soft_margin_loss(input, label,reduction='mean', fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'soft_margin_loss') fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['float32', 'float64'], 'soft_margin_loss') + label, 'label', ['int32', 'int64', 'float32', 'float64'], 'soft_margin_loss') + + label = fluid.layers.cast(label, input.dtype) sub_name = name helper = LayerHelper("soft_margin_loss", name=sub_name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 5d81d59202352..c7806a3a2a61a 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1335,7 +1335,7 @@ class SoftMarginLoss(Layer): Label (Tensor): The target labels tensor with the same shape as ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is float32, float64. + Available dtype is int32, int64, float32, float64. Returns: From cb649181f3ead2789b09734da0a6bce2c94dbb8e Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Fri, 20 May 2022 17:28:03 +0800 Subject: [PATCH 20/31] Update test_soft_margin_loss.py --- python/paddle/fluid/tests/unittests/test_soft_margin_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 14df63ce93274..94e8e2a186d85 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -116,7 +116,7 @@ def test_SoftMarginLoss(self): for place in places: for reduction in reductions: for _type in types: - label_np = np.random.randint(0, 2, size=(5, 5)).astype(_types) + label_np = np.random.randint(0, 2, size=(5, 5)).astype(_type) label_np[label_np == 0] = -1 static_result = test_static_layer(place, input_np, label_np, reduction) From e7d23ed539de93d14f2b29af4e9f9ba84d51210b Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 1 Jun 2022 22:30:06 +0800 Subject: [PATCH 21/31] 2022-06-01_V1 --- python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/loss.py | 13 +++++++------ python/paddle/nn/layer/loss.py | 19 +++++++++---------- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 13c27c266c758..30cfe8de412ec 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -314,5 +314,6 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', + 'RReLU', 'SoftMarginLoss', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index cee1c035693c5..b57d01d5cc006 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -229,5 +229,6 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'rrelu', 'soft_margin_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 85fe40f3a14c8..8d959302590b1 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2231,8 +2231,7 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): return loss -def soft_margin_loss(input, label,reduction='mean', - name=None): +def soft_margin_loss(input, label, reduction='mean', name=None): """ This APIs measures the soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: @@ -2274,8 +2273,8 @@ def soft_margin_loss(input, label,reduction='mean', output = paddle.nn.functional.soft_margin_loss(input, label) shape = (5, 5) - input = np.random.uniform(0, 2, shape).astype('float32') - label = np.random.randint(0, 2, shape).astype('float32') + input = paddle.uniform(shape, 0, 2, dtype='float32') + label = paddle.randint(shape, 0, 2, dtype='float32') label[label==0]=-1 output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') """ @@ -2285,7 +2284,8 @@ def soft_margin_loss(input, label,reduction='mean', "'mean' or 'none', but received %s, which is not allowed." % reduction) if _non_static_mode(): - label = _C_ops.cast(label, 'in_dtype', label.dtype, 'out_dtype', input.dtype) + label = _C_ops.cast(label, 'in_dtype', label.dtype, 'out_dtype', + input.dtype) out = _C_ops.soft_margin_loss(input, label) if reduction == 'sum': return _C_ops.reduce_sum(out, "reduce_all", True) @@ -2297,7 +2297,8 @@ def soft_margin_loss(input, label,reduction='mean', fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'soft_margin_loss') fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['int32', 'int64', 'float32', 'float64'], 'soft_margin_loss') + label, 'label', ['int32', 'int64', 'float32', 'float64'], + 'soft_margin_loss') label = fluid.layers.cast(label, input.dtype) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index c7806a3a2a61a..f2c72fddb94f6 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1307,6 +1307,7 @@ def forward(self, input, label): margin=self.margin, name=self.name) + class SoftMarginLoss(Layer): r""" Creates a criterion that measures a two-class @@ -1337,6 +1338,9 @@ class SoftMarginLoss(Layer): ``input``. The target labels which values should be numbers -1 or 1. Available dtype is int32, int64, float32, float64. + Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``input`` , else the shape of output is [1]. + Returns: A callable object of SoftMarginLoss @@ -1351,16 +1355,14 @@ class SoftMarginLoss(Layer): output = soft_margin_loss(input, label) shape = (5, 5) - input = np.random.uniform(0, 2, shape).astype('float32') - label = np.random.randint(0, 2, shape).astype('float32') + input = paddle.uniform(shape, 0, 2, dtype='float32') + label = paddle.randint(shape, 0, 2, dtype='float32') label[label==0]=-1 soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label,reduction='none') """ - def __init__(self, - reduction='mean', - name=None): + def __init__(self, reduction='mean', name=None): if reduction not in ['sum', 'mean', 'none']: raise ValueError( "The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but " @@ -1371,9 +1373,6 @@ def __init__(self, self.name = name def forward(self, input, label): - out = paddle.nn.functional.soft_margin_loss( - input, - label, - self.reduction, - self.name) + out = paddle.nn.functional.soft_margin_loss(input, label, + self.reduction, self.name) return out From aab11368200b13f84f5c4b6cf6e74a4f5c66a330 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sun, 5 Jun 2022 18:44:21 +0800 Subject: [PATCH 22/31] 2022-06-05 --- .../tests/unittests/test_soft_margin_loss.py | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 94e8e2a186d85..086f7ae1dab93 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -18,10 +18,11 @@ from op_test import OpTest -def test_static_layer(place, - input_np, - label_np, - reduction='mean',): +def test_static_layer( + place, + input_np, + label_np, + reduction='mean', ): paddle.enable_static() prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -40,10 +41,11 @@ def test_static_layer(place, return static_result -def test_static_functional(place, - input_np, - label_np, - reduction='mean',): +def test_static_functional( + place, + input_np, + label_np, + reduction='mean', ): paddle.enable_static() prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -63,10 +65,11 @@ def test_static_functional(place, return static_result -def test_dygraph_layer(place, - input_np, - label_np, - reduction='mean',): +def test_dygraph_layer( + place, + input_np, + label_np, + reduction='mean', ): paddle.disable_static() sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) dy_res = sm_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np)) @@ -75,10 +78,11 @@ def test_dygraph_layer(place, return dy_result -def test_dygraph_functional(place, - input_np, - label_np, - reduction='mean',): +def test_dygraph_functional( + place, + input_np, + label_np, + reduction='mean', ): paddle.disable_static() input = paddle.to_tensor(input_np) label = paddle.to_tensor(label_np) @@ -90,9 +94,11 @@ def test_dygraph_functional(place, return dy_result -def calc_softmarginloss(input_np, label_np, reduction='mean',): - - expected = np.log(1+np.exp(-label_np * input_np)) +def calc_softmarginloss( + input_np, + label_np, + reduction='mean', ): + expected = np.log(1 + np.exp(-label_np * input_np)) # expected = np.mean(expected, axis=-1) if reduction == 'mean': @@ -108,7 +114,7 @@ def calc_softmarginloss(input_np, label_np, reduction='mean',): class TestSoftMarginLoss(unittest.TestCase): def test_SoftMarginLoss(self): input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) - types = [np.int32,np.int64,np.float32,np.float64] + types = [np.int32, np.int64, np.float32, np.float64] places = ['cpu'] if paddle.device.is_compiled_with_cuda(): places.append('gpu') @@ -116,22 +122,25 @@ def test_SoftMarginLoss(self): for place in places: for reduction in reductions: for _type in types: - label_np = np.random.randint(0, 2, size=(5, 5)).astype(_type) + label_np = np.random.randint( + 0, 2, size=(5, 5)).astype(_type) label_np[label_np == 0] = -1 static_result = test_static_layer(place, input_np, label_np, reduction) dy_result = test_dygraph_layer(place, input_np, label_np, reduction) - expected = calc_softmarginloss(input_np, label_np, reduction) + expected = calc_softmarginloss(input_np, label_np, + reduction) self.assertTrue(np.allclose(static_result, expected)) self.assertTrue(np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(dy_result, expected)) - static_functional = test_static_functional(place, input_np, - label_np, reduction) + static_functional = test_static_functional( + place, input_np, label_np, reduction) dy_functional = test_dygraph_functional(place, input_np, label_np, reduction) self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue( + np.allclose(static_functional, dy_functional)) self.assertTrue(np.allclose(dy_functional, expected)) def test_SoftMarginLoss_error(self): @@ -152,7 +161,7 @@ def test_SoftMarginLoss_error(self): def soft_margin_loss(input, label): - return np.log(1+np.exp(-label * input)) + return np.log(1 + np.exp(-label * input)) class TestSoftMarginLossOp(OpTest): @@ -161,7 +170,7 @@ def setUp(self): self.op_type = "soft_margin_loss" input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") label_np = np.random.randint(0, 2, self.shape).astype("float64") - label_np[label_np==0]=-1 + label_np[label_np == 0] = -1 output_np = soft_margin_loss(input_np, label_np) self.inputs = {'X': input_np, 'Label': label_np} From f5549be5ed74c9b3d95e29ac4850f56187769c6c Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 7 Jun 2022 00:46:04 +0800 Subject: [PATCH 23/31] 2022-06-07 --- .../tests/unittests/test_soft_margin_loss.py | 117 ++++++++++-------- 1 file changed, 67 insertions(+), 50 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index 086f7ae1dab93..cb63f0443089c 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -19,57 +19,69 @@ def test_static_layer( - place, - input_np, - label_np, - reduction='mean', ): + place, + input_np, + label_np, + reduction='mean', +): paddle.enable_static() prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=input_np.shape, dtype=input_np.dtype) - label = paddle.static.data( - name='label', shape=label_np.shape, dtype=label_np.dtype) + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype=input_np.dtype) + label = paddle.static.data(name='label', + shape=label_np.shape, + dtype=label_np.dtype) sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) res = sm_loss(input, label) exe = paddle.static.Executor(place) static_result = exe.run(prog, - feed={"input": input_np, - "label": label_np}, + feed={ + "input": input_np, + "label": label_np + }, fetch_list=[res]) return static_result def test_static_functional( - place, - input_np, - label_np, - reduction='mean', ): + place, + input_np, + label_np, + reduction='mean', +): paddle.enable_static() prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=input_np.shape, dtype=input_np.dtype) - label = paddle.static.data( - name='label', shape=label_np.shape, dtype=label_np.dtype) - - res = paddle.nn.functional.soft_margin_loss( - input, label, reduction=reduction) + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype=input_np.dtype) + label = paddle.static.data(name='label', + shape=label_np.shape, + dtype=label_np.dtype) + + res = paddle.nn.functional.soft_margin_loss(input, + label, + reduction=reduction) exe = paddle.static.Executor(place) static_result = exe.run(prog, - feed={"input": input_np, - "label": label_np}, + feed={ + "input": input_np, + "label": label_np + }, fetch_list=[res]) return static_result def test_dygraph_layer( - place, - input_np, - label_np, - reduction='mean', ): + place, + input_np, + label_np, + reduction='mean', +): paddle.disable_static() sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction) dy_res = sm_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np)) @@ -79,25 +91,28 @@ def test_dygraph_layer( def test_dygraph_functional( - place, - input_np, - label_np, - reduction='mean', ): + place, + input_np, + label_np, + reduction='mean', +): paddle.disable_static() input = paddle.to_tensor(input_np) label = paddle.to_tensor(label_np) - dy_res = paddle.nn.functional.soft_margin_loss( - input, label, reduction=reduction) + dy_res = paddle.nn.functional.soft_margin_loss(input, + label, + reduction=reduction) dy_result = dy_res.numpy() paddle.enable_static() return dy_result def calc_softmarginloss( - input_np, - label_np, - reduction='mean', ): + input_np, + label_np, + reduction='mean', +): expected = np.log(1 + np.exp(-label_np * input_np)) # expected = np.mean(expected, axis=-1) @@ -112,6 +127,7 @@ def calc_softmarginloss( class TestSoftMarginLoss(unittest.TestCase): + def test_SoftMarginLoss(self): input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) types = [np.int32, np.int64, np.float32, np.float64] @@ -122,8 +138,8 @@ def test_SoftMarginLoss(self): for place in places: for reduction in reductions: for _type in types: - label_np = np.random.randint( - 0, 2, size=(5, 5)).astype(_type) + label_np = np.random.randint(0, 2, + size=(5, 5)).astype(_type) label_np[label_np == 0] = -1 static_result = test_static_layer(place, input_np, label_np, reduction) @@ -136,8 +152,8 @@ def test_SoftMarginLoss(self): self.assertTrue(np.allclose(dy_result, expected)) static_functional = test_static_functional( place, input_np, label_np, reduction) - dy_functional = test_dygraph_functional(place, input_np, - label_np, reduction) + dy_functional = test_dygraph_functional( + place, input_np, label_np, reduction) self.assertTrue(np.allclose(static_functional, expected)) self.assertTrue( np.allclose(static_functional, dy_functional)) @@ -145,18 +161,16 @@ def test_SoftMarginLoss(self): def test_SoftMarginLoss_error(self): paddle.disable_static() - self.assertRaises( - ValueError, - paddle.nn.loss.SoftMarginLoss, - reduction="unsupport reduction") + self.assertRaises(ValueError, + paddle.nn.loss.SoftMarginLoss, + reduction="unsupport reduction") input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') label = paddle.to_tensor([[-1.0, 1.0]], dtype='float32') - self.assertRaises( - ValueError, - paddle.nn.functional.soft_margin_loss, - input=input, - label=label, - reduction="unsupport reduction") + self.assertRaises(ValueError, + paddle.nn.functional.soft_margin_loss, + input=input, + label=label, + reduction="unsupport reduction") paddle.enable_static() @@ -165,6 +179,7 @@ def soft_margin_loss(input, label): class TestSoftMarginLossOp(OpTest): + def setUp(self): self.init_test_case() self.op_type = "soft_margin_loss" @@ -187,11 +202,13 @@ def init_test_case(self): class TestSoftMarginLossOpCase1(OpTest): + def init_test_cast(self): self.shape = [2, 3, 4, 5] class TestSoftMarginLossOpCase2(OpTest): + def init_test_cast(self): self.shape = [2, 3, 20] From ff7d7a0007810107dfe1c2510db616a348017db2 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 7 Jun 2022 21:05:33 +0800 Subject: [PATCH 24/31] 2022-06-07 --- python/paddle/nn/__init__.py | 3 +-- python/paddle/nn/functional/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 7474ddbd4c045..bc23d2b1bd83d 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -191,8 +191,7 @@ def weight_norm(*args): return utils.weight_norm(*args) - -__all__ = [ #noqa +__all__ = [ # noqa 'BatchNorm', 'CELU', 'GroupNorm', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 6e05176b20b7b..d13859a8ca0ae 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -125,8 +125,7 @@ from .sparse_attention import sparse_attention - -__all__ = [ #noqa +__all__ = [ # noqa 'celu', 'conv1d', 'conv1d_transpose', From 3403bf55d9a0f3a04f443ea9f45b723d01b6275e Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 8 Jun 2022 01:18:34 +0800 Subject: [PATCH 25/31] 2022-06-08 --- paddle/fluid/operators/soft_margin_loss_op.cc | 7 +++--- paddle/phi/infermeta/binary.cc | 1 + .../cpu/soft_margin_loss_grad_kernel.cc | 24 +++++++++++-------- .../kernels/cpu/soft_margin_loss_kernel.cc | 18 +++++++++----- .../gpu/soft_margin_loss_grad_kernel.cu | 19 ++++++++------- .../kernels/gpu/soft_margin_loss_kernel.cu | 21 ++++++++-------- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 18 +++++++------- python/paddle/nn/functional/loss.py | 24 ++++++++++--------- python/paddle/nn/layer/loss.py | 4 ++-- 9 files changed, 76 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc index e3a6278572ad8..c2b61078d4d44 100644 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ b/paddle/fluid/operators/soft_margin_loss_op.cc @@ -41,7 +41,8 @@ class SoftMarginLossGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftMarginLossGrad"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "SoftMarginLossGrad"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", + "SoftMarginLossGrad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "SoftMarginLossGrad"); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", @@ -124,13 +125,11 @@ class SoftMarginLossGradOpMaker : public framework::SingleGradOpMaker { } }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(soft_margin_loss, - SoftMarginLossInferShapeFunctor, +DECLARE_INFER_SHAPE_FUNCTOR(soft_margin_loss, SoftMarginLossInferShapeFunctor, PD_INFER_META(phi::SoftMarginLossInferMeta)); REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index f49a1b50db7f0..07bb703f3fd96 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1804,6 +1804,7 @@ void SoftMarginLossInferMeta(const MetaTensor& input, label_dims.size())); bool check = true; + if ((!config.is_runtime) && (phi::product(input_dims) <= 0 || phi::product(label_dims) <= 0)) { check = false; diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc index 80d09ffa75d22..0f7c99450041b 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc @@ -15,18 +15,18 @@ #include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" #include + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" - namespace phi { template void SoftMarginLossGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - const DenseTensor& out_grad, - DenseTensor* input_grad) { + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { auto dx_data = dev_ctx.template Alloc(input_grad); auto dout_data = out_grad.data(); auto x_data = input.data(); @@ -36,12 +36,16 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, // dx = dout * (-label * exp(-label * x))/(1 + exp(-label * x )) for (int i = 0; i < x_numel; ++i) { - dx_data[i] = - dout_data[i] * ((- label_data[i]* std::exp(-label_data[i]*x_data[i] )) / - (static_cast(1) + std::exp(-label_data[i]*x_data[i]))); + dx_data[i] = dout_data[i] * + ((-label_data[i] * std::exp(-label_data[i] * x_data[i])) / + (static_cast(1) + std::exp(-label_data[i] * x_data[i]))); } } } // namespace phi -PD_REGISTER_KERNEL( - soft_margin_loss_grad, CPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} +PD_REGISTER_KERNEL(soft_margin_loss_grad, + CPU, + ALL_LAYOUT, + phi::SoftMarginLossGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc index 181e8ae5911af..078320eeb3967 100644 --- a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/soft_margin_loss_kernel.h" #include + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -22,9 +23,9 @@ namespace phi { template void SoftMarginLossKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - DenseTensor* out) { + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { auto x_data = input.data(); auto label_data = label.data(); auto out_data = dev_ctx.template Alloc(out); @@ -32,9 +33,14 @@ void SoftMarginLossKernel(const Context& dev_ctx, // out = ln(1+exp(-label * x)/(x_numel) for (int64_t i = 0; i < x_numel; ++i) { - out_data[i] =std::log(static_cast(1) + std::exp(-label_data[i]* x_data[i])); + out_data[i] = + std::log(static_cast(1) + std::exp(-label_data[i] * x_data[i])); } } } // namespace phi -PD_REGISTER_KERNEL( - soft_margin_loss, CPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} +PD_REGISTER_KERNEL(soft_margin_loss, + CPU, + ALL_LAYOUT, + phi::SoftMarginLossKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu index f774b0c5d7598..f26df06d4b900 100644 --- a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" - #include #include @@ -21,6 +19,7 @@ #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" namespace phi { @@ -42,10 +41,10 @@ struct SoftMarginLossGradFunctor { template void SoftMarginLossGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - const DenseTensor& out_grad, - DenseTensor* input_grad) { + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { dev_ctx.template Alloc(input_grad); std::vector ins = {&input, &label, &out_grad}; std::vector outs = {input_grad}; @@ -55,5 +54,9 @@ void SoftMarginLossGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - soft_margin_loss_grad, GPU, ALL_LAYOUT, phi::SoftMarginLossGradKernel, float, double) {} +PD_REGISTER_KERNEL(soft_margin_loss_grad, + GPU, + ALL_LAYOUT, + phi::SoftMarginLossGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu index d0d087eb87759..041b045780579 100644 --- a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/soft_margin_loss_kernel.h" - #include #include @@ -22,6 +20,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" +#include "paddle/phi/kernels/soft_margin_loss_kernel.h" namespace phi { @@ -29,9 +28,7 @@ template struct SoftMarginLossFunctor { T one; - HOSTDEVICE inline SoftMarginLossFunctor() { - one = static_cast(1.0f); - } + HOSTDEVICE inline SoftMarginLossFunctor() { one = static_cast(1.0f); } HOSTDEVICE inline T operator()(const T x, const T label) const { T term1 = std::log(one + std::exp(-label * x)); @@ -41,9 +38,9 @@ struct SoftMarginLossFunctor { template void SoftMarginLossKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - DenseTensor* out) { + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { dev_ctx.template Alloc(out); std::vector ins = {&input, &label}; std::vector outs = {out}; @@ -53,5 +50,9 @@ void SoftMarginLossKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - soft_margin_loss, GPU, ALL_LAYOUT, phi::SoftMarginLossKernel, float, double) {} +PD_REGISTER_KERNEL(soft_margin_loss, + GPU, + ALL_LAYOUT, + phi::SoftMarginLossKernel, + float, + double) {} diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc index e358795f9f5bc..7e84f774d8e6f 100644 --- a/paddle/phi/ops/compat/soft_margin_loss_sig.cc +++ b/paddle/phi/ops/compat/soft_margin_loss_sig.cc @@ -14,13 +14,13 @@ #include "paddle/phi/core/compat/op_utils.h" -namespace phi{ -KernelSignature SoftMarginLossGradOpArgumentMapping(const ArgumentMappingContext& ctx){ -return KernelSignature("soft_margin_loss_grad", - {"X","Label","Out@GRAD"}, - {}, - {"X@GRAD"}); - } -}// namespace phi +namespace phi { +KernelSignature SoftMarginLossGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "soft_margin_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"}); +} +} // namespace phi -PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad,phi::SoftMarginLossGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad, + phi::SoftMarginLossGradOpArgumentMapping); diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index a664deeb9bf2b..1cd49477c0a83 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2807,9 +2807,11 @@ def soft_margin_loss(input, label, reduction='mean', name=None): output = paddle.nn.functional.soft_margin_loss(input, label) shape = (5, 5) - input = paddle.uniform(shape, 0, 2, dtype='float32') - label = paddle.randint(shape, 0, 2, dtype='float32') + input_np = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label_np = np.random.randint(0, 2, size=shape).astype(int64) label[label==0]=-1 + input = paddle.to_tensor(input_np) + label = paddle.to_tensor(label_np) output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') """ if reduction not in ['sum', 'mean', 'none']: @@ -2828,8 +2830,9 @@ def soft_margin_loss(input, label, reduction='mean', name=None): else: return out - fluid.data_feeder.check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'soft_margin_loss') + fluid.data_feeder.check_variable_and_dtype(input, 'input', + ['float32', 'float64'], + 'soft_margin_loss') fluid.data_feeder.check_variable_and_dtype( label, 'label', ['int32', 'int64', 'float32', 'float64'], 'soft_margin_loss') @@ -2839,13 +2842,12 @@ def soft_margin_loss(input, label, reduction='mean', name=None): sub_name = name helper = LayerHelper("soft_margin_loss", name=sub_name) out = helper.create_variable_for_type_inference(dtype=input.dtype) - helper.append_op( - type='soft_margin_loss', - inputs={ - 'X': [input], - 'Label': [label], - }, - outputs={'Out': [out]}) + helper.append_op(type='soft_margin_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) if reduction == 'sum': return paddle.sum(out, name=name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 70478788a1e23..3bf7434b2a587 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1361,8 +1361,8 @@ class SoftMarginLoss(Layer): input = paddle.uniform(shape, 0, 2, dtype='float32') label = paddle.randint(shape, 0, 2, dtype='float32') label[label==0]=-1 - soft_margin_loss = paddle.nn.SoftMarginLoss() - output = soft_margin_loss(input, label,reduction='none') + soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none') + output = soft_margin_loss(input, label,) """ def __init__(self, reduction='mean', name=None): From 7bdf5d3dd933aa553ba86a049b62ea175fe6db4b Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 8 Jun 2022 10:15:24 +0800 Subject: [PATCH 26/31] 2022-06-08_V2 --- python/paddle/nn/functional/loss.py | 19 ++++++++++--------- python/paddle/nn/layer/loss.py | 15 +++++++++------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 1cd49477c0a83..a9aac33304294 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2767,7 +2767,7 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): def soft_margin_loss(input, label, reduction='mean', name=None): """ - This APIs measures the soft margin loss between input predictions ``input`` + The API measures the soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: .. math:: @@ -2777,14 +2777,14 @@ def soft_margin_loss(input, label, reduction='mean', name=None): input (Tensor): The input predications tensor with shape: [N, *], N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf. - Available dtype is int32, int64, float32, float64. + Available dtype is float32, float64. label (Tensor): The target labels tensor with the same shape as ``input``. The target labels which values should be numbers -1 or 1. - Available dtype is float32, float64. + Available dtype is int32, int64, float32, float64. reduction (str, optional): Indicate how to average the loss by batch_size, - the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. @@ -2802,17 +2802,18 @@ def soft_margin_loss(input, label, reduction='mean', name=None): .. code-block:: python import paddle + import numpy as np + input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32') label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32') output = paddle.nn.functional.soft_margin_loss(input, label) - shape = (5, 5) - input_np = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) - label_np = np.random.randint(0, 2, size=shape).astype(int64) - label[label==0]=-1 + input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) + label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.int64) + label_np[label_np==0]=-1 input = paddle.to_tensor(input_np) label = paddle.to_tensor(label_np) - output = paddle.nn.functional.soft_margin_loss(input, label,reduction='none') + output = paddle.nn.functional.soft_margin_loss(input, label, reduction='none') """ if reduction not in ['sum', 'mean', 'none']: raise ValueError( diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3bf7434b2a587..1dc73dc57645d 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1323,7 +1323,7 @@ class SoftMarginLoss(Layer): Parameters: reduction (str, optional): Indicate how to average the loss by batch_size, - the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. @@ -1352,17 +1352,20 @@ class SoftMarginLoss(Layer): .. code-block:: python import paddle + import numpy as np + input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32') label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32') soft_margin_loss = paddle.nn.SoftMarginLoss() output = soft_margin_loss(input, label) - shape = (5, 5) - input = paddle.uniform(shape, 0, 2, dtype='float32') - label = paddle.randint(shape, 0, 2, dtype='float32') - label[label==0]=-1 + input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) + label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.int64) + label_np[label_np==0]=-1 + input = paddle.to_tensor(input_np) + label = paddle.to_tensor(label_np) soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none') - output = soft_margin_loss(input, label,) + output = soft_margin_loss(input, label) """ def __init__(self, reduction='mean', name=None): From 91111ee62acdc3d7be9f4411028d494a1311ee0d Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 17 Jun 2022 09:24:38 +0800 Subject: [PATCH 27/31] 2022-06-17-code_style --- python/paddle/nn/layer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 663b6a9f2a1d2..161e76b2b7c8d 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1508,7 +1508,7 @@ def forward(self, input, positive, negative): reduction=self.reduction, name=self.name) - + class SoftMarginLoss(Layer): r""" Creates a criterion that measures a two-class From c558d148c3145123863e592daeb5b930d1f8a4ea Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 17 Jun 2022 15:48:37 +0800 Subject: [PATCH 28/31] Modify python --- paddle/fluid/operators/soft_margin_loss_op.cc | 140 ------------------ paddle/phi/infermeta/binary.cc | 40 ----- paddle/phi/infermeta/binary.h | 5 - .../cpu/soft_margin_loss_grad_kernel.cc | 51 ------- .../kernels/cpu/soft_margin_loss_kernel.cc | 46 ------ .../gpu/soft_margin_loss_grad_kernel.cu | 62 -------- .../kernels/gpu/soft_margin_loss_kernel.cu | 58 -------- .../kernels/soft_margin_loss_grad_kernel.h | 28 ---- paddle/phi/kernels/soft_margin_loss_kernel.h | 27 ---- paddle/phi/ops/compat/soft_margin_loss_sig.cc | 26 ---- python/paddle/nn/functional/loss.py | 39 ++--- 11 files changed, 13 insertions(+), 509 deletions(-) delete mode 100644 paddle/fluid/operators/soft_margin_loss_op.cc delete mode 100644 paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc delete mode 100644 paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc delete mode 100644 paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu delete mode 100644 paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu delete mode 100644 paddle/phi/kernels/soft_margin_loss_grad_kernel.h delete mode 100644 paddle/phi/kernels/soft_margin_loss_kernel.h delete mode 100644 paddle/phi/ops/compat/soft_margin_loss_sig.cc diff --git a/paddle/fluid/operators/soft_margin_loss_op.cc b/paddle/fluid/operators/soft_margin_loss_op.cc deleted file mode 100644 index c2b61078d4d44..0000000000000 --- a/paddle/fluid/operators/soft_margin_loss_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -class SoftMarginLossOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class SoftMarginLossGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftMarginLossGrad"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", - "SoftMarginLossGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "SoftMarginLossGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - framework::GradVarName("X"), "SoftMarginLossGrad"); - - auto x_dims = ctx->GetInputDim("X"); - auto labels_dims = ctx->GetInputDim("Label"); - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - - bool check = true; - if ((!ctx->IsRuntime()) && - (phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(x_dims, labels_dims, - platform::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same " - "shape. But received: the shape of Input(X) is " - "[%s], the shape of Input(Label) is [%s].", - x_dims, labels_dims)); - - PADDLE_ENFORCE_EQ(x_dims, dout_dims, - platform::errors::InvalidArgument( - "Input(X) and Input(Out@Grad) shall have the same " - "shape. But received: the shape of Input(X) is " - "[%s], the shape of Input(Out@Grad) is [%s].", - x_dims, dout_dims)); - } - - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("X", framework::GradVarName("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class SoftMarginLossOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor, default Tensor), the input is a tensor of logits" - "computed by the previous operator. "); - AddInput("Label", - "(Tensor, default Tensor), have same shape with input" - "label should be -1 or 1."); - AddOutput("Out", - "(Tensor, default Tensor), have same shape with" - "input"); - AddComment(R"DOC( -SoftMarginLoss operator. -This measures the element-wise probability error in classification tasks -in which each class is independent. -The logitstic loss is given as follows: - $$loss = log(1+exp(-Label * X))$$ -)DOC"); - } -}; - -template -class SoftMarginLossGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("soft_margin_loss_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("Label", this->Input("Label")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(soft_margin_loss, SoftMarginLossInferShapeFunctor, - PD_INFER_META(phi::SoftMarginLossInferMeta)); - -REGISTER_OPERATOR(soft_margin_loss, ops::SoftMarginLossOp, - ops::SoftMarginLossOpMaker, - ops::SoftMarginLossGradOpMaker, - ops::SoftMarginLossGradOpMaker, - SoftMarginLossInferShapeFunctor); -REGISTER_OPERATOR(soft_margin_loss_grad, ops::SoftMarginLossGradOp); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 42b020b026334..add27da56b59a 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1838,46 +1838,6 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, out->share_lod(x); } -void SoftMarginLossInferMeta(const MetaTensor& input, - const MetaTensor& label, - MetaTensor* out, - MetaConfig config) { - auto input_dims = input.dims(); - auto label_dims = label.dims(); - - int rank = input_dims.size(); - PADDLE_ENFORCE_EQ(rank, - label_dims.size(), - phi::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same rank." - "But received: the rank of Input(X) is [%d], " - "the rank of Input(Label) is [%d].", - rank, - label_dims.size())); - - bool check = true; - - if ((!config.is_runtime) && - (phi::product(input_dims) <= 0 || phi::product(label_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(input_dims, - label_dims, - phi::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same " - "shape. But received: the shape of Input(X) is " - "[%s], the shape of Input(Label) is [%s].", - input_dims, - label_dims)); - } - - out->set_dims(input_dims); - out->set_dtype(input.dtype()); - out->share_lod(input); -} - void TakeAlongAxisInferMeta(const MetaTensor& x, const MetaTensor& index, int axis, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index a42a9e4c29a52..9709edf63ccc0 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -270,11 +270,6 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void SoftMarginLossInferMeta(const MetaTensor& input, - const MetaTensor& label, - MetaTensor* out, - MetaConfig config = MetaConfig()); - void TakeAlongAxisInferMeta(const MetaTensor& x, const MetaTensor& index, int axis, diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc deleted file mode 100644 index 0f7c99450041b..0000000000000 --- a/paddle/phi/kernels/cpu/soft_margin_loss_grad_kernel.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" - -#include - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void SoftMarginLossGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - const DenseTensor& out_grad, - DenseTensor* input_grad) { - auto dx_data = dev_ctx.template Alloc(input_grad); - auto dout_data = out_grad.data(); - auto x_data = input.data(); - auto label_data = label.data(); - - int x_numel = input.numel(); - - // dx = dout * (-label * exp(-label * x))/(1 + exp(-label * x )) - for (int i = 0; i < x_numel; ++i) { - dx_data[i] = dout_data[i] * - ((-label_data[i] * std::exp(-label_data[i] * x_data[i])) / - (static_cast(1) + std::exp(-label_data[i] * x_data[i]))); - } -} -} // namespace phi - -PD_REGISTER_KERNEL(soft_margin_loss_grad, - CPU, - ALL_LAYOUT, - phi::SoftMarginLossGradKernel, - float, - double) {} diff --git a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc b/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc deleted file mode 100644 index 078320eeb3967..0000000000000 --- a/paddle/phi/kernels/cpu/soft_margin_loss_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/soft_margin_loss_kernel.h" - -#include - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void SoftMarginLossKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - DenseTensor* out) { - auto x_data = input.data(); - auto label_data = label.data(); - auto out_data = dev_ctx.template Alloc(out); - auto x_numel = input.numel(); - - // out = ln(1+exp(-label * x)/(x_numel) - for (int64_t i = 0; i < x_numel; ++i) { - out_data[i] = - std::log(static_cast(1) + std::exp(-label_data[i] * x_data[i])); - } -} -} // namespace phi -PD_REGISTER_KERNEL(soft_margin_loss, - CPU, - ALL_LAYOUT, - phi::SoftMarginLossKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu deleted file mode 100644 index f26df06d4b900..0000000000000 --- a/paddle/phi/kernels/gpu/soft_margin_loss_grad_kernel.cu +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/soft_margin_loss_grad_kernel.h" - -namespace phi { - -template -struct SoftMarginLossGradFunctor { - T one; - T eps; - - HOSTDEVICE inline SoftMarginLossGradFunctor() { - one = static_cast(1.0f); - eps = static_cast(1e-12); - } - - HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { - T term1 = (one + std::exp(-label * x)); - return (dout * (-label * std::exp(-label * x)) / term1); - } -}; - -template -void SoftMarginLossGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - const DenseTensor& out_grad, - DenseTensor* input_grad) { - dev_ctx.template Alloc(input_grad); - std::vector ins = {&input, &label, &out_grad}; - std::vector outs = {input_grad}; - auto functor = SoftMarginLossGradFunctor(); - phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); -} - -} // namespace phi - -PD_REGISTER_KERNEL(soft_margin_loss_grad, - GPU, - ALL_LAYOUT, - phi::SoftMarginLossGradKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu b/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu deleted file mode 100644 index 041b045780579..0000000000000 --- a/paddle/phi/kernels/gpu/soft_margin_loss_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/primitive/functor_primitives.h" -#include "paddle/phi/kernels/soft_margin_loss_kernel.h" - -namespace phi { - -template -struct SoftMarginLossFunctor { - T one; - - HOSTDEVICE inline SoftMarginLossFunctor() { one = static_cast(1.0f); } - - HOSTDEVICE inline T operator()(const T x, const T label) const { - T term1 = std::log(one + std::exp(-label * x)); - return term1; - } -}; - -template -void SoftMarginLossKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - DenseTensor* out) { - dev_ctx.template Alloc(out); - std::vector ins = {&input, &label}; - std::vector outs = {out}; - auto functor = SoftMarginLossFunctor(); - phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); -} - -} // namespace phi - -PD_REGISTER_KERNEL(soft_margin_loss, - GPU, - ALL_LAYOUT, - phi::SoftMarginLossKernel, - float, - double) {} diff --git a/paddle/phi/kernels/soft_margin_loss_grad_kernel.h b/paddle/phi/kernels/soft_margin_loss_grad_kernel.h deleted file mode 100644 index 6008360b59270..0000000000000 --- a/paddle/phi/kernels/soft_margin_loss_grad_kernel.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void SoftMarginLossGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - const DenseTensor& out_grad, - DenseTensor* input_grad); - -} // namespace phi diff --git a/paddle/phi/kernels/soft_margin_loss_kernel.h b/paddle/phi/kernels/soft_margin_loss_kernel.h deleted file mode 100644 index 171add14f328c..0000000000000 --- a/paddle/phi/kernels/soft_margin_loss_kernel.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void SoftMarginLossKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& label, - DenseTensor* out); - -} // namespace phi diff --git a/paddle/phi/ops/compat/soft_margin_loss_sig.cc b/paddle/phi/ops/compat/soft_margin_loss_sig.cc deleted file mode 100644 index 7e84f774d8e6f..0000000000000 --- a/paddle/phi/ops/compat/soft_margin_loss_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { -KernelSignature SoftMarginLossGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "soft_margin_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"}); -} -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(soft_margin_loss_grad, - phi::SoftMarginLossGradOpArgumentMapping); diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 7de4fcee1f374..88b6eda78d0a3 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3177,35 +3177,22 @@ def soft_margin_loss(input, label, reduction='mean', name=None): "The value of 'reduction' in soft_margin_loss should be 'sum', " "'mean' or 'none', but received %s, which is not allowed." % reduction) - if _non_static_mode(): - label = _C_ops.cast(label, 'in_dtype', label.dtype, 'out_dtype', - input.dtype) - out = _C_ops.soft_margin_loss(input, label) - if reduction == 'sum': - return _C_ops.reduce_sum(out, "reduce_all", True) - elif reduction == 'mean': - return _C_ops.mean(out) - else: - return out - fluid.data_feeder.check_variable_and_dtype(input, 'input', - ['float32', 'float64'], - 'soft_margin_loss') - fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['int32', 'int64', 'float32', 'float64'], - 'soft_margin_loss') + if not _non_static_mode(): + fluid.data_feeder.check_variable_and_dtype(input, 'input', + ['float32', 'float64'], + 'soft_margin_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['int32', 'int64', 'float32', 'float64'], + 'soft_margin_loss') - label = fluid.layers.cast(label, input.dtype) + label = fluid.layers.cast(label, input.dtype) - sub_name = name - helper = LayerHelper("soft_margin_loss", name=sub_name) - out = helper.create_variable_for_type_inference(dtype=input.dtype) - helper.append_op(type='soft_margin_loss', - inputs={ - 'X': [input], - 'Label': [label], - }, - outputs={'Out': [out]}) + if not (input.shape == label.shape): + raise ValueError("input's shape must equal to " + "label's shape") + + out = paddle.log(1 + paddle.exp(-label * input)) if reduction == 'sum': return paddle.sum(out, name=name) From cc269a919ef93581f4e890d5e4220208d8aafa0e Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 20 Jun 2022 09:40:45 +0800 Subject: [PATCH 29/31] 2022-06-20 --- .../tests/unittests/test_soft_margin_loss.py | 40 ------------------- python/paddle/nn/functional/loss.py | 3 +- 2 files changed, 1 insertion(+), 42 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py index cb63f0443089c..3cf61c556e72b 100644 --- a/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_soft_margin_loss.py @@ -15,7 +15,6 @@ import paddle import numpy as np import unittest -from op_test import OpTest def test_static_layer( @@ -174,44 +173,5 @@ def test_SoftMarginLoss_error(self): paddle.enable_static() -def soft_margin_loss(input, label): - return np.log(1 + np.exp(-label * input)) - - -class TestSoftMarginLossOp(OpTest): - - def setUp(self): - self.init_test_case() - self.op_type = "soft_margin_loss" - input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") - label_np = np.random.randint(0, 2, self.shape).astype("float64") - label_np[label_np == 0] = -1 - output_np = soft_margin_loss(input_np, label_np) - - self.inputs = {'X': input_np, 'Label': label_np} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - def init_test_case(self): - self.shape = [10, 10] - - -class TestSoftMarginLossOpCase1(OpTest): - - def init_test_cast(self): - self.shape = [2, 3, 4, 5] - - -class TestSoftMarginLossOpCase2(OpTest): - - def init_test_cast(self): - self.shape = [2, 3, 20] - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 88b6eda78d0a3..67a1a3947a645 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3186,12 +3186,11 @@ def soft_margin_loss(input, label, reduction='mean', name=None): label, 'label', ['int32', 'int64', 'float32', 'float64'], 'soft_margin_loss') - label = fluid.layers.cast(label, input.dtype) - if not (input.shape == label.shape): raise ValueError("input's shape must equal to " "label's shape") + label = fluid.layers.cast(label, input.dtype) out = paddle.log(1 + paddle.exp(-label * input)) if reduction == 'sum': From 038f50e501b9127ec1b8d4c64b9728d19459aeed Mon Sep 17 00:00:00 2001 From: Ligoml <39876205+Ligoml@users.noreply.github.com> Date: Wed, 20 Jul 2022 11:34:41 +0800 Subject: [PATCH 30/31] for --- python/paddle/nn/layer/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3812194b3ab05..3f5126112dd75 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1727,8 +1727,7 @@ class SoftMarginLoss(Layer): same as ``input`` , else the shape of output is [1]. Returns: - - A callable object of SoftMarginLoss + A callable object of SoftMarginLoss. Examples: .. code-block:: python From 0b0e00a459719fc236e7da74214a6478f7a6a98d Mon Sep 17 00:00:00 2001 From: Ligoml <39876205+Ligoml@users.noreply.github.com> Date: Wed, 20 Jul 2022 11:35:35 +0800 Subject: [PATCH 31/31] for CI;test=document_fix --- python/paddle/nn/layer/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3f5126112dd75..54ef832d73179 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1695,8 +1695,7 @@ def forward(self, input, positive, negative): class SoftMarginLoss(Layer): r""" - Creates a criterion that measures a two-class - soft margin loss between input predictions ``input`` + Creates a criterion that measures a two-class soft margin loss between input predictions ``input`` and target labels ``label`` . It can be described as: .. math::