Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearWarmup.get_lr should not increase epoch of built-in LRScheduler #31843

Merged
merged 1 commit into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,18 @@ def test_scheduler(self):
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()

def test_linear_warmp(self):
natural_lr = paddle.optimizer.lr.NaturalExpDecay(
learning_rate=0.5, gamma=0.1)
natural_lr_warmup = paddle.optimizer.lr.LinearWarmup(
learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1)
for idx in range(30):
if idx >= 10:
self.assertEqual(natural_lr_warmup.get_lr(),
natural_lr.get_lr())
natural_lr.step()
natural_lr_warmup.step()


if __name__ == '__main__':
unittest.main()
5 changes: 2 additions & 3 deletions python/paddle/optimizer/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,8 @@ def get_lr(self):
self.last_epoch) / float(self.warmup_steps) + self.start_lr
else:
if isinstance(self.learning_rate, LRScheduler):
lr_value = self.learning_rate()
self.learning_rate.step()
return lr_value
self.learning_rate.step(self.last_epoch - self.warmup_steps)
return self.learning_rate()

return self.learning_rate

Expand Down