Skip to content

Commit

Permalink
[DistDGL] sort node/edge_map to obtain expected id ranges (#5872)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed Jun 15, 2023
1 parent d9da420 commit 0ea36f3
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 4 deletions.
32 changes: 28 additions & 4 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _dump_part_config(part_config, part_metadata):
'''Format and dump part config.
'''
part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str)
with open(part_config, 'w') as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4)
with open(part_config, "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=False, indent=4)

def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
'''Preprocess partitions before saving:
Expand Down Expand Up @@ -335,8 +335,32 @@ def load_partition_book(part_config, part_id):

node_map = _get_part_ranges(node_map)
edge_map = _get_part_ranges(edge_map)
return RangePartitionBook(part_id, num_parts, node_map, edge_map, ntypes, etypes), \
part_metadata['graph_name'], ntypes, etypes

# Sort the node/edge maps by the node/edge type ID.
node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]]))
edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]]))

def _assert_is_sorted(id_map):
id_ranges = np.array(list(id_map.values()))
ids = []
for i in range(num_parts):
ids.append(id_ranges[:, i, :])
ids = np.array(ids).flatten()
assert np.all(
ids[:-1] <= ids[1:]
), f"The node/edge map is not sorted: {ids}"

_assert_is_sorted(node_map)
_assert_is_sorted(edge_map)

return (
RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
),
part_metadata["graph_name"],
ntypes,
etypes,
)

def _get_orig_ids(g, sim_g, orig_nids, orig_eids):
'''Convert/construct the original node IDs and edge IDs.
Expand Down
112 changes: 112 additions & 0 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,3 +683,115 @@ def test_UnknownPartitionBook():
except Exception as e:
if not isinstance(e, TypeError):
raise e


def test_not_sorted_node_edge_map():
# Partition configure file which includes not sorted node/edge map.
part_config_str = """
{
"edge_map": {
"item:likes-rev:user": [
[
0,
100
],
[
1000,
1500
]
],
"user:follows-rev:user": [
[
300,
600
],
[
2100,
2800
]
],
"user:follows:user": [
[
100,
300
],
[
1500,
2100
]
],
"user:likes:item": [
[
600,
1000
],
[
2800,
3600
]
]
},
"etypes": {
"item:likes-rev:user": 0,
"user:follows-rev:user": 2,
"user:follows:user": 1,
"user:likes:item": 3
},
"graph_name": "test_graph",
"halo_hops": 1,
"node_map": {
"user": [
[
100,
300
],
[
600,
1000
]
],
"item": [
[
0,
100
],
[
300,
600
]
]
},
"ntypes": {
"user": 1,
"item": 0
},
"num_edges": 3600,
"num_nodes": 1000,
"num_parts": 2,
"part-0": {
"edge_feats": "part0/edge_feat.dgl",
"node_feats": "part0/node_feat.dgl",
"part_graph": "part0/graph.dgl"
},
"part-1": {
"edge_feats": "part1/edge_feat.dgl",
"node_feats": "part1/node_feat.dgl",
"part_graph": "part1/graph.dgl"
},
"part_method": "metis"
}
"""
with tempfile.TemporaryDirectory() as test_dir:
part_config = os.path.join(test_dir, "test_graph.json")
print(part_config)
with open(part_config, "w") as file:
file.write(part_config_str)
# Part 0.
gpb, _, _, _ = load_partition_book(part_config, 0)
assert gpb.local_ntype_offset == [0, 100, 300]
assert gpb.local_etype_offset == [0, 100, 300, 600, 1000]
# Patr 1.
gpb, _, _, _ = load_partition_book(part_config, 1)
assert gpb.local_ntype_offset == [0, 300, 700]
assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600]

0 comments on commit 0ea36f3

Please sign in to comment.