Skip to content

Commit

Permalink
[AMP] refine paddle.amp.decorate code example (#40159)
Browse files Browse the repository at this point in the history
* refine amp.decorate code example

* refine code
  • Loading branch information
zhangbo9674 committed Mar 7, 2022
1 parent d30d85d commit da3de72
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def decorate(models,
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimzier = paddle.optimizer.SGD(parameters=model.parameters())
optimizer = paddle.optimizer.SGD(parameters=model.parameters())
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimzier, level='O2')
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2')
data = paddle.rand([10, 3, 32, 32])
Expand All @@ -122,7 +122,7 @@ def decorate(models,
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2')
models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
data = paddle.rand([10, 3, 32, 32])
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ def amp_decorate(models,
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimzier = paddle.optimizer.SGD(parameters=model.parameters())
optimizer = paddle.optimizer.SGD(parameters=model.parameters())
model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimzier, level='O2')
model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimizer, level='O2')
data = paddle.rand([10, 3, 32, 32])
Expand All @@ -426,7 +426,7 @@ def amp_decorate(models,
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2')
models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
data = paddle.rand([10, 3, 32, 32])
Expand Down

0 comments on commit da3de72

Please sign in to comment.