Skip to content

Commit

Permalink
[GraphBolt][CUDA] Dataloader num_workers > 0 fix. (#6924) (#6928)
Browse files Browse the repository at this point in the history
Co-authored-by: Muhammed Fatih BALIN <m.f.balin@gmail.com>
  • Loading branch information
Rhett-Ying and mfbalin committed Jan 10, 2024
1 parent 92d4ba9 commit c047950
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 18 deletions.
7 changes: 3 additions & 4 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const {

c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) &&
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
Expand Down Expand Up @@ -614,9 +613,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
probs_or_mask = this->EdgeAttribute(probs_name);
}

if (!replace && utils::is_accessible_from_gpu(indptr_) &&
if (!replace && utils::is_on_gpu(nodes) &&
utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!probs_or_mask.has_value() ||
utils::is_accessible_from_gpu(probs_or_mask.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
Expand Down
8 changes: 3 additions & 5 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace graphbolt {
namespace ops {

torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
if (input.is_pinned() &&
(index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) {
if (utils::is_on_gpu(index) && input.is_pinned()) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); });
Expand All @@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) &&
utils::is_accessible_from_gpu(nodes)) {
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); });
Expand Down
3 changes: 1 addition & 2 deletions graphbolt/src/isin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ torch::Tensor IsInCPU(

torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) &&
utils::is_accessible_from_gpu(test_elements)) {
if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); });
Expand Down
5 changes: 2 additions & 3 deletions graphbolt/src/unique_and_compact.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) &&
utils::is_accessible_from_gpu(dst_ids) &&
utils::is_accessible_from_gpu(unique_dst_ids)) {
if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&
utils::is_on_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
Expand Down
9 changes: 8 additions & 1 deletion graphbolt/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
namespace graphbolt {
namespace utils {

/**
* @brief Checks whether the tensor is stored on the GPU.
*/
inline bool is_on_gpu(torch::Tensor tensor) {
return tensor.device().is_cuda();
}

/**
* @brief Checks whether the tensor is stored on the GPU or the pinned memory.
*/
inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA;
return is_on_gpu(tensor) || tensor.is_pinned();
}

/**
Expand Down
3 changes: 0 additions & 3 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@


def test_DataLoader():
# https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
mp.set_start_method("spawn", force=True)

N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
Expand Down

0 comments on commit c047950

Please sign in to comment.