Skip to content

Commit

Permalink
Fix second linter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushnoori committed Jul 6, 2024
1 parent c11b871 commit 161edd6
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/dgl/dataloading/capped_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def __init__(
self.prefetch_edge_feats = prefetch_edge_feats
self.output_device = output_device

def sample(self, g, seed_nodes, exclude_eids=None):
def sample(self, g, indices, exclude_eids=None):
"""Sampling function.
Parameters
----------
g : DGLGraph
The graph to sample from.
seed_nodes : Tensor or dict[str, Tensor]
indices : Tensor or dict[str, Tensor]
Nodes which induce the subgraph.
exclude_eids : Tensor or dict[etype, Tensor], optional
The edges to exclude from the sampled subgraph.
Expand All @@ -85,15 +85,15 @@ def sample(self, g, seed_nodes, exclude_eids=None):
"""

# Define empty dictionary to store reached nodes.
output_nodes = seed_nodes
all_reached_nodes = [seed_nodes]
output_nodes = indices
all_reached_nodes = [indices]

# Iterate over fanout.
for fanout in reversed(self.fanouts):

# Sample frontier.
frontier = g.sample_neighbors(
seed_nodes,
indices,
fanout,
output_device=self.output_device,
replace=self.replace,
Expand All @@ -104,7 +104,7 @@ def sample(self, g, seed_nodes, exclude_eids=None):
# Get reached nodes.
curr_reached = defaultdict(list)
for c_etype in frontier.canonical_etypes:
(src_type, rel_type, dst_type) = c_etype
(src_type, _, _) = c_etype
src, _ = frontier.edges(etype=c_etype)
curr_reached[src_type].append(src)

Expand Down Expand Up @@ -153,18 +153,18 @@ def sample(self, g, seed_nodes, exclude_eids=None):

# Downsample nodes.
curr_reached_k = {}
for node_type, node_IDs in curr_reached.items():
for node_type, node_ids in curr_reached.items():

# Get number of total nodes and number to sample.
num_nodes = node_IDs.shape[0]
num_nodes = node_ids.shape[0]
n_to_sample = min(num_nodes, n_per_type[node_type])

# Downsample nodes of current type.
random_indices = torch.randperm(num_nodes)[:n_to_sample]
curr_reached_k[node_type] = node_IDs[random_indices]
curr_reached_k[node_type] = node_ids[random_indices]

# Update seed nodes.
seed_nodes = curr_reached_k
indices = curr_reached_k
all_reached_nodes.append(curr_reached_k)

# Merge all reached nodes before sending to `DGLGraph.subgraph`.
Expand All @@ -185,4 +185,4 @@ def sample(self, g, seed_nodes, exclude_eids=None):
set_node_lazy_features(subg, self.prefetch_node_feats)
set_edge_lazy_features(subg, self.prefetch_edge_feats)

return seed_nodes, output_nodes, subg
return indices, output_nodes, subg

0 comments on commit 161edd6

Please sign in to comment.