Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Specify the graph format for distributed training #2948

Merged
merged 7 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
classicsong marked this conversation as resolved.
Show resolved Hide resolved
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
os.environ.get('DGL_CONF_PATH'),
graph_format=formats)
serv.start()
sys.exit()
else:
Expand Down
14 changes: 10 additions & 4 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __getstate__(self):
def __setstate__(self, state):
self._graph_name = state

def _copy_graph_to_shared_mem(g, graph_name):
new_g = g.shared_memory(graph_name, formats='csc')
def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'],
Expand Down Expand Up @@ -289,9 +289,12 @@ class DistGraphServer(KVServer):
The path of the config file generated by the partition tool.
disable_shared_mem : bool
Disable shared memory.
graph_format : str or list of str
The graph formats.
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False):
num_clients, part_config, disable_shared_mem=False,
graph_format='csc'):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
Expand All @@ -307,8 +310,11 @@ def __init__(self, server_id, ip_config, num_servers,
self.client_g, node_feats, edge_feats, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id)
print('load ' + graph_name)
# Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)

if not disable_shared_mem:
self.gpb.shared_memory(graph_name)
Expand Down
8 changes: 5 additions & 3 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dgl.distributed import DistGraphServer, DistGraph


def start_server(rank, tmpdir, disable_shared_mem, graph_name):
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format)
g.start()


Expand Down Expand Up @@ -102,7 +103,8 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
'test_find_edges', ['csr', 'coo']))
p.start()
time.sleep(1)
pserver_list.append(p)
Expand Down
5 changes: 5 additions & 0 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)

torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(args.num_trainers)
Expand Down Expand Up @@ -248,6 +249,10 @@ def main():
help='The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \
the same machine. By default, it is 1.')
parser.add_argument('--graph_format', type=str, default='csc',
help='The format of the graph structure of each partition. \
The allowed formats are csr, csc and coo. A user can specify multiple
formats, separated by ",". For example, the graph format is "csr,csc".')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \
Expand Down