Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi pricison for lars op and lars optimizer #33280

Merged
merged 2 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
Expand All @@ -51,6 +56,15 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);

AddComment(R"DOC(
Lars Momentum Optimizer.
Expand Down
119 changes: 89 additions & 30 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,105 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, const T lars_coeff,
const T lars_weight_decay, const T* p_norm,
const T* g_norm, T* p_out, T* v_out,
const T epsilon) {
T lr = learning_rate[0];
T local_lr = learning_rate[0];
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* p, const T* g, const MT* v,
const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
const MT lars_coeff, const MT lars_weight_decay,
const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
const MultiPrecisionType<T> rescale_grad) {
const MT lr = static_cast<MT>(learning_rate[0]);
MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);

if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
g_n > static_cast<MT>(0)) {
local_lr =
lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
}
CUDA_KERNEL_LOOP(i, num) {
if (lars_weight_decay > 0 && p_norm[0] > 0 && g_norm[0] > 0) {
local_lr = lr * lars_coeff * p_norm[0] /
(g_norm[0] + lars_weight_decay * p_norm[0] + epsilon);
}
MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);

MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
MT p_new = param - v_new;

T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
v_out[i] = v_new;
p_out[i] = p[i] - v_new;
p_out[i] = static_cast<T>(p_new);
if (master_p_out) master_p_out[i] = p_new;
}
}

template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;

public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPDType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}

private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision) const {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");

const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}

const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
MT* master_p_out = multi_precision
? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;

T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());

T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon");
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MPDType rescale_grad =
static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));

auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
auto* v = velocity->data<MT>();
auto* lr = learning_rate->data<MPDType>();

int block = 512;
int grid = (param->numel() + block - 1) / block;
Expand All @@ -72,17 +122,24 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);

auto* place = ctx.template device_context<DeviceContext>().eigen_device();
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(

// eigen unsupport fp16 l2-norm
ep_norm.device(*place) =
eigen_p.template cast<MPDType>().square().sum().sqrt();
eg_norm.device(*place) =
(eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();

MomentumLarsKernel<
T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out, epsilon);
p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
rescale_grad);
}
};

Expand All @@ -93,4 +150,6 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lars_momentum,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
3 changes: 3 additions & 0 deletions paddle/fluid/operators/optimizers/momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class MomentumOp : public framework::OperatorWithKernel {

ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
}
}

framework::OpKernelType GetExpectedKernelType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride):
return pool


def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 32
Expand All @@ -96,12 +96,17 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
# Test program
test_program = train_program.clone(for_test=True)

if use_adam:
if optimizer == "Adam":
optimizer = paddle.optimizer.AdamW(
learning_rate=0.001,
epsilon=1e-8,
weight_decay=0.0,
multi_precision=True)
elif optimizer == "Lars":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001,
momentum=0.9,
multi_precision=use_pure_fp16)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
Expand Down Expand Up @@ -169,9 +174,11 @@ def test_resnet_pure_fp16(self):
if not fluid.core.is_compiled_with_cuda():
return

def do_test(use_nesterov=False, use_adam=False):
if use_adam:
def do_test(use_nesterov=False, optimizer=""):
if optimizer == "Adam":
suffix = "use Adam"
elif optimizer == "Lars":
suffix = "use Lars"
else:
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
Expand All @@ -180,14 +187,14 @@ def do_test(use_nesterov=False, use_adam=False):
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
optimizer=optimizer)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)
optimizer=optimizer)

self.assertTrue(
np.allclose(
Expand All @@ -208,7 +215,8 @@ def do_test(use_nesterov=False, use_adam=False):

do_test(use_nesterov=False)
do_test(use_nesterov=True)
do_test(use_adam=True)
do_test(optimizer="Adam")
do_test(optimizer="Lars")

@contextlib.contextmanager
def scope_prog_guard(self):
Expand Down
Loading