Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of scalar_integer_logger not deal with NOT_ADD_ID #99

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions xdl/ps-plus/ps-plus/scheduler/asynchronizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ namespace {

function<void (const Status&)> MkCb(int id) {
return [id](const Status& st) {
//LOG_WARN("Null callback for worker %d invoked", id);
};
}

Expand Down Expand Up @@ -71,7 +70,6 @@ void Asynchronizer::Enter(int id, function<void (const Status&)> cb) {
return;
}
if (removed_workers_.find(id) != removed_workers_.end()) {
//LOG_FATAL("Worker %d revived", id);
abort();
}
Context* ctx = contexts_.get() + id;
Expand All @@ -92,12 +90,11 @@ void Asynchronizer::Enter(int id, function<void (const Status&)> cb) {
cb(Status::Ok());
}

void Asynchronizer::WorkerReportFinish(int id, std::function<void (const Status&)> cb) {
Status Asynchronizer::WorkerReportFinish(int id) {
if (id < 0 || id >= worker_count_) {
cb(Status::ArgumentError("Offset out of bound: min=0, max="
+ to_string(worker_count_) + ", actual="
+ to_string(id)));
return;
return Status::ArgumentError("Offset out of bound: min=0, max="
+ to_string(worker_count_) + ", actual="
+ to_string(id));
}
removed_workers_.insert(id);
Context* ctx = contexts_.get() + id;
Expand All @@ -109,7 +106,7 @@ void Asynchronizer::WorkerReportFinish(int id, std::function<void (const Status&
UnlockNewSteps(least_step);
}
}
cb(Status::Ok());
return Status::Ok();
}

void Asynchronizer::Reset() {
Expand Down
23 changes: 18 additions & 5 deletions xdl/ps-plus/ps-plus/scheduler/scheduler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,25 @@ void SchedulerImpl::InternalWorkerReportFinish(Version version, int id, function
cb(VersionMismatch(version_, version));
return;
}
Asynchronizer* sync = dynamic_cast<Asynchronizer*>(sync_.get());
if (sync == nullptr) {
LOG(ERROR) << "Call async method in sync mode.";
cb(Status::ArgumentError("Call async method in sync mode."));
if (sync_.get() != nullptr) {
Status st = sync_->WorkerReportFinish(id);
if (!st.IsOk()) {
cb(st);
return;
}
}
finished_workers_.insert(id);
auto iter = worker_barriers_.find(id);
if (iter != worker_barriers_.end()) {
worker_barriers_.erase(iter);
}
if (worker_barriers_.size() == worker_count_ - finished_workers_.size()) {
for (auto iter : worker_barriers_) {
(iter.second)(Status::Ok());
}
worker_barriers_.clear();
}
sync->WorkerReportFinish(id, cb);
cb(Status::Ok());
}

void SchedulerImpl::InternalWorkerBarrier(Version version, int id, int worker_count, function<void (const Status&)> cb) {
Expand Down
22 changes: 15 additions & 7 deletions xdl/ps-plus/ps-plus/scheduler/synchronizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ limitations under the License.
#include "synchronizer.h"

#include "ps-plus/common/status.h"

#include <glog/logging.h>
#include <string>
#include <iostream>
#include <glog/logging.h>

using namespace std;
using namespace std::chrono;
Expand All @@ -33,7 +32,7 @@ namespace {

function<void (int, const Status&)> MkCb(int id) {
return [id](int, const Status& st) {
LOG(WARNING) << "Null callback for worker" << id << " invoked";
LOG(WARNING) << "Null callback for worker " << id << " invoked";
};
}

Expand Down Expand Up @@ -77,15 +76,24 @@ void Synchronizer::Enter(int id, function<void (int64_t, const Status&)> cb) {
waiting_list_.insert(ctx);
}

Status Synchronizer::WorkerReportFinish(int id) {
if (working_list_.find(id) == working_list_.end()) {
return Status::Ok();
}
working_list_.erase(id);
if (left_token_ == 0 && working_list_.empty()) {
UnlockNewToken();
}
return Status::Ok();
}

void Synchronizer::Leave(int id, int64_t token, function<void (const Status&)> cb) {
if (token != current_token_) {
LOG(WARNING) << "Receive token " << token << " from " << id <<
"while current_token_ is " << current_token_;
LOG(WARNING) << "Receive token " << token << " from " << id << " while current_token_ is " << current_token_;
cb(Status::Ok());
}
if (working_list_.find(id) == working_list_.end()) {
LOG(FATAL) << "Worker " << id << " not granted token, but it call leave with token " << token <<
", current token is " << current_token_;
LOG(FATAL) << "Worker " << id << " not granted token, but it call leave with token " << token << ", current token is " << current_token_;
abort();
}
working_list_.erase(id);
Expand Down
6 changes: 4 additions & 2 deletions xdl/ps-plus/ps-plus/scheduler/synchronizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SyncMechanism {
SyncMechanism() {}
virtual ~SyncMechanism() {}
virtual void Reset() = 0;
virtual Status WorkerReportFinish(int id) = 0;
};

class Asynchronizer : public SyncMechanism {
Expand All @@ -53,7 +54,7 @@ class Asynchronizer : public SyncMechanism {
Asynchronizer(int staleness, int worker_count);
~Asynchronizer();
void Enter(int id, std::function<void (const Status&)> cb);
void WorkerReportFinish(int id, std::function<void (const Status&)> cb);
Status WorkerReportFinish(int id);
void Reset();
};

Expand All @@ -74,7 +75,8 @@ class Synchronizer : public SyncMechanism {
Synchronizer(int worker_count);
~Synchronizer() {}
void Enter(int id, std::function<void (int64_t, const Status&)> cb);
void Leave(int id, int64_t token, std::function<void (const Status&)> cb);
void Leave(int id, int64_t token, std::function<void (const Status&)> cb);
Status WorkerReportFinish(int id);
void Reset();
};

Expand Down
14 changes: 6 additions & 8 deletions xdl/ps-plus/ps-plus/scheduler/test/asynchronizer_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@ TEST(Asynchronizer, EnterAndFinish) {
execute_log += "0-2";
});
EXPECT_EQ(execute_log, "0-01-00-12-0");
async->WorkerReportFinish(1, [&execute_log](const Status& st) {
execute_log += "1-finish";
});
EXPECT_EQ(execute_log, "0-01-00-12-01-finish");
async->WorkerReportFinish(2, [&execute_log](const Status& st) {
execute_log += "2-finish";
});
EXPECT_EQ(execute_log, "0-01-00-12-01-finish0-22-finish");
Status st = async->WorkerReportFinish(1);
EXPECT_TRUE(st.IsOk());
EXPECT_EQ(execute_log, "0-01-00-12-0");
st = async->WorkerReportFinish(2);
EXPECT_TRUE(st.IsOk());
EXPECT_EQ(execute_log, "0-01-00-12-00-2");
}

TEST(Asynchronizer, Reset) {
Expand Down
4 changes: 4 additions & 0 deletions xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "ps-plus/server/udf/simple_udf.h"
#include "ps-plus/server/slice.h"
#include "ps-plus/common/initializer/constant_initializer.h"
#include "ps-plus/common/hashmap.h"

namespace ps {
namespace server {
Expand All @@ -32,6 +33,9 @@ class ScalarIntegerLogger : public SimpleUdf<Slices, std::string, int64_t> {
int64_t* data = t->Raw<int64_t>();
int64_t val = pval;
for (size_t slice : slices.slice_id) {
if ((int64_t)slice == ps::HashMap::NOT_ADD_ID) {
continue;
}
data[slice] = val;
}
return Status::Ok();
Expand Down
2 changes: 1 addition & 1 deletion xdl/xdl/core/ops/ps_ops/ps_convert_ckpt_variable_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class PsConvertCkptVariableOp : public xdl::OpKernelAsync {
PS_CHECK_STATUS(ps::FileSystem::OpenWriteStreamAny(file_name, &output_stream));
for (size_t i = 0; i < info.parts.size(); i++) {
ps::server::CheckpointUtils::VariableStruct vs;
printf("Start convert [%s], part[%d]\n", info.name.c_str(), i);
printf("Start convert [%s], part[%ld]\n", info.name.c_str(), i);
PS_CHECK_STATUS(utils.LoadVariable(info.name, i, &vs));
if (!vs.initialized) {
return ps::Status::DataLoss("Load variable " + info.name + " failed.");
Expand Down
7 changes: 3 additions & 4 deletions xdl/xdl/core/ops/ps_ops/ps_synchronizer_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class PsSynchronizeLeaveOp: public xdl::OpKernelAsync {
}
};

class PsSemiSynchronizeLeaveOp: public xdl::OpKernelAsync {
class WorkerReportFinishOp: public xdl::OpKernelAsync {
public:
Status Init(OpKernelConstruction* ctx) override {
return Status::Ok();
Expand All @@ -113,7 +113,6 @@ class WorkerBarrierOp: public xdl::OpKernelAsync {
Status Init(OpKernelConstruction* ctx) override {
return Status::Ok();
}

void Compute(OpKernelContext* ctx, Callback done) override {
ps::client::BaseClient* client;
XDL_CHECK_STATUS_ASYNC(GetClient(&client), done);
Expand All @@ -139,7 +138,7 @@ XDL_DEFINE_OP(PsSynchronizeEnterOp)
XDL_DEFINE_OP(PsSynchronizeLeaveOp)
.Input("id", DataType::kInt32);

XDL_DEFINE_OP(PsSemiSynchronizeLeaveOp)
XDL_DEFINE_OP(WorkerReportFinishOp)
.Input("id", DataType::kInt32);

XDL_DEFINE_OP(WorkerBarrierOp)
Expand All @@ -149,7 +148,7 @@ XDL_DEFINE_OP(WorkerBarrierOp)
XDL_REGISTER_KERNEL(PsAsynchronizeEnterOp, PsAsynchronizeEnterOp).Device("CPU");
XDL_REGISTER_KERNEL(PsSynchronizeEnterOp, PsSynchronizeEnterOp).Device("CPU");
XDL_REGISTER_KERNEL(PsSynchronizeLeaveOp, PsSynchronizeLeaveOp).Device("CPU");
XDL_REGISTER_KERNEL(PsSemiSynchronizeLeaveOp, PsSemiSynchronizeLeaveOp).Device("CPU");
XDL_REGISTER_KERNEL(WorkerReportFinishOp, WorkerReportFinishOp).Device("CPU");
XDL_REGISTER_KERNEL(WorkerBarrierOp, WorkerBarrierOp).Device("CPU");

}
Expand Down