Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PsCore] support ssd #33031

Merged
merged 10 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cmake/external/rocksdb.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

INCLUDE(ExternalProject)

SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb)
SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb)
SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE)
SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE)
SET(ROCKSDB_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR})

ExternalProject_Add(
extern_rocksdb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${ROCKSDB_SOURCES_DIR}
GIT_REPOSITORY "https://github.com/facebook/rocksdb"
GIT_TAG v6.10.1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DWITH_BZ2=OFF
-DWITH_GFLAGS=OFF
-DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
# BUILD_BYPRODUCTS ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a
INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/
&& cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES}
&& cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
)

ADD_DEPENDENCIES(extern_rocksdb snappy)

ADD_LIBRARY(rocksdb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET rocksdb PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES})
ADD_DEPENDENCIES(rocksdb extern_rocksdb)

LIST(APPEND external_project_dependencies rocksdb)

5 changes: 5 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ if (WITH_PSCORE)

include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
endif()

if(WITH_XBYAK)
Expand Down
11 changes: 8 additions & 3 deletions paddle/fluid/distributed/fleet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return;
}

void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, mode);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
Expand All @@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {

void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret =
pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
communicator->_worker_ptr->load(table_id, path, std::to_string(mode));
// auto ret =
// pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/fleet.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
void LoadModel(const std::string& path, const std::string& mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/distributed/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ ::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id,
::std::future<int32_t> PsLocalClient::load(const std::string& epoch,
const std::string& mode) {
// TODO
// for (auto& it : _table_map) {
// load(it.first, epoch, mode);
//}
for (auto& it : _table_map) {
load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
// auto* table_ptr = table(table_id);
// table_ptr->load(epoch, mode);
auto* table_ptr = table(table_id);
table_ptr->load(epoch, mode);
return done();
}

Expand Down Expand Up @@ -245,7 +245,6 @@ ::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values,
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) {
VLOG(1) << "wxx push_sparse_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/distributed/service/ps_local_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ class PsLocalServer : public PSServer {
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string& ip, uint32_t port) { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}

private:
virtual int32_t initialize() { return 0; }
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class PSServer {

virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) final;
const std::vector<framework::ProgramDesc> &server_sub_program = {});

// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/distributed/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $
cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc
sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator)
set(EXTERN_DEP "")
if(WITH_HETERPS)
set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
set(EXTERN_DEP rocksdb)
else()
set(TABLE_SRC common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
endif()

cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})

set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
104 changes: 19 additions & 85 deletions paddle/fluid/distributed/table/common_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,83 +25,12 @@ class ValueBlock;
} // namespace distributed
} // namespace paddle

#define PSERVER_SAVE_SUFFIX ".shard"
using boost::lexical_cast;

namespace paddle {
namespace distributed {

enum SaveMode { all, base, delta };

struct Meta {
std::string param;
int shard_id;
std::vector<std::string> names;
std::vector<int> dims;
uint64_t count;
std::unordered_map<std::string, int> dims_map;

explicit Meta(const std::string& metapath) {
std::ifstream file(metapath);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
if (StartWith(line, "#")) {
continue;
}
auto pairs = paddle::string::split_string<std::string>(line, "=");
PADDLE_ENFORCE_EQ(
pairs.size(), 2,
paddle::platform::errors::InvalidArgument(
"info in %s except k=v, but got %s", metapath, line));

if (pairs[0] == "param") {
param = pairs[1];
}
if (pairs[0] == "shard_id") {
shard_id = std::stoi(pairs[1]);
}
if (pairs[0] == "row_names") {
names = paddle::string::split_string<std::string>(pairs[1], ",");
}
if (pairs[0] == "row_dims") {
auto dims_strs =
paddle::string::split_string<std::string>(pairs[1], ",");
for (auto& str : dims_strs) {
dims.push_back(std::stoi(str));
}
}
if (pairs[0] == "count") {
count = std::stoull(pairs[1]);
}
}
for (int x = 0; x < names.size(); ++x) {
dims_map[names[x]] = dims[x];
}
}

Meta(std::string param, int shard_id, std::vector<std::string> row_names,
std::vector<int> dims, uint64_t count) {
this->param = param;
this->shard_id = shard_id;
this->names = row_names;
this->dims = dims;
this->count = count;
}

std::string ToString() {
std::stringstream ss;
ss << "param=" << param << "\n";
ss << "shard_id=" << shard_id << "\n";
ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n";
ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n";
ss << "count=" << count << "\n";
return ss.str();
}
};

void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
const int64_t id, std::vector<std::vector<float>>* values) {
void CommonSparseTable::ProcessALine(const std::vector<std::string>& columns,
const Meta& meta, const int64_t id,
std::vector<std::vector<float>>* values) {
auto colunmn_size = columns.size();
auto load_values =
paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
Expand Down Expand Up @@ -134,8 +63,10 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
}
}

void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total) {
void CommonSparseTable::SaveMetaToText(std::ostream* os,
const CommonAccessorParameter& common,
const size_t shard_idx,
const int64_t total) {
// save meta
std::stringstream stream;
stream << "param=" << common.table_name() << "\n";
Expand All @@ -148,8 +79,10 @@ void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
}

int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode) {
int64_t CommonSparseTable::SaveValueToText(std::ostream* os,
std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool,
const int mode, int shard_id) {
int64_t save_num = 0;
for (auto& table : block->values_) {
for (auto& value : table) {
Expand Down Expand Up @@ -186,10 +119,10 @@ int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
return save_num;
}

int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num,
const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
int64_t CommonSparseTable::LoadFromText(
const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
Meta meta = Meta(metapath);

int num_lines = 0;
Expand All @@ -198,7 +131,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
auto id = lexical_cast<int64_t>(values[0]);
auto id = lexical_cast<uint64_t>(values[0]);

if (id % pserver_num != pserver_id) {
VLOG(3) << "will not load " << values[0] << " from " << valuepath
Expand Down Expand Up @@ -388,8 +321,9 @@ int32_t CommonSparseTable::save(const std::string& dirname,
int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values
auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode);
auto shard_save_num =
SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode, shard_id);
total_ins += shard_save_num;
}
vs->close();
Expand Down
Loading