Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Implement labor dependent minibatching - python side. #7208

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
Expand All @@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt);
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> random_seed = torch::nullopt,
float seed2_contribution = .0f);

/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
Expand Down
7 changes: 6 additions & 1 deletion graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed,
* [0, 1) for layer=True.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*/
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

/**
* @brief Sample neighboring edges of the given nodes with a temporal
Expand Down
14 changes: 11 additions & 3 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask) {
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed_tensor,
float seed2_contribution) {
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
Expand Down Expand Up @@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto coo_rows = ExpandIndptrImpl(
sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
num_edges = coo_rows.size(0);
const continuous_seed random_seed(RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()));
const continuous_seed random_seed = [&] {
if (random_seed_tensor.has_value()) {
return continuous_seed(random_seed_tensor.value(), seed2_contribution);
} else {
return continuous_seed{RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())};
}
}();
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
Expand Down
23 changes: 18 additions & 5 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const {
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name);

// If nodes does not have a value, then we expect all arguments to be resident
Expand All @@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, nodes, fanouts, replace, layer, return_eids,
type_per_edge_, probs_or_mask);
type_per_edge_, probs_or_mask, random_seed, seed2_contribution);
});
}
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU.");
Expand All @@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}

if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
Expand Down
8 changes: 7 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,11 @@ def _sample_neighbors(
nodes,
fanouts.tolist(),
replace,
False,
False, # is_labor
return_eids,
probs_name,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
)

def sample_layer_neighbors(
Expand All @@ -746,6 +748,8 @@ def sample_layer_neighbors(
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
Expand Down Expand Up @@ -833,6 +837,8 @@ def sample_layer_neighbors(
True,
has_original_eids,
probs_name,
random_seed,
seed2_contribution,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)

Expand Down
88 changes: 83 additions & 5 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,17 @@ def __init__(self, datapipe, sample_per_layer_obj):

def _sample_per_layer_from_fetched_subgraph(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]

kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
)
delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
Expand All @@ -172,8 +177,17 @@ def __init__(self, datapipe, sampler, fanout, replace, prob_name):
self.prob_name = prob_name

def _sample_per_layer(self, minibatch):
kwargs = {
key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"]
if hasattr(minibatch, key)
}
subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name
minibatch._seed_nodes,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch
Expand Down Expand Up @@ -244,10 +258,56 @@ def __init__(
prob_name,
deduplicate,
sampler,
layer_dependency=None,
batch_dependency=None,
):
if sampler.__name__ == "sample_layer_neighbors":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking the sampler name here, why not move the self._init_seed(batch_dependency) to init of LayerNeighborSampler?

self._init_seed(batch_dependency)
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
)

def _init_seed(self, batch_dependency):
self.rng = torch.random.manual_seed(
torch.randint(0, int(1e18), size=tuple())
)
self.cnt = [-1, int(batch_dependency)]
self.random_seed = torch.empty(
2 if self.cnt[1] > 1 else 1, dtype=torch.int64
)
self.random_seed.random_(generator=self.rng)

def _set_seed(self, minibatch):
self.cnt[0] += 1
if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0:
self.random_seed[0] = self.random_seed[-1]
self.random_seed[-1:].random_(generator=self.rng)
minibatch._random_seed = self.random_seed.clone()
minibatch._seed2_contribution = (
0.0
if self.cnt[1] <= 1
else (self.cnt[0] % self.cnt[1]) / self.cnt[1]
)
minibatch._iter = self.cnt[0]
return minibatch

@staticmethod
def _increment_seed(minibatch):
minibatch._random_seed = 1 + minibatch._random_seed
return minibatch

@staticmethod
def _delattr_dependency(minibatch):
delattr(minibatch, "_random_seed")
delattr(minibatch, "_seed2_contribution")
return minibatch

@staticmethod
def _prepare(node_type_to_id, minibatch):
Expand Down Expand Up @@ -277,11 +337,22 @@ def _set_input_nodes(minibatch):

# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
self,
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
layer_dependency,
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
is_labor = sampler.__name__ == "sample_layer_neighbors"
if is_labor:
datapipe = datapipe.transform(self._set_seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, is it possible to move this sample_layer_neighbors specific code to its own class, instead of hack in the parent class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you propose we do that without replicating the whole implementation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it makes sense to replicating some code if needed, it is usually error prone to have parent class's behavior depend on child class's type, which is anti-"oop" pattern.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that. But I feel anti-oop pattern is better than replicating the code. I plan to base Temporal samplers into the NeighborSamplerImpl as well, so that the GPU sampling optimizations can be enabled for the temporal case as well.

for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
Expand All @@ -290,7 +361,10 @@ def sampling_stages(
sampler, fanout, replace, prob_name
)
datapipe = datapipe.compact_per_layer(deduplicate)

if is_labor and not layer_dependency:
datapipe = datapipe.transform(self._increment_seed)
if is_labor:
datapipe = datapipe.transform(self._delattr_dependency)
return datapipe.transform(self._set_input_nodes)


Expand Down Expand Up @@ -504,6 +578,8 @@ def __init__(
replace=False,
prob_name=None,
deduplicate=True,
layer_dependency=False,
batch_dependency=1,
):
super().__init__(
datapipe,
Expand All @@ -513,4 +589,6 @@ def __init__(
prob_name,
deduplicate,
graph.sample_layer_neighbors,
layer_dependency,
batch_dependency,
)
56 changes: 56 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b)


@pytest.mark.parametrize("layer_dependency", [False, True])
@pytest.mark.parametrize("overlap_graph_fetch", [False, True])
def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
num_edges = 200
csc_indptr = torch.cat(
(
torch.zeros(1, dtype=torch.int64),
torch.ones(num_edges + 1, dtype=torch.int64) * num_edges,
)
)
indices = torch.arange(1, num_edges + 1)
graph = gb.fused_csc_sampling_graph(
csc_indptr.int(),
indices.int(),
).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100
itemset = gb.ItemSet(
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor(
graph,
fanouts,
layer_dependency=layer_dependency,
batch_dependency=batch_dependency,
)
dataloader = gb.DataLoader(
datapipe, overlap_graph_fetch=overlap_graph_fetch
)
res = list(dataloader)
assert len(res) == batch_dependency + 1
if layer_dependency:
assert torch.equal(
res[0].input_nodes,
res[0].sampled_subgraphs[1].original_row_node_ids,
)
else:
assert res[0].input_nodes.size(0) > res[0].sampled_subgraphs[
1
].original_row_node_ids.size(0)
delta = 0
for i in range(batch_dependency):
res_current = (
res[i].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
res_next = (
res[i + 1].sampled_subgraphs[-1].original_row_node_ids.tolist()
)
intersect_len = len(set(res_current).intersection(set(res_next)))
assert intersect_len >= fanouts[-1]
delta += 1 + fanouts[-1] - intersect_len
assert delta >= fanouts[-1]
Loading