Skip to content

Commit

Permalink
fix the replace fnction
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jun 20, 2024
1 parent 0c0f115 commit 2a74efc
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions graphbolt/src/cuda/extension/gpu_graph_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ using map_t = cuco::static_map<
template <typename index_t, typename map_t>
__global__ void _Insert(
const int64_t num_nodes, const index_t num_existing, const index_t* seeds,
const index_t* indices, map_t map) {
const index_t* missing_indices, const index_t* indices, map_t map) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;

while (i < num_nodes) {
const auto key = seeds[indices[i]];
const auto key = seeds[missing_indices[indices[i]]];

auto slot = map.find(key);
slot->second = num_existing + i;
Expand Down Expand Up @@ -241,41 +241,47 @@ void GpuGraphCache::Replace(
torch::Tensor sindptr, sindices;
auto [in_degree, sliced_indptr] =
ops::SliceCSCIndptr(indptr, output_indices);
std::tie(sindptr, sindices) =
ops::IndexSelectCSCImpl(indptr, indices, output_indices);
torch::optional<int64_t> output_size;
if (num_edges_ + output_size < indices_.size(0) &&
num_nodes_ + num_entering < indptr_.size(0) - 1) {
if (num_nodes_ + num_entering < indptr_.size(0) - 1) {
torch::Tensor sindptr;
bool enough_space;
for (size_t i = 0; i < edge_tensors.size(); i++) {
torch::Tensor sindices;
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
in_degree, sliced_indptr, edge_tensors[i], output_indices,
indptr.size(0) - 2, output_size);
output_size = sindices.size(0);
cached_edge_tensors_.at(i).slice(
0, num_edges_, num_edges_ + *output_size) = sindices;
enough_space =
num_edges_ + *output_size < cached_edge_tensors_.at(i).size(0);
if (enough_space) {
cached_edge_tensors_.at(i).slice(
0, num_edges_, num_edges_ + *output_size) = sindices;
}
}
if (enough_space) {
AT_DISPATCH_INDEX_TYPES(
sindptr.scalar_type(), "GpuGraphCache::Replace", ([&] {
auto adjusted_indptr = thrust::make_transform_iterator(
sindptr.data_ptr<index_t>(),
[=] __host__ __device__(index_t x) {
return x + num_edges_;
});
CUB_CALL(
DeviceScan::ExclusiveSum, adjusted_indptr + 1,
indptr_.data_ptr<index_t>() + num_nodes_ + 1,
num_entering);
}));
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_entering + BLOCK_SIZE - 1) / BLOCK_SIZE);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<index_t>(num_nodes_), seeds.data_ptr<index_t>(),
missing_indices.data_ptr<index_t>(),
output_indices.data_ptr<index_t>(),
reinterpret_cast<map_t<index_t>*>(map_)->ref(cuco::find));
num_edges_ += *output_size;
num_nodes_ += num_entering;
}
AT_DISPATCH_INTEGRAL_TYPES(
sindptr.scalar_type(), "GpuGraphCache::Replace", ([&] {
auto adjusted_indptr = thrust::make_transform_iterator(
sindptr.data_ptr<scalar_t>(),
[=] __host__ __device__(scalar_t x) {
return x + num_edges_;
});
CUB_CALL(
DeviceScan::ExclusiveSum, adjusted_indptr + 1,
indptr_.data_ptr<int64_t>() + num_nodes_ + 1, num_entering);
}));
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_entering + BLOCK_SIZE - 1) / BLOCK_SIZE);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<index_t>(num_nodes_), seeds.data_ptr<index_t>(),
output_indices.data_ptr<index_t>(),
reinterpret_cast<map_t<index_t>*>(map_)->ref(cuco::find));
num_edges_ += output_size;
num_nodes_ += num_entering;
}
}));
}
Expand Down

0 comments on commit 2a74efc

Please sign in to comment.