Skip to content

Commit

Permalink
Merge branch 'master' into gb_pyg_hetero_example
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 20, 2024
2 parents a282b58 + 513a50f commit ea2e5a4
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 27 deletions.
4 changes: 2 additions & 2 deletions examples/graphbolt/disk_based_feature/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def main():
args.feature_device == "pinned",
)
cpu_cached_feature = features[("node", None, "feat")]
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature._feature.miss_rate
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate
else:
cpu_cache_miss_rate_fn = lambda: 1

Expand All @@ -479,7 +479,7 @@ def main():
int(args.gpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),
)
gpu_cached_feature = features[("node", None, "feat")]
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature._feature.miss_rate
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate
else:
gpu_cache_miss_rate_fn = lambda: 1

Expand Down
4 changes: 2 additions & 2 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def main():
args.feature_device == "pinned",
)
cpu_cached_feature = features[("node", None, "feat")]
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature._feature.miss_rate
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate
else:
cpu_cache_miss_rate_fn = lambda: 1
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
Expand All @@ -510,7 +510,7 @@ def main():
args.num_gpu_cached_features * feature_num_bytes,
)
gpu_cached_feature = features[("node", None, "feat")]
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature._feature.miss_rate
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate
else:
gpu_cache_miss_rate_fn = lambda: 1

Expand Down
8 changes: 4 additions & 4 deletions graphbolt/src/cnumpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
CircularQueue<ReadRequest> read_queue(8 * kGroupSize);
int64_t num_submitted = 0;
int64_t num_completed = 0;
auto [acquired_queue_handle, my_read_buffer2] = queue_source.get();
auto [acquired_queue_handle, read_buffer_source2] = queue_source.get();
auto &io_uring_queue = acquired_queue_handle.get();
// Capturing structured binding is available only in C++20, so we rename.
auto my_read_buffer = my_read_buffer2;
auto read_buffer_source = read_buffer_source2;
auto submit_fn = [&](int64_t submission_minimum_batch_size) {
if (read_queue.Size() < submission_minimum_batch_size) return;
TORCH_CHECK( // Check for sqe overflow.
Expand All @@ -200,8 +200,8 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
};
for (int64_t read_buffer_slot = 0; true;) {
auto request_read_buffer = [&]() {
return my_read_buffer + (aligned_length_ + block_size_) *
(read_buffer_slot++ % (8 * kGroupSize));
return read_buffer_source + (aligned_length_ + block_size_) *
(read_buffer_slot++ % (8 * kGroupSize));
};
const auto num_requested_items = std::max(
std::min(
Expand Down
24 changes: 22 additions & 2 deletions graphbolt/src/cnumpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,21 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
static inline std::mutex available_queues_mtx_; // available_queues_ mutex.
static inline std::vector<int> available_queues_;

struct QueueAndBufferAcquirer {
struct UniqueQueue {
/**
* @brief This class is meant to distribute the available read buffers and the
* statically declared io_uring queues among the worker threads.
*/
class QueueAndBufferAcquirer {
public:
class UniqueQueue {
public:
UniqueQueue(int thread_id) : thread_id_(thread_id) {}
UniqueQueue(const UniqueQueue&) = delete;
UniqueQueue& operator=(const UniqueQueue&) = delete;

/**
* @brief Returns the queue back to the pool.
*/
~UniqueQueue() {
{
// We give back the slot we used.
Expand All @@ -161,6 +170,9 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
semaphore_.release();
}

/**
* @brief Returns the raw io_uring queue.
*/
::io_uring& get() const { return io_uring_queue_[thread_id_]; }

private:
Expand All @@ -179,6 +191,14 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
}
}

/**
* @brief Returns the secured io_uring queue and the read buffer as a pair.
* The raw io_uring queue can be accessed by calling `.get()` on the
* returned UniqueQueue object.
*
* @note The returned UniqueQueue object manages the lifetime of the
* io_uring queue. Its destructor returns the queue back to the pool.
*/
std::pair<UniqueQueue, char*> get() {
// We consume a slot from the semaphore to use a queue.
if (entering_first_.test_and_set(std::memory_order_relaxed)) {
Expand Down
3 changes: 2 additions & 1 deletion graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@ auto GetPickFn(
type_per_edge.value(), probs_or_mask, args, picked_data_ptr,
seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset);
} else {
picked_data_ptr += subgraph_indptr_ptr[seed_offset];
int64_t num_sampled = Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args, picked_data_ptr + subgraph_indptr_ptr[seed_offset]);
args, picked_data_ptr);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
Expand Down
15 changes: 10 additions & 5 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class CPUCachedFeature(Feature):
will hang due to all cache entries being read and/or write locked,
resulting in a deadlock.
policy : str
The cache eviction policy algorithm name. See gb.impl.CPUFeatureCache
for the list of available policies.
The cache eviction policy algorithm name. The available policies are
["s3-fifo", "sieve", "lru", "clock"]. Default is "sieve".
pin_memory : bool
Whether the cache storage should be allocated on system pinned memory.
Default is False.
Expand Down Expand Up @@ -94,9 +94,9 @@ def read_async(self, ids: torch.Tensor):
-------
A generator object.
The returned generator object returns a future on
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
``read_async_num_stages(ids.device)``th invocation. The return result
can be accessed by calling ``.wait()``. on the returned future object.
It is undefined behavior to call ``.wait()`` more than once.
Examples
--------
Expand Down Expand Up @@ -449,3 +449,8 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
else:
self._fallback_feature.update(value, ids)
self._feature.replace(ids, value)

@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
return self._feature.miss_rate
14 changes: 10 additions & 4 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def num_cache_items(cache_capacity_in_bytes, single_item):


class GPUCachedFeature(Feature):
r"""GPU cached feature wrapping a fallback feature.
r"""GPU cached feature wrapping a fallback feature. It uses the least
recently used (LRU) algorithm as the cache eviction policy.
Places the GPU cache to torch.cuda.current_device().
Expand Down Expand Up @@ -100,9 +101,9 @@ def read_async(self, ids: torch.Tensor):
-------
A generator object.
The returned generator object returns a future on
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
``read_async_num_stages(ids.device)``th invocation. The return result
can be accessed by calling ``.wait()``. on the returned future object.
It is undefined behavior to call ``.wait()`` more than once.
Examples
--------
Expand Down Expand Up @@ -219,3 +220,8 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
else:
self._fallback_feature.update(value, ids)
self._feature.replace(ids, value)

@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
return self._feature.miss_rate
14 changes: 7 additions & 7 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def read_async(self, ids: torch.Tensor):
-------
A generator object.
The returned generator object returns a future on
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
``read_async_num_stages(ids.device)``th invocation. The return result
can be accessed by calling ``.wait()``. on the returned future object.
It is undefined behavior to call ``.wait()`` more than once.
Examples
--------
Expand Down Expand Up @@ -424,9 +424,9 @@ def read_async(self, ids: torch.Tensor):
-------
A generator object.
The returned generator object returns a future on
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
``read_async_num_stages(ids.device)``th invocation. The return result
can be accessed by calling ``.wait()``. on the returned future object.
It is undefined behavior to call ``.wait()`` more than once.
Examples
--------
Expand Down Expand Up @@ -520,7 +520,7 @@ def to(self, _): # pylint: disable=invalid-name
return self

def pin_memory_(self): # pylint: disable=invalid-name
"""Placeholder `DiskBasedFeature` pin_memory_ implementation. It is a no-op."""
r"""Placeholder `DiskBasedFeature` pin_memory_ implementation. It is a no-op."""
gb_warning(
"`DiskBasedFeature.pin_memory_()` is not supported. Leaving unmodified."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_cpu_cached_feature(dtype, policy):
total_miss = feat_store_b._feature.total_miss
feat_store_b.read(torch.tensor([0, 1]))
assert total_miss == feat_store_b._feature.total_miss
assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate

# Test get the size of the entire feature with ids.
assert feat_store_a.size() == torch.Size([3])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,21 @@ def test_sample_neighbors_homo(
assert subgraph.original_row_node_ids is None


@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero_single_fanout(labor):
u, i = torch.randint(20, size=(1000,)), torch.randint(10, size=(1000,))
graph = dgl.heterograph({("u", "w", "i"): (u, i), ("i", "b", "u"): (i, u)})

graph = gb.from_dglgraph(graph).to(F.ctx())

sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors

for i in range(11):
nodes = {"u": torch.randint(10, (100,), device=F.ctx())}
sampler(nodes, fanouts=torch.tensor([-1]))
# Should reach here without crashing.


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
total_miss = feat_store_b._feature.total_miss
feat_store_b.read(torch.tensor([0, 1]).to("cuda"))
assert total_miss == feat_store_b._feature.total_miss
assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate

# Test get the size of the entire feature with ids.
assert feat_store_a.size() == torch.Size([3])
Expand Down

0 comments on commit ea2e5a4

Please sign in to comment.