Skip to content

Commit

Permalink
multi pricison for lars op and lars optimizer
Browse files Browse the repository at this point in the history
add unitest for new lars op

update momentum op to add dim for master_out_p

add fp16 unitest for lars optimizer
  • Loading branch information
FeixLiu committed Jun 2, 2021
1 parent 47774d9 commit f4a8eeb
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 58 deletions.
14 changes: 14 additions & 0 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
Expand All @@ -51,6 +56,15 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);

AddComment(R"DOC(
Lars Momentum Optimizer.
Expand Down
119 changes: 89 additions & 30 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,105 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, const T lars_coeff,
const T lars_weight_decay, const T* p_norm,
const T* g_norm, T* p_out, T* v_out,
const T epsilon) {
T lr = learning_rate[0];
T local_lr = learning_rate[0];
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* p, const T* g, const MT* v,
const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
const MT lars_coeff, const MT lars_weight_decay,
const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
const MultiPrecisionType<T> rescale_grad) {
const MT lr = static_cast<MT>(learning_rate[0]);
MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);

if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
g_n > static_cast<MT>(0)) {
local_lr =
lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
}
CUDA_KERNEL_LOOP(i, num) {
if (lars_weight_decay > 0 && p_norm[0] > 0 && g_norm[0] > 0) {
local_lr = lr * lars_coeff * p_norm[0] /
(g_norm[0] + lars_weight_decay * p_norm[0] + epsilon);
}
MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);

MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
MT p_new = param - v_new;

T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
v_out[i] = v_new;
p_out[i] = p[i] - v_new;
p_out[i] = static_cast<T>(p_new);
if (master_p_out) master_p_out[i] = p_new;
}
}

template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;

public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPDType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}

private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision) const {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");

const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}

const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
MT* master_p_out = multi_precision
? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;

T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());

T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon");
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MPDType rescale_grad =
static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));

auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
auto* v = velocity->data<MT>();
auto* lr = learning_rate->data<MPDType>();

int block = 512;
int grid = (param->numel() + block - 1) / block;
Expand All @@ -72,17 +122,24 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);

auto* place = ctx.template device_context<DeviceContext>().eigen_device();
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(

// eigen unsupport fp16 l2-norm
ep_norm.device(*place) =
eigen_p.template cast<MPDType>().square().sum().sqrt();
eg_norm.device(*place) =
(eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();

MomentumLarsKernel<
T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out, epsilon);
p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
rescale_grad);
}
};

Expand All @@ -93,4 +150,6 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lars_momentum,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/optimizers/momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class MomentumOp : public framework::OperatorWithKernel {

ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut"))
ctx->SetOutputDim("MasterParamOut", param_dim);
}

framework::OpKernelType GetExpectedKernelType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride):
return pool


def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
def train(use_pure_fp16=True,
use_nesterov=False,
use_adam=False,
use_lars=False):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 32
Expand Down Expand Up @@ -102,6 +105,11 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
epsilon=1e-8,
weight_decay=0.0,
multi_precision=True)
elif use_lars:
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001,
momentum=0.9,
multi_precision=use_pure_fp16)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
Expand Down Expand Up @@ -169,25 +177,41 @@ def test_resnet_pure_fp16(self):
if not fluid.core.is_compiled_with_cuda():
return

def do_test(use_nesterov=False, use_adam=False):
def do_test(use_nesterov=False, use_adam=False, use_lars=False):
assert not (use_adam and
use_lars), "cannot use adam and lars at the same time"
if use_adam:
suffix = "use Adam"
elif use_lars:
suffix = "use Lars"
else:
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
print("-----------------FP16 Train {}-----------------".format(
suffix))
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
if use_lars:
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_lars=use_lars)
else:
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)
if use_lars:
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_lars=use_lars)
else:
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)

self.assertTrue(
np.allclose(
Expand All @@ -209,6 +233,7 @@ def do_test(use_nesterov=False, use_adam=False):
do_test(use_nesterov=False)
do_test(use_nesterov=True)
do_test(use_adam=True)
do_test(use_lars=True)

@contextlib.contextmanager
def scope_prog_guard(self):
Expand Down
Loading

0 comments on commit f4a8eeb

Please sign in to comment.