Skip to content

Commit

Permalink
[GraphBolt] Per-thread RandomEngine initialization fix. (#7557)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 22, 2024
1 parent 5b4635a commit c83acdd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
1 change: 0 additions & 1 deletion graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ endif()
add_library(${LIB_GRAPHBOLT_NAME} SHARED ${BOLT_SRC} ${BOLT_HEADERS})
include_directories(BEFORE ${BOLT_DIR}
${BOLT_HEADERS}
"../third_party/dmlc-core/include"
"../third_party/pcg/include"
"../third_party/phmap")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")
Expand Down
9 changes: 5 additions & 4 deletions graphbolt/src/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ std::optional<uint64_t> RandomEngine::manual_seed;
RandomEngine::RandomEngine() {
std::random_device rd;
std::lock_guard lock(manual_seed_mutex);
uint64_t seed = manual_seed.value_or(rd());
SetSeed(seed);
if (!manual_seed.has_value()) manual_seed = rd();
SetSeed(manual_seed.value());
}

/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed) { RandomEngine(seed, GetThreadId()); }
RandomEngine::RandomEngine(uint64_t seed) : RandomEngine(seed, GetThreadId()) {}

/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {
Expand All @@ -49,7 +49,8 @@ RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {

/** @brief Get the thread-local random number generator instance. */
RandomEngine* RandomEngine::ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
static thread_local RandomEngine engine;
return &engine;
}

/** @brief Set the seed. */
Expand Down
2 changes: 0 additions & 2 deletions graphbolt/src/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#ifndef GRAPHBOLT_RANDOM_H_
#define GRAPHBOLT_RANDOM_H_

#include <dmlc/thread_local.h>

#include <mutex>
#include <optional>
#include <pcg_random.hpp>
Expand Down

0 comments on commit c83acdd

Please sign in to comment.