Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Change id hash map #5304

Merged
merged 6 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/**
* Copyright (c) 2023 by Contributors
* @file array/cpu/id_hash_map.cc
* @file array/cpu/concurrent_id_hash_map.cc
* @brief Class about id hash map
*/

#include "id_hash_map.h"
#include "concurrent_id_hash_map.h"

#ifdef _MSC_VER
#include <intrin.h>
Expand Down Expand Up @@ -35,7 +35,7 @@ namespace dgl {
namespace aten {

template <typename IdType>
IdType IdHashMap<IdType>::CompareAndSwap(
IdType ConcurrentIdHashMap<IdType>::CompareAndSwap(
IdType* ptr, IdType old_val, IdType new_val) {
#ifdef _MSC_VER
if (sizeof(IdType) == 4) {
Expand All @@ -55,7 +55,7 @@ IdType IdHashMap<IdType>::CompareAndSwap(
}

template <typename IdType>
IdHashMap<IdType>::IdHashMap() : mask_(0) {
ConcurrentIdHashMap<IdType>::ConcurrentIdHashMap() : mask_(0) {
// Used to deallocate the memory in hash_map_ with device api
// when the pointer is freed.
auto deleter = [](Mapping* mappings) {
Expand All @@ -69,36 +69,49 @@ IdHashMap<IdType>::IdHashMap() : mask_(0) {
}

template <typename IdType>
IdArray IdHashMap<IdType>::Init(const IdArray& ids) {
IdArray ConcurrentIdHashMap<IdType>::Init(
const IdArray& ids, size_t num_seeds) {
CHECK_EQ(ids.defined(), true);
const IdType* ids_data = ids.Ptr<IdType>();
const size_t num_ids = static_cast<size_t>(ids->shape[0]);
// Make sure `ids` is not 0 dim.
CHECK_GT(num_ids, 0);
CHECK_GE(num_seeds, 0);
CHECK_GE(num_ids, num_seeds);
size_t capacity = GetMapSize(num_ids);
mask_ = static_cast<IdType>(capacity - 1);

auto ctx = DGLContext{kDGLCPU, 0};
auto device = DeviceAPI::Get(ctx);
hash_map_.reset(static_cast<Mapping*>(
DeviceAPI::Get(ctx)->AllocWorkspace(ctx, sizeof(Mapping) * capacity)));
device->AllocWorkspace(ctx, sizeof(Mapping) * capacity)));
memset(hash_map_.get(), -1, sizeof(Mapping) * capacity);

// This code block is to fill the ids into hash_map_.
IdArray unique_ids = NewIdArray(num_ids, ctx, sizeof(IdType) * 8);
IdType* unique_ids_data = unique_ids.Ptr<IdType>();
// Fill in the first `num_seeds` ids.
parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) {
Copy link
Collaborator

@frozenbugs frozenbugs Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the common scale of num_seeds? Since kGrainSize is 256 already, do we need to use parallel?

Copy link
Collaborator Author

@peizhou001 peizhou001 Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scale depends on your fan-out, which usually about 1/10 of the original nodes. When input nodes is huge, it could be also very large. And parallel doesn't introduce side effects, so keep it here should be better.

for (int64_t i = s; i < e; i++) {
InsertAndSet(ids_data[i], static_cast<IdType>(i));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why for seed ids we don't use AttemptInsertAt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Seed ids mapping value is exactly its index in the array, so the key and value need to be set at the same time.
  2. Seed ids is unique so the insertion is simpler and some checks can be removed to save efforts.

}
});
// Place the first `num_seeds` ids.
device->CopyDataFromTo(
ids_data, 0, unique_ids_data, 0, sizeof(IdType) * num_seeds, ctx, ctx,
ids->dtype);

// An auxiliary array indicates whether the corresponding elements
// are inserted into hash map or not. Use `int16_t` instead of `bool` as
// vector<bool> is unsafe when updating different elements from different
// threads. See https://en.cppreference.com/w/cpp/container#Thread_safety.
std::vector<int16_t> valid(num_ids);
auto thread_num = compute_num_threads(0, num_ids, kGrainSize);
std::vector<size_t> block_offset(thread_num + 1, 0);
IdType* unique_ids_data = unique_ids.Ptr<IdType>();

// Insert all elements in this loop.
parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
size_t count = 0;
for (int64_t i = s; i < e; i++) {
Insert(ids_data[i], &valid, i);
valid[i] = Insert(ids_data[i]);
count += valid[i];
}
block_offset[omp_get_thread_num() + 1] = count;
Expand All @@ -107,12 +120,12 @@ IdArray IdHashMap<IdType>::Init(const IdArray& ids) {
// Get ExclusiveSum of each block.
std::partial_sum(
block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1);
unique_ids->shape[0] = block_offset.back();
unique_ids->shape[0] = num_seeds + block_offset.back();

// Get unique array from ids and set value for hash map.
parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
auto tid = omp_get_thread_num();
auto pos = block_offset[tid];
auto pos = block_offset[tid] + num_seeds;
for (int64_t i = s; i < e; i++) {
if (valid[i]) {
unique_ids_data[pos] = ids_data[i];
Expand All @@ -125,7 +138,7 @@ IdArray IdHashMap<IdType>::Init(const IdArray& ids) {
}

template <typename IdType>
IdArray IdHashMap<IdType>::MapIds(const IdArray& ids) const {
IdArray ConcurrentIdHashMap<IdType>::MapIds(const IdArray& ids) const {
CHECK_EQ(ids.defined(), true);
const IdType* ids_data = ids.Ptr<IdType>();
const size_t num_ids = static_cast<size_t>(ids->shape[0]);
Expand All @@ -144,14 +157,15 @@ IdArray IdHashMap<IdType>::MapIds(const IdArray& ids) const {
}

template <typename IdType>
inline void IdHashMap<IdType>::Next(IdType* pos, IdType* delta) const {
inline void ConcurrentIdHashMap<IdType>::Next(
IdType* pos, IdType* delta) const {
// Use Quadric probing.
*pos = (*pos + (*delta) * (*delta)) & mask_;
*delta = *delta + 1;
}

template <typename IdType>
IdType IdHashMap<IdType>::MapId(IdType id) const {
inline IdType ConcurrentIdHashMap<IdType>::MapId(IdType id) const {
IdType pos = (id & mask_), delta = 1;
IdType empty_key = static_cast<IdType>(kEmptyKey);
while (hash_map_[pos].key != empty_key && hash_map_[pos].key != id) {
Expand All @@ -161,16 +175,19 @@ IdType IdHashMap<IdType>::MapId(IdType id) const {
}

template <typename IdType>
void IdHashMap<IdType>::Insert(
IdType id, std::vector<int16_t>* valid, size_t index) {
bool ConcurrentIdHashMap<IdType>::Insert(IdType id) {
IdType pos = (id & mask_), delta = 1;
while (!AttemptInsertAt(pos, id, valid, index)) {
InsertState state = AttemptInsertAt(pos, id);
while (state == InsertState::OCCUPIED) {
Next(&pos, &delta);
state = AttemptInsertAt(pos, id);
}

return state == InsertState::INSERTED;
}

template <typename IdType>
void IdHashMap<IdType>::Set(IdType key, IdType value) {
inline void ConcurrentIdHashMap<IdType>::Set(IdType key, IdType value) {
IdType pos = (key & mask_), delta = 1;
while (hash_map_[pos].key != key) {
Next(&pos, &delta);
Expand All @@ -180,21 +197,31 @@ void IdHashMap<IdType>::Set(IdType key, IdType value) {
}

template <typename IdType>
bool IdHashMap<IdType>::AttemptInsertAt(
int64_t pos, IdType key, std::vector<int16_t>* valid, size_t index) {
inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {
IdType pos = (id & mask_), delta = 1;
while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) {
Next(&pos, &delta);
}

hash_map_[pos].value = value;
}

template <typename IdType>
inline typename ConcurrentIdHashMap<IdType>::InsertState
ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
IdType empty_key = static_cast<IdType>(kEmptyKey);
IdType old_val = CompareAndSwap(&(hash_map_[pos].key), empty_key, key);

if (old_val != empty_key && old_val != key) {
return false;
if (old_val == empty_key) {
return InsertState::INSERTED;
} else if (old_val == key) {
return InsertState::EXISTED;
} else {
if (old_val == empty_key) (*valid)[index] = true;
return true;
return InsertState::OCCUPIED;
}
}

template class IdHashMap<int32_t>;
template class IdHashMap<int64_t>;
template class ConcurrentIdHashMap<int32_t>;
template class ConcurrentIdHashMap<int64_t>;

} // namespace aten
} // namespace dgl
119 changes: 69 additions & 50 deletions src/array/cpu/id_hash_map.h → src/array/cpu/concurrent_id_hash_map.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/**
* Copyright (c) 2023 by Contributors
* @file array/cpu/id_hash_map.h
* @brief Class about id hash map
* @file array/cpu/concurrent_id_hash_map.h
* @brief Class about concurrent id hash map
*/

#ifndef DGL_ARRAY_CPU_ID_HASH_MAP_H_
#define DGL_ARRAY_CPU_ID_HASH_MAP_H_
#ifndef DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_
#define DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_

#include <dgl/aten/types.h>

Expand All @@ -24,34 +24,53 @@ namespace aten {
* key insertions once with Init function, and it does not support key deletion.
*
* The hash map should be prepared in two phases before using. With the first
* being creating the hashmap, and then init it with an id array.
* being creating the hashmap, and then initialize it with an id array which is
* divided into 2 parts: [`seed ids`, `sampled ids`]. `Seed ids` refer to
* a set ids chosen as the input for sampling process and `sampled ids` are the
* ids new sampled from the process (note the the `seed ids` might also be
* sampled in the process and included in the `sampled ids`). In result `seed
* ids` are mapped to [0, num_seed_ids) and `sampled ids` to [num_seed_ids,
* num_unique_ids). Notice that mapping order is stable for `seed ids` while not
* for the `sampled ids`.
*
* For example, for an array A with following entries:
* [98, 98, 100, 99, 97, 99, 101, 100, 102]
* Create the hashmap H with:
* `H = CpuIdHashMap()` (1)
* For example, for an array `A` having 4 seed ids with following entries:
* [99, 98, 100, 97, 97, 101, 101, 102, 101]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the comment below you mentioned num_seeds ids are unique, I am assuming the first 4 are seed ids, but I see duplicated 97 in this example, it this intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Seed ids is unique among themselves, but it can be duplicate with other ids. So put it here may help user clarify it.

* Create the hashmap `H` with:
* `H = ConcurrentIdHashMap()` (1)
* And Init it with:
* `U = H.Init(A)` (2) (U is an id array used to store the unqiue
* ids in A).
* Then U should be (U is not exclusive as the element order is not
* guaranteed to be steady):
* [98, 100, 99, 97, 101, 102]
* Then `U` should be (U is not exclusive as the overall mapping is not stable):
* [99, 98, 100, 97, 102, 101]
* And the hashmap should generate following mappings:
* * [
* {key: 98, value: 0},
* {key: 100, value: 1},
* {key: 99, value: 2},
* {key: 99, value: 0},
* {key: 98, value: 1},
* {key: 100, value: 2},
* {key: 97, value: 3},
* {key: 101, value: 4},
* {key: 102, value: 5}
* {key: 102, value: 4},
* {key: 101, value: 5}
* ]
* Search the hashmap with array I=[98, 99, 102]:
* Search the hashmap with array `I`=[98, 99, 102]:
* R = H.Map(I) (3)
* R should be:
* [0, 2, 5]
* [1, 0, 4]
**/
template <typename IdType>
class IdHashMap {
class ConcurrentIdHashMap {
private:
/**
* @brief The result state of an attempt to insert.
*/
enum class InsertState {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to private section?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

OCCUPIED, // Indicates that the space where an insertion is being
// attempted is already occupied by another element.
EXISTED, // Indicates that the element being inserted already exists in the
// map, and thus no insertion is performed.
INSERTED // Indicates that the insertion was successful and a new element
// was added to the map.
};

public:
/**
* @brief An entry in the hashtable.
Expand Down Expand Up @@ -81,25 +100,25 @@ class IdHashMap {
*/
static IdType CompareAndSwap(IdType* ptr, IdType old_val, IdType new_val);

IdHashMap();
ConcurrentIdHashMap();

IdHashMap(const IdHashMap& other) = delete;
IdHashMap& operator=(const IdHashMap& other) = delete;
ConcurrentIdHashMap(const ConcurrentIdHashMap& other) = delete;
ConcurrentIdHashMap& operator=(const ConcurrentIdHashMap& other) = delete;

/**
* @brief Init the hashmap with an array of ids.
* Firstly allocating the memeory and init the entire space with empty key.
* And then insert the items in `ids` concurrently to generate the
* mappings, in passing returning the unique ids in `ids`.
* @brief Initialize the hashmap with an array of ids. The first `num_seeds`
* ids are unique and must be mapped to a contiguous array starting
* from 0. The left can be duplicated and the mapping result is not stable.
*
* @param ids The array of ids to be inserted as keys.
* @param ids The array of the ids to be inserted.
* @param num_seeds The number of seed ids.
*
* @return Unique ids for the input `ids`.
* @return Unique ids from the input `ids`.
*/
IdArray Init(const IdArray& ids);
IdArray Init(const IdArray& ids, size_t num_seeds);

/**
* @brief Find the mappings of given keys.
* @brief Find mappings of given keys.
*
* @param ids The keys to map for.
*
Expand All @@ -114,27 +133,25 @@ class IdHashMap {
* @param[in,out] pos Calculate the next position with quadric probing.
* @param[in,out] delta Calculate the next delta by adding 1.
*/
void Next(IdType* pos, IdType* delta) const;
inline void Next(IdType* pos, IdType* delta) const;

/**
* @brief Find the mapping of a given key.
*
* @param id The key to map for.
*
* @return Mapping result for the `id`.
* @return Mapping result corresponding to `id`.
*/
IdType MapId(const IdType id) const;
inline IdType MapId(const IdType id) const;

/**
* @brief Insert an id into the hash map.
*
* @param id The id to be inserted.
* @param valid The item at index will be set to indicate
* whether the `id` at `index` is inserted or not.
* @param index The index of the `id`.
*
* @return Whether the `id` is inserted or not.
*/
void Insert(IdType id, std::vector<int16_t>* valid, size_t index);
inline bool Insert(IdType id);

/**
* @brief Set the value for the key in the hash map.
Expand All @@ -144,24 +161,26 @@ class IdHashMap {
*
* @warning Key must exist.
*/
void Set(IdType key, IdType value);
inline void Set(IdType key, IdType value);

/**
* @brief Insert a key into the hash map.
*
* @param id The key to be inserted.
* @param value The value to be set for the `key`.
*
*/
inline void InsertAndSet(IdType key, IdType value);

/**
* @brief Attempt to insert the key into the hash map at the given position.
* 1. If the key at `pos` is empty -> Set the key, return true and set
* `valid[index]` to true.
* 2. If the key at `pos` is equal to `key` -> Return true.
* 3. If the key at `pos` is non-empty and not equal to `key` -> Return false.
*
* @param pos The position in the hash map to be inserted at.
* @param key The key to be inserted.
* @param valid The item at index will be set to indicate
* whether the `key` at `index` is inserted or not.
* @param index The index of the `key`.
*
* @return Whether the key exists in the map now.
* @return The state of the insertion.
*/
bool AttemptInsertAt(
int64_t pos, IdType key, std::vector<int16_t>* valid, size_t index);
inline InsertState AttemptInsertAt(int64_t pos, IdType key);

private:
/**
Expand All @@ -179,4 +198,4 @@ class IdHashMap {
} // namespace aten
} // namespace dgl

#endif // DGL_ARRAY_CPU_ID_HASH_MAP_H_
#endif // DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_
Loading