Skip to content

Commit

Permalink
[Phi] migrate clip_by_norm to phi
Browse files Browse the repository at this point in the history
  • Loading branch information
affectionlu committed Jul 20, 2022
1 parent 1047cb1 commit fd8329d
Show file tree
Hide file tree
Showing 18 changed files with 429 additions and 203 deletions.
14 changes: 10 additions & 4 deletions paddle/fluid/operators/clip_by_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/clip_by_norm_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(clip_by_norm,
ClipByNormInferShapeFunctor,
PD_INFER_META(phi::ClipByNormInferMeta));

REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm,
ops::ClipByNormOp,
ops::ClipByNormOpMaker);

REGISTER_OP_CPU_KERNEL(clip_by_norm,
ops::ClipByNormKernel<phi::CPUContext, float>);
ops::ClipByNormOpMaker,
ClipByNormInferShapeFunctor);
122 changes: 0 additions & 122 deletions paddle/fluid/operators/clip_by_norm_op.cu

This file was deleted.

70 changes: 0 additions & 70 deletions paddle/fluid/operators/clip_by_norm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,76 +30,6 @@ template <typename T,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
class ClipByNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<T>("max_norm");
auto in_var = context.InputVar("X");

Tensor* output = nullptr;
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");

output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
} else if (in_var->IsType<phi::SelectedRows>()) {
auto* x = context.Input<phi::SelectedRows>("X");

// merge ids in selected rows first
math::scatter::MergeAdd<DeviceContext, T> merge_func;
phi::SelectedRows* merged_input =
const_cast<framework::Scope&>(context.scope())
.Var()
->GetMutable<phi::SelectedRows>();
merge_func(
context.template device_context<DeviceContext>(), *x, merged_input);
input = &(merged_input->value());

phi::SelectedRows* output_selected_rows =
context.Output<phi::SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value();
output->Resize(merged_input->value().dims());
output->mutable_data<T>(context.GetPlace());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input variable type, only support LodTensor and "
"SelectedRows types, but got type is %s.",
framework::ToTypeName(in_var->Type())));
}

PADDLE_ENFORCE_NOT_NULL(input,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));

auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output);
auto x_norm = x.square().sum().sqrt();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();

auto temp = (x_norm <= max_norm).template cast<T>();
auto epsilon =
((x_norm <= static_cast<T>(1e-30)).all().template cast<T>()) *
static_cast<T>(1e-6);

auto scaling =
temp + (static_cast<T>(1) - temp) * max_norm / (x_norm + epsilon);
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel());
if (context.GetPlace() == platform::CPUPlace()) {
out.device(place) =
x * scaling.reshape(one_dim).eval().broadcast(m_dsize);
} else {
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
}
}
};

class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down
37 changes: 32 additions & 5 deletions paddle/fluid/operators/dgc_clip_by_norm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@ limitations under the License. */
#pragma once

#include "paddle/fluid/operators/clip_by_norm_op.h"
#include "paddle/phi/kernels/clip_by_norm_kernel.h"
#include "paddle/phi/kernels/selected_rows/clip_by_norm_kernel.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
class DGCClipByNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
void Compute(const framework::ExecutionContext& ctx) const override {
auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) < 0) {
return;
}

auto current_step_tensor = context.Input<framework::Tensor>("current_step");
auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>();

VLOG(10) << "current_step:" << *current_step
Expand All @@ -41,7 +45,30 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
return;
}

return ClipByNormKernel<DeviceContext, T>::Compute(context);
auto in_var = ctx.InputVar("X");
auto max_norm = ctx.Attr<float>("max_norm");
auto& dev_ctx = ctx.device_context<DeviceContext>();

if (in_var->IsType<framework::LoDTensor>()) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
return phi::ClipByNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
max_norm,
y);
} else if (in_var->IsType<phi::SelectedRows>()) {
auto* x = ctx.Input<phi::SelectedRows>("X");
phi::SelectedRows* output_selected_rows =
ctx.Output<phi::SelectedRows>("Out");
return phi::sr::ClipByNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
max_norm,
output_selected_rows);
}
};
};

Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@
func : clip
backward : clip_grad

- api : clip_by_norm
args : (Tensor x, float max_norm)
output : Tensor(out)
infer_meta :
func : ClipByNormInferMeta
kernel :
func : clip_by_norm

- api : complex
args : (Tensor x, Tensor y)
output : Tensor
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,18 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out) {
PADDLE_ENFORCE_GT(
max_norm,
0,
phi::errors::InvalidArgument("max_norm should be greater than 0. "
"Received max_norm is %f.",
max_norm));
out->set_dims(x.dims());
out->set_dtype(x.dtype());
out->share_lod(x);
}

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumInferMeta(const MetaTensor& x,
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/kernels/clip_by_norm_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ClipByNormKernel(const Context& dev_ctx,
const DenseTensor& x,
float max_norm,
DenseTensor* out);

} // namespace phi
34 changes: 34 additions & 0 deletions paddle/phi/kernels/cpu/clip_by_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_by_norm_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void ClipByNormKernel(const Context& dev_ctx,
const DenseTensor& in,
float max_norm,
DenseTensor* output) {
return ClipByNormFunctor<T, Context>(dev_ctx, in, max_norm, output);
}

} // namespace phi

PD_REGISTER_KERNEL(
clip_by_norm, CPU, ALL_LAYOUT, phi::ClipByNormKernel, float) {}
Loading

0 comments on commit fd8329d

Please sign in to comment.