Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[amp] dygraph amp support param_group #34899

Merged
merged 3 commits into from
Aug 16, 2021

Conversation

zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Aug 14, 2021

PR types

New features

PR changes

Others

Describe

  • dygraph amp support param_group
linear_1 = paddle.nn.Linear(10, 10)
linear_2 = paddle.nn.Linear(10, 10)
inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
with paddle.amp.auto_cast():
    out = linear_1(inp)
    out = linear_2(out)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(
    learning_rate=0.1,
    parameters=[{
        'params': linear_1.parameters()
    }, {
        'params': linear_2.parameters(),
        'weight_decay': 0.001,
        'learning_rate': 0.1
    }],
    weight_decay=0.01)              
scaler=paddle.amp.GradScaler()     
scaler.scale(out).backward()
scaler.step(sgd)
sgd.clear_grad()
  • add step method for class GradScaler, to be consistent with class Optimizer, since Optimizer.minimize() does not support param_group

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@jerrywgz jerrywgz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

optimizer(Optimizer): The optimizer used to update parameters.
Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import 前要加一个空行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx


If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Args 前要加空行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples 前要加空行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

phlrain
phlrain previously approved these changes Aug 16, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for API docs

@phlrain phlrain self-requested a review August 16, 2021 04:06
@zhiqiu zhiqiu merged commit e29c2d1 into PaddlePaddle:develop Aug 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants