Skip to content

Commit

Permalink
[GraphBolt] Avoid initializing CUDAContext in DataLoader workers. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 22, 2024
1 parent d775ab1 commit c218b19
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 18 deletions.
21 changes: 13 additions & 8 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
*/
#include "./cache_policy.h"

#include "./utils.h"

namespace graphbolt {
namespace storage {

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto indices = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto filtered_keys =
torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto filtered_keys = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
Expand Down Expand Up @@ -63,8 +67,9 @@ template <typename CachePolicy>
torch::Tensor BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
Expand Down
3 changes: 2 additions & 1 deletion graphbolt/src/cnumpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdexcept>

#include "./circular_queue.h"
#include "./utils.h"

namespace graphbolt {
namespace storage {
Expand Down Expand Up @@ -152,7 +153,7 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
shape, index.options()
.dtype(dtype_)
.layout(torch::kStrided)
.pinned_memory(index.is_pinned())
.pinned_memory(utils::is_pinned(index))
.requires_grad(false));
auto result_buffer = reinterpret_cast<char *>(result.data_ptr());

Expand Down
4 changes: 3 additions & 1 deletion graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "./feature_cache.h"

#include "./index_select.h"
#include "./utils.h"

namespace graphbolt {
namespace storage {
Expand All @@ -34,7 +35,8 @@ FeatureCache::FeatureCache(

torch::Tensor FeatureCache::Query(
torch::Tensor positions, torch::Tensor indices, int64_t size) {
const bool pin_memory = positions.is_pinned() || indices.is_pinned();
const bool pin_memory =
utils::is_pinned(positions) || utils::is_pinned(indices);
std::vector<int64_t> output_shape{
tensor_.sizes().begin(), tensor_.sizes().end()};
output_shape[0] = size;
Expand Down
5 changes: 3 additions & 2 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
auto output_shape = input.sizes().vec();
output_shape[0] = index.numel();
auto result = torch::empty(
output_shape,
index.options().dtype(input.dtype()).pinned_memory(index.is_pinned()));
output_shape, index.options()
.dtype(input.dtype())
.pinned_memory(utils::is_pinned(index)));
return torch::index_select_out(result, input, 0, index);
}

Expand Down
15 changes: 9 additions & 6 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include <numeric>

#include "./utils.h"

namespace graphbolt {
namespace storage {

Expand Down Expand Up @@ -140,15 +142,15 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
result_offsets, 0);
torch::Tensor positions = torch::empty(
result_offsets[policies_.size()],
std::get<0>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor output_indices = torch::empty_like(
indices, indices.options().pinned_memory(keys.is_pinned()));
indices, indices.options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor missing_keys = torch::empty(
indices.size(0) - positions.size(0),
std::get<2>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor found_keys = torch::empty(
positions.size(0),
std::get<3>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
Expand Down Expand Up @@ -200,8 +202,9 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto output_positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto output_positions_ptr = output_positions.data_ptr<int64_t>();
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "./index_select.h"
#include "./partitioned_cache_policy.h"
#include "./random.h"
#include "./utils.h"

#ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/extension/gpu_graph_cache.h"
Expand Down Expand Up @@ -145,6 +146,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched);
m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
m.def("detect_io_uring", &io_uring::IsAvailable);
m.def("set_worker_id", &utils::SetWorkerId);
m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
Expand Down
36 changes: 36 additions & 0 deletions graphbolt/src/utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file utils.cc
* @brief Graphbolt utils implementations.
*/
#include "./utils.h"

#include <optional>

namespace graphbolt {
namespace utils {

namespace {
std::optional<int64_t> worker_id;
}

std::optional<int64_t> GetWorkerId() { return worker_id; }

void SetWorkerId(int64_t worker_id_value) { worker_id = worker_id_value; }

} // namespace utils
} // namespace graphbolt
20 changes: 20 additions & 0 deletions graphbolt/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
namespace graphbolt {
namespace utils {

/**
* @brief If this process is a worker part as part of a DataLoader, then returns
* the assigned worker id less than the # workers.
*/
std::optional<int64_t> GetWorkerId();

/**
* @brief If this process is a worker part as part of a DataLoader, then this
* function is called to initialize its worked id to be less than the # workers.
*/
void SetWorkerId(int64_t worker_id_value);

/**
* @brief Checks whether the tensor is stored on the GPU.
*/
Expand All @@ -26,6 +38,14 @@ inline bool is_accessible_from_gpu(const torch::Tensor& tensor) {
return is_on_gpu(tensor) || tensor.is_pinned();
}

/**
* @brief Checks whether the tensor is stored on the pinned memory.
*/
inline bool is_pinned(const torch::Tensor& tensor) {
// If this process is a worker, we should avoid initializing the CUDA context.
return !GetWorkerId() && tensor.is_pinned();
}

/**
* @brief Checks whether the tensors are all stored on the GPU or the pinned
* memory.
Expand Down
5 changes: 5 additions & 0 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
return datapipe_graph


def _set_worker_id(worked_id):
torch.ops.graphbolt.set_worker_id(worked_id)


class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing.
Expand All @@ -89,6 +93,7 @@ def __init__(self, datapipe, num_workers=0, persistent_workers=True):
batch_size=None,
num_workers=num_workers,
persistent_workers=(num_workers > 0) and persistent_workers,
worker_init_fn=_set_worker_id if num_workers > 0 else None,
)

def __iter__(self):
Expand Down

0 comments on commit c218b19

Please sign in to comment.