Skip to content

Commit

Permalink
Merge pull request #4 from colourful-tree/develop
Browse files Browse the repository at this point in the history
merge colourful-tree's change
  • Loading branch information
guru4elephant committed Dec 3, 2018
2 parents ee4c51a + 0a3d8a2 commit 8eb7f3a
Show file tree
Hide file tree
Showing 11 changed files with 978 additions and 13 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ include(external/warpctc) # download, build, install warpctc
include(cupti)
include(external/gzstream)
endif (NOT WIN32)
include(external/libmct)
#include(external/pslib_brpc)
include(external/pslib)

if(WITH_DISTRIBUTE)
if(WITH_GRPC)
Expand Down Expand Up @@ -276,6 +279,9 @@ set(EXTERNAL_LIBS
protobuf
zlib
${PYTHON_LIBRARIES}
pslib
#pslib_brpc
libmct
)

if(WITH_AMD_GPU)
Expand Down
77 changes: 77 additions & 0 deletions cmake/external/libmct.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

IF(NOT ${WITH_LIBMCT})
return()
ENDIF(NOT ${WITH_LIBMCT})

IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with LIBMCT in Paddle yet."
"Force WITH_LIBMCT=OFF")
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
return()
ENDIF()

INCLUDE(ExternalProject)

SET(LIBMCT_PROJECT "extern_libmct")
IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
MESSAGE(STATUS "use pre defined download url")
SET(LIBMCT_VER "libmct" CACHE STRING "" FORCE) #todo libmct version
SET(LIBMCT_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${LIBMCT_VER}.tar.gz" CACHE STRING "" FORCE) #todo libmct url
ENDIF()
MESSAGE(STATUS "LIBMCT_VER: ${LIBMCT_VER}, LIBMCT_URL: ${LIBMCT_URL}")
SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct")
SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}")
SET(LIBMCT_DST_DIR "libmct")
SET(LIBMCT_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(LIBMCT_INSTALL_DIR ${LIBMCT_INSTALL_ROOT}/${LIBMCT_DST_DIR})
SET(LIBMCT_ROOT ${LIBMCT_INSTALL_DIR})
SET(LIBMCT_INC_DIR ${LIBMCT_ROOT}/include)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${LIBMCT_ROOT}/lib")

INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})

FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(LIBMCT)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${LIBMCT_VER}/include ${LIBMCT_VER}/lib \n"
" DESTINATION ${LIBMCT_DST_DIR})\n")

ExternalProject_Add(
${LIBMCT_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LIBMCT_SOURCE_DIR}
DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_VER}.tar.gz
&& tar zxvf ${LIBMCT_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${LIBMCT_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${LIBMCT_INSTALL_ROOT}
)

if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
add_library(libmct STATIC ${dummyfile})
else()
add_library(libmct INTERFACE)
endif()

#ADD_LIBRARY(libmct SHARED IMPORTED GLOBAL)
ADD_DEPENDENCIES(libmct ${LIBMCT_PROJECT})
LIST(APPEND external_project_dependencies libmct)

76 changes: 76 additions & 0 deletions cmake/external/pslib.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

IF(NOT ${WITH_PSLIB})
return()
ENDIF(NOT ${WITH_PSLIB})

IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB in Paddle yet."
"Force WITH_PSLIB=OFF")
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
return()
ENDIF()

INCLUDE(ExternalProject)

SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "pslib" CACHE STRING "" FORCE) #todo pslib version
SET(PSLIB_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${PSLIB_VER}.tar.gz" CACHE STRING "" FORCE) #todo pslib url
ENDIF()
MESSAGE(STATUS "PSLIB_VER: ${PSLIB_VER}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
SET(PSLIB_DST_DIR "pslib")
SET(PSLIB_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_INSTALL_DIR ${PSLIB_INSTALL_ROOT}/${PSLIB_DST_DIR})
SET(PSLIB_ROOT ${PSLIB_INSTALL_DIR})
SET(PSLIB_INC_DIR ${PSLIB_ROOT}/include)
SET(PSLIB_LIB_DIR ${PSLIB_ROOT}/lib)
SET(PSLIB_LIB ${PSLIB_LIB_DIR}/libps.so)
SET(PSLIB_IOMP_LIB ${PSLIB_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_ROOT}/lib")

INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})

FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_VER}/include ${PSLIB_VER}/lib \n"
" DESTINATION ${PSLIB_DST_DIR})\n")

ExternalProject_Add(
${PSLIB_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_VER}.tar.gz
&& tar zxvf ${PSLIB_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_INSTALL_ROOT}
)

ADD_LIBRARY(pslib STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib PROPERTY IMPORTED_LOCATION ${PSLIB_LIB})
ADD_DEPENDENCIES(pslib ${PSLIB_PROJECT})
LIST(APPEND external_project_dependencies pslib)

IF(WITH_C_API)
INSTALL(FILES ${PSLIB_LIB} ${PSLIB_IOMP_LIB} DESTINATION lib)
ENDIF()
76 changes: 76 additions & 0 deletions cmake/external/pslib_brpc.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

IF(NOT ${WITH_PSLIB_BRPC})
return()
ENDIF(NOT ${WITH_PSLIB_BRPC})

IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
"Force WITH_PSLIB_BRPC=OFF")
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
return()
ENDIF()

INCLUDE(ExternalProject)

SET(PSLIB_BRPC_PROJECT "extern_pslib_brpc")
IF((NOT DEFINED PSLIB_BRPC_VER) OR (NOT DEFINED PSLIB_BRPC_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_BRPC_VER "pslib_brpc" CACHE STRING "" FORCE) #todo pslib version
SET(PSLIB_BRPC_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${PSLIB_BRPC_VER}.tar.gz" CACHE STRING "" FORCE) #todo pslib_brpc url
ENDIF()
MESSAGE(STATUS "PSLIB_BRPC_VER: ${PSLIB_BRPC_VER}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
SET(PSLIB_BRPC_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib_brpc")
SET(PSLIB_BRPC_DOWNLOAD_DIR "${PSLIB_BRPC_SOURCE_DIR}/src/${PSLIB_BRPC_PROJECT}")
SET(PSLIB_BRPC_DST_DIR "pslib_brpc")
SET(PSLIB_BRPC_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_BRPC_INSTALL_DIR ${PSLIB_BRPC_INSTALL_ROOT}/${PSLIB_BRPC_DST_DIR})
SET(PSLIB_BRPC_ROOT ${PSLIB_BRPC_INSTALL_DIR})
SET(PSLIB_BRPC_INC_DIR ${PSLIB_BRPC_ROOT}/include)
SET(PSLIB_BRPC_LIB_DIR ${PSLIB_BRPC_ROOT}/lib)
SET(PSLIB_BRPC_LIB ${PSLIB_BRPC_LIB_DIR}/libps.so)
SET(PSLIB_BRPC_IOMP_LIB ${PSLIB_BRPC_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_BRPC_ROOT}/lib")

INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})

FILE(WRITE ${PSLIB_BRPC_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB_BRPC)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_BRPC_VER}/include ${PSLIB_BRPC_VER}/lib \n"
" DESTINATION ${PSLIB_BRPC_DST_DIR})\n")

ExternalProject_Add(
${PSLIB_BRPC_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_BRPC_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_VER}.tar.gz
&& tar zxvf ${PSLIB_BRPC_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_BRPC_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_BRPC_INSTALL_ROOT}
)

ADD_LIBRARY(pslib_brpc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib_brpc PROPERTY IMPORTED_LOCATION ${PSLIB_BRPC_LIB})
ADD_DEPENDENCIES(pslib_brpc ${PSLIB_BRPC_PROJECT})
LIST(APPEND external_project_dependencies pslib_brpc)

IF(WITH_C_API)
INSTALL(FILES ${PSLIB_BRPC_LIB} ${PSLIB_BRPC_IOMP_LIB} DESTINATION lib)
ENDIF()
84 changes: 80 additions & 4 deletions paddle/fluid/framework/async_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
#include "pslib.h"

namespace paddle {
namespace framework {
Expand All @@ -47,6 +48,10 @@ void AsyncExecutor::CreateThreads(
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->BindingSlotVariableMemory();
worker->SetParamConfig(&_param_config);
}

void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
Expand All @@ -60,6 +65,77 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers[0]->SetFileList(filelist);
}

void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
_pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO
}

void AsyncExecutor::StartServer() {
_pslib_ptr->run_server();
}

void AsyncExecutor::InitModel() {
//TODO only rank = 0 do this
std::vector<int> all_dense_table_id; //TODO
all_dense_table_id.push_back(0);
for (auto table_id: all_dense_table_id) {
std::vector<paddle::ps::Region> regions;
std::vector<std::string> variables; //TODO
for (auto& t : variables) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();

float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";

float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);

std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}

paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}

auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}

void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // TODO should be feasign_cnt < 0, because server bug
LOG(FATAL) << "save model failed";
exit(-1);
}
}

void AsyncExecutor::PrepareDenseThread() {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO

_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));

}

void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
Expand All @@ -82,7 +158,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);

int actual_thread_num = thread_num;
actual_thread_num = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");

Expand All @@ -106,11 +182,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);

PrepareDenseThread();
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
worker.reset(new AsyncExecutorThreadWorker);
}

// prepare thread resource here
Expand All @@ -128,7 +204,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) {
th.join();
}

_pull_dense_thread->stop();
root_scope_->DropKids();

return;
Expand Down
Loading

0 comments on commit 8eb7f3a

Please sign in to comment.