Skip to content

Commit

Permalink
unuse global_ring from hybird, test=allcase
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding committed Aug 1, 2021
1 parent 2a8b4b2 commit ee5a94b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def remove_cast_op(block, params, segment, offset):
return inserted_op_num

@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
"""
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
Expand Down Expand Up @@ -146,6 +146,7 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)

block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
Expand All @@ -156,19 +157,26 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
# this allreduce communication should not overlap with calc
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1

# allreduce(mp)->allreduce(sharding)->allreduce(pp)
for ring_id in ring_ids:
if ring_id == -1: continue
# this allreduce communication should not overlap with calc
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1

block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
Expand All @@ -177,22 +185,27 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._sync_with_cpp()

# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
def sync_amp_check_nan_inf(block, ring_ids):
update_loss_scaling_op_idx = -1

for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
break

# not use amp
if update_loss_scaling_op_idx == -1:
return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
Expand All @@ -212,18 +225,25 @@ def sync_amp_check_nan_inf(block, ring_id):
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1

# allreduce(mp)->allreduce(pp)
for ring_id in ring_ids:
if ring_id == -1: continue
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1

block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
Expand All @@ -232,4 +252,5 @@ def sync_amp_check_nan_inf(block, ring_id):
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._sync_with_cpp()
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")

def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
def prune_gradient_clip(self, block, shard, ring_ids):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
Expand Down Expand Up @@ -82,33 +82,23 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]

# this allreduce should not overlap with calc and should be scheduled in calc stream
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})

# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1
for ring_id in ring_ids:
if ring_id == -1: continue
# this allreduce should not overlap with calc and should be scheduled in calc stream
block._insert_op_without_sync(
idx + 2,
type='scale',
idx + idx_offset,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
idx_offset += 1

# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
Expand All @@ -126,43 +116,32 @@ def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
return

# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
def sync_global_norm(self, block, ring_ids):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
# FIXME(wangxi): mp should prune duplicated param_grads
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue

if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})

# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
for ring_id in ring_ids:
if ring_id == -1: continue

idx = idx + 1
block._insert_op_without_sync(
idx + 2,
type='scale',
idx,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})

return
return
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ def get_grad_device(grad_name, shard):


def append_naive_sync(block, sync_var, ring_id):
if core.is_compiled_with_cuda(): return
# NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
# sync within global
block.append_op(
Expand Down
Loading

0 comments on commit ee5a94b

Please sign in to comment.