Skip to content

Commit

Permalink
add fp16 training code (PaddlePaddle#4)
Browse files Browse the repository at this point in the history
* add support for mixed precision training
  • Loading branch information
danleifeng authored and lilong12 committed Dec 27, 2019
1 parent be5ba88 commit 371d8e2
Show file tree
Hide file tree
Showing 4 changed files with 933 additions and 59 deletions.
25 changes: 22 additions & 3 deletions plsc/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def __init__(self):
self.fs_ugi = None
self.fs_dir = None

self.use_fp16 = False
self.init_loss_scaling = 1.0
self.fp16_user_dict = None

self.val_targets = self.config.val_targets
self.dataset_dir = self.config.dataset_dir
self.num_classes = self.config.num_classes
Expand Down Expand Up @@ -145,6 +149,17 @@ def set_train_batch_size(self, batch_size):
self.global_train_batch_size = batch_size * self.num_trainers
logger.info("Set train batch size to {}.".format(batch_size))

def set_mixed_precision(self, use_fp16, loss_scaling):
"""
Whether to use mixed precision training.
"""
self.use_fp16 = use_fp16
self.init_loss_scaling = loss_scaling
self.fp16_user_dict = dict()
self.fp16_user_dict['init_loss_scaling'] = self.init_loss_scaling
logger.info("Use mixed precision training: {}.".format(use_fp16))
logger.info("Set init loss scaling to {}.".format(loss_scaling))

def set_test_batch_size(self, batch_size):
self.test_batch_size = batch_size
self.global_test_batch_size = batch_size * self.num_trainers
Expand Down Expand Up @@ -293,8 +308,12 @@ def _get_optimizer(self):

if self.loss_type in ["dist_softmax", "dist_arcface"]:
self.optimizer = DistributedClassificationOptimizer(
self.optimizer, global_batch_size)

self.optimizer, global_batch_size, use_fp16=self.use_fp16,
loss_type=self.loss_type,
fp16_user_dict=self.fp16_user_dict)
elif self.use_fp16:
self.optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=self.init_loss_scaling)
return self.optimizer

def build_program(self,
Expand Down Expand Up @@ -358,7 +377,7 @@ def build_program(self,
dist_optimizer = self.fleet.distributed_optimizer(
optimizer, strategy=self.strategy)
dist_optimizer.minimize(loss)
if "dist" in self.loss_type:
if "dist" in self.loss_type or self.use_fp16:
optimizer = optimizer._optimizer
elif use_parallel_test:
emb = fluid.layers.collective._c_allgather(emb,
Expand Down
Loading

0 comments on commit 371d8e2

Please sign in to comment.