Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAUNCH] add distributed launch check tools #44495

Merged
merged 3 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/distributed/launch/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def continous_log(self) -> bool:
return False

def set_env_in_args(self):
# this logic may not propre to replace args with env, but ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug fix

for k, v in env_args_mapping.items():
if k in self.envs:
setattr(self.args, v, self.envs[k])
setattr(self.args, v, type(getattr(self.args, v))(self.envs[k]))
10 changes: 7 additions & 3 deletions python/paddle/distributed/launch/controllers/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def build_pod(self):
"PADDLE_TRAINERS_NUM": "{}".format(global_size),
"PADDLE_RANK_IN_NODE": str(i),
}
if self.pod.replicas == 1:
e.update({selected_dev_key: ",".join(selected_dev_list)})
if len(selected_dev_list) > 0:
kuizhiqing marked this conversation as resolved.
Show resolved Hide resolved
if self.pod.replicas == 1:
e.update({selected_dev_key: ",".join(selected_dev_list)})
else:
e.update({selected_dev_key: selected_dev_list[i]})
else:
e.update({selected_dev_key: selected_dev_list[i]})
e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

self.add_container(envs=e, log_tag=i)

return True
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/distributed/launch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import six
import os

__all__ = []

Expand Down Expand Up @@ -60,4 +61,15 @@ def rewrite_host_ip(ctx):
ctx.node.ip = ctx.args.host


enabled_plugins = [collective_compatible, rewrite_host_ip, process_args]
def test_mode(ctx):
if ctx.args.training_script == 'test':
Copy link
Collaborator

@sneaxiy sneaxiy Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend to rename to run_check. You can keep the name in this PR. The final name should be discussed offline in the meeting.

ctx.logger.info('Paddle Distributed Test begin...')
if int(ctx.args.nnodes) < 2:
ctx.args.nnodes = 2
ctx.args.training_script = '{}/test.py'.format(
os.path.dirname(__file__))


enabled_plugins = [
test_mode, collective_compatible, rewrite_host_ip, process_args
]
100 changes: 100 additions & 0 deletions python/paddle/distributed/launch/plugins/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock
from paddle.io import Dataset, BatchSampler, DataLoader

base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4

epoch = 3
batch_num = 1
batch_size = 1
class_dim = 102


# define a random dataset
class RandomDataset(Dataset):

def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
image = np.random.random([3, 224, 224]).astype('float32')
label = np.random.randint(0, class_dim - 1, (1, )).astype('int64')
return image, label

def __len__(self):
return self.num_samples


def optimizer_setting(parameter_list=None):
optimizer = paddle.optimizer.Momentum(
learning_rate=base_lr,
momentum=momentum_rate,
weight_decay=paddle.regularizer.L2Decay(l2_decay),
parameters=parameter_list)
return optimizer


def train_resnet():
fleet.init(is_collective=True)

resnet = ResNet(BottleneckBlock, 18, num_classes=class_dim)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
optimizer = fleet.distributed_optimizer(optimizer)
resnet = fleet.distributed_model(resnet)

dataset = RandomDataset(batch_num * batch_size)
train_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=2)

print("Distributed training start...")
for eop in range(epoch):
resnet.train()

for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True

out = resnet(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)

avg_loss.backward()
optimizer.step()
resnet.clear_gradients()

print("[Epoch %d, batch %d] loss: %.5f, acc1: %.5f, acc5: %.5f" %
(eop, batch_id, avg_loss, acc_top1, acc_top5))

print("Distributed training completed")


if __name__ == '__main__':
import os
nnodes = os.getenv('PADDLE_NNODES')
cn = os.getenv('PADDLE_LOCAL_SIZE')
print(f"Prepare distributed training with {nnodes} nodes {cn} cards")
train_resnet()