Skip to content

Commit

Permalink
add horizontal federation learning ps feature (#44327)
Browse files Browse the repository at this point in the history
* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* fl-ps v1.0

* .

* support N + N mode

* .

* .

* .

* .

* delete print

* .

* .

* .

* .

* fix bug

* .

* .

* fl-ps with coordinator ready

* merge dev

* update message parse only

* update fl client scheduler

* fix bug

* update multithreads sync

* fix ci errors

* update role_maker.py

* update role_maker.py

* fix ci error: windows py import error

* fix ci error: windows py import error

* fix windows ci pylib import error

* add dump fields & params

* try to fix windows import fleet error

* fix ps FLAGS error
  • Loading branch information
ziyoujiyi committed Jul 26, 2022
1 parent e3ee510 commit 4bc22b6
Show file tree
Hide file tree
Showing 36 changed files with 1,676 additions and 115 deletions.
1 change: 0 additions & 1 deletion cmake/external/brpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed
GIT_REPOSITORY "https://github.com/wangjiawei04/brpc"
#GIT_REPOSITORY "https://github.com/ziyoujiyi/brpc" # ssl error in the previous repo(can be mannual fixed)
GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e"
PREFIX ${BRPC_PREFIX_DIR}
UPDATE_COMMAND ""
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ set_source_files_properties(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

set_source_files_properties(
coordinator_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

cc_library(
brpc_utils
SRCS brpc_utils.cc
Expand All @@ -90,6 +94,7 @@ cc_library(
cc_library(
downpour_client
SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
coordinator_client.cc
DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})

cc_library(
Expand Down
161 changes: 150 additions & 11 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,22 @@
#include <sstream>
#include <string>

#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"

static const int max_port = 65535;

namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace distributed {

DEFINE_int32(pserver_push_dense_merge_limit,
12,
"limit max push_dense local merge requests");
Expand Down Expand Up @@ -66,16 +78,6 @@ DEFINE_int32(pserver_sparse_table_shard_num,
1000,
"sparse table shard for save & load");

namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace distributed {

inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
Expand All @@ -101,7 +103,7 @@ void DownpourPsClientService::service(
}
}

// 启动client端RpcService 用于数据互发等操作
// 启动 client 端 RpcService 用于数据互发等操作
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
Expand All @@ -122,6 +124,35 @@ int32_t BrpcPsClient::StartClientService() {
_server_started = true;
_env->RegistePsClient(
butil::my_ip_cstr(), _server.listen_address().port, _client_id);
VLOG(0) << "BrpcPsClient Service addr: " << butil::my_ip_cstr() << ", "
<< _server.listen_address().port << ", " << _client_id;
return 0;
}

// 启动 FlClientService,用户接收 coordinator 数据
int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) {
_fl_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (self_endpoint.empty()) {
LOG(ERROR) << "fl-ps > fl client endpoint not set";
return -1;
}

if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) {
VLOG(0) << "fl-ps > StartFlClientService failed. Try again.";
auto ip_port = paddle::string::Split(self_endpoint, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_fl_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartFlClientService failed, ip_port= "
<< int_ip_port;
return -1;
}
} else {
VLOG(0) << "fl-ps > StartFlClientService succeed! listen on "
<< self_endpoint;
}
return 0;
}

Expand Down Expand Up @@ -166,6 +197,96 @@ int32_t BrpcPsClient::CreateClient2ClientConnection(
return 0;
}

int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
// 获取 coordinator 列表,并连接
std::string coordinator_ip_port;
std::vector<PSHost> coordinator_list = _env->GetCoordinators();
_coordinator_channels.resize(coordinator_list.size());
for (size_t i = 0; i < coordinator_list.size(); ++i) {
coordinator_ip_port.assign(coordinator_list[i].ip.c_str());
coordinator_ip_port.append(":");
coordinator_ip_port.append(std::to_string(coordinator_list[i].port));
VLOG(0) << "fl-ps > BrpcFlclient connetcting to coordinator: "
<< coordinator_ip_port;
for (size_t j = 0; j < _coordinator_channels[i].size(); ++j) {
_coordinator_channels[i][j].reset(new brpc::Channel());
if (_coordinator_channels[i][j]->Init(
coordinator_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< coordinator_ip_port << " Failed! Try again.";
std::string int_ip_port = GetIntTypeEndpoint(coordinator_list[i].ip,
coordinator_list[i].port);
if (_coordinator_channels[i][j]->Init(
int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< int_ip_port << " Failed!";
return -1;
}
}
}
}
StartFlClientService(self_endpoint);
VLOG(0) << "fl-ps > InitializeFlWorker finished!";
return 0;
}

void BrpcPsClient::PushFLClientInfoSync(const std::string &fl_client_info) {
size_t request_call_num = _coordinator_channels.size();
FlClientBrpcClosure *closure =
new FlClientBrpcClosure(request_call_num, [request_call_num](void *done) {
auto *closure = reinterpret_cast<FlClientBrpcClosure *>(done);
int ret = 0;
for (size_t i = 0; i < request_call_num; i++) {
if (closure->check_response(i, PUSH_FL_CLIENT_INFO_SYNC) != 0) {
LOG(ERROR) << "fl-ps > PushFLClientInfoSync response from "
"coordinator is failed";
ret = -1;
return;
} else {
VLOG(0) << "fl-ps > rpc service call cost time: "
<< (closure->cntl(i)->latency_us() / 1000) << " ms";
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PUSH_FL_CLIENT_INFO_SYNC);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->set_str_params(fl_client_info);
brpc::Channel *rpc_channel = _coordinator_channels[0][0].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "_coordinator_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // CoordinatorService
rpc_stub.FLService(
closure->cntl(i), closure->request(i), closure->response(i), closure);
fut.wait();
}
VLOG(0) << "fl-ps > PushFLClientInfoSync finished, client id: " << _client_id;
return;
}

std::string BrpcPsClient::PullFlStrategy() {
while (!_service._is_fl_strategy_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(0) << "fl-ps > waiting for fl strategy returned from coordinator";
}
_service._is_fl_strategy_ready =
false; // only support single thread, no need for multi-threads
return _service._fl_strategy;
}

int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;

Expand Down Expand Up @@ -300,6 +421,24 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
return data;
}

int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}

std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
Expand Down
75 changes: 71 additions & 4 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
Expand Down Expand Up @@ -56,16 +57,71 @@ class DownpourPsClientService : public PsService {
_rank = rank_id;
return 0;
}
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;

virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done);

virtual void FLService(::google::protobuf::RpcController *controller,
const CoordinatorReqMessage *request,
CoordinatorResMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
size_t client_id = request->client_id();
CHECK(_client->_client_id == client_id)
<< "request client id not matched self";
_fl_strategy = request->str_params();
_is_fl_strategy_ready = true;
response->set_err_code(0);
response->set_err_msg("");
VLOG(0) << "fl-ps > DownpourPsClientService::FLService finished!";
return;
}

public:
std::string _fl_strategy;
bool _is_fl_strategy_ready = false;

protected:
size_t _rank;
PSClient *_client;
};

class FlClientBrpcClosure : public PSClientClosure {
public:
FlClientBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;

_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~FlClientBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
CoordinatorReqMessage *request(size_t i) { return &_requests[i]; }
CoordinatorResMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);

private:
std::atomic<int32_t> _waiting_num;
std::vector<CoordinatorReqMessage> _requests;
std::vector<CoordinatorResMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};

class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
Expand Down Expand Up @@ -267,6 +323,14 @@ class BrpcPsClient : public PSClient {
}
int32_t Initialize() override;

// for fl
public:
virtual int32_t InitializeFlWorker(const std::string &self_endpoint);
int32_t StartFlClientService(const std::string &self_endpoint);
virtual void PushFLClientInfoSync(const std::string &fl_client_info);
std::string PullFlStrategy();
// for fl

private:
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
Expand Down Expand Up @@ -320,6 +384,8 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_coordinator_channels; // client2coordinator
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
Expand Down Expand Up @@ -360,6 +426,7 @@ class BrpcPsClient : public PSClient {
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
brpc::Server _fl_server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
set_source_files_properties(
communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

Expand Down
Loading

0 comments on commit 4bc22b6

Please sign in to comment.