Skip to content

Commit

Permalink
[graphbolt] skip non-existent types in input_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed May 9, 2024
1 parent 2da713f commit fad21e9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def record_stream(tensor):
if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = input_nodes[type_name]
nodes = input_nodes.get(type_name, None)
if nodes is None:
continue
if nodes.is_cuda:
Expand Down
4 changes: 3 additions & 1 deletion tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ def test_FeatureFetcher_hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
# "n3" is not in the sampled input nodes.
node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)

assert len(list(fetcher_dp)) == 3
Expand Down

0 comments on commit fad21e9

Please sign in to comment.