Skip to content

Commit

Permalink
support qat in sharding stage2 (#47169) (#47240)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Oct 21, 2022
1 parent d1fedc5 commit 281891c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def dtype_rank_params(self):
"""
if len(self._dtype_rank_params) == 0:
# Assign the parameters of each rank according to the type
for param in self._local_params:
trainable_params = list(
filter(lambda x: x.trainable, self._local_params))
for param in trainable_params:
if param.dtype not in self._dtype_rank_params.keys():
self._dtype_rank_params[param.dtype] = [
[] for _ in range(self.world_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ def __init__(
# sharing stage 2 comm overlap flag
self._reduce_overlap = False

self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
self._trainable_param2align = {}
self._trainable_mask = list(map(_trainable, self._all_params))
self._trainable_params = list(
filter(lambda x: x.trainable, self._all_params))
self._trainable_mask = list(map(_trainable, self._trainable_params))
self._param_grads = []

# Set grad storage size & Display param sizes and model sizes
Expand Down Expand Up @@ -488,7 +489,7 @@ def _setup_use_grad_storage(self):

def _detect_train_change(self):
# Current trainable parameters
trainable_mask = list(map(_trainable, self._all_params))
trainable_mask = list(map(_trainable, self._trainable_params))

# Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask
Expand Down

0 comments on commit 281891c

Please sign in to comment.