Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi] Migrate gelu/log_softmax/prelu op kernel and infershape #40393

Merged
merged 24 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dcb3348
add gelu
shentanyue Mar 8, 2022
3a6b3a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
liyancas Mar 8, 2022
5d29f87
fix gelu
shentanyue Mar 9, 2022
aa3e4cf
add log_softmax
shentanyue Mar 10, 2022
ba41947
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 10, 2022
8ac6c62
add prelu kernel and prelu/gelu/logsoftmax infershape
shentanyue Mar 10, 2022
3b524a2
fix
shentanyue Mar 11, 2022
8053ea4
fix
shentanyue Mar 11, 2022
8be57c9
fix
shentanyue Mar 12, 2022
8d309d7
fix
shentanyue Mar 14, 2022
991ce44
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 14, 2022
ddc6d28
fix ci
shentanyue Mar 14, 2022
8f4a1b4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 14, 2022
c62be31
log_softmax rewrite
shentanyue Mar 14, 2022
cea4826
fix
shentanyue Mar 15, 2022
5c60299
fix
shentanyue Mar 15, 2022
5c27111
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 15, 2022
266aa70
fix conflict
shentanyue Mar 15, 2022
ff12a05
fix compile error
shentanyue Mar 15, 2022
f505864
fix comment
shentanyue Mar 16, 2022
e677d26
fix
shentanyue Mar 17, 2022
dbd470a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 17, 2022
acd4196
ci_fix
shentanyue Mar 18, 2022
83869c8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shentanyue Mar 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ USE_OP(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);

namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <unordered_set>

#include <boost/logic/tribool.hpp>

#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"

Expand All @@ -27,10 +28,11 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);

namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::EventRole::kUniqueOp);

reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
}

// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
Expand Down Expand Up @@ -253,10 +250,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}

// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
Expand Down
32 changes: 9 additions & 23 deletions paddle/fluid/operators/gelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ limitations under the License. */

#include <memory>
#include <string>
#include <unordered_map>

#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -29,18 +30,6 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of GeluOp should not be null.", "Out"));

ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -156,13 +145,10 @@ class GeluGradOpMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(gelu, GeluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(gelu, ops::GeluOp, ops::GeluOpMaker,
ops::GeluGradOpMaker<paddle::framework::OpDesc>,
ops::GeluGradOpMaker<paddle::imperative::OpBase>);
ops::GeluGradOpMaker<paddle::imperative::OpBase>,
GeluInferShapeFunctor);
REGISTER_OPERATOR(gelu_grad, ops::GeluGradOp);
REGISTER_OP_CPU_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CPUDeviceContext, double>);
Loading