Skip to content

Commit

Permalink
[GraphBolt] Implement labor dependent minibatching - python side. (#7208
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mfbalin committed Mar 13, 2024
1 parent 93990a9 commit a272efe
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 16 deletions.
7 changes: 6 additions & 1 deletion graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
Expand All @@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt);
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> random_seed = torch::nullopt,
float seed2_contribution = .0f);

/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
Expand Down
7 changes: 6 additions & 1 deletion graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed,
* [0, 1) for layer=True.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*/
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

/**
* @brief Sample neighboring edges of the given nodes with a temporal
Expand Down
14 changes: 11 additions & 3 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask) {
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed_tensor,
float seed2_contribution) {
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
Expand Down Expand Up @@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto coo_rows = ExpandIndptrImpl(
sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
num_edges = coo_rows.size(0);
const continuous_seed random_seed(RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()));
const continuous_seed random_seed = [&] {
if (random_seed_tensor.has_value()) {
return continuous_seed(random_seed_tensor.value(), seed2_contribution);
} else {
return continuous_seed{RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())};
}
}();
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
Expand Down
23 changes: 18 additions & 5 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const {
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name);

// If nodes does not have a value, then we expect all arguments to be resident
Expand All @@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, nodes, fanouts, replace, layer, return_eids,
type_per_edge_, probs_or_mask);
type_per_edge_, probs_or_mask, random_seed, seed2_contribution);
});
}
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU.");
Expand All @@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}

if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
Expand Down
8 changes: 7 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,11 @@ def _sample_neighbors(
nodes,
fanouts.tolist(),
replace,
False,
False, # is_labor
return_eids,
probs_name,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
)

def sample_layer_neighbors(
Expand All @@ -746,6 +748,8 @@ def sample_layer_neighbors(
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
Expand Down Expand Up @@ -833,6 +837,8 @@ def sample_layer_neighbors(
True,
has_original_eids,
probs_name,
random_seed,
seed2_contribution,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)

Expand Down
88 changes: 83 additions & 5 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,17 @@ def __init__(self, datapipe, sample_per_layer_obj):

def _sample_per_layer_from_fetched_subgraph(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]

kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
)
delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
Expand All @@ -172,8 +177,17 @@ def __init__(self, datapipe, sampler, fanout, replace, prob_name):
self.prob_name = prob_name

def _sample_per_layer(self, minibatch):
kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name
minibatch._seed_nodes,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch
Expand Down Expand Up @@ -244,10 +258,56 @@ def __init__(
prob_name,
deduplicate,
sampler,
layer_dependency=None,
batch_dependency=None,
):
if sampler.__name__ == "sample_layer_neighbors":
self._init_seed(batch_dependency)
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
)

def _init_seed(self, batch_dependency):
self.rng = torch.random.manual_seed(
torch.randint(0, int(1e18), size=tuple())
)
self.cnt = [-1, int(batch_dependency)]
self.random_seed = torch.empty(
2 if self.cnt[1] > 1 else 1, dtype=torch.int64
)
self.random_seed.random_(generator=self.rng)

def _set_seed(self, minibatch):
self.cnt[0] += 1
if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0:
self.random_seed[0] = self.random_seed[-1]
self.random_seed[-1:].random_(generator=self.rng)
minibatch._random_seed = self.random_seed.clone()
minibatch._seed2_contribution = (
0.0
if self.cnt[1] <= 1
else (self.cnt[0] % self.cnt[1]) / self.cnt[1]
)
minibatch._iter = self.cnt[0]
return minibatch

@staticmethod
def _increment_seed(minibatch):
minibatch._random_seed = 1 + minibatch._random_seed
return minibatch

@staticmethod
def _delattr_dependency(minibatch):
delattr(minibatch, "_random_seed")
delattr(minibatch, "_seed2_contribution")
return minibatch

@staticmethod
def _prepare(node_type_to_id, minibatch):
Expand Down Expand Up @@ -277,11 +337,22 @@ def _set_input_nodes(minibatch):

# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
self,
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
is_labor = sampler.__name__ == "sample_layer_neighbors"
if is_labor:
datapipe = datapipe.transform(self._set_seed)
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
Expand All @@ -290,7 +361,10 @@ def sampling_stages(
sampler, fanout, replace, prob_name
)
datapipe = datapipe.compact_per_layer(deduplicate)

if is_labor and not layer_dependency:
datapipe = datapipe.transform(self._increment_seed)
if is_labor:
datapipe = datapipe.transform(self._delattr_dependency)
return datapipe.transform(self._set_input_nodes)


Expand Down Expand Up @@ -504,6 +578,8 @@ def __init__(
replace=False,
prob_name=None,
deduplicate=True,
layer_dependency=False,
batch_dependency=1,
):
super().__init__(
datapipe,
Expand All @@ -513,4 +589,6 @@ def __init__(
prob_name,
deduplicate,
graph.sample_layer_neighbors,
layer_dependency,
batch_dependency,
)
56 changes: 56 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b)


@pytest.mark.parametrize("layer_dependency", [False, True])
@pytest.mark.parametrize("overlap_graph_fetch", [False, True])
def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
num_edges = 200
csc_indptr = torch.cat(
(
torch.zeros(1, dtype=torch.int64),
torch.ones(num_edges + 1, dtype=torch.int64) * num_edges,
)
)
indices = torch.arange(1, num_edges + 1)
graph = gb.fused_csc_sampling_graph(
csc_indptr.int(),
indices.int(),
).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100
itemset = gb.ItemSet(
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor(
graph,
fanouts,
layer_dependency=layer_dependency,
batch_dependency=batch_dependency,
)
dataloader = gb.DataLoader(
datapipe, overlap_graph_fetch=overlap_graph_fetch
)
res = list(dataloader)
assert len(res) == batch_dependency + 1
if layer_dependency:
assert torch.equal(
res[0].input_nodes,
res[0].sampled_subgraphs[1].original_row_node_ids,
)
else:
assert res[0].input_nodes.size(0) > res[0].sampled_subgraphs[
1
].original_row_node_ids.size(0)
delta = 0
for i in range(batch_dependency):
res_current = (
res[i].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
res_next = (
res[i + 1].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
intersect_len = len(set(res_current).intersection(set(res_next)))
assert intersect_len >= fanouts[-1]
delta += 1 + fanouts[-1] - intersect_len
assert delta >= fanouts[-1]

0 comments on commit a272efe

Please sign in to comment.