Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Add regularation and Nesterov for mergerd_momentum op #37527

Merged
merged 7 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion paddle/fluid/operators/optimizers/merged_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
Expand All @@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum or not.")
.SetDefault(false);
AddAttr<std::vector<std::string>>(
"regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault({});
AddAttr<std::vector<float>>("regularization_coeff",
"(float) regularization_coeff")
.SetDefault({});
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
Expand Down
174 changes: 148 additions & 26 deletions paddle/fluid/operators/optimizers/merged_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"

Expand Down Expand Up @@ -85,33 +86,43 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(params[i], params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}

auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), n));

auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(), n));

auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
Expand All @@ -126,12 +137,18 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
platform::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(), n));
PADDLE_ENFORCE_EQ(
n, master_params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
Expand All @@ -147,20 +164,61 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
master_params_out.clear();
}

auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n, lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(), n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n, regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(), n));
PADDLE_ENFORCE_EQ(
n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(), n));
}

VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();

using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
Expand All @@ -182,14 +240,78 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}

if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
for (size_t idx = 0; idx < n; idx++) {
RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;

#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
MPType regularization_coeff = static_cast<MPType>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];

const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
UseNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
NoNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems too many duplicate codes with momentum_op.h. Maybe we can use a common function defined in momentum_op.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these codes have reused the DenseMomentumFunctor function in momentum_op.h.

VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);

if (list == nullptr) {
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
Expand Down
Loading