Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_pipelined_sampling_optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Feb 4, 2024
2 parents 6b21742 + 346197c commit 0a21186
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 46 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.graphbolt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str
isin
seed
index_select
expand_indptr
add_reverse_edges
exclude_seed_edges
Expand Down
36 changes: 14 additions & 22 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,10 @@ def forward(self, blocks, x):
hidden_x = F.relu(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers):
Expand All @@ -99,17 +95,17 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm.tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand Down Expand Up @@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph, args.fanout if is_train else [-1]
)

############################################################################
# [Input]:
Expand Down Expand Up @@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in evaluation/inference because features
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
# subgraphs.
############################################################################
if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Input]:
Expand Down Expand Up @@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
model.eval()
evaluator = Evaluator(name="ogbl-citation2")

# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
args.fanout = [-1]
dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False
)

# Compute node embeddings for the entire graph.
node_emb = model.inference(graph, features, dataloader, args.device)
node_emb = model.inference(graph, features, dataloader, args.storage_device)
results = []

# Loop over both validation and test sets.
Expand Down
29 changes: 11 additions & 18 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ def create_dataloader(
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in inference because features are updated
# as a whole during it, thus storing features in minibatch is unnecessary.
# subgraphs.
############################################################################
if job != "infer":
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Step-5]:
Expand Down Expand Up @@ -194,14 +192,10 @@ def forward(self, blocks, x):
hidden_x = self.dropout(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
Expand All @@ -213,19 +207,18 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)

for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand All @@ -245,7 +238,7 @@ def layerwise_infer(
num_workers=args.num_workers,
job="infer",
)
pred = model.inference(graph, features, dataloader, args.device)
pred = model.inference(graph, features, dataloader, args.storage_device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)

Expand Down
28 changes: 28 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"FutureWaiter",
"EndMarker",
"isin",
"index_select",
"expand_indptr",
"CSCFormatBase",
"seed",
Expand Down Expand Up @@ -107,6 +108,33 @@ def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
)


def index_select(tensor, index):
"""Returns a new tensor which indexes the input tensor along dimension dim
using the entries in index.
The returned tensor has the same number of dimensions as the original tensor
(tensor). The first dimension has the same size as the length of index;
other dimensions have the same size as in the original tensor.
When tensor is a pinned tensor and index.is_cuda is True, the operation runs
on the CUDA device and the returned tensor will also be on CUDA.
Parameters
----------
tensor : torch.Tensor
The input tensor.
index : torch.Tensor
The 1-D tensor containing the indices to index.
Returns
-------
torch.Tensor
The indexed input tensor, equivalent to tensor[index].
"""
assert index.dim() == 1, "Index should be 1D tensor."
return torch.ops.graphbolt.index_select(tensor, index)


def etype_tuple_to_str(c_etype):
"""Convert canonical etype from tuple to string.
Expand Down
8 changes: 2 additions & 6 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch

from ..base import index_select
from ..feature_store import Feature
from .basic_feature_store import BasicFeatureStore
from .ondisk_metadata import OnDiskFeatureData
Expand Down Expand Up @@ -117,7 +118,7 @@ def read(self, ids: torch.Tensor = None):
if self._tensor.is_pinned():
return self._tensor.cuda()
return self._tensor
return torch.ops.graphbolt.index_select(self._tensor, ids)
return index_select(self._tensor, ids)

def size(self):
"""Get the size of the feature.
Expand All @@ -144,11 +145,6 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
updated.
"""
if ids is None:
assert self.size() == value.size()[1:], (
f"ids is None, so the entire feature will be updated. "
f"But the size of the feature is {self.size()}, "
f"while the size of the value is {value.size()[1:]}."
)
self._tensor = value
else:
assert ids.shape[0] == value.shape[0], (
Expand Down
28 changes: 28 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,34 @@ def test_isin_non_1D_dim():
gb.isin(elements, test_elements)


@pytest.mark.parametrize(
"dtype",
[
torch.bool,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
],
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("pinned", [False, True])
def test_index_select(dtype, idtype, pinned):
if F._default_context_str != "gpu" and pinned:
pytest.skip("Pinned tests are available only on GPU.")
tensor = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype)
tensor = tensor.pin_memory() if pinned else tensor.to(F.ctx())
index = torch.tensor([0, 2], dtype=idtype, device=F.ctx())
gb_result = gb.index_select(tensor, index)
torch_result = tensor.to(F.ctx())[index.long()]
assert torch.equal(torch_result, gb_result)


def torch_expand_indptr(indptr, dtype, nodes=None):
if nodes is None:
nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)
Expand Down

0 comments on commit 0a21186

Please sign in to comment.