From d28b3094dd75bce8df079fce5d1fa2f33654ba56 Mon Sep 17 00:00:00 2001 From: sidgoyal78 Date: Mon, 2 Oct 2017 20:52:31 -0700 Subject: [PATCH 1/3] Add momentum operator --- paddle/operators/momentum_op.cc | 89 +++++++++++++++++++ paddle/operators/momentum_op.cu | 20 +++++ paddle/operators/momentum_op.h | 53 +++++++++++ .../v2/framework/tests/test_momentum_op.py | 35 ++++++++ 4 files changed, 197 insertions(+) create mode 100644 paddle/operators/momentum_op.cc create mode 100644 paddle/operators/momentum_op.cu create mode 100644 paddle/operators/momentum_op.h create mode 100644 python/paddle/v2/framework/tests/test_momentum_op.py diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc new file mode 100644 index 0000000000000..2c6ffd618a6e9 --- /dev/null +++ b/paddle/operators/momentum_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/momentum_op.h" + +namespace paddle { +namespace operators { + +class MomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(param) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(grad) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Velocity"), + "Input(velocity) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of Momentum should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), + "Output(VelocityOut) of Momentum should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, + "Learning_rate should be a scalar"); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("VelocityOut", param_dim); + } +}; + +class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MomentumOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "Input parameter"); + AddInput("Grad", "Input gradient"); + AddInput("Velocity", "Input velocity"); + AddInput("LearningRate", "Input learning rate"); + + AddOutput("ParamOut", "Output parameter"); + AddOutput("VelocityOut", "Output velocity"); + + AddAttr("mu", "Momentum coefficient"); + AddComment(R"DOC( + +Momentum Algorithm (momentum). + +velocity_out = mu * velocity - learning_rate * grad +param_out = param + velocity_out + +Ref: Sutskever, Ilya, et al. "On the importance of initialization + and momentum in deep learning." ICML 2013; + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker); +REGISTER_OP_CPU_KERNEL( + momentum, ops::MomentumOpKernel); diff --git a/paddle/operators/momentum_op.cu b/paddle/operators/momentum_op.cu new file mode 100644 index 0000000000000..efc24e795e059 --- /dev/null +++ b/paddle/operators/momentum_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/momentum_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + momentum, ops::MomentumOpKernel); diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h new file mode 100644 index 0000000000000..60ff2b759018d --- /dev/null +++ b/paddle/operators/momentum_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class MomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + float mu = ctx.Attr("mu"); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto v = EigenVector::Flatten(*ctx.Input("Velocity")); + float lr = ctx.Input("LearningRate")->data()[0]; + auto p_out = EigenVector::Flatten(*param_out); + auto v_out = EigenVector::Flatten(*velocity_out); + auto place = ctx.GetEigenDevice(); + + v_out.device(place) = mu * v - lr * g; + p_out.device(place) = p + v_out; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/framework/tests/test_momentum_op.py new file mode 100644 index 0000000000000..cb455bdc9f2b9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_momentum_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMomentumOp(OpTest): + def setUp(self): + self.op_type = "momentum" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + mu = 0.0001 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = {'mu': mu} + + velocity_out = mu * velocity - learning_rate * grad + param_out = param + velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From c10da26cf5626c6b34ab31b864fc49ea9c6e725c Mon Sep 17 00:00:00 2001 From: sidgoyal78 Date: Thu, 5 Oct 2017 14:54:33 -0700 Subject: [PATCH 2/3] Modify implementation --- paddle/operators/momentum_op.cc | 35 +++++++++++-------- paddle/operators/momentum_op.h | 12 +++---- .../v2/framework/tests/test_momentum_op.py | 4 +-- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index 2c6ffd618a6e9..efa0b59992fe6 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -57,25 +57,30 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { MomentumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Param", "Input parameter"); - AddInput("Grad", "Input gradient"); - AddInput("Velocity", "Input velocity"); - AddInput("LearningRate", "Input learning rate"); - - AddOutput("ParamOut", "Output parameter"); - AddOutput("VelocityOut", "Output velocity"); - - AddAttr("mu", "Momentum coefficient"); + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(Tensor, default Tensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "Input learning rate"); + + AddOutput("ParamOut", "(Tensor) Output updated parameter"); + AddOutput("VelocityOut", "(Tensor) Output updated velocity"); + + AddAttr("mu", "(float) Momentum coefficient"); AddComment(R"DOC( Momentum Algorithm (momentum). -velocity_out = mu * velocity - learning_rate * grad -param_out = param + velocity_out - -Ref: Sutskever, Ilya, et al. "On the importance of initialization - and momentum in deep learning." ICML 2013; - http://jmlr.org/proceedings/papers/v28/sutskever13.pdf +velocity = mu * velocity + gradient +param = param - learning_rate * velocity )DOC"); } diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index 60ff2b759018d..fa3788a8abd14 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -36,16 +36,16 @@ class MomentumOpKernel : public framework::OpKernel { float mu = ctx.Attr("mu"); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto g = EigenVector::Flatten(*ctx.Input("Grad")); - auto v = EigenVector::Flatten(*ctx.Input("Velocity")); - float lr = ctx.Input("LearningRate")->data()[0]; + auto param = EigenVector::Flatten(*ctx.Input("Param")); + auto grad = EigenVector::Flatten(*ctx.Input("Grad")); + auto velocity = EigenVector::Flatten(*ctx.Input("Velocity")); + float learning_rate = ctx.Input("LearningRate")->data()[0]; auto p_out = EigenVector::Flatten(*param_out); auto v_out = EigenVector::Flatten(*velocity_out); auto place = ctx.GetEigenDevice(); - v_out.device(place) = mu * v - lr * g; - p_out.device(place) = p + v_out; + v_out.device(place) = velocity * mu + grad; + p_out.device(place) = param - learning_rate * v_out; } }; diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/framework/tests/test_momentum_op.py index cb455bdc9f2b9..d3353ff6e4f4d 100644 --- a/python/paddle/v2/framework/tests/test_momentum_op.py +++ b/python/paddle/v2/framework/tests/test_momentum_op.py @@ -22,8 +22,8 @@ def setUp(self): self.attrs = {'mu': mu} - velocity_out = mu * velocity - learning_rate * grad - param_out = param + velocity_out + velocity_out = mu * velocity + grad + param_out = param - learning_rate * velocity_out self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} From db77937ea4985dfc6404cc120457d21774fcd3ed Mon Sep 17 00:00:00 2001 From: sidgoyal78 Date: Thu, 5 Oct 2017 16:11:45 -0700 Subject: [PATCH 3/3] Fix learning_rate usage for momentum --- paddle/operators/momentum_op.h | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index fa3788a8abd14..f7a724f048782 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -19,33 +19,35 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); param_out->mutable_data(ctx.GetPlace()); velocity_out->mutable_data(ctx.GetPlace()); float mu = ctx.Attr("mu"); - auto param = EigenVector::Flatten(*ctx.Input("Param")); - auto grad = EigenVector::Flatten(*ctx.Input("Grad")); - auto velocity = EigenVector::Flatten(*ctx.Input("Velocity")); - float learning_rate = ctx.Input("LearningRate")->data()[0]; - auto p_out = EigenVector::Flatten(*param_out); - auto v_out = EigenVector::Flatten(*velocity_out); + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto lr = framework::EigenVector::Flatten(*learning_rate); + auto place = ctx.GetEigenDevice(); - v_out.device(place) = velocity * mu + grad; - p_out.device(place) = param - learning_rate * v_out; + Eigen::DSizes grad_dsize(grad->numel()); + v_out.device(place) = v * mu + g; + p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out; } };