Skip to content

Commit

Permalink
Merge branch 'master' into exchange_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jul 31, 2024
2 parents 7c06d14 + 65f85b5 commit 05534e1
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 86 deletions.
2 changes: 2 additions & 0 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ BaseCachePolicy::QueryAndThenReplaceImpl(
// we do not have to check for the uniqueness of the positions.
std::get<1>(position_set.insert(it->second->getPos())),
"Can't insert all, larger cache capacity is needed.");
} else {
policy.MarkExistingWriting(it);
}
auto& cache_key = *it->second;
positions_ptr[i] = cache_key.getPos();
Expand Down
35 changes: 22 additions & 13 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.Increment().StartUse<false>();
return {it, true};
} else {
cache_key.Increment().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
Expand All @@ -296,6 +294,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
it->second = queue.Push(CacheKey(key, pos));
}

void MarkExistingWriting(map_iterator it) {
it->second->Increment().StartUse<true>();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -414,15 +416,13 @@ class SieveCachePolicy : public BaseCachePolicy {

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
Expand All @@ -444,6 +444,10 @@ class SieveCachePolicy : public BaseCachePolicy {
it->second = &queue_.front();
}

void MarkExistingWriting(map_iterator it) {
it->second->SetFreq().StartUse<true>();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -552,16 +556,14 @@ class LruCachePolicy : public BaseCachePolicy {

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != queue_.end()) {
MoveToFront(it->second);
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
MoveToFront(it->second);
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.StartUse<false>();
return {it, true};
} else {
cache_key.StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
Expand All @@ -582,6 +584,11 @@ class LruCachePolicy : public BaseCachePolicy {
it->second = queue_.begin();
}

void MarkExistingWriting(map_iterator it) {
MoveToFront(it->second);
it->second->StartUse<true>();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -678,15 +685,13 @@ class ClockCachePolicy : public BaseCachePolicy {

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
Expand All @@ -706,6 +711,10 @@ class ClockCachePolicy : public BaseCachePolicy {
it->second = queue_.Push(CacheKey(key, pos));
}

void MarkExistingWriting(map_iterator it) {
it->second->SetFreq().StartUse<true>();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down
4 changes: 2 additions & 2 deletions graphbolt/src/cuda/expand_indptr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct AdjacentDifference {
torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> nodes, torch::optional<int64_t> output_size,
const bool edge_ids) {
const bool is_edge_ids_variant) {
if (!output_size.has_value()) {
output_size = AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
Expand Down Expand Up @@ -102,7 +102,7 @@ torch::Tensor ExpandIndptrImpl(
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();

if (edge_ids) {
if (is_edge_ids_variant) {
auto input_buffer = thrust::make_transform_iterator(
iota, IotaIndex<indices_t, nodes_t>{nodes_ptr});
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
Expand Down
106 changes: 100 additions & 6 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,10 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
selected_positions_ptr, selected_positions_ptr + num_selected,
positions.data_ptr<int64_t>() + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
std::memcpy(
reinterpret_cast<std::byte*>(found_pointers.data_ptr()) +
begin * found_pointers.element_size(),
std::get<3>(results[tid]).data_ptr(),
num_selected * found_pointers.element_size());
auto selected_pointers_ptr = std::get<3>(results[tid]).data_ptr<int64_t>();
std::copy(
selected_pointers_ptr, selected_pointers_ptr + num_selected,
found_pointers.data_ptr<int64_t>() + begin);
begin = result_offsets[policies_.size() + tid];
end = result_offsets[policies_.size() + tid + 1];
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
Expand Down Expand Up @@ -263,7 +262,102 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
return {positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
};
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
results(policies_.size());
torch::Tensor result_offsets_tensor =
torch::empty(policies_.size() * 2 + 1, offsets.options());
auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();
namespace gb = graphbolt;
{
std::lock_guard lock(mtx_);
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] = policies_.at(tid)->QueryAndThenReplace(
permuted_keys.slice(0, begin, end));
const auto missing_cnt = std::get<3>(results[tid]).size(0);
result_offsets[tid] = end - begin - missing_cnt;
result_offsets[tid + policies_.size()] = missing_cnt;
});
}
std::exclusive_scan(
result_offsets, result_offsets + result_offsets_tensor.size(0),
result_offsets, 0);
torch::Tensor positions = torch::empty(
keys.size(0),
std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor output_indices = torch::empty_like(
indices, indices.options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor pointers = torch::empty(
keys.size(0),
std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor missing_keys = torch::empty(
result_offsets[2 * policies_.size()] - result_offsets[policies_.size()],
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
auto missing_offsets =
torch::empty(policies_.size() + 1, result_offsets_tensor.options());
auto positions_ptr = positions.data_ptr<int64_t>();
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
auto pointers_ptr = pointers.data_ptr<int64_t>();
auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();
missing_offsets_ptr[0] = 0;
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
auto out_index_ptr = indices_ptr + offsets_ptr[tid];
begin = result_offsets[tid];
end = result_offsets[tid + 1];
const auto num_selected = end - begin;
auto indices_ptr = std::get<1>(results[tid]).data_ptr<int64_t>();
for (int64_t i = 0; i < num_selected; i++) {
output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]];
}
auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr<int64_t>();
std::transform(
selected_positions_ptr, selected_positions_ptr + num_selected,
positions_ptr + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
auto selected_pointers_ptr = std::get<2>(results[tid]).data_ptr<int64_t>();
std::copy(
selected_pointers_ptr, selected_pointers_ptr + num_selected,
pointers_ptr + begin);
begin = result_offsets[policies_.size() + tid];
end = result_offsets[policies_.size() + tid + 1];
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
const auto num_missing = end - begin;
for (int64_t i = 0; i < num_missing; i++) {
output_indices_ptr[begin + i] =
out_index_ptr[indices_ptr[i + num_selected]];
}
auto missing_positions_ptr = selected_positions_ptr + num_selected;
std::transform(
missing_positions_ptr, missing_positions_ptr + num_missing,
positions_ptr + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
auto missing_pointers_ptr = selected_pointers_ptr + num_selected;
std::copy(
missing_pointers_ptr, missing_pointers_ptr + num_missing,
pointers_ptr + begin);
std::memcpy(
reinterpret_cast<std::byte*>(missing_keys.data_ptr()) +
(begin - result_offsets[policies_.size()]) *
missing_keys.element_size(),
std::get<3>(results[tid]).data_ptr(),
num_missing * missing_keys.element_size());
});
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
return std::make_tuple(
positions, output_indices, pointers, missing_keys, found_offsets,
missing_offsets);
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
Expand Down
Loading

0 comments on commit 05534e1

Please sign in to comment.