Skip to content

Commit

Permalink
[GraphBolt][CUDA] Enable recent optimizations in the examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 15, 2024
1 parent 0d68130 commit ba86576
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion examples/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
graph, args.fanout if is_train else [-1]
graph,
args.fanout if is_train else [-1],
overlap_fetch=args.storage_device == "pinned",
asynchronous=args.storage_device != "cpu",
)

############################################################################
Expand Down
1 change: 1 addition & 0 deletions examples/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def create_dataloader(
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.storage_device == "pinned",
asynchronous=args.storage_device != "cpu",
)

############################################################################
Expand Down
1 change: 1 addition & 0 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def create_dataloader(
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.overlap_graph_fetch,
asynchronous=args.graph_device != "cpu",
**kwargs,
)
# Copy the data to the specified device.
Expand Down
1 change: 1 addition & 0 deletions examples/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def create_dataloader(
overlap_fetch=args.overlap_graph_fetch,
num_gpu_cached_edges=args.num_gpu_cached_edges,
gpu_cache_threshold=args.gpu_graph_caching_threshold,
asynchronous=args.graph_device != "cpu",
)
# Copy the data to the specified device.
if args.feature_device != "cpu" and need_copy:
Expand Down
12 changes: 9 additions & 3 deletions examples/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def create_dataloader(
# `fanouts`:
# The number of neighbors to sample for each node in each layer.
datapipe = datapipe.sample_neighbor(
graph, fanouts=fanouts, overlap_fetch=args.overlap_graph_fetch
graph,
fanouts=fanouts,
overlap_fetch=args.overlap_graph_fetch,
asynchronous=args.asynchronous,
)

# Fetch the features for each node in the mini-batch.
Expand Down Expand Up @@ -571,10 +574,13 @@ def main(args):

# Move the dataset to the pinned memory to enable GPU access.
args.overlap_graph_fetch = False
args.asynchronous = False
if device == torch.device("cuda"):
g.pin_memory_()
features.pin_memory_()
g = g.pin_memory_()
features = features.pin_memory_()
# Enable optimizations for sampling on the GPU.
args.overlap_graph_fetch = True
args.asynchronous = True

feat_size = features.size("node", "paper", "feat")[0]

Expand Down

0 comments on commit ba86576

Please sign in to comment.