Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Temporal] Optimize num pick by early stop #7370

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
37 changes: 32 additions & 5 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1160,14 +1160,41 @@ int64_t TemporalNumPick(
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
constexpr int64_t kFastPathThreshold = 1000;
if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) {
if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value() &&
fanout != -1) {
// TODO: Currently we use the fast path both in TemporalNumPick and
// TemporalPick. We may only sample once in TemporalNumPick and use the
// sampled edges in TemporalPick to avoid sampling twice.
auto [success, sampled_edges] = FastTemporalPick(
seed_timestamp, csc_indics, fanout, replace, seed_pre_time_window,
node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors);
if (success) return sampled_edges.size();
int64_t sampled_count = 0;
auto timestamp =
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset);
for (int64_t edge_id = offset;
edge_id < offset + num_neighbors && sampled_count < fanout;
edge_id++) {
if (replace && sampled_count > 0) {
peizhou001 marked this conversation as resolved.
Show resolved Hide resolved
sampled_count = fanout;
break;
}
if (node_timestamp.has_value()) {
bool flag = true;
AT_DISPATCH_INDEX_TYPES(
csc_indics.scalar_type(), "CheckNodeTimeStamp", ([&] {
int64_t neighbor_id =
utils::GetValueByIndex<index_t>(csc_indics, edge_id);
if (utils::GetValueByIndex<int64_t>(
node_timestamp.value(), neighbor_id) >= timestamp)
peizhou001 marked this conversation as resolved.
Show resolved Hide resolved
flag = false;
}));
if (!flag) continue;
}
if (edge_timestamp.has_value() &&
utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id) >=
timestamp) {
continue;
}
sampled_count++;
}
return sampled_count;
}
torch::optional<int64_t> time_window = torch::nullopt;
if (seed_pre_time_window.has_value()) {
Expand Down
Loading