diff --git a/paddle/fluid/framework/new_executor/event_count.h b/paddle/fluid/framework/new_executor/event_count.h index 0c6d49042d22d..7f1e3670056fc 100644 --- a/paddle/fluid/framework/new_executor/event_count.h +++ b/paddle/fluid/framework/new_executor/event_count.h @@ -50,11 +50,13 @@ #include #include #include -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" namespace paddle { namespace framework { +void* AlignedMalloc(size_t size, size_t alignment); +void AlignedFree(void* memory_ptr); + class EventCount { public: class Waiter; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 083d989cb5267..c8acf8cc5dbb8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -37,7 +37,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, main_program_(main_prog), global_scope_(global_scope), stream_analyzer_(place), - async_work_queue_(kHostNumThreads) { + async_work_queue_(kHostNumThreads, main_thread_blocker_) { is_build_ = false; feed_names_ = feed_names; @@ -365,7 +365,8 @@ void InterpreterCore::ExecuteInstructionList( } } - async_work_queue_.WaitEmpty(); + auto event_id = main_thread_blocker_.WaitEvent(); + VLOG(3) << "event_id " << event_id; PADDLE_ENFORCE_EQ( op_run_number_.load(), vec_instr.size(), diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 47f23aff4f00e..eac3131ca1b31 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -94,6 +94,7 @@ class InterpreterCore { InterpreterProfiler dry_run_profiler_; StreamAnalyzer stream_analyzer_; EventManager event_manager_; + EventsWaiter main_thread_blocker_; interpretercore::AsyncWorkQueue async_work_queue_; InterpreterCoreGarbageCollector gc_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 2a5942c712365..3c927a8d81d16 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -33,6 +33,7 @@ #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/workqueue.h" +#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -53,16 +54,19 @@ using AtomicVectorSizeT = std::vector>>; class AsyncWorkQueue { public: - explicit AsyncWorkQueue(size_t host_num_threads) + AsyncWorkQueue(size_t host_num_threads, EventsWaiter* waiter) : host_num_thread_(host_num_threads) { std::vector group_options; // for execute host Kernel group_options.emplace_back(/*num_threads*/ host_num_threads, /*allow_spinning*/ true, - /*track_task*/ true); + /*track_task*/ true, + /*queue_empty_waiter*/ waiter); // for launch device Kernel group_options.emplace_back(/*num_threads*/ 1, - /*allow_spinning*/ true, /*track_task*/ true); + /*allow_spinning*/ true, + /*track_task*/ true, + /*queue_empty_waiter*/ waiter); queue_group_ = CreateWorkQueueGroup(group_options); } @@ -71,7 +75,7 @@ class AsyncWorkQueue { AtomicVectorSizeT& PrepareAtomicVarRef( const std::vector& vec_meta_info); - void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } + // void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } void AddTask(const OpFuncType& op_func_type, std::function fn) { queue_group_->AddTask(static_cast(op_func_type), std::move(fn)); diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h index 2997ce1fe2473..667723c67165c 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -19,9 +19,12 @@ namespace paddle { namespace framework { +template class TaskTracker { public: - TaskTracker() : wait_empty_cv_(1) {} + TaskTracker() = default; + + explicit TaskTracker(Notifier& notifier) : notifier_(¬ifier) {} TaskTracker(const TaskTracker&) = delete; @@ -33,32 +36,17 @@ class TaskTracker { void SubCounter() { if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { - wait_empty_cv_.Notify(true); + if (notifier_ != nullptr) { + notifier_->NotifyEvent(); + } } } - // only one user can wait at any time - void WaitTaskNumToZero() { - bool waiting = false; - if (!wait_empty_.compare_exchange_strong(waiting, true, - std::memory_order_seq_cst, - std::memory_order_relaxed)) { - abort(); - } - EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0); - wait_empty_cv_.Prewait(); - if (num_tasks_.load(std::memory_order_relaxed) == 0) { - wait_empty_cv_.CancelWait(); - } else { - wait_empty_cv_.CommitWait(w); - } - wait_empty_.store(false); - } + uint64_t PendingTaskNum() { return num_tasks_.load(); } private: alignas(64) std::atomic num_tasks_{0}; - alignas(64) EventCount wait_empty_cv_; - alignas(64) std::atomic wait_empty_{false}; + Notifier* notifier_{nullptr}; }; template diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 8c6eeab4d5c0a..deeed1eb72af2 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -13,13 +13,18 @@ namespace paddle { namespace framework { namespace { +using TaskTracker = TaskTracker; + class WorkQueueImpl : public WorkQueue { public: - explicit WorkQueueImpl(const WorkQueueOptions& options) - : WorkQueue(options), queue_(nullptr), tracker_(nullptr) { - if (options_.track_task) { + explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { + if (options_.track_task && options.queue_empty_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - tracker_ = new (storage) TaskTracker; + TaskTracker* tracker = reinterpret_cast(storage); + auto event_id = options.queue_empty_waiter->RegisterEvent( + [tracker]() { return tracker->PendingTaskNum() == 0; }); + auto notifier = options.queue_empty_waiter->GetEventNotifier(event_id); + tracker_ = new (storage) TaskTracker(notifier.get()); } queue_ = new NonblockingThreadPool(options_.num_threads, options_.allow_spinning); @@ -44,20 +49,11 @@ class WorkQueueImpl : public WorkQueue { queue_->AddTask(std::move(fn)); } - void WaitQueueEmpty() override { - if (tracker_ == nullptr) { - PADDLE_THROW( - platform::errors::Unavailable("set WorkQueueOptions.track_task = " - "true before call this interface.")); - } - tracker_->WaitTaskNumToZero(); - } - size_t NumThreads() const override { return queue_->NumThreads(); } private: - NonblockingThreadPool* queue_; - TaskTracker* tracker_; + NonblockingThreadPool* queue_{nullptr}; + TaskTracker* tracker_{nullptr}; }; class WorkQueueGroupImpl : public WorkQueueGroup { @@ -69,8 +65,6 @@ class WorkQueueGroupImpl : public WorkQueueGroup { void AddTask(size_t queue_idx, std::function fn) override; - void WaitQueueGroupEmpty() override; - size_t QueueNumThreads(size_t queue_idx) const override; size_t QueueGroupNumThreads() const override; @@ -92,9 +86,14 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( queues_storage_ = reinterpret_cast(buffer); for (size_t idx = 0; idx < num_queues; ++idx) { const auto& options = queues_options_[idx]; - if (options.track_task && tracker_ == nullptr) { + if (options.track_task && tracker_ == nullptr && + options.queue_empty_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - tracker_ = new (storage) TaskTracker; + TaskTracker* tracker = reinterpret_cast(storage); + auto event_id = options.queue_empty_waiter->RegisterEvent( + [tracker]() { return tracker->PendingTaskNum() == 0; }); + auto notifier = options.queue_empty_waiter->GetEventNotifier(event_id); + tracker_ = new (storage) TaskTracker(notifier.get()); } queues_[idx] = new (&queues_storage_[idx]) NonblockingThreadPool(options.num_threads, options.allow_spinning); @@ -124,15 +123,6 @@ void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function fn) { queues_[queue_idx]->AddTask(std::move(fn)); } -void WorkQueueGroupImpl::WaitQueueGroupEmpty() { - if (nullptr == tracker_) { - PADDLE_THROW(platform::errors::Unavailable( - "set WorkQueueOptions.track_task = true for at least one of queues " - "before call this interface.")); - } - tracker_->WaitTaskNumToZero(); -} - size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const { assert(queue_idx < queues_.size()); return queues_.at(queue_idx)->NumThreads(); diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue.h index ead9d9949b700..e184566f914c3 100644 --- a/paddle/fluid/framework/new_executor/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue.h @@ -21,15 +21,29 @@ namespace paddle { namespace framework { +class EventsWaiter; + struct WorkQueueOptions { WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task) : num_threads(num_threads), allow_spinning(allow_spinning), track_task(track_task) {} + WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task, + EventsWaiter* waiter) + : num_threads(num_threads), + allow_spinning(allow_spinning), + track_task(track_task), + queue_empty_waiter(waiter) {} + size_t num_threads; bool allow_spinning; + // If you need to blocking the calling thread to wait "queue empty", set + // track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will + // block the calling thread until any of events (including "queue empty") + // occured. bool track_task; + EventsWaiter* queue_empty_waiter{nullptr}; }; class WorkQueue { @@ -44,9 +58,8 @@ class WorkQueue { virtual void AddTask(std::function fn) = 0; - // set WorkQueueOptions.track_task = true before call this - // interface, otherwise will abort() - virtual void WaitQueueEmpty() = 0; + // See WorkQueueOptions.track_task for details + // virtual void WaitQueueEmpty() = 0; virtual size_t NumThreads() const = 0; @@ -67,9 +80,8 @@ class WorkQueueGroup { virtual void AddTask(size_t queue_idx, std::function fn) = 0; - // set WorkQueueOptions.track_task = true for at least one of queues - // before call this interface, otherwise will abort() - virtual void WaitQueueGroupEmpty() = 0; + // See WorkQueueOptions.track_task for details + // virtual void WaitQueueGroupEmpty() = 0; virtual size_t QueueNumThreads(size_t queue_idx) const = 0; diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index c229a84b145ab..51a0b58a3865f 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -16,18 +16,21 @@ #include #include "glog/logging.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/new_executor/workqueue_utils.h" TEST(WorkQueue, TestSingleThreadedWorkQueue) { VLOG(1) << "In Test"; using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueue; using paddle::framework::CreateSingleThreadedWorkQueue; + using paddle::framework::EventsWaiter; std::atomic finished{false}; std::atomic counter{0}; constexpr unsigned kLoopNum = 1000000; // CreateSingleThreadedWorkQueue + EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true); + /*track_task*/ true, events_waiter); auto work_queue = CreateSingleThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 1u); @@ -42,7 +45,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { }); // WaitQueueEmpty EXPECT_EQ(finished.load(), false); - work_queue->WaitQueueEmpty(); + events_waiter.WaitEvent(); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum); } @@ -52,13 +55,15 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueue; using paddle::framework::CreateMultiThreadedWorkQueue; + using paddle::framework::EventsWaiter; std::atomic finished{false}; std::atomic counter{0}; constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kLoopNum = 1000000; // CreateMultiThreadedWorkQueue + EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true); + /*track_task*/ true, events_waiter); auto work_queue = CreateMultiThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 10u); @@ -75,7 +80,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { } // WaitQueueEmpty EXPECT_EQ(finished.load(), false); - work_queue->WaitQueueEmpty(); + events_waiter.WaitEvent(); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); } @@ -84,15 +89,17 @@ TEST(WorkQueue, TestWorkQueueGroup) { using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueueGroup; using paddle::framework::CreateWorkQueueGroup; + using paddle::framework::EventsWaiter; std::atomic finished{false}; std::atomic counter{0}; constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kLoopNum = 1000000; - // CreateMultiThreadedWorkQueue + // ThreadedWorkQueueGroup + EventsWaiter events_waiter; WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true); + /*track_task*/ true, events_waiter); WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true); + /*track_task*/ true, events_waiter); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); // NumThreads EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); @@ -113,6 +120,6 @@ TEST(WorkQueue, TestWorkQueueGroup) { } }); // WaitQueueGroupEmpty() - queue_group->WaitQueueGroupEmpty(); + events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); } diff --git a/paddle/fluid/framework/new_executor/workqueue_utils.cc b/paddle/fluid/framework/new_executor/workqueue_utils.cc index 2ea49e676a807..9303588152c31 100644 --- a/paddle/fluid/framework/new_executor/workqueue_utils.cc +++ b/paddle/fluid/framework/new_executor/workqueue_utils.cc @@ -55,5 +55,79 @@ void AlignedFree(void* mem_ptr) { #endif } +constexpr EventsWaiter::EventId kEmptyEventId = -1; + +EventsWaiter::EventsWaiter() + : trigger_event_(kEmptyEventId), waiting_(false), cv_(1) {} + +EventsWaiter::EventId EventsWaiter::RegisterEvent(EventChecker checker) { + checkers_.push_back(std::move(checker)); + EventId id = checkers_.size() - 1; + Notifier n(id, *this); + notifiers_.push_back(n); + return id; +} + +std::reference_wrapper EventsWaiter::GetEventNotifier( + const EventId& id) { + int64_t event_num = checkers_.size(); + if (id < 0 || id > event_num) { + PADDLE_THROW(platform::errors::OutOfRange("Invalid EventId")); + } + return notifiers_[id]; +} + +EventsWaiter::EventId EventsWaiter::WaitEvent() { + // only one user can wait at any time + bool waiting = false; + if (!waiting_.compare_exchange_strong(waiting, true, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + PADDLE_THROW( + platform::errors::ResourceExhausted("Another thread is waiting.")); + } + EventId id = trigger_event_.load(std::memory_order_acquire); + auto w = cv_.GetWaiter(0); + cv_.Prewait(); + int64_t event_num = checkers_.size(); + for (int64_t i = 0; id == kEmptyEventId && i < event_num; ++i) { + if (checkers_[i]()) { + id = i; + } + } + if (id != kEmptyEventId) { + cv_.CancelWait(); + } else { + cv_.CommitWait(w); + } + trigger_event_.store(kEmptyEventId); + waiting_.store(false); + return id; +} + +const std::set EventsWaiter::WaitEvents() { + std::set ids; + auto trigger_ev = WaitEvent(); + ids.insert(trigger_ev); + for (int64_t i = 0; i < static_cast(checkers_.size()); ++i) { + if (checkers_[i]()) { + ids.insert(i); + } + } + return ids; +} + +void EventsWaiter::SetTriggerEvent(const EventId& id) { + EventId expected_id = kEmptyEventId; + if (!trigger_event_.compare_exchange_strong(expected_id, id, + std::memory_order_seq_cst, + std::memory_order_release)) { + return; + } + cv_.Notify(true); +} + +void EventsWaiter::Notifier::NotifyEvent() { waiter_.SetTriggerEvent(id_); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_utils.h b/paddle/fluid/framework/new_executor/workqueue_utils.h index bb219fea36267..d04bf3bfe3c0c 100644 --- a/paddle/fluid/framework/new_executor/workqueue_utils.h +++ b/paddle/fluid/framework/new_executor/workqueue_utils.h @@ -18,6 +18,9 @@ #include #include #include +#include +#include +#include "paddle/fluid/framework/new_executor/event_count.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -64,5 +67,52 @@ void* AlignedMalloc(size_t size, size_t alignment); void AlignedFree(void* memory_ptr); +// A multiplexing waiter, be able to wait multi events simultaneously +// Blocking the calling thread to wait any of the registered events +class EventsWaiter { + public: + using EventId = int64_t; + + using EventChecker = std::function; + + class Notifier { + public: + void NotifyEvent(); + + private: + friend EventsWaiter; + Notifier(EventId id, EventsWaiter& waiter) : id_(id), waiter_(waiter) {} + + EventId id_; + EventsWaiter& waiter_; + }; + + EventsWaiter(); + + EventsWaiter(const EventsWaiter&) = delete; + + EventsWaiter& operator=(const EventsWaiter&) = delete; + + // RegisterEvent must be called before WaitEvent + EventId RegisterEvent(EventChecker checker); + + std::reference_wrapper GetEventNotifier(const EventId& id); + + // Wait any of the registered events + EventId WaitEvent(); + + const std::set WaitEvents(); + + private: + friend Notifier; + void SetTriggerEvent(const EventId& id); + + std::vector checkers_; + std::vector notifiers_; + std::atomic trigger_event_; + std::atomic waiting_; + EventCount cv_; +}; + } // namespace framework } // namespace paddle