Skip to content

Commit

Permalink
add pipelined sampling optimization datapipes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 27, 2024
1 parent 2473722 commit 1a5abae
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Neighbor subgraph samplers for GraphBolt."""

from functools import partial

import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe, Mapper
Expand All @@ -13,6 +15,66 @@
__all__ = ["NeighborSampler", "LayerNeighborSampler", "NeighborSampler2"]


@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
""""""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._fetch_per_layer)
self.graph = sample_per_layer_obj.sampler.__self__
self.prob_name = sample_per_layer_obj.prob_name

def _fetch_per_layer(self, minibatch):
index = minibatch.input_nodes
index_select = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)
indptr, indices = index_select(self.graph.indices, index, None)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select(
self.graph.type_per_edge, index, output_size
)
else:
type_per_edge = None
if self.graph.edge_attributes is not None:
probs_or_mask = self.graph.edge_attributes[self.probs_name]
_, probs_or_mask = index_select(probs_or_mask, index, output_size)
else:
probs_or_mask = None
edge_attributes = {self.probs_name: probs_or_mask}
subgraph = FusedCSCSamplingGraph(
indptr,
indices,
type_per_edge=type_per_edge,
edge_attributes=edge_attributes,
)

return subgraph, minibatch


@functional_datapipe("sample_per_layer_from_fetched_subgraph")
class SamplePerLayer(Mapper):
"""Sample neighbor edges from a graph for a single layer."""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph)
self.sampler_name = sample_per_layer_obj.sampler.__name__
self.fanout = sample_per_layer_obj.fanout
self.replace = sample_per_layer_obj.replace
self.prob_name = sample_per_layer_obj.prob_name

def _sample_per_layer_from_fetched_subgraph(self, subgraph_minibatch):
subgraph, minibatch = subgraph_minibatch

sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch.input_nodes, self.fanout, self.replace, self.prob_name
)
sampled_subgraph.original_column_node_ids = minibatch.input_nodes

return sampled_subgraph, minibatch


@functional_datapipe("sample_per_layer")
class SamplePerLayer(Mapper):
"""Sample neighbor edges from a graph for a single layer."""
Expand Down

0 comments on commit 1a5abae

Please sign in to comment.