From 281891c5afca435113f8843858fd448bd10b67c4 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 21 Oct 2022 20:33:47 +0800 Subject: [PATCH] support qat in sharding stage2 (#47169) (#47240) --- .../sharding/group_sharded_optimizer_stage2.py | 4 +++- .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index d366291b6bebf..de5743a226683 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -294,7 +294,9 @@ def dtype_rank_params(self): """ if len(self._dtype_rank_params) == 0: # Assign the parameters of each rank according to the type - for param in self._local_params: + trainable_params = list( + filter(lambda x: x.trainable, self._local_params)) + for param in trainable_params: if param.dtype not in self._dtype_rank_params.keys(): self._dtype_rank_params[param.dtype] = [ [] for _ in range(self.world_size) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index babf9391b928d..4caf2d6013a4f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -103,11 +103,12 @@ def __init__( # sharing stage 2 comm overlap flag self._reduce_overlap = False - self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} self._trainable_param2align = {} - self._trainable_mask = list(map(_trainable, self._all_params)) + self._trainable_params = list( + filter(lambda x: x.trainable, self._all_params)) + self._trainable_mask = list(map(_trainable, self._trainable_params)) self._param_grads = [] # Set grad storage size & Display param sizes and model sizes @@ -488,7 +489,7 @@ def _setup_use_grad_storage(self): def _detect_train_change(self): # Current trainable parameters - trainable_mask = list(map(_trainable, self._all_params)) + trainable_mask = list(map(_trainable, self._trainable_params)) # Whether parameters trainability changed trainability_changed = trainable_mask != self._trainable_mask