Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 4, 2024
1 parent 650e504 commit 202388c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 21 deletions.
25 changes: 23 additions & 2 deletions python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = ["TemporalNeighborSampler", "TemporalLayerNeighborSampler"]


class TemporalNeighborSamplerImpl(SubgraphSampler):
"""Base class for TemporalNeighborSamplers."""

Expand Down Expand Up @@ -100,6 +101,7 @@ def sample_subgraphs(
seeds_timestamp = row_timestamps
return seeds, subgraphs


@functional_datapipe("temporal_sample_neighbor")
class TemporalNeighborSampler(TemporalNeighborSamplerImpl):
"""Temporally sample neighbor edges from a graph and return sampled
Expand Down Expand Up @@ -164,7 +166,17 @@ def __init__(
node_timestamp_attr_name=None,
edge_timestamp_attr_name=None,
):
super().__init__(datapipe, graph, fanouts, replace, prob_name, node_timestamp_attr_name, edge_timestamp_attr_name, graph.temporal_sample_neighbors)
super().__init__(
datapipe,
graph,
fanouts,
replace,
prob_name,
node_timestamp_attr_name,
edge_timestamp_attr_name,
graph.temporal_sample_neighbors,
)


@functional_datapipe("temporal_sample_layer_neighbor")
class TemporalLayerNeighborSampler(TemporalNeighborSamplerImpl):
Expand Down Expand Up @@ -245,4 +257,13 @@ def __init__(
node_timestamp_attr_name=None,
edge_timestamp_attr_name=None,
):
super().__init__(datapipe, graph, fanouts, replace, prob_name, node_timestamp_attr_name, edge_timestamp_attr_name, graph.temporal_sample_layer_neighbors)
super().__init__(
datapipe,
graph,
fanouts,
replace,
prob_name,
node_timestamp_attr_name,
edge_timestamp_attr_name,
graph.temporal_sample_layer_neighbors,
)
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,12 @@ def _get_available_neighbors():
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, labor, use_node_timestamp, use_edge_timestamp
indptr_dtype,
indices_dtype,
replace,
labor,
use_node_timestamp,
use_edge_timestamp,
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
Expand Down
127 changes: 109 additions & 18 deletions tests/python/pytorch/graphbolt/test_subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _get_sampler(sampler_type):
edge_timestamp_attr_name="timestamp",
)


def _is_temporal(sampler_type):
return sampler_type in [SamplerType.Temporal, SamplerType.TemporalLayer]

Expand Down Expand Up @@ -190,7 +191,12 @@ def test_NeighborSampler_fanouts(labor):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Node(sampler_type):
_check_sampler_type(sampler_type)
Expand All @@ -217,7 +223,12 @@ def test_SubgraphSampler_Node(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -249,7 +260,12 @@ def test_SubgraphSampler_Link(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -278,7 +294,12 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -310,7 +331,12 @@ def test_SubgraphSampler_HyperLink(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -341,7 +367,12 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -396,7 +427,12 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -441,7 +477,12 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -486,7 +527,12 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -532,7 +578,12 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -583,7 +634,12 @@ def test_SubgraphSampler_HyperLink_Hetero(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
@pytest.mark.parametrize(
"replace",
Expand Down Expand Up @@ -684,7 +740,12 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -753,7 +814,12 @@ def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -1010,7 +1076,12 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Node(labor):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -1081,7 +1152,12 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -1143,7 +1219,12 @@ def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -1453,7 +1534,12 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
Expand Down Expand Up @@ -1514,7 +1600,12 @@ def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):

@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal, SamplerType.TemporalLayer],
[
SamplerType.Normal,
SamplerType.Layer,
SamplerType.Temporal,
SamplerType.TemporalLayer,
],
)
def test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
Expand Down

0 comments on commit 202388c

Please sign in to comment.