Skip to content

Commit

Permalink
[hybrid] out data parallel as optimizer sharding parallel (PaddlePadd…
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored and AnnaTrainingG committed Sep 29, 2021
1 parent eb373c9 commit bc43b5d
Show file tree
Hide file tree
Showing 16 changed files with 967 additions and 177 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ message ShardingConfig {
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
}

message HybridConfig {
Expand Down
31 changes: 16 additions & 15 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
"check_finite_and_unscale");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"check_finite_and_unscale");
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("FoundInfinite", {1});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data);

size_t xs_size = xs.size();
if (xs_size == 0) return;

const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
Expand Down
26 changes: 19 additions & 7 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
Expand All @@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
"update_loss_scaling");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);

if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}

ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1});
Expand All @@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}

return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
size_t xs_size = xs.size();
if (xs_size == 0) return;

const auto& cpu_place = platform::CPUPlace();
// alloc each tensor's start index and copy to device
auto h_in_starts_mem =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
Expand Down Expand Up @@ -142,10 +141,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)

block._insert_op_without_sync(
update_loss_scaling_op_idx,
Expand Down Expand Up @@ -179,10 +174,10 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
outputs={'Out': inf_var},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
Expand Down Expand Up @@ -210,10 +205,6 @@ def sync_amp_check_nan_inf(block, ring_ids):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
global_norm_sum_op_idx = idx
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
Expand All @@ -61,7 +62,10 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)

if not deperated_vars:
# NOTE(wangxi): If only have 2 sharding, and 1 param.
# sharding 0 will not deperated_vars, will return, only
# sharding 1 will insert allreduce, then hang.
if not deperated_vars and global_norm_sum_op_idx == -1:
# got no gradient_clip op
return

Expand All @@ -71,8 +75,8 @@ def prune_gradient_clip(self, block, shard, ring_ids):
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
reversed_inputs = []
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
Expand All @@ -82,6 +86,28 @@ def prune_gradient_clip(self, block, shard, ring_ids):
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]

# NOTE(wangxi): If we have 2 param, but sharding is 4,
# then the sum op in some cards will not have input.
# So we use fill_constant_op to set `sum_var` to zero,
# which does not affect correctness.
if len(reversed_inputs) == 0:
sum_var = block.var(sum_res)
namescope = op.attr("op_namescope")

block._remove_op(idx, sync=False)
op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_res},
attrs={
'shape': sum_var.shape,
'dtype': sum_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
op._set_attr('op_namescope', namescope)

# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1
for ring_id in ring_ids:
Expand Down
17 changes: 12 additions & 5 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,28 @@ def crop_output_var_from_op(self, op_idx, var_name):
var_name] == []:
self._block._remove_var(var_name, sync=False)

def remove_op(self, op_idx):
def remove_op(self, op_idx, reserved_vars=None):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if reserved_vars is not None and input_name in reserved_vars:
continue
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
if reserved_vars is not None and output_name in reserved_vars:
continue
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)

def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True

# NOTE: At present, it is found that the OP without output is
# only send_v2 and partial_send op, which will be used in
# all device
if len(op.desc.output_arg_names()) == 0:
return False

for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, ):
self.global_params = set([])
self.worker_idx = -1
self.worker_num = -1
self.global_param2device = {}
self.global_param2device = dict()
self.device2global_params = dict()

def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices
Expand All @@ -33,8 +34,9 @@ def setup(self, params_grads, worker_idx, worker_num):
self.worker_idx = worker_idx
self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx,
worker_num)
# device2global_params only contains fp32 params
self.global_param2device, self.device2global_params \
= self._split_params(params_grads, worker_idx, worker_num)

def has_param(self, var_name):
return var_name in self.global_param2device and \
Expand Down Expand Up @@ -64,7 +66,7 @@ def _split_params(self, params_grads, worker_idx, worker_num):
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
return param2device
return param2device, device2params

def _var_device_id(self, var_name):
if var_name in self.global_param2device:
Expand Down
Loading

0 comments on commit bc43b5d

Please sign in to comment.