From bae06c8750b870f4f9b9d279ac6c4ba73a9f41c7 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Wed, 24 Mar 2021 09:29:44 +0000 Subject: [PATCH] LRScheduler.get_lr should not update lr in LinearWarmup --- .../fluid/tests/unittests/test_lr_scheduler.py | 12 ++++++++++++ python/paddle/optimizer/lr.py | 5 ++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 8c6383cd6ef52..04a0d47e47c86 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -537,6 +537,18 @@ def test_scheduler(self): self._test_dygraph(python_func, paddle_api, kwarg, place) paddle.enable_static() + def test_linear_warmp(self): + natural_lr = paddle.optimizer.lr.NaturalExpDecay( + learning_rate=0.5, gamma=0.1) + natural_lr_warmup = paddle.optimizer.lr.LinearWarmup( + learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1) + for idx in range(30): + if idx >= 10: + self.assertEqual(natural_lr_warmup.get_lr(), + natural_lr.get_lr()) + natural_lr.step() + natural_lr_warmup.step() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 5085911ce927a..484b4fb7246a7 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -786,9 +786,8 @@ def get_lr(self): self.last_epoch) / float(self.warmup_steps) + self.start_lr else: if isinstance(self.learning_rate, LRScheduler): - lr_value = self.learning_rate() - self.learning_rate.step() - return lr_value + self.learning_rate.step(self.last_epoch - self.warmup_steps) + return self.learning_rate() return self.learning_rate