Skip to content

Commit

Permalink
[Heterps]Refactor Heter Pipeline Parameter Server (#36845)
Browse files Browse the repository at this point in the history
* change username

* fix

* fix

* fix

* fix

* fix

* update

* update

* update unittests

* fix

* update

* fix

* update

* fix

* fix

* fix

* update

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update send_and_recv op. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix unit. notest,test=coverage

* fix ut. notest, test=coverage

* update. notest,test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix. notest, test=coverage

* fix. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* add func. notest, test=coverage

* fix ut. notest, test=coverage

* fix. test=develop

* fix. test=develop
  • Loading branch information
zmxdream committed Nov 11, 2021
1 parent 5264566 commit a2da1ef
Show file tree
Hide file tree
Showing 51 changed files with 4,073 additions and 822 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/service/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/service/communicator.h"

#include <google/protobuf/text_format.h>

#include "gflags/gflags.h"
Expand Down Expand Up @@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
<< " from 0' trainer done";
}
}
std::this_thread::sleep_for(
std::chrono::milliseconds(100 + trainer_id_ * 10));
BarrierWithTable(1);
return;
}
Expand Down Expand Up @@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}

if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
Expand Down
101 changes: 78 additions & 23 deletions paddle/fluid/distributed/service/heter_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,36 @@ namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;

int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"the type of micro id shoulde be LoDTensor."));
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data<void>());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
char* temp_ptr = temp.data();
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
}
return micro_id;
}

void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
Expand Down Expand Up @@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() {
}
}
}
previous_xpu_channels_.resize(previous_xpu_list_.size());
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
previous_xpu_channels_[i].reset(new brpc::Channel());
if (previous_xpu_channels_[i]->Init(previous_xpu_list_[i].c_str(), "",
&options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
}

void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& message_name,
const platform::DeviceContext& ctx, const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) {
const std::vector<std::string>& recv_var_name, const std::string& mode) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;

VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
// Todo: get correct channel
int num = trainer_id_ % xpu_channels_.size();

brpc::Controller cntl;
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));

VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE(
cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
ctx, p_scope);

int micro_id = GetMicroId(ctx, p_scope);
auto minibatch_id = micro_id / 10;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
channel = xpu_channels_[num].get();
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
closure);
}

std::future<int32_t> HeterClient::SendCmd(
Expand Down
17 changes: 13 additions & 4 deletions paddle/fluid/distributed/service/heter_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,23 @@ class HeterClient {

void CreateClient2XpuConnection();

void SendAndRecvAsync(const std::vector<std::string>& ep,
const platform::DeviceContext& ctx,
void SendAndRecvAsync(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name);
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");

// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint, const int& trainer_id) {
const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
Expand Down Expand Up @@ -118,16 +121,22 @@ class HeterClient {
xpu_list_ = xpu_list;
}

void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
previous_xpu_list_ = xpu_list;
}

void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }

private:
static std::shared_ptr<HeterClient> s_instance_;
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;

DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;

bool running_ = false;
int trainer_id_;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/heter_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ void HeterServer::StartHeterService() {
ready_ = 1;
}
condition_ready_.notify_all();

std::unique_lock<std::mutex> running_lock(mutex_);
stoped_ = false;
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
}

void HeterServer::SetEndPoint(std::string& endpoint) {
void HeterServer::SetEndPoint(const std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}

void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }

void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
Expand Down
Loading

0 comments on commit a2da1ef

Please sign in to comment.