diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 12b1cba9125a..34e6a63233dd 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const { c10::intrusive_ptr FusedCSCSamplingGraph::InSubgraph( const torch::Tensor& nodes) const { - if (utils::is_accessible_from_gpu(indptr_) && + if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) && utils::is_accessible_from_gpu(indices_) && - utils::is_accessible_from_gpu(nodes) && (!type_per_edge_.has_value() || utils::is_accessible_from_gpu(type_per_edge_.value()))) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", { @@ -614,9 +613,9 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( probs_or_mask = this->EdgeAttribute(probs_name); } - if (!replace && utils::is_accessible_from_gpu(indptr_) && + if (!replace && utils::is_on_gpu(nodes) && + utils::is_accessible_from_gpu(indptr_) && utils::is_accessible_from_gpu(indices_) && - utils::is_accessible_from_gpu(nodes) && (!probs_or_mask.has_value() || utils::is_accessible_from_gpu(probs_or_mask.value()))) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( diff --git a/graphbolt/src/index_select.cc b/graphbolt/src/index_select.cc index 34191a8f8257..00257061c675 100644 --- a/graphbolt/src/index_select.cc +++ b/graphbolt/src/index_select.cc @@ -13,8 +13,7 @@ namespace graphbolt { namespace ops { torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) { - if (input.is_pinned() && - (index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) { + if (utils::is_on_gpu(index) && input.is_pinned()) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( c10::DeviceType::CUDA, "UVAIndexSelect", { return UVAIndexSelectImpl(input, index); }); @@ -26,9 +25,8 @@ std::tuple IndexSelectCSC( torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { TORCH_CHECK( indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors"); - if (utils::is_accessible_from_gpu(indptr) && - utils::is_accessible_from_gpu(indices) && - utils::is_accessible_from_gpu(nodes)) { + if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) && + utils::is_accessible_from_gpu(indices)) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( c10::DeviceType::CUDA, "IndexSelectCSCImpl", { return IndexSelectCSCImpl(indptr, indices, nodes); }); diff --git a/graphbolt/src/isin.cc b/graphbolt/src/isin.cc index ff9a976b03b9..c41b839b1651 100644 --- a/graphbolt/src/isin.cc +++ b/graphbolt/src/isin.cc @@ -48,8 +48,7 @@ torch::Tensor IsInCPU( torch::Tensor IsIn( const torch::Tensor& elements, const torch::Tensor& test_elements) { - if (utils::is_accessible_from_gpu(elements) && - utils::is_accessible_from_gpu(test_elements)) { + if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( c10::DeviceType::CUDA, "IsInOperation", { return ops::IsIn(elements, test_elements); }); diff --git a/graphbolt/src/unique_and_compact.cc b/graphbolt/src/unique_and_compact.cc index fd0d23e84e63..3a7e4963bbdd 100644 --- a/graphbolt/src/unique_and_compact.cc +++ b/graphbolt/src/unique_and_compact.cc @@ -19,9 +19,8 @@ namespace sampling { std::tuple UniqueAndCompact( const torch::Tensor& src_ids, const torch::Tensor& dst_ids, const torch::Tensor unique_dst_ids) { - if (utils::is_accessible_from_gpu(src_ids) && - utils::is_accessible_from_gpu(dst_ids) && - utils::is_accessible_from_gpu(unique_dst_ids)) { + if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) && + utils::is_on_gpu(unique_dst_ids)) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( c10::DeviceType::CUDA, "unique_and_compact", { return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); }); diff --git a/graphbolt/src/utils.h b/graphbolt/src/utils.h index 093e920af017..fcd66ef93d3b 100644 --- a/graphbolt/src/utils.h +++ b/graphbolt/src/utils.h @@ -12,11 +12,18 @@ namespace graphbolt { namespace utils { +/** + * @brief Checks whether the tensor is stored on the GPU. + */ +inline bool is_on_gpu(torch::Tensor tensor) { + return tensor.device().is_cuda(); +} + /** * @brief Checks whether the tensor is stored on the GPU or the pinned memory. */ inline bool is_accessible_from_gpu(torch::Tensor tensor) { - return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA; + return is_on_gpu(tensor) || tensor.is_pinned(); } /** diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 04cef3958356..b59260cf654e 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -9,9 +9,6 @@ def test_DataLoader(): - # https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing - mp.set_start_method("spawn", force=True) - N = 40 B = 4 itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")