Skip to content

Commit

Permalink
[zero] support all-gather overlap (#5898)
Browse files Browse the repository at this point in the history
* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
  • Loading branch information
ver217 committed Jul 11, 2024
1 parent dd9e1cd commit c068ef0
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 25 deletions.
1 change: 1 addition & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def __init__(
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=False,
)

def sync_dp_grads(self):
Expand Down
50 changes: 46 additions & 4 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import MethodType
Expand Down Expand Up @@ -34,7 +35,10 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand All @@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -72,12 +76,25 @@ def __init__(self, module: nn.Module, precision: str) -> None:
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
with ctx:
return super().forward(*args, **kwargs)

def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
Expand Down Expand Up @@ -209,6 +226,7 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()

Expand All @@ -221,16 +239,38 @@ def load_sharded_model(
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)

def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
Expand Down Expand Up @@ -290,6 +330,7 @@ def __init__(
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
Expand All @@ -316,6 +357,7 @@ def __init__(
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.overlap_allgather = overlap_allgather
self.lora_enabled = False
self.verbose = verbose

Expand Down Expand Up @@ -431,11 +473,11 @@ def configure(
self.add_lora_params_to_optimizer(model, optimizer)

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)

# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
dp_size = dist.get_world_size()

# Replace with the distributed implementation if exists
Expand Down
50 changes: 31 additions & 19 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle


class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)

Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(

# communication params
self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype

Expand All @@ -145,6 +148,8 @@ def __init__(

# record the padding size of each param
self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()

# mapping working param and master param
self.master_to_working_param = dict()
Expand Down Expand Up @@ -245,11 +250,12 @@ def _create_master_param_current_rank(self, param_list):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
# # reset working params' ptr when no master weights
# if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param

splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
Expand All @@ -258,7 +264,7 @@ def _create_master_param_current_rank(self, param_list):

# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device)
splited_param_current_rank = splited_params.detach().clone().float().to(device)
else:
splited_param_current_rank = splited_params

Expand Down Expand Up @@ -549,22 +555,24 @@ def step(self, closure=None):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)

def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Expand Down Expand Up @@ -892,3 +900,7 @@ def get_working_grad_by_param_id(self, param_id: int) -> Tensor:
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)

def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)
33 changes: 33 additions & 0 deletions colossalai/zero/low_level/zero_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List

from torch._tensor import Tensor

from colossalai.tensor.param_op_hook import ColoParamOpHook

_ALL_GATHER_HANDLE = "_all_gather_handle"


def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)


def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)


class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)

def post_forward(self, params: List[Tensor]) -> None:
pass

def pre_backward(self, params: List[Tensor]) -> None:
pass

def post_backward(self, params: List[Tensor]) -> None:
pass
4 changes: 2 additions & 2 deletions examples/language/performance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
self.timer.start()

def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
self.timer.end()

batch_size, seq_len = input_ids.shape
Expand Down
4 changes: 4 additions & 0 deletions tests/test_zero/test_low_level/test_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ def fwd_bwd_func(number, cur_data, check_flag):
zero1_optimizer.step()
zero2_optimizer.step()

zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()

# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)


Expand Down
2 changes: 2 additions & 0 deletions tests/test_zero/test_low_level/test_zero1_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step
torch_optimizer.step()

zero_optimizer._force_wait_all_gather()

# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype)
Expand Down

0 comments on commit c068ef0

Please sign in to comment.