Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FleetExecutor] Move SendIntra to Carrier && Using BlockingQueue #38322

Merged
merged 1 commit into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else()
set(BRPC_DEPS "")
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper ${BRPC_DEPS})
executor_gc_helper gflags glog ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
Expand Down
71 changes: 48 additions & 23 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
Expand All @@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.

// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
Expand All @@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) {
// handle control message
return true;
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
Expand All @@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
}
return true;
}

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
Expand Down Expand Up @@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
int64_t Carrier::GetRank(int64_t interceptor_id) const {
PADDLE_ENFORCE_NE(
interceptor_id_to_rank_.find(interceptor_id),
interceptor_id_to_rank_.end(),
platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
interceptor_id));
return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
}
}

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
Expand Down Expand Up @@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return;
if (runtime_graph_->interceptor_id_to_node().empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class MessageBus;
class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
Expand Down Expand Up @@ -75,7 +78,7 @@ class Carrier final {

bool IsInit() const;

bool Send(const InterceptorMessage& msg) const;
bool Send(const InterceptorMessage& msg);

// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
Expand All @@ -90,6 +93,8 @@ class Carrier final {

void HandleTmpMessages();

int64_t GetRank(int64_t interceptor_id) const;

// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
Expand All @@ -111,6 +116,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};

Expand Down
30 changes: 15 additions & 15 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand All @@ -28,6 +27,8 @@
namespace paddle {
namespace distributed {

std::unique_ptr<Carrier> FleetExecutor::carrier_;

FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
Expand All @@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
GetCarrier()->Release();
}

Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
}

void FleetExecutor::Init(
Expand Down Expand Up @@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
InitCarrier();
InitMessageBus();
}

void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}

Expand Down Expand Up @@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
addr);
msg_bus_->Init(cur_rank, rank_to_addr, addr);
}
}

void FleetExecutor::Run() {
// Run
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ(
carrier.IsInit(), true,
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier.Start();
GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <memory>
#include <string>

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
Expand All @@ -30,7 +31,6 @@ namespace distributed {
class RuntimeGraph;
class MessageBus;
class TaskNode;
class Carrier;

class FleetExecutor final {
public:
Expand All @@ -43,7 +43,15 @@ class FleetExecutor final {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
Expand All @@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};

} // namespace distributed
Expand Down
28 changes: 5 additions & 23 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var.notify_all();
}

std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
}

int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return interceptor_id_;
}

bool Interceptor::EnqueueRemoteInterceptorMessage(
void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox.";
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
remote_mailbox_.push(interceptor_message);
return true;
remote_mailbox_.Push(interceptor_message);
}

bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
Expand All @@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
Expand All @@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}

bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); });
if (remote_mailbox_.empty()) {
// the thread has been unblocked accidentally
return false;
}
while (!remote_mailbox_.empty()) {
local_mailbox_.push(std::move(remote_mailbox_.front()));
remote_mailbox_.pop();
}
return true;
remote_mailbox_.PopAll(&local_mailbox_);
return !local_mailbox_.empty();
}

static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
Expand Down
Loading