Skip to content

Commit

Permalink
Add ring_id to bkcl hccl gen_comm_id_op, test=allcase
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding committed Aug 3, 2021
1 parent 3efe777 commit a45d3e1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 33 deletions.
9 changes: 5 additions & 4 deletions paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
Expand All @@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids);
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
} else {
std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &bkcl_ids);
platform::RecvBroadCastCommID(endpoint, &bkcl_ids, ring_id);
}

CopyBKCLIDToVar(bkcl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};

Expand All @@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
Expand All @@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID(&hccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &hccl_ids);
platform::SendBroadCastCommID(endpoint_list, &hccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids);
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids, ring_id);
}

CopyHCCLIDToVar(hccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};

Expand Down Expand Up @@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};

Expand Down
23 changes: 8 additions & 15 deletions python/paddle/distributed/fleet/meta_optimizers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def _add_sync_by_allreduce(block):
_add_sync_by_allreduce(block)
return

comm_id_var = block.create_var(
name=unique_name.generate('comm_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
if core.is_compiled_with_cuda():
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
Expand All @@ -153,10 +153,6 @@ def _add_sync_by_allreduce(block):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_xpu():
comm_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
Expand All @@ -165,6 +161,7 @@ def _add_sync_by_allreduce(block):
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
Expand All @@ -178,24 +175,20 @@ def _add_sync_by_allreduce(block):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_npu():
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
outputs={'Out': comm_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
inputs={'X': comm_id_var},
outputs={},
attrs={
'rank': rank,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_sharding_with_mp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_0":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
Expand All @@ -345,7 +345,7 @@ def test_sharding_with_mp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_sharding_hybrid_dp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_0":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
Expand All @@ -390,7 +390,7 @@ def test_sharding_hybrid_dp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])

Expand Down Expand Up @@ -450,7 +450,7 @@ def test_sharding_hybrid_dp_gm(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_0":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
Expand All @@ -459,7 +459,7 @@ def test_sharding_hybrid_dp_gm(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])

Expand Down Expand Up @@ -568,7 +568,7 @@ def test_sharding_with_pp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_0":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
Expand All @@ -577,7 +577,7 @@ def test_sharding_with_pp(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
Expand Down Expand Up @@ -678,7 +678,7 @@ def test_hybrid_with_mp_pp_amp_gclip(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_0":
0] == "comm_id_0":
mp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(mp_group_waiting_ports, ['127.0.0.1:36003'])
Expand All @@ -687,7 +687,7 @@ def test_hybrid_with_mp_pp_amp_gclip(self):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_1":
pp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
Expand Down

0 comments on commit a45d3e1

Please sign in to comment.