Skip to content

Commit

Permalink
[graphbolt] skip non-existent types in input_nodes (#7386)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed May 10, 2024
1 parent 3b01cd4 commit 8083ffd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
28 changes: 14 additions & 14 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def record_stream(tensor):

if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = input_nodes[type_name]
if nodes is None:
for type_name, nodes in input_nodes.items():
if type_name not in self.node_feature_keys or nodes is None:
continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.node_feature_keys[type_name]:
node_features[
(type_name, feature_name)
] = record_stream(
Expand Down Expand Up @@ -126,21 +125,22 @@ def record_stream(tensor):
if is_heterogeneous:
# Convert edge type to string.
original_edge_ids = {
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key: value
(
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key
): value
for key, value in original_edge_ids.items()
}
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = original_edge_ids.get(type_name, None)
if edges is None:
for type_name, edges in original_edge_ids.items():
if (
type_name not in self.edge_feature_keys
or edges is None
):
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.edge_feature_keys[type_name]:
edge_features[i][
(type_name, feature_name)
] = record_stream(
Expand Down
29 changes: 25 additions & 4 deletions tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
# "n3" is not in the sampled input nodes.
node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)

assert len(list(fetcher_dp)) == 3

# Do not fetch feature for "n1".
node_feature_keys = {"n2": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)
for mini_batch in fetcher_dp:
assert ("n1", "a") not in mini_batch.node_features


def test_FeatureFetcher_with_edges_hetero():
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
Expand Down Expand Up @@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch):
return data

features = {}
keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
keys = [
("node", "n1", "a"),
("edge", "n1:e1:n2", "a"),
("edge", "n2:e2:n1", "a"),
]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
Expand All @@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch):
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
# "n3:e3:n3" is not in the sampled edges.
# Do not fetch feature for "n2:e2:n1".
node_feature_keys = {"n1": ["a"]}
edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
converter_dp,
feature_store,
node_feature_keys=node_feature_keys,
edge_feature_keys=edge_feature_keys,
)

assert len(list(fetcher_dp)) == 5
Expand All @@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch):
assert len(data.edge_features) == 3
for edge_feature in data.edge_features:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
assert ("n2:e2:n1", "a") not in edge_feature

0 comments on commit 8083ffd

Please sign in to comment.