Skip to content

Commit

Permalink
Initial exception rethrowing in executor
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Mar 4, 2024
1 parent 73beab6 commit cc23c51
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,11 @@
#define DALI_PIPELINE_EXECUTOR_EXECUTOR_H_

#include <atomic>
#include <exception>
#include <map>
#include <memory>
#include <queue>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -38,8 +40,9 @@
#include "dali/pipeline/graph/op_graph_verifier.h"
#include "dali/pipeline/operator/batch_size_provider.h"
#include "dali/pipeline/operator/builtin/conditional/split_merge.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/checkpointing/checkpoint.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/error_reporting.h"
#include "dali/pipeline/util/batch_utils.h"
#include "dali/pipeline/util/event_pool.h"
#include "dali/pipeline/util/stream_pool.h"
Expand Down Expand Up @@ -265,10 +268,10 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
}
}

void HandleError(const std::string& message = "Unknown exception") {
void HandleError(const std::string& context = "") {
{
std::lock_guard<std::mutex> errors_lock(errors_mutex_);
errors_.push_back(message);
errors_.push_back({std::current_exception(), context});
}
exec_error_ = true;
ShutdownQueue();
Expand Down Expand Up @@ -359,7 +362,11 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
OpGraph *graph_ = nullptr;
EventPool event_pool_;
ThreadPool thread_pool_;
std::vector<std::string> errors_;
struct ErrorInfo {
std::exception_ptr exception;
std::string context_info;
};
std::vector<ErrorInfo> errors_;
mutable std::mutex errors_mutex_;
bool exec_error_;
QueueSizes queue_sizes_;
Expand Down Expand Up @@ -392,16 +399,23 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {

void RethrowError() const {
std::lock_guard<std::mutex> errors_lock(errors_mutex_);
if (errors_.empty()) {
if (QueuePolicy::IsStopSignaled() && !exec_error_) {
throw std::runtime_error("Stop signaled");
}
throw std::runtime_error("Unknown error");
}

// TODO(klecki): collect all errors
std::string message = errors_.empty()
? QueuePolicy::IsStopSignaled() && !exec_error_
? "Stop signaled"
: "Unknown error"
: errors_.front();

// TODO(michalz): rethrow actual error through std::exception_ptr instead of
// converting everything to runtime_error
throw std::runtime_error(message);
auto &error = errors_.front();
try {
std::rethrow_exception(error.exception);
} catch (DaliError &e) {
if (error.context_info.size()) {
e.AddOriginInfo(error.context_info);
}
throw;
}
}

void DiscoverBatchSizeProviders() {
Expand Down

0 comments on commit cc23c51

Please sign in to comment.