Skip to content

Commit

Permalink
[GraphBolt] Add temporal labor sampling to graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 3, 2024
1 parent 713ffb5 commit 6b29acb
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 17 deletions.
4 changes: 3 additions & 1 deletion graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
torch::optional<torch::Tensor> input_nodes_pre_time_window,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const;
torch::optional<std::string> edge_timestamp_attr_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

/**
* @brief Copy the graph to shared memory.
Expand Down
56 changes: 42 additions & 14 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,9 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
torch::optional<torch::Tensor> input_nodes_pre_time_window,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
torch::optional<std::string> edge_timestamp_attr_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
torch::optional<std::vector<int64_t>> seed_offsets = torch::nullopt;
// 1. Get probs_or_mask.
if (probs_or_mask.has_value()) {
Expand All @@ -928,19 +930,45 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);
// 4. Call SampleNeighborsImpl
if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, input_nodes_pre_time_window, probs_or_mask,
node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, input_nodes_pre_time_window,
probs_or_mask, node_timestamp, edge_timestamp, args));
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_,
input_nodes_pre_time_window, probs_or_mask, node_timestamp,
edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, indices_, fanouts, replace,
indptr_.options(), type_per_edge_, input_nodes_pre_time_window,
probs_or_mask, node_timestamp, edge_timestamp, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
return SamplerArgs<SamplerType::LABOR>{
indices_, random_seed.value(), NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_,
input_nodes_pre_time_window, probs_or_mask, node_timestamp,
edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, indices_, fanouts, replace,
indptr_.options(), type_per_edge_, input_nodes_pre_time_window,
probs_or_mask, node_timestamp, edge_timestamp, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
Expand Down
133 changes: 133 additions & 0 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,139 @@ def temporal_sample_neighbors(
probs_or_mask,
node_timestamp_attr_name,
edge_timestamp_attr_name,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)

def temporal_sample_layer_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
input_nodes_pre_time_window: Optional[
Union[torch.Tensor, Dict[str, torch.Tensor]]
] = None,
probs_name: Optional[str] = None,
node_timestamp_attr_name: Optional[str] = None,
edge_timestamp_attr_name: Optional[str] = None,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
) -> torch.ScriptObject:
"""Temporally Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
<https://proceedings.neurips.cc/paper_files/paper/2023/file/51f9036d5e7ae822da8f6d4adda1fb39-Paper-Conference.pdf>`__
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbor or edge of an input node must have a timestamp
that is smaller than that of the input node.
Parameters
----------
nodes: torch.Tensor
IDs of the given seed nodes.
input_nodes_timestamp: torch.Tensor
Timestamps of the given seed nodes.
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors (with non-zero probability,
if weighted) will be sampled once regardless of replacement. It
is equivalent to selecting all neighbors with non-zero
probability when the fanout is >= the number of neighbors (and
replace is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
input_nodes_pre_time_window: torch.Tensor
The time window of the nodes represents a period of time before
`input_nodes_timestamp`. If provided, only neighbors and related
edges whose timestamps fall within `[input_nodes_timestamp -
input_nodes_pre_time_window, input_nodes_timestamp]` will be
filtered.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
node_timestamp_attr_name: str, optional
An optional string specifying the name of an node attribute.
edge_timestamp_attr_name: str, optional
An optional string specifying the name of an edge attribute.
random_seed: torch.Tensor, optional
An int64 tensor with one or two elements.
The passed random_seed makes it so that for any seed node ``s`` and
its neighbor ``t``, the rolled random variate ``r_t`` is the same
for any call to this function with the same random seed. When
sampling as part of the same batch, one would want identical seeds
so that LABOR can globally sample. One example is that for
heterogenous graphs, there is a single random seed passed for each
edge type. This will sample much fewer nodes compared to having
unique random seeds for each edge type. If one called this function
individually for each edge type for a heterogenous graph with
different random seeds, then it would run LABOR locally for each
edge type, resulting into a larger number of nodes being sampled.
If this function is called without a ``random_seed``, we get the
random seed by getting a random number from GraphBolt. Use this
argument with identical random_seed if multiple calls to this
function are used to sample as part of a single batch.
If given two numbers, then the ``seed2_contribution`` argument
determines the interpolation between the two random seeds.
seed2_contribution: float, optional
A float value between [0, 1) that determines the contribution of the
second random seed, ``random_seed[-1]``, to generate the random
variates.
Returns
-------
SampledSubgraphImpl
The sampled subgraph.
"""
if isinstance(nodes, dict):
(
nodes,
input_nodes_timestamp,
input_nodes_pre_time_window,
) = self._convert_to_homogeneous_nodes(
nodes, input_nodes_timestamp, input_nodes_pre_time_window
)

# Ensure nodes is 1-D tensor.
probs_or_mask = self.edge_attributes[probs_name] if probs_name else None
self._check_sampler_arguments(nodes, fanouts, probs_or_mask)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors(
nodes,
input_nodes_timestamp,
fanouts.tolist(),
replace,
True,
has_original_eids,
input_nodes_pre_time_window,
probs_or_mask,
node_timestamp_attr_name,
edge_timestamp_attr_name,
random_seed,
seed2_contribution,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,16 @@ def test_in_subgraph_hetero():
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_homo(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
indptr_dtype,
indices_dtype,
replace,
labor,
use_node_timestamp,
use_edge_timestamp,
):
"""Original graph in COO:
1 0 1 0 1
Expand All @@ -853,7 +859,11 @@ def test_temporal_sample_neighbors_homo(

# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
sampler = graph.temporal_sample_neighbors
sampler = (
graph.temporal_sample_layer_neighbors
if labor
else graph.temporal_sample_neighbors
)

seed_list = [1, 3, 4]
seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)
Expand Down

0 comments on commit 6b29acb

Please sign in to comment.