Skip to content

Commit

Permalink
[GraphBolt] Refactor and extend FeatureStore. (#7558)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 22, 2024
1 parent 69fd95e commit d775ab1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 108 deletions.
39 changes: 34 additions & 5 deletions python/dgl/graphbolt/feature_store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
"""Feature store for GraphBolt."""

from typing import NamedTuple

import torch

__all__ = ["Feature", "FeatureStore"]
__all__ = ["Feature", "FeatureStore", "FeatureKey"]


class FeatureKey(NamedTuple):
"""A named tuple class to represent feature keys in FeatureStore classes.
The fields are domain, type and name all of which take string values.
"""

domain: str
type: str
name: int


class Feature:
Expand Down Expand Up @@ -109,6 +121,23 @@ class FeatureStore:
def __init__(self):
pass

def __getitem__(self, feature_key: FeatureKey) -> Feature:
"""Access the underlying `Feature` with its (domain, type, name) as
the feature_key.
"""
raise NotImplementedError

def __setitem__(self, feature_key: FeatureKey, feature: Feature):
"""Set the underlying `Feature` with its (domain, type, name) as
the feature_key and feature as the value.
"""
raise NotImplementedError

def __contains__(self, feature_key: FeatureKey) -> bool:
"""Checks whether the provided (domain, type, name) as the feature_key
is container in the FeatureStore."""
raise NotImplementedError

def read(
self,
domain: str,
Expand All @@ -135,7 +164,7 @@ def read(
torch.Tensor
The read feature.
"""
raise NotImplementedError
return self.__getitem__((domain, type_name, feature_name)).read(ids)

def size(
self,
Expand All @@ -158,7 +187,7 @@ def size(
torch.Size
The size of the specified feature in the feature store.
"""
raise NotImplementedError
return self.__getitem__((domain, type_name, feature_name)).size()

def metadata(
self,
Expand All @@ -181,7 +210,7 @@ def metadata(
Dict
The metadata of the feature.
"""
raise NotImplementedError
return self.__getitem__((domain, type_name, feature_name)).metadata()

def update(
self,
Expand Down Expand Up @@ -210,7 +239,7 @@ def update(
must have the same length. If None, the entire feature will be
updated.
"""
raise NotImplementedError
self.__getitem__((domain, type_name, feature_name)).update(value, ids)

def keys(self):
"""Get the keys of the features.
Expand Down
115 changes: 13 additions & 102 deletions python/dgl/graphbolt/impl/basic_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from typing import Dict, Tuple

import torch

from ..feature_store import Feature, FeatureStore
from ..feature_store import Feature, FeatureKey, FeatureStore

__all__ = ["BasicFeatureStore"]

Expand All @@ -29,109 +27,22 @@ def __init__(self, features: Dict[Tuple[str, str, str], Feature]):
super().__init__()
self._features = features

def read(
self,
domain: str,
type_name: str,
feature_name: str,
ids: torch.Tensor = None,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return self._features[(domain, type_name, feature_name)].read(ids)

def size(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the size of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
torch.Size
The size of the specified feature in the feature store.
def __getitem__(self, feature_key: FeatureKey) -> Feature:
"""Access the underlying `Feature` with its (domain, type, name) as
the feature_key.
"""
return self._features[(domain, type_name, feature_name)].size()
return self._features[feature_key]

def metadata(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the metadata of the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
Dict
The metadata of the feature.
def __setitem__(self, feature_key: FeatureKey, feature: Feature):
"""Set the underlying `Feature` with its (domain, type, name) as
the feature_key and feature as the value.
"""
return self._features[(domain, type_name, feature_name)].metadata()

def update(
self,
domain: str,
type_name: str,
feature_name: str,
value: torch.Tensor,
ids: torch.Tensor = None,
):
"""Update the feature store.
self._features[feature_key] = feature

Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self._features[(domain, type_name, feature_name)].update(value, ids)
def __contains__(self, feature_key: FeatureKey) -> bool:
"""Checks whether the provided (domain, type, name) as the feature_key
is container in the BasicFeatureStore."""
return feature_key in self._features

def __len__(self):
"""Return the number of features."""
Expand Down
35 changes: 34 additions & 1 deletion tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def test_basic_feature_store_homo():

feature_store = gb.BasicFeatureStore(features)

# Test __getitem__ to access the stored Feature.
feature = feature_store[("node", None, "a")]
assert isinstance(feature, gb.Feature)
assert torch.equal(
feature.read(),
torch.tensor([[1, 2, 4], [2, 5, 3]]),
)

# Test read the entire feature.
assert torch.equal(
feature_store.read("node", None, "a"),
Expand Down Expand Up @@ -43,8 +51,17 @@ def test_basic_feature_store_homo():
assert feature_store.metadata("node", None, "a") == metadata
assert feature_store.metadata("node", None, "b") == {}

# Test __setitem__ and __contains__ of FeatureStore.
assert ("node", None, "c") not in feature_store
feature_store[("node", None, "c")] = feature_store[("node", None, "a")]
assert ("node", None, "c") in feature_store

# Test get keys of the features.
assert feature_store.keys() == [("node", None, "a"), ("node", None, "b")]
assert feature_store.keys() == [
("node", None, "a"),
("node", None, "b"),
("node", None, "c"),
]


def test_basic_feature_store_hetero():
Expand All @@ -60,6 +77,14 @@ def test_basic_feature_store_hetero():

feature_store = gb.BasicFeatureStore(features)

# Test __getitem__ to access the stored Feature.
feature = feature_store[("node", "author", "a")]
assert isinstance(feature, gb.Feature)
assert torch.equal(
feature.read(),
torch.tensor([[1, 2, 4], [2, 5, 3]]),
)

# Test read the entire feature.
assert torch.equal(
feature_store.read("node", "author", "a"),
Expand All @@ -84,10 +109,18 @@ def test_basic_feature_store_hetero():
assert feature_store.metadata("node", "author", "a") == metadata
assert feature_store.metadata("edge", "paper:cites", "b") == {}

# Test __setitem__ and __contains__ of FeatureStore.
assert ("node", "author", "c") not in feature_store
feature_store[("node", "author", "c")] = feature_store[
("node", "author", "a")
]
assert ("node", "author", "c") in feature_store

# Test get keys of the features.
assert feature_store.keys() == [
("node", "author", "a"),
("edge", "paper:cites", "b"),
("node", "author", "c"),
]


Expand Down

0 comments on commit d775ab1

Please sign in to comment.