-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add momentum operator #4571
Add momentum operator #4571
Conversation
paddle/operators/momentum_op.h
Outdated
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); | ||
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad")); | ||
auto v = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity")); | ||
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might be not good for GPU. If the LearningRate
is in GPU memory, we cannot get float
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed as per #4598
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
class MomentumOpKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto param_out = ctx.Output<framework::Tensor>("ParamOut"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These parameters will be better if can be named with an auto *
to indicate the real type.
Thanks for this PR! @sidgoyal78 . Since our book chapters heavily depend on these optimizer operators, so merge this PR ASAP. We can leave the name style unified work in the future. |
This PR adds the implementation of momentum operator.
In summary, we want to perform the update with a new velocity vector, such that,
(where mu is the momentum coefficient).