Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Jul 22, 2024
1 parent 9dc890b commit 4d61a06
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import concurrent
import concurrent.futures
import copy
from functools import partial
import json
import logging
import os
import time
from functools import partial

import numpy as np

Expand Down Expand Up @@ -1257,9 +1257,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue

Check warning on line 1259 in python/dgl/distributed/partition.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
else:
for ntype in g.ntypes:
if len(g.ntypes) > 1:
Expand Down Expand Up @@ -1294,9 +1294,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
# delete `orig_id` from ndata/edata
del part.ndata["orig_id"]
del part.edata["orig_id"]
Expand Down Expand Up @@ -1570,17 +1570,20 @@ def convert_partition(part_id, graph_formats):
)
torch.save(csc_graph, csc_graph_path)


return os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Update graph path.

# Iterate over partitions.
convert_with_format = partial(convert_partition, graph_formats=graph_formats)
with concurrent.futures.ProcessPoolExecutor(max_workers=min(num_parts, n_jobs)) as executor:
for part_id, part_path in enumerate(executor.map(convert_with_format, range(num_parts))):
new_part_meta[f"part-{part_id}"][
"part_graph_graphbolt"
] = part_path
convert_with_format = partial(
convert_partition, graph_formats=graph_formats
)
with concurrent.futures.ProcessPoolExecutor(
max_workers=min(num_parts, n_jobs)
) as executor:
for part_id, part_path in enumerate(
executor.map(convert_with_format, range(num_parts))
):
new_part_meta[f"part-{part_id}"]["part_graph_graphbolt"] = part_path

# Save dtype info into partition config.
# [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more
Expand Down

0 comments on commit 4d61a06

Please sign in to comment.