Skip to content

Commit

Permalink
[GraphBolt] Make unique_and_compact deterministic (#7217)
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou committed Mar 25, 2024
1 parent 1ad78fb commit 3c39153
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
36 changes: 30 additions & 6 deletions graphbolt/src/concurrent_id_hash_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init(
// This code block is to fill the ids into hash_map_.
auto unique_ids = torch::empty_like(ids);
IdType* unique_ids_data = unique_ids.data_ptr<IdType>();
// Fill in the first `num_seeds` ids.
torch::parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) {
// Insert all ids into the hash map.
torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
for (int64_t i = s; i < e; i++) {
InsertAndSet(ids_data[i], static_cast<IdType>(i));
InsertAndSetMin(ids_data[i], static_cast<IdType>(i));
}
});
// Place the first `num_seeds` ids.
Expand All @@ -82,13 +82,16 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init(

const int64_t num_threads = torch::get_num_threads();
std::vector<size_t> block_offset(num_threads + 1, 0);
// Insert all elements in this loop.

// Count the valid numbers in each thread.
torch::parallel_for(
num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
size_t count = 0;
for (int64_t i = s; i < e; i++) {
valid[i] = Insert(ids_data[i]);
count += valid[i];
if (MapId(ids_data[i]) == i) {
count++;
valid[i] = 1;
}
}
auto thread_id = torch::get_thread_num();
block_offset[thread_id + 1] = count;
Expand Down Expand Up @@ -199,6 +202,27 @@ inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {
hash_map_.data_ptr<IdType>()[getValueIndex(pos)] = value;
}

template <typename IdType>
void ConcurrentIdHashMap<IdType>::InsertAndSetMin(IdType id, IdType value) {
IdType pos = (id & mask_), delta = 1;
IdType* hash_map_data = hash_map_.data_ptr<IdType>();
InsertState state = AttemptInsertAt(pos, id);
while (state == InsertState::OCCUPIED) {
Next(&pos, &delta);
state = AttemptInsertAt(pos, id);
}

IdType empty_key = static_cast<IdType>(kEmptyKey);
IdType val_pos = getValueIndex(pos);
IdType old_val = empty_key;
while (old_val == empty_key || old_val > value) {
IdType replaced_val =
CompareAndSwap(&(hash_map_data[val_pos]), old_val, value);
if (old_val == replaced_val) break;
old_val = replaced_val;
}
}

template <typename IdType>
inline typename ConcurrentIdHashMap<IdType>::InsertState
ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
Expand Down
10 changes: 10 additions & 0 deletions graphbolt/src/concurrent_id_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ class ConcurrentIdHashMap {
*/
inline void InsertAndSet(IdType key, IdType value);

/**
* @brief Insert a key into the hash map. If the key exists, set the value
* with the smaller value.
*
* @param id The key to be inserted.
* @param value The value to be set for the `key`.
*
*/
inline void InsertAndSetMin(IdType id, IdType value);

/**
* @brief Attempt to insert the key into the hash map at the given position.
*
Expand Down

0 comments on commit 3c39153

Please sign in to comment.