Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto para] Relaunch with auto mapping function #37326

Merged
merged 55 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
68b2c10
[Auto Parallel] Add the unified cluster representation
aoyulong Nov 10, 2021
70e188a
[Auto Parallel] Add the graph class for physical mapping
aoyulong Nov 10, 2021
1a44c06
[Auto Parallel] Add the simple physical mapper
aoyulong Nov 10, 2021
b00f5fb
Set the timeout of the mapper
aoyulong Nov 11, 2021
76498be
Merge the upstream develop unittests cmake files
aoyulong Nov 11, 2021
b8a7be4
Merge branch 'develop' into auto_para_mapping
aoyulong Nov 11, 2021
9127177
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 11, 2021
2315472
Fix a bug of the process group
aoyulong Nov 11, 2021
ab7bea4
Merge branch 'auto_para_mapping' of github.com:aoyulong/Paddle into a…
aoyulong Nov 11, 2021
8f3b236
Remove mapper unittest from platforms which is not GPU
aoyulong Nov 11, 2021
72086ae
Merge branch 'develop' into auto_para_mapping
aoyulong Nov 12, 2021
95d6d3a
Move the instantiation of process group after resharding
aoyulong Nov 12, 2021
6402759
Merge branch 'develop' into auto_para_mapping
aoyulong Nov 12, 2021
6f1559d
Merge branch 'auto_para_mapping' of github.com:aoyulong/Paddle into a…
aoyulong Nov 12, 2021
e50494f
Add the local id for devices
aoyulong Nov 14, 2021
14be54b
Merge branch 'auto_para_cluster' into auto_para_mapping
aoyulong Nov 14, 2021
0ccb242
Update the rank mapping format
aoyulong Nov 14, 2021
4060856
[Auto Parallel] Relaunch with the rank mapping file
aoyulong Nov 18, 2021
c287b5a
Merge branch 'develop' of github.com:aoyulong/Paddle into auto_para_l…
aoyulong Nov 18, 2021
a0127f1
Remove the unnecessary json file
aoyulong Nov 18, 2021
48936b8
Avoid entering get_device_proc_info for auto mapping
aoyulong Nov 18, 2021
9cd37a6
Correct the mapper unit test
aoyulong Nov 19, 2021
7349999
Add some comments
aoyulong Nov 23, 2021
d8647be
Merge branch 'auto_para_cluster' into auto_para_mapping
aoyulong Nov 23, 2021
11d41b4
Remove the related files about mapping
aoyulong Nov 23, 2021
cb8de4c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 23, 2021
f56cacf
Update the unittest for auto mapping
aoyulong Nov 24, 2021
f36849e
Merge branch 'auto_para_mapping' into auto_para_launch
aoyulong Nov 24, 2021
9cb742d
Merge branch 'develop' into auto_para_graph
aoyulong Nov 24, 2021
cb9041a
Merge branch 'develop' into auto_para_graph
aoyulong Nov 24, 2021
7b831ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 24, 2021
677d3e3
Remove unused rank_mapping unittest
aoyulong Nov 25, 2021
dc2ba12
Improve the unittest coverage
aoyulong Nov 26, 2021
5494547
Merge branch 'auto_para_graph' into auto_para_mapping
aoyulong Nov 28, 2021
6c268b5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 28, 2021
d56ebf8
Improve the unittest coverage
aoyulong Nov 29, 2021
55870d2
Merge branch 'auto_para_mapping' into auto_para_launch
aoyulong Nov 30, 2021
9e8cc18
Improve the unittest of relaunch
aoyulong Nov 30, 2021
7b24059
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 30, 2021
e71ce76
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Nov 30, 2021
fd8ff31
Fix the unittest problem in CI
aoyulong Nov 30, 2021
df19fa2
Merge branch 'develop' into auto_para_launch
aoyulong Nov 30, 2021
a65acab
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Dec 1, 2021
8002a63
Merge branch 'auto_para_launch' of github.com:aoyulong/Paddle into au…
aoyulong Dec 1, 2021
35828dd
Improve the unittest of relaunch
aoyulong Dec 1, 2021
8d4199c
Remove unnecessary statements
aoyulong Dec 1, 2021
d2e3737
Update the unittest cmakefile
aoyulong Dec 1, 2021
3aef5c5
Correct the cmakefile of auto parallel unittests
aoyulong Dec 3, 2021
6040fea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Dec 3, 2021
e746224
Modify codes based on the new elastic change
aoyulong Dec 3, 2021
ce25444
Use the GPUs exclusively in the unittest
aoyulong Dec 3, 2021
8706b24
Correct the cmakefile
aoyulong Dec 3, 2021
9a23b7f
Set the timeout of the unittest
aoyulong Dec 4, 2021
31ef42c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Dec 6, 2021
6db885e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Dec 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/paddle/distributed/auto_parallel/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License

import os
import operator
import functools
import json
Expand Down Expand Up @@ -175,9 +176,19 @@ def build_process_graph(distributed_program):

def build_cluster_graph(cluster):
graph = Graph()
cuda_visible_devices_env = os.getenv("CUDA_VISIBLE_DEVICES")
cuda_visible_devices = []
if cuda_visible_devices_env is not None and cuda_visible_devices_env != "":
cuda_visible_devices = [
int(d.strip()) for d in cuda_visible_devices_env.split(",")
]
for machine in cluster.machines.values():
for device in machine.devices.values():
graph.add_node(device.global_id, device=device)
if cuda_visible_devices and device.local_id not in cuda_visible_devices:
graph.nodes[device.global_id]["occupied"] = True
else:
graph.nodes[device.global_id]["occupied"] = False
for link in machine.links.values():
graph.add_edge(
link.source.global_id, link.target.global_id, link=link)
Expand All @@ -195,9 +206,6 @@ def mapping(distributed_program, cluster):
for cur_rank_node in process_graph:
cur_rank_node["visited"] = False

for cur_device_node in cluster_graph:
cur_device_node["occupied"] = False

def sort_by_comm_volume(rank_edge):
return rank_edge["comm_requirements"]["comm_volume"]

Expand Down
158 changes: 113 additions & 45 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import json
import shlex
import copy
import pathlib
import subprocess
import logging
import paddle
from paddle.distributed.utils import get_logger
Expand All @@ -23,9 +30,12 @@
from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_world_process_groups
from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .reshard import reshard
from .cluster import Cluster
from .mapper import mapping
# from .auto_search import auto_search

_logger = get_logger(logging.INFO)
Expand All @@ -46,6 +56,21 @@ def __init__(self, fleet):
self._optimizer = self._fleet.user_defined_optimizer
self._dist_strategy = self._fleet._user_defined_strategy
self._dist_context = DistributedContext()
self._cluster = None
self._cluster_topo_path = os.getenv("PADDLE_CLUSTER_TOPO_PATH", None)
if self._cluster_topo_path is not None:
self._cluster = Cluster()
self._cluster.build_from_file(self._cluster_topo_path)
# Prepare information for auto mapping
self._rank_mapping_path = os.getenv("PADDLE_RANK_MAPPING_PATH", None)
enable_auto_mapping_env = os.getenv("PADDLE_ENABLE_AUTO_MAPPING", None)
if enable_auto_mapping_env is None:
self._enable_auto_mapping = False
else:
self._enable_auto_mapping = True
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False

def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
Expand All @@ -57,60 +82,103 @@ def _remove_distributed_attrs(self, main_program):
if suffix in attr_name:
op._remove_attr(attr_name)

def _get_dist_program(self, dist_context, rank):
# Annotation completion
completed_main_program = complete_annotation(self._main_program,
dist_context)
# Logical partition
partitioner = Partitioner(self._dist_strategy, dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward(
self._loss, completed_main_program, self._startup_program,
dist_main_prog, dist_startup_prog)
dist_optimize_ops = partitioner.apply_optimize(
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog)

make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)

reshard(dist_main_prog, dist_startup_prog, rank, dist_context)

return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog

def parallelize(self,
loss,
startup_program,
parameter_list=None,
no_grad_set=None):
assert startup_program is not None
main_program = loss.block.program

if self._dist_strategy.auto_search:
# auto search
_logger.info("Start search dist attr.")
# self._dist_context, _ = auto_search(main_program, startup_program,
# loss, self._optimizer)
# completed_main_program = main_program
raise NotImplementedError("Auto search has not implemented")
else:
# Annotation completion
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(main_program,
self._dist_context)

# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
partitioned_main_prog, partitioned_startup_prog = partitioner.transpile_forward(
completed_main_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, completed_main_program, startup_program,
partitioned_main_prog, partitioned_startup_prog)
dist_optimize_ops = partitioner.apply_optimize(
self._optimizer, dist_params_grads, partitioned_main_prog,
partitioned_startup_prog)
self._loss = loss
self._startup_program = startup_program
self._main_program = loss.block.program
self._parameter_list = parameter_list
self._no_grad_set = no_grad_set

if self._enable_auto_mapping and self._need_rank_mapping:
# Do the mapping pass before parallelization
assert self._cluster is not None, \
"The cluster must not be none when using auto mapping."
dist_programs = {}
world_process_group = get_world_process_groups()
for rank in world_process_group.ranks:
dist_context = DistributedContext()
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program(
dist_context, rank)
dist_programs[rank] = dist_main_prog

# Do the mapping between the distributed program graph and the cluster graph
rank_mapping_dict = mapping(dist_programs, self._cluster)
rank_mapping = list(rank_mapping_dict.values())

# set the grad var shape
set_grad_var_shape(partitioned_main_prog, self._dist_context)
# Relaunch the training by using the rank mapping file
with open(self._rank_mapping_path, "w") as rank_mapping_file:
json.dump(rank_mapping, rank_mapping_file)

enable_elastic = os.getenv("PADDLE_ENABLE_ELASTIC")
enable_elastic = True if enable_elastic and enable_elastic.lower(
) == 'true' else False
if enable_elastic:
print("Auto mapping finished, now do elastic re-launch")
sys.exit(paddle.distributed.fleet.elastic.manager.
ELASTIC_AUTO_PARALLEL_EXIT_CODE)

original_cmd_args = os.getenv("PADDLE_ORIGINAL_CMD_ARGS")
rank_mapping_args = " ".join(
["--rank_mapping_path", self._rank_mapping_path])
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
new_cmd_args = "-m paddle.distributed.fleet.launch" + " " + rank_mapping_args + " " + original_cmd_args
new_cmd = [sys.executable, "-u"] + coverage_args + shlex.split(
new_cmd_args)
new_process = subprocess.Popen(new_cmd)
new_process.wait()
assert new_process.returncode == 0, \
"Launch failed with rank mapping"
print("Successfully do the second launch for auto mapping!")
sys.exit(0)
else:
# Parallelization after the mapping pass
rank = paddle.distributed.get_rank()

# The last step: remove all distributed attributes to be compatiable
# with inference.
self._remove_distributed_attrs(partitioned_main_prog)
make_data_unshard(partitioned_main_prog, partitioned_startup_prog,
self._dist_context)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program(
self._dist_context, rank)

reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context)
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group.ranks:
continue
process_group.instantiate()

# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group.ranks:
continue
process_group.instantiate()
# Copy distributed info to the default context
set_default_distributed_context(self._dist_context)

# Copy distributed info to the default context
set_default_distributed_context(self._dist_context)
# The last step: remove all distributed attributes to be compatible
# with inference.
self._remove_distributed_attrs(dist_main_prog)

return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog
12 changes: 8 additions & 4 deletions python/paddle/distributed/auto_parallel/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant

# Note that Process group 0 is reserved for representing all ranks.
# At the begining, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {}


def get_all_process_groups():
global _g_process_group_map
Expand All @@ -34,6 +30,11 @@ def get_process_group(group_id):
return _g_process_group_map.get(group_id, None)


def get_world_process_groups():
global _g_process_group_map
return _g_process_group_map[0]


def new_process_group(ranks):
global _g_process_group_map
# A key constructed from ranks is used for avoiding duplication
Expand Down Expand Up @@ -151,4 +152,7 @@ def __str__(self):
return string


# Note that Process group 0 is reserved for representing all ranks.
# At the begining, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {}
_g_process_group_map[0] = ProcessGroup(0, [])
73 changes: 51 additions & 22 deletions python/paddle/distributed/fleet/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,25 +175,17 @@ def _parse_args():
default="127.0.0.1",
help="Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..")
collective_group.add_argument(
"--rank_mapping_file",
type=argparse.FileType('r'),
default=sys.stdin,
help="This rank mapping information in json format is used specifically "
"for lazy launch for auto parallel. Some of the ranks in each node "
"may not be used, and the indices of rank should be kept the same "
"as the indices of sub-task splited by auto parallel. "
" { "
" \"ip_ranks\": [ "
" { "
" \"ip\": \"127.0.0.1\", "
" \"ranks\": [0,1] "
" }, "
" { "
" \"ip\": \"127.0.0.2\", "
" \"ranks\": [2,3,4] "
" } "
" ] "
" } ")
"--cluster_topo_path",
type=str,
default=None,
help="A json format file will be stored in this path which is used"
"to represent the cluster topology information for auto parallel.")
collective_group.add_argument(
"--rank_mapping_path",
type=str,
default=None,
help="A json format file will be stored in this path which is used"
"to map processes to machines for auto parallel.")
Comment on lines +178 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one more config file is expensive, is it possible to use xxx_config to hold all ? may be paddle_config and you hold some sections ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rank_mapping file will be automatically generated by our framework in the pre-launch analysis pass and must not be exposed to users.

collective_group.add_argument(
"--enable_auto_mapping",
type=bool,
Expand Down Expand Up @@ -297,20 +289,56 @@ def cpuonly_check(args):
def get_cluster_info(args):
# parse arguments, used for cloud-single-machine and local
if args.backend == 'gloo': cpuonly_check(args)
(device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args)
if args.enable_auto_mapping:
(device_mode, devices_per_proc) = (DeviceMode.GPU, [])
else:
(device_mode,
devices_per_proc) = launch_utils.get_device_proc_info(args)
trainers_num = cloud_utils.get_trainers_num()
logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format(
trainers_num, device_mode, devices_per_proc))

cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")

cluster = None
pod = None

start_port = 6170
if os.environ.get('FLAGS_START_PORT') is not None:
start_port = os.environ.get('FLAGS_START_PORT')
# lazy launch for auto-parallel
# auto mapping between processes and devices for auto-parallel
if args.enable_auto_mapping == True:
cluster, pod = get_mapped_cluster_from_args(args, device_mode)
assert args.cluster_topo_path is not None, \
"The cluster topology must be provied when enabling auto mapping."
rank_mapping_path = args.rank_mapping_path or os.getenv(
"PADDLE_RANK_MAPPING_PATH")
if not rank_mapping_path:
os.environ["PADDLE_NEED_RANK_MAPPING"] = str(True)
os.environ["PADDLE_ENABLE_ELASTIC"] = str(
enable_elastic(args, device_mode))
cwd = pathlib.Path().resolve()
rank_mapping_path = os.path.join(cwd,
"auto_parallel_rank_mapping.json")
os.environ["PADDLE_RANK_MAPPING_PATH"] = str(rank_mapping_path)

original_args = sys.argv[1:]
os.environ["PADDLE_ORIGINAL_CMD_ARGS"] = " ".join(original_args)
Comment on lines +324 to +325
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part looks fragile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been dealt with by shlex.split on line 154 of the above parallelizer.py.

os.environ["PADDLE_CLUSTER_TOPO_PATH"] = str(args.cluster_topo_path)
os.environ["PADDLE_ENABLE_AUTO_MAPPING"] = str(
args.enable_auto_mapping)
cluster, pod = launch_utils.get_mapped_cluster_from_args_without_rank_mapping(
args, device_mode)
else:
os.environ["PADDLE_NEED_RANK_MAPPING"] = str(False)
os.environ["PADDLE_ENABLE_ELASTIC"] = str(
enable_elastic(args, device_mode))

os.environ["PADDLE_CLUSTER_TOPO_PATH"] = str(args.cluster_topo_path)
os.environ["PADDLE_RANK_MAPPING_PATH"] = str(rank_mapping_path)
os.environ["PADDLE_ENABLE_AUTO_MAPPING"] = str(
args.enable_auto_mapping)
cluster, pod = launch_utils.get_mapped_cluster_from_args_with_rank_mapping(
args, device_mode)
elif cloud_utils.use_paddlecloud() and trainers_num != 1:
cluster, pod = cloud_utils.get_cloud_cluster(
args.ips, device_mode, devices_per_proc, start_port)
Expand All @@ -328,6 +356,7 @@ def get_cluster_info(args):
logger.debug("get cluster from args:{}".format(cluster))
return cluster, pod


def get_global_envs(args, tmp_dir):
global_envs = copy.copy(os.environ.copy())
# add gloo env
Expand Down
Loading