Skip to content

Commit

Permalink
[Tools] use torchrun instead of torch.distributed.launch (dmlc#6304)
Browse files Browse the repository at this point in the history
  • Loading branch information
9rum authored and DominikaJedynak committed Mar 12, 2024
1 parent 2eb2b65 commit d46cf5b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions tests/tools/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_simple(self):
master_port=1234,
)
expected = (
"python3.7 -m torch.distributed.launch "
"python3.7 -m torch.distributed.run "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
Expand All @@ -41,7 +41,7 @@ def test_chained_udf(self):
master_port=1234,
)
expected = (
"cd path/to && python3.7 -m torch.distributed.launch "
"cd path/to && python3.7 -m torch.distributed.run "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
Expand All @@ -68,7 +68,7 @@ def test_py_versions(self):
master_port=1234,
)
expected = (
"{python_bin} -m torch.distributed.launch ".format(
"{python_bin} -m torch.distributed.run ".format(
python_bin=py_bin
)
+ "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
Expand Down Expand Up @@ -221,7 +221,7 @@ def common_checks():
assert "DGL_ROLE=client" in cmd
assert "DGL_GROUP_ID=0" in cmd
assert (
f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
f"python3 -m torch.distributed.run --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
in cmd
)
assert "--master_addr=127.0.0" in cmd
Expand Down
12 changes: 6 additions & 6 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def cleanup_proc(get_all_remote_pids, conn):
"""This process tries to clean up the remote training tasks."""
print("cleanupu process runs")
print("cleanup process runs")
# This process should not handle SIGINT.
signal.signal(signal.SIGINT, signal.SIG_IGN)

Expand Down Expand Up @@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd(
cmd_str.
"""
torch_cmd_template = (
"-m torch.distributed.launch "
"-m torch.distributed.run "
"--nproc_per_node={nproc_per_node} "
"--nnodes={nnodes} "
"--node_rank={node_rank} "
Expand All @@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher(
master_addr: str,
master_port: int,
) -> str:
"""Wraps the user-defined function (udf_command) with the torch.distributed.launch module.
"""Wraps the user-defined function (udf_command) with the torch.distributed.run module.
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
"python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
"python3 -m torch.distributed.run <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
Examples:
Expand Down Expand Up @@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher(
# transforms the udf_command from:
# python path/to/dist_trainer.py arg0 arg1
# to:
# python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# python -m torch.distributed.run [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# python command within the torch distributed launcher.
new_udf_command = udf_command.replace(
Expand Down Expand Up @@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False):
master_port = get_available_port(master_addr)
for node_id, host in enumerate(hosts):
ip, _ = host
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF`
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.run ... UDF`
torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command,
num_trainers=args.num_trainers,
Expand Down

0 comments on commit d46cf5b

Please sign in to comment.