Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Apr 26, 2024
1 parent d402e97 commit 14a6284
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def _seeds_preprocess(minibatch):
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
is_hyperlink = False
for type, typed_seeds in seeds.items():
for seed_type, typed_seeds in seeds.items():
# When typed_seeds is a one-dimensional tensor, it represents
# seed nodes, which does not need to do unique and compact.
if typed_seeds.ndim == 1:
Expand All @@ -136,20 +135,24 @@ def _seeds_preprocess(minibatch):
"Only tensor with shape 1*N and N*M is "
+ f"supported now, but got {typed_seeds.shape}."
)
ntypes = seed_type_str_to_ntypes(type, typed_seeds.shape[1])
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
if use_timestamp:
negative_ratio = (
typed_seeds.shape[0]
// minibatch.timestamp[type].shape[0]
// minibatch.timestamp[seed_type].shape[0]
- 1
)
neg_timestamp = minibatch.timestamp[type].repeat_interleave(
negative_ratio
)
neg_timestamp = minibatch.timestamp[
seed_type
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(typed_seeds[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(minibatch.timestamp[type])
nodes_timestamp[ntype].append(
minibatch.timestamp[seed_type]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
if use_timestamp:
Expand All @@ -163,12 +166,14 @@ def _seeds_preprocess(minibatch):
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for type, typed_seeds in seeds.items():
ntypes = seed_type_str_to_ntypes(type, typed_seeds.shape[1])
for seed_type, typed_seeds in seeds.items():
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
compacted_seed = []
for ntype in ntypes:
compacted_seed.append(compacted[ntype].pop(0))
compacted_seeds[type] = (
compacted_seeds[seed_type] = (
torch.cat(compacted_seed).view(len(ntypes), -1).T
)
else:
Expand Down

0 comments on commit 14a6284

Please sign in to comment.