Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Add regularation and Nesterov for mergerd_momentum op #37527

Merged
merged 7 commits into from
Nov 30, 2021

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Nov 24, 2021

PR types

Performance optimization

PR changes

OPs

Describe

增强mergerd_momentum op功能,包括:

  • 由仅支持输入单个lr到支持输入多lrs(数量与输入的参数一致);
  • 添加use_nesterov属性,支持use_nesterov策略的计算;
  • 添加regularization属性,支持regularization计算。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

"Attr(regularization_coeff) number must be equal "
"to Input(Param) number."));
}
VLOG(1) << use_nesterov << regularization_methods.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VLOG(1) << use_nesterov << regularization_methods.size()
VLOG(5) << "use_nesterov: " << use_nesterov <<", regularization_methods.size(): " << regularization_methods.size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!

"to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"Attr(regularization_coeff) number must be equal "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of Attr(regularization_coeff) must be equal to the size of Input(Param), but got the size of Attr(regularization_coeff) is %d, the size of Input(Param) is %d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to make the error message helpful, same for others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks!

@@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Use Nesterov Momentum")
"Use Nesterov Momentum or not")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!

PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
for (size_t idx = 0; idx < n; idx++) {
std::string regularization_method = " ";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string regularization_method = " ";
std::string regularization_method = "";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!


#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
params_out[idx]->data<T>();
velocitys_out[idx]->data<MPType>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know what is the purpose to write these 2 lines? Just check whether params_out[idx] and velocitys_out[idx] is properly initialized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks, this code has been deleted.

if (regularization_methods.size() != 0) {
regularization_method = regularization_methods[idx];
}
RegularizationType regularization_flag{RegularizationType::kNONE};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT. Not required to change. Maybe the following code would be simpler:

RegularizationType regularization_flag = regularization_methods.size() > 0 &&  regularization_methods[idx] == "l2_decay" ? RegularizationType::kL2DECAY : RegularizationType::kNONE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks, this code has been modified according to the comments.

}
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems too many duplicate codes with momentum_op.h. Maybe we can use a common function defined in momentum_op.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these codes have reused the DenseMomentumFunctor function in momentum_op.h.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit c8ffdec into PaddlePaddle:develop Nov 30, 2021
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
…ddle#37527)

* add regularation and Nesterov for mergerd_momentum

* refine unittest for use_nesterov attr

* refine op check

* refine code

* fix bug

* refine code of regularization_flag

* delete useless code
@zhangbo9674 zhangbo9674 deleted the dev/merge_momentum branch March 2, 2023 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants