From 9da0de576ed5656d864685873b56a60e3577f3f4 Mon Sep 17 00:00:00 2001 From: guguguzi <1075040010@qq.com> Date: Sun, 26 Dec 2021 15:53:40 +0800 Subject: [PATCH 1/4] delete the modification of dygraph --- .../tests/unittests/test_lr_scheduler.py | 11 +++ python/paddle/optimizer/lr.py | 95 ++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 04a0d47e47c86..d62a633c28576 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -205,6 +205,13 @@ def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False): return learning_rate * lr_lambda(epoch_num) +def multiplicative_lr(epoch_num, learning_rate, lr_lambda, verbose=False): + latest_lr = learning_rate + for i in range(epoch_num): + latest_lr = latest_lr * lr_lambda(i + 1) + return latest_lr + + def piecewise_lr(epoch_num, boundaries, values, verbose=False): assert len(boundaries) + 1 == len(values) for i in range(len(boundaries)): @@ -519,6 +526,10 @@ def test_scheduler(self): "learning_rate": 0.5, "lr_lambda": lambda x: 0.95**x, "verbose": True + }), (multiplicative_lr, paddle.optimizer.lr.MultiplicativeDecay, { + "learning_rate": 0.5, + "lr_lambda": lambda x: 0.95, + "verbose": True }), (cosine_annealing_lr, paddle.optimizer.lr.CosineAnnealingDecay, { "learning_rate": 0.5, "T_max": 10, diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index be1786696bd92..60d7b00dcbb02 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -30,7 +30,8 @@ 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau', - 'CosineAnnealingDecay' + 'CosineAnnealingDecay', + 'MultiplicativeDecay' ] @@ -1513,3 +1514,95 @@ def get_lr(self): def _get_closed_form_lr(self): return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos( math.pi * self.last_epoch / self.T_max)) / 2 + + +class MultiplicativeDecay(LRScheduler): + """ + Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` . + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 # init learning_rate + lr_lambda = lambda epoch: 0.95 + + learning_rate = 0.5 # epoch 0, + learning_rate = 0.475 # epoch 1, 0.5*0.95 + learning_rate = 0.45125 # epoch 2, 0.475*0.95 + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``MultiplicativeDecay`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dynamic graph mode + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + for epoch in range(20): + for batch_id in range(5): + x = paddle.uniform([10, 10]) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_gradients() + scheduler.step() # If you update learning rate each step + # scheduler.step() # If you update learning rate each epoch + + # train on static graph mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[None, 4, 5]) + y = paddle.static.data(name='y', shape=[None, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(5): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=loss.name) + scheduler.step() # If you update learning rate each step + # scheduler.step() # If you update learning rate each epoch + + """ + + def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False): + if not callable(lr_lambda): + raise TypeError( + "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s." + % type(lr_lambda)) + + self.lr_lambda = lr_lambda + super(MultiplicativeDecay, self).__init__(learning_rate, last_epoch, + verbose) + + def get_lr(self): + if self.last_epoch > 0: + return self.last_lr * self.lr_lambda(self.last_epoch) + else: + return self.last_lr From 0e524d7dc49ae783396d1b3a2e72ecc80d28d5f5 Mon Sep 17 00:00:00 2001 From: guguguzi <1075040010@qq.com> Date: Wed, 5 Jan 2022 15:03:12 +0800 Subject: [PATCH 2/4] CI --- python/paddle/optimizer/lr.py | 163 ++++++++++++++-------------------- 1 file changed, 69 insertions(+), 94 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 60d7b00dcbb02..5d13680fbb976 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -17,7 +17,7 @@ import warnings from paddle import Tensor -__all__ = [ #noqa +__all__ = [ # noqa 'LRScheduler', 'NoamDecay', 'PiecewiseDecay', @@ -56,9 +56,9 @@ class LRScheduler(object): Examples: Here is an example of a simple ``StepDecay`` implementation. - + .. code-block:: python - + import paddle from paddle.optimizer.lr import LRScheduler @@ -100,7 +100,7 @@ def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False): self.step() def __call__(self): - """ + """ Return lastest computed learning rate on current epoch. """ return self.last_lr @@ -108,7 +108,7 @@ def __call__(self): def step(self, epoch=None): """ - ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . + ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . The new learning rate will take effect on next ``optimizer.step`` . Args: @@ -192,7 +192,7 @@ def set_state_dict(self, state_dict): def get_lr(self): """ - + For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` . Otherwise, an ``NotImplementedError`` exception will be thrown. @@ -204,7 +204,7 @@ def get_lr(self): class NoamDecay(LRScheduler): r""" - Applies Noam Decay to the initial learning rate. + Applies Noam Decay to the initial learning rate. The algorithm can be described as following. @@ -212,7 +212,7 @@ class NoamDecay(LRScheduler): new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5}) - Please reference `attention is all you need `_ + Please reference `attention is all you need `_ Args: @@ -313,8 +313,8 @@ class PiecewiseDecay(LRScheduler): learning_rate = 0.1 Args: - boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int. - values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries. + boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int. + values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries. The type of element in the list is python float. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . @@ -323,7 +323,7 @@ class PiecewiseDecay(LRScheduler): ``PiecewiseDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -389,7 +389,7 @@ class NaturalExpDecay(LRScheduler): r""" Applies natural exponential decay to the initial learning rate. - + The algorithm can be described as following: .. math:: @@ -406,7 +406,7 @@ class NaturalExpDecay(LRScheduler): ``NaturalExpDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -477,7 +477,7 @@ class InverseTimeDecay(LRScheduler): Args: learning_rate (float): The initial learning rate. It is a python float number. - gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . It should be less than 1.0. Default: 0.1. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . @@ -486,7 +486,7 @@ class InverseTimeDecay(LRScheduler): ``InverseTimeDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -556,7 +556,7 @@ class PolynomialDecay(LRScheduler): .. math:: - decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps}) + decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps}) new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr @@ -564,7 +564,7 @@ class PolynomialDecay(LRScheduler): .. math:: - epoch & = min(epoch, decay\_steps) + epoch & = min(epoch, decay\_steps) new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr @@ -574,7 +574,7 @@ class PolynomialDecay(LRScheduler): decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer. end_lr(float, optional): The minimum final learning rate. Default: 0.0001. power(float, optional): Power of polynomial. Default: 1.0. - cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease + cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . @@ -583,7 +583,7 @@ class PolynomialDecay(LRScheduler): ``PolynomialDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -672,21 +672,21 @@ class LinearWarmup(LRScheduler): Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler. For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ - + When epoch < warmup_steps, learning rate is updated as: - + .. math:: - + lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps} - + where start_lr is the initial learning rate, and end_lr is the final learning rate; - + When epoch >= warmup_steps, learning rate is updated as: - + .. math:: - + lr = learning_rate - + where ``learning_rate`` is float or any subclass of ``LRScheduler`` . Args: @@ -701,7 +701,7 @@ class LinearWarmup(LRScheduler): ``LinearWarmup`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -812,14 +812,14 @@ class ExponentialDecay(LRScheduler): Update learning rate by `gamma` each epoch. The algorithm can be described as following. - + .. math:: new\_learning\_rate = last\_learning\_rate * gamma Args: learning_rate (float): The initial learning rate. It is a python float number. - gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . It should be less than 1.0. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . @@ -828,7 +828,7 @@ class ExponentialDecay(LRScheduler): ``ExponentialDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -890,7 +890,7 @@ class MultiStepDecay(LRScheduler): """ Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones. - The algorithm can be described as the code below. + The algorithm can be described as the code below. .. code-block:: text @@ -907,17 +907,17 @@ class MultiStepDecay(LRScheduler): Args: learning_rate (float): The initial learning rate. It is a python float number. milestones (tuple|list): List or tuple of each boundaries. Must be increasing. - gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . It should be less than 1.0. Default: 0.1. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . - + Returns: ``MultiStepDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -1000,7 +1000,7 @@ class StepDecay(LRScheduler): """ Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch. - The algorithm can be described as the code below. + The algorithm can be described as the code below. .. code-block:: text @@ -1016,7 +1016,7 @@ class StepDecay(LRScheduler): Args: learning_rate (float): The initial learning rate. It is a python float number. step_size (int): the interval to update. It must be a positive integer. - gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . It should be less than 1.0. Default: 0.1. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . @@ -1026,7 +1026,7 @@ class StepDecay(LRScheduler): Examples: - + .. code-block:: python import paddle @@ -1103,7 +1103,7 @@ class LambdaDecay(LRScheduler): """ Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` . - The algorithm can be described as the code below. + The algorithm can be described as the code below. .. code-block:: text @@ -1119,12 +1119,12 @@ class LambdaDecay(LRScheduler): lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . - + Returns: ``LambdaDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -1189,37 +1189,37 @@ def get_lr(self): class ReduceOnPlateau(LRScheduler): """ - Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate + Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate by 2 to 10 times once model performance has no longer improvement. - The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` - stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . - (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` + The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` + stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . + (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` number of epochs, the learning rate will be reduced.) In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation. Args: learning_rate (float): The initial learning rate. It is a python float number. - mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the - learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning + mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the + learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning rate will reduce when ``loss`` stops ascending. Default: ``'min'`` . - factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . + factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . It should be less than 1.0. Default: 0.1. - patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. + patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. Default: 10. - threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . + threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . This make tiny changes of ``loss`` will be ignored. Default: 1e-4. threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss`` - is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum + is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum change of ``loss`` is ``threshold`` . Default: ``'rel'`` . cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0. min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0. - epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, + epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, the update is ignored. Default: 1e-8. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``. - + Returns: ``ReduceOnPlateau`` instance to schedule learning rate. @@ -1332,18 +1332,18 @@ def state_keys(self): def step(self, metrics, epoch=None): """ - step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` . + step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` . The new learning rate will take effect on next epoch. Args: - metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. + metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or 'numpy.ndarray', its shape must be [1]. epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. Returns: None - + Examples: Please refer to the example of current LRScheduler. """ @@ -1355,8 +1355,9 @@ def step(self, metrics, epoch=None): # loss must be float, numpy.ndarray or 1-D Tensor with shape [1] if isinstance(metrics, (Tensor, numpy.ndarray)): assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \ - "should be (1L,), but the current metrics.shape is {}. Maybe that " \ - "you should call paddle.mean to process it first.".format(metrics.shape) + "should be (1L,), but the current metrics.shape is {}. Maybe that " \ + "you should call paddle.mean to process it first.".format( + metrics.shape) elif not isinstance(metrics, (int, float, numpy.float32, numpy.float64)): raise TypeError( @@ -1400,8 +1401,8 @@ def _is_better(self, current, best): class CosineAnnealingDecay(LRScheduler): r""" - Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to - the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in + Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to + the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in SGDR. The algorithm can be described as following. @@ -1410,15 +1411,15 @@ class CosineAnnealingDecay(LRScheduler): \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; + & T_{cur} \neq (2k+1)T_{max}; \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), & T_{cur} = (2k+1)T_{max}. - - It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts `_. + + It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts `_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. - + Args: learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number. T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer. @@ -1430,7 +1431,7 @@ class CosineAnnealingDecay(LRScheduler): ``CosineAnnealingDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -1520,7 +1521,7 @@ class MultiplicativeDecay(LRScheduler): """ Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` . - The algorithm can be described as the code below. + The algorithm can be described as the code below. .. code-block:: text @@ -1536,12 +1537,12 @@ class MultiplicativeDecay(LRScheduler): lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . - + Returns: ``MultiplicativeDecay`` instance to schedule learning rate. Examples: - + .. code-block:: python import paddle @@ -1562,32 +1563,6 @@ class MultiplicativeDecay(LRScheduler): scheduler.step() # If you update learning rate each step # scheduler.step() # If you update learning rate each epoch - # train on static graph mode - paddle.enable_static() - main_prog = paddle.static.Program() - start_prog = paddle.static.Program() - with paddle.static.program_guard(main_prog, start_prog): - x = paddle.static.data(name='x', shape=[None, 4, 5]) - y = paddle.static.data(name='y', shape=[None, 4, 5]) - z = paddle.static.nn.fc(x, 100) - loss = paddle.mean(z) - scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True) - sgd = paddle.optimizer.SGD(learning_rate=scheduler) - sgd.minimize(loss) - - exe = paddle.static.Executor() - exe.run(start_prog) - for epoch in range(20): - for batch_id in range(5): - out = exe.run( - main_prog, - feed={ - 'x': np.random.randn(3, 4, 5).astype('float32'), - 'y': np.random.randn(3, 4, 5).astype('float32') - }, - fetch_list=loss.name) - scheduler.step() # If you update learning rate each step - # scheduler.step() # If you update learning rate each epoch """ From 30206baa234506320e81576d7aa818e37be43225 Mon Sep 17 00:00:00 2001 From: guguguzi <1075040010@qq.com> Date: Wed, 5 Jan 2022 15:23:03 +0800 Subject: [PATCH 3/4] check CI --- python/paddle/optimizer/lr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 5d13680fbb976..6dcb53479e7d0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1563,7 +1563,6 @@ class MultiplicativeDecay(LRScheduler): scheduler.step() # If you update learning rate each step # scheduler.step() # If you update learning rate each epoch - """ def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False): From ed015d436d078af80c7a223bbcab9e8dcd92196f Mon Sep 17 00:00:00 2001 From: guguguzi <1075040010@qq.com> Date: Fri, 7 Jan 2022 11:32:37 +0800 Subject: [PATCH 4/4] modify the retrun value of get_lr --- python/paddle/optimizer/lr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 6dcb53479e7d0..d4fafba9229b0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1579,4 +1579,4 @@ def get_lr(self): if self.last_epoch > 0: return self.last_lr * self.lr_lambda(self.last_epoch) else: - return self.last_lr + return self.base_lr