Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added fmax and fmin operators #37826

Merged
merged 16 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,21 @@ struct MinFunctor {
}
};

// Fmax
template <typename T>
struct FMaxFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
return std::fmax(a, b);
}
};

// Fmin
template <typename T>
struct FMinFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
return std::fmin(a, b);
}
};

} // namespace operators
} // namespace paddle
46 changes: 46 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};

template <typename T>
class ElementwiseFMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_fmax_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -103,3 +120,32 @@ REGISTER_OP_VERSION(elementwise_max)
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_max.",
1.0f));

REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseFMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseFMaxGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad);

REGISTER_OP_CPU_KERNEL(
elementwise_fmax,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

REGISTER_OP_VERSION(elementwise_fmax)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段不需要,这是用于 op 兼容性升级的。

.AddCheckpoint(
R"ROC(Register elementwise_fmax for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_fmax.",
1.0f));
13 changes: 13 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,16 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);

REGISTER_OP_CUDA_KERNEL(
elementwise_fmax,
iclementine marked this conversation as resolved.
Show resolved Hide resolved
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
49 changes: 49 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_max_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class ElementwiseFMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");

z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FMaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
FMaxFunctor<T>(), z);
}
};


template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
Expand Down Expand Up @@ -68,5 +84,38 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
};

template <typename T>
struct FMaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x >= y) || isnan(y));
}
};

template <typename T>
struct FMaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T> (!((x >= y) || isnan(y)));
}
};

template <typename DeviceContext, typename T>
class ElementwiseFMaxGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMaxGradDx<T>, FMaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx<T>(), FMaxGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
46 changes: 46 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_min_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};

template <typename T>
class ElementwiseFMinGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_fmin_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -103,3 +120,32 @@ REGISTER_OP_VERSION(elementwise_min)
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_min.",
1.0f));

REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseFMinGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseFMinGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad);

REGISTER_OP_CPU_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

REGISTER_OP_VERSION(elementwise_fmin)
.AddCheckpoint(
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_fmin.",
1.0f));
13 changes: 13 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_min_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);

REGISTER_OP_CUDA_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
48 changes: 48 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_min_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class ElementwiseFMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");

z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MinFunctor<T>(), z);
}
};

template <typename T>
struct MinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
Expand Down Expand Up @@ -68,5 +83,38 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
};

template <typename T>
struct FMinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x <= y) || isnan(y));
}
};

template <typename T>
struct FMinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T> (!((x <= y) || isnan(y)));
}
};

template <typename DeviceContext, typename T>
class ElementwiseFMinGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMinGradDx<T>, FMinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx<T>(), FMinGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@
from .tensor.math import deg2rad # noqa: F401
from .tensor.math import diff # noqa: F401
from .tensor.math import angle # noqa: F401
from .tensor.math import fmax # noqa: F401
from .tensor.math import fmin # noqa: F401

from .tensor.random import multinomial # noqa: F401
from .tensor.random import standard_normal # noqa: F401
Expand Down Expand Up @@ -437,6 +439,8 @@
'pow',
'zeros_like',
'maximum',
'fmax',
'fmin',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加入此文件的all 列表

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经加了呀

'topk',
'index_select',
'CPUPlace',
Expand Down
Loading