Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DistDGL] sort node/edge_map to obtain expected id ranges #5872

Merged
merged 3 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _dump_part_config(part_config, part_metadata):
"""Format and dump part config."""
part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str)
with open(part_config, "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4)
json.dump(part_metadata, outfile, sort_keys=False, indent=4)


def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
Expand Down Expand Up @@ -420,6 +420,22 @@ def load_partition_book(part_config, part_id):

node_map = _get_part_ranges(node_map)
edge_map = _get_part_ranges(edge_map)

# Sort the node/edge maps by the node/edge type ID.
node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]]))
edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]]))

def _assert_is_sorted(id_map):
ids = [[]] * num_parts
for values in id_map.values():
for i, v in enumerate(values):
ids[i].append(v)
ids = np.array(ids).flatten()
assert np.all(ids[:-1] <= ids[1:]), "The node/edge map is not sorted."

_assert_is_sorted(node_map)
_assert_is_sorted(edge_map)

return (
RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
Expand Down
75 changes: 75 additions & 0 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,78 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
assert g.get_etype_id(edge_type) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)


def test_not_sorted_node_edge_map():
# Partition configure file which includes not sorted node/edge map.
part_config_str = """
{
"edge_map": {
"item:likes-rev:user": [
[
0,
10000
]
],
"user:follows-rev:user": [
[
20000,
30000
]
],
"user:follows:user": [
[
10000,
20000
]
],
"user:likes:item": [
[
30000,
40000
]
]
},
"etypes": {
"item:likes-rev:user": 0,
"user:follows-rev:user": 2,
"user:follows:user": 1,
"user:likes:item": 3
},
"graph_name": "test_graph",
"halo_hops": 1,
"node_map": {
"user": [
[
10000,
30000
]
],
"item": [
[
0,
10000
]
]
},
"ntypes": {
"user": 1,
"item": 0
},
"num_edges": 40000,
"num_nodes": 30000,
"num_parts": 1,
"part-0": {
"edge_feats": "part0/edge_feat.dgl",
"node_feats": "part0/node_feat.dgl",
"part_graph": "part0/graph.dgl"
},
"part_method": "metis"
}
"""
with tempfile.TemporaryDirectory() as test_dir:
part_config = os.path.join(test_dir, "test_graph.json")
print(part_config)
with open(part_config, "w") as file:
file.write(part_config_str)
load_partition_book(part_config, 0)
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved