-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GraphBolt] Implement labor dependent minibatching - python side. #7208
Changes from all commits
ce907f3
c48fd33
0f820b8
1a176e6
7a0ec6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,12 +146,17 @@ def __init__(self, datapipe, sample_per_layer_obj): | |
|
||
def _sample_per_layer_from_fetched_subgraph(self, minibatch): | ||
subgraph = minibatch.sampled_subgraphs[0] | ||
|
||
kwargs = { | ||
key[1:]: getattr(minibatch, key) | ||
for key in ["_random_seed", "_seed2_contribution"] | ||
if hasattr(minibatch, key) | ||
} | ||
sampled_subgraph = getattr(subgraph, self.sampler_name)( | ||
minibatch._subgraph_seed_nodes, | ||
self.fanout, | ||
self.replace, | ||
self.prob_name, | ||
**kwargs, | ||
) | ||
delattr(minibatch, "_subgraph_seed_nodes") | ||
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes | ||
|
@@ -172,8 +177,17 @@ def __init__(self, datapipe, sampler, fanout, replace, prob_name): | |
self.prob_name = prob_name | ||
|
||
def _sample_per_layer(self, minibatch): | ||
kwargs = { | ||
key[1:]: getattr(minibatch, key) | ||
for key in ["_random_seed", "_seed2_contribution"] | ||
if hasattr(minibatch, key) | ||
} | ||
subgraph = self.sampler( | ||
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name | ||
minibatch._seed_nodes, | ||
self.fanout, | ||
self.replace, | ||
self.prob_name, | ||
**kwargs, | ||
) | ||
minibatch.sampled_subgraphs.insert(0, subgraph) | ||
return minibatch | ||
|
@@ -244,10 +258,56 @@ def __init__( | |
prob_name, | ||
deduplicate, | ||
sampler, | ||
layer_dependency=None, | ||
batch_dependency=None, | ||
): | ||
if sampler.__name__ == "sample_layer_neighbors": | ||
self._init_seed(batch_dependency) | ||
super().__init__( | ||
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler | ||
datapipe, | ||
graph, | ||
fanouts, | ||
replace, | ||
prob_name, | ||
deduplicate, | ||
sampler, | ||
layer_dependency, | ||
) | ||
|
||
def _init_seed(self, batch_dependency): | ||
self.rng = torch.random.manual_seed( | ||
torch.randint(0, int(1e18), size=tuple()) | ||
) | ||
self.cnt = [-1, int(batch_dependency)] | ||
self.random_seed = torch.empty( | ||
2 if self.cnt[1] > 1 else 1, dtype=torch.int64 | ||
) | ||
self.random_seed.random_(generator=self.rng) | ||
|
||
def _set_seed(self, minibatch): | ||
self.cnt[0] += 1 | ||
if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0: | ||
self.random_seed[0] = self.random_seed[-1] | ||
self.random_seed[-1:].random_(generator=self.rng) | ||
minibatch._random_seed = self.random_seed.clone() | ||
minibatch._seed2_contribution = ( | ||
0.0 | ||
if self.cnt[1] <= 1 | ||
else (self.cnt[0] % self.cnt[1]) / self.cnt[1] | ||
) | ||
minibatch._iter = self.cnt[0] | ||
return minibatch | ||
|
||
@staticmethod | ||
def _increment_seed(minibatch): | ||
minibatch._random_seed = 1 + minibatch._random_seed | ||
return minibatch | ||
|
||
@staticmethod | ||
def _delattr_dependency(minibatch): | ||
delattr(minibatch, "_random_seed") | ||
delattr(minibatch, "_seed2_contribution") | ||
return minibatch | ||
|
||
@staticmethod | ||
def _prepare(node_type_to_id, minibatch): | ||
|
@@ -277,11 +337,22 @@ def _set_input_nodes(minibatch): | |
|
||
# pylint: disable=arguments-differ | ||
def sampling_stages( | ||
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler | ||
self, | ||
datapipe, | ||
graph, | ||
fanouts, | ||
replace, | ||
prob_name, | ||
deduplicate, | ||
sampler, | ||
layer_dependency, | ||
): | ||
datapipe = datapipe.transform( | ||
partial(self._prepare, graph.node_type_to_id) | ||
) | ||
is_labor = sampler.__name__ == "sample_layer_neighbors" | ||
if is_labor: | ||
datapipe = datapipe.transform(self._set_seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, is it possible to move this sample_layer_neighbors specific code to its own class, instead of hack in the parent class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you propose we do that without replicating the whole implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it makes sense to replicating some code if needed, it is usually error prone to have parent class's behavior depend on child class's type, which is anti-"oop" pattern. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that. But I feel anti-oop pattern is better than replicating the code. I plan to base Temporal samplers into the NeighborSamplerImpl as well, so that the GPU sampling optimizations can be enabled for the temporal case as well. |
||
for fanout in reversed(fanouts): | ||
# Convert fanout to tensor. | ||
if not isinstance(fanout, torch.Tensor): | ||
|
@@ -290,7 +361,10 @@ def sampling_stages( | |
sampler, fanout, replace, prob_name | ||
) | ||
datapipe = datapipe.compact_per_layer(deduplicate) | ||
|
||
if is_labor and not layer_dependency: | ||
datapipe = datapipe.transform(self._increment_seed) | ||
if is_labor: | ||
datapipe = datapipe.transform(self._delattr_dependency) | ||
return datapipe.transform(self._set_input_nodes) | ||
|
||
|
||
|
@@ -504,6 +578,8 @@ def __init__( | |
replace=False, | ||
prob_name=None, | ||
deduplicate=True, | ||
layer_dependency=False, | ||
batch_dependency=1, | ||
): | ||
super().__init__( | ||
datapipe, | ||
|
@@ -513,4 +589,6 @@ def __init__( | |
prob_name, | ||
deduplicate, | ||
graph.sample_layer_neighbors, | ||
layer_dependency, | ||
batch_dependency, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking the sampler name here, why not move the self._init_seed(batch_dependency) to init of LayerNeighborSampler?