Skip to content

Commit

Permalink
Fix inconsistent calls to nvml::Init and nvml::Shutdown
Browse files Browse the repository at this point in the history
- in the worker thread and thread poll the nvml is called
  only for non-CPU pipelines but the shutdown is called
  unconditionally

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL committed Feb 13, 2024
1 parent 9fbbe5e commit 1941091
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
8 changes: 5 additions & 3 deletions dali/pipeline/util/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ namespace dali {

ThreadPool::ThreadPool(int num_thread, int device_id, bool set_affinity, const char* name)
: threads_(num_thread), running_(true), work_complete_(true), started_(false)
, active_threads_(0) {
, active_threads_(0), device_id_(device_id) {
DALI_ENFORCE(num_thread > 0, "Thread pool must have non-zero size");
#if NVML_ENABLED
// only for the CPU pipeline
if (device_id != CPU_ONLY_DEVICE_ID) {
if (device_id_ != CPU_ONLY_DEVICE_ID) {
nvml::Init();
}
#endif
Expand All @@ -55,7 +55,9 @@ ThreadPool::~ThreadPool() {
thread.join();
}
#if NVML_ENABLED
nvml::Shutdown();
if (device_id_ != CPU_ONLY_DEVICE_ID) {
nvml::Shutdown();
}
#endif
}

Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/util/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DLL_PUBLIC ThreadPool {
bool work_complete_;
bool started_;
int active_threads_;
int device_id_;
std::mutex mutex_;
std::condition_variable condition_;
std::condition_variable completed_;
Expand Down
9 changes: 6 additions & 3 deletions dali/pipeline/util/worker_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class WorkerThread {
typedef std::function<void(void)> Work;

inline WorkerThread(int device_id, bool set_affinity, const std::string &name) :
running_(true), work_complete_(true), barrier_(2) {
running_(true), work_complete_(true), barrier_(2), device_id_(device_id) {
#if NVML_ENABLED
if (device_id != CPU_ONLY_DEVICE_ID) {
if (device_id_ != CPU_ONLY_DEVICE_ID) {
nvml::Init();
}
#endif
Expand All @@ -81,7 +81,9 @@ class WorkerThread {
inline ~WorkerThread() {
Shutdown();
#if NVML_ENABLED
nvml::Shutdown();
if (device_id_ != CPU_ONLY_DEVICE_ID) {
nvml::Shutdown();
}
#endif
}

Expand Down Expand Up @@ -236,6 +238,7 @@ class WorkerThread {
std::queue<string> errors_;

Barrier barrier_;
int device_id_;
};

} // namespace dali
Expand Down

0 comments on commit 1941091

Please sign in to comment.