Skip to content

Commit

Permalink
dist partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 8, 2024
1 parent 7f18105 commit 7786f66
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
27 changes: 25 additions & 2 deletions tools/dispatch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def submit_jobs(args) -> str:
argslist += "--log-level {} ".format(args.log_level)
argslist += "--save-orig-nids " if args.save_orig_nids else ""
argslist += "--save-orig-eids " if args.save_orig_eids else ""
argslist += "--use-graphbolt" if args.use_graphbolt else ""
argslist += "--use-graphbolt " if args.use_graphbolt else ""
argslist += "--store-inner-edge " if args.store_inner_edge else ""
argslist += "--store-inner-node " if args.store_inner_node else ""
argslist += "--store-eids " if args.store_eids else ""
argslist += (
f"--graph-formats {args.graph_formats} " if args.graph_formats else ""
)
Expand Down Expand Up @@ -164,6 +167,26 @@ def main():
action="store_true",
help="Use GraphBolt for distributed train.",
)
parser.add_argument(
"--store-inner-node",
action="store_true",
default=False,
help="Store inner nodes.",
)

parser.add_argument(
"--store-inner-edge",
action="store_true",
default=False,
help="Store inner edges.",
)

parser.add_argument(
"--store-eids",
action="store_true",
default=False,
help="Store edge IDs.",
)
parser.add_argument(
"--graph-formats",
type=str,
Expand All @@ -175,7 +198,7 @@ def main():
)

args, _ = parser.parse_known_args()

assert args.store_inner_edge is True
fmt = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(
format=fmt,
Expand Down
13 changes: 8 additions & 5 deletions tools/distpartitioning/data_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from convert_partition import create_dgl_object, create_metadata_json
from convert_partition import create_graph_object, create_metadata_json
from dataset_utils import get_dataset
from dist_lookup import DistLookupService
from globalids import (
Expand Down Expand Up @@ -1323,7 +1323,7 @@ def prepare_local_data(src_data, local_part_id):
etypes_map,
orig_nids,
orig_eids,
) = create_dgl_object(
) = create_graph_object(
schema_map,
rank + local_part_id * world_size,
local_node_data,
Expand All @@ -1334,9 +1334,12 @@ def prepare_local_data(src_data, local_part_id):
schema_map[constants.STR_NUM_NODES_PER_TYPE],
),
edge_typecounts,
params.save_orig_nids,
params.save_orig_eids,
params.use_graphbolt,
return_orig_nids=params.save_orig_nids,
return_orig_eids=params.save_orig_eids,
use_graphbolt=params.use_graphbolt,

Check warning on line 1339 in tools/distpartitioning/data_shuffle.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
store_inner_node=params.store_inner_node,
store_inner_edge=params.store_inner_edge,
store_eids=params.store_eids,
)
sort_etypes = len(etypes_map) > 1
local_node_features = prepare_local_data(
Expand Down

0 comments on commit 7786f66

Please sign in to comment.