diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9c4dd5f25083d..aaea7427b9512 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -105,7 +105,12 @@ def save_parameter_to_tar(self, f): self.__parameters__.to_tar(f) self.__parameter_updater__.restore() - def train(self, reader, num_passes=1, event_handler=None, feeding=None): + def train(self, + reader, + num_passes=1, + event_handler=None, + feeding=None, + step_size=1): """ Training method. Will train num_passes of input data. @@ -119,6 +124,7 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None): :param feeding: Feeding is a map of neural network input name and array index that reader returns. :type feeding: dict|list + :param step_size: the step size for parameter updater :return: """ import py_paddle.swig_paddle as api @@ -167,7 +173,8 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None): batch_id=batch_id, cost=cost, evaluator=batch_evaluator)) - self.__parameter_updater__.finishBatch(cost) + if batch_id >= step_size and batch_id % step_size == 0: + self.__parameter_updater__.finishBatch(cost) batch_evaluator.finish() self.__parameter_updater__.finishPass()