Skip to content

Commit

Permalink
Merge pull request #2 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
fix aucrunner bug,  fused_seqpool_cvm add click filter
  • Loading branch information
jiaoxuewu committed May 12, 2021
2 parents 6e1191c + 86abab9 commit 01bf0c2
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 65 deletions.
3 changes: 2 additions & 1 deletion cmake/external/box_ps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL))
MESSAGE(STATUS "use pre defined download url")
SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE)
SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE)
SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE)
#SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE)
SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.11" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")
SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps")
Expand Down
39 changes: 20 additions & 19 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,6 @@ class DatasetImpl : public Dataset {
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<T>& GetInputRecord() { return input_records_; }

virtual std::set<uint16_t> GetSlotsIdx(
const std::set<std::string>& str_slots) {
std::set<uint16_t> slots_idx;

auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (str_slots.find(cur_slot) != str_slots.end()) {
slots_idx.insert(i);
}
}

return slots_idx;
}

protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
Expand Down Expand Up @@ -376,7 +361,7 @@ class PadBoxSlotDataset : public DatasetImpl<SlotRecord> {
virtual int64_t GetShuffleDataSize() { return input_records_.size(); }
// merge ins from multiple sources and unroll
virtual void UnrollInstance();

virtual void ReceiveSuffleData(const int client_id, const char* msg, int len);

// pre load
virtual void LoadIndexIntoMemory() {}
Expand All @@ -387,15 +372,30 @@ class PadBoxSlotDataset : public DatasetImpl<SlotRecord> {
// shuffle data
virtual void ShuffleData(int thread_num = -1);

public:
virtual void ReceiveSuffleData(const int client_id, const char* msg, int len);

public:
void SetPSAgent(boxps::PSAgentBase* agent) { p_agent_ = agent; }
boxps::PSAgentBase* GetPSAgent(void) { return p_agent_; }
double GetReadInsTime(void) { return max_read_ins_span_; }
double GetOtherTime(void) { return other_timer_.ElapsedSec(); }
double GetMergeTime(void) { return max_merge_ins_span_; }
// aucrunner
std::set<uint16_t> GetSlotsIdx(const std::set<std::string>& str_slots) {
std::set<uint16_t> slots_idx;
uint16_t idx = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
auto slot = multi_slot_desc.slots(i);
if (!slot.is_used() || slot.type().at(0) != 'u') {
continue;
}
if (str_slots.find(slot.name()) != str_slots.end()) {
slots_idx.insert(idx);
}
++idx;
}

return slots_idx;
}

protected:
void MergeInsKeys(const Channel<SlotRecord>& in);
Expand Down Expand Up @@ -437,6 +437,7 @@ class InputTableDataset : public PadBoxSlotDataset {
index_filelist_ = filelist;
}
virtual void LoadIndexIntoMemory();

private:
std::vector<std::string> index_filelist_;
};
Expand Down
61 changes: 36 additions & 25 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -1119,27 +1119,20 @@ class BoxWrapper {
random_ins_pool_list[i].Resize(pool_size);
}

std::unordered_set<std::string> slot_set;
slot_eval_set_.clear();
for (size_t i = 0; i < slot_eval.size(); ++i) {
for (const auto& slot : slot_eval[i]) {
slot_set.insert(slot);
slot_eval_set_.insert(slot);
}
}
for (size_t i = 0; i < slot_list.size(); ++i) {
if (slot_set.find(slot_list[i]) != slot_set.end()) {
slot_index_to_replace_.insert(static_cast<uint16_t>(i));
}
}
for (int i = 0; i < auc_runner_thread_num_; ++i) {
random_ins_pool_list[i].SetReplacedSlots(slot_index_to_replace_);
}

VLOG(0) << "AucRunner configuration: thread number[" << thread_num
<< "], pool size[" << pool_size << "], runner_group[" << phase_num_
<< "]";
VLOG(0) << "Slots that need to be evaluated:";
for (auto e : slot_index_to_replace_) {
VLOG(0) << e << ": " << slot_list[e];
}
<< "], eval size:[" << slot_eval_set_.size() << "]";
// VLOG(0) << "Slots that need to be evaluated:";
// for (auto e : slot_index_to_replace_) {
// VLOG(0) << e << ": " << slot_list[e];
// }
}
void GetRandomReplace(std::vector<SlotRecord>* records);
void PostUpdate();
Expand Down Expand Up @@ -1184,15 +1177,23 @@ class BoxWrapper {
void RecordReplaceBack(std::vector<SlotRecord>* records,
const std::set<uint16_t>& slots);

// aucrunner
void SetReplacedSlots(const std::set<uint16_t>& slot_index_to_replace) {
for (int i = 0; i < auc_runner_thread_num_; ++i) {
random_ins_pool_list[i].SetReplacedSlots(slot_index_to_replace);
}
}
const std::set<std::string>& GetEvalSlotSet() { return slot_eval_set_; }

private:
int mode_ = 0; // 0 means train/test 1 means auc_runner
int auc_runner_thread_num_ = 1;
bool init_done_ = false;
paddle::framework::Channel<int> pass_done_semi_;

std::set<uint16_t> slot_index_to_replace_;
std::vector<FeasignValuesCandidateList> random_ins_pool_list;
std::mutex mutex4random_pool_;
std::set<std::string> slot_eval_set_;
};
/**
* @brief file mgr
Expand Down Expand Up @@ -1250,7 +1251,23 @@ class BoxHelper {
}
#endif
}

#ifdef PADDLE_WITH_BOX_PS
void LoadAucRunnerData(PadBoxSlotDataset* dataset,
boxps::PSAgentBase* agent) {
auto box_ptr = BoxWrapper::GetInstance();
// init random pool slots replace
static bool slot_init = false;
if (!slot_init) {
slot_init = true;
auto slots_set = dataset->GetSlotsIdx(box_ptr->GetEvalSlotSet());
box_ptr->SetReplacedSlots(slots_set);
}
box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum());
auto& records = dataset->GetInputRecord();
box_ptr->PushAucRunnerResource(records.size());
box_ptr->GetRandomReplace(&records);
}
#endif
void ReadData2Memory() {
platform::Timer timer;
VLOG(3) << "Begin ReadData2Memory(), dataset[" << dataset_ << "]";
Expand Down Expand Up @@ -1287,10 +1304,7 @@ class BoxHelper {
timer.Start();
// auc runner
if (box_ptr->Mode() == 1) {
box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum());
auto& records = dataset->GetInputRecord();
box_ptr->PushAucRunnerResource(records.size());
box_ptr->GetRandomReplace(&records);
LoadAucRunnerData(dataset, agent);
}
box_ptr->EndFeedPass(agent);
#endif
Expand Down Expand Up @@ -1350,10 +1364,7 @@ class BoxHelper {
auto box_ptr = BoxWrapper::GetInstance();
// auc runner
if (box_ptr->Mode() == 1) {
box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum());
auto& records = dataset->GetInputRecord();
box_ptr->PushAucRunnerResource(records.size());
box_ptr->GetRandomReplace(&records);
LoadAucRunnerData(dataset, agent);
}
box_ptr->EndFeedPass(agent);
timer.Pause();
Expand Down
23 changes: 18 additions & 5 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
const size_t num_inputs = ins_dims.size();
std::vector<framework::DDim> outs_dims;
outs_dims.resize(num_inputs);
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
bool clk_filter = ctx->Attrs().Get<bool>("clk_filter");

// need filter quant_ratio more than zero
if (ctx->Attrs().Get<bool>("need_filter")) {
Expand Down Expand Up @@ -66,7 +68,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < num_inputs; ++i) {
const auto dims = ins_dims[i];
int rank = dims.size();
if (ctx->Attrs().Get<bool>("use_cvm")) {
if (use_cvm) {
PADDLE_ENFORCE_GT(
dims[rank - 1], 2,
"Shape error in %lu id, the last dimension(embedding) of the "
Expand All @@ -75,8 +77,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
}
// input lod is not accessible here
std::vector<int64_t> out_dim;
if (ctx->Attrs().Get<bool>("use_cvm")) {
out_dim = {-1, dims[rank - 1]};
if (use_cvm) {
if (clk_filter) {
out_dim = {-1, dims[rank - 1] - 1};
} else {
out_dim = {-1, dims[rank - 1]};
}
} else {
out_dim = {-1, dims[rank - 1] - cvm_offset};
}
Expand Down Expand Up @@ -122,6 +128,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("threshold", "(float, default 0.96)").SetDefault(0.96);
AddAttr<int>("cvm_offset", "(int, default 2)").SetDefault(2);
AddAttr<int>("quant_ratio", "(int, default 128)").SetDefault(0);
AddAttr<bool>("clk_filter", "(bool, default false)").SetDefault(false);

AddComment(R"DOC(
Fuse multiple pairs of Sequence Pool and CVM Operator.
Expand All @@ -139,6 +146,8 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
bool clk_filter = ctx->Attrs().Get<bool>("clk_filter");

PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2,
Expand All @@ -151,9 +160,13 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
"The rank of output grad must equal to Input(X). But "
"received: input rank %u, input shape [%s].",
og_dims[i].size(), og_dims[i]));
if (ctx->Attrs().Get<bool>("use_cvm")) {
if (use_cvm) {
auto o_dim = og_dims[i][og_dims[i].size() - 1];
if (clk_filter) { // filter clk need + 1
o_dim = o_dim + 1;
}
PADDLE_ENFORCE_EQ(
og_dims[i][og_dims[i].size() - 1], x_dims[i][og_dims[i].size() - 1],
o_dim, x_dims[i][og_dims[i].size() - 1],
platform::errors::InvalidArgument(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
Expand Down
Loading

0 comments on commit 01bf0c2

Please sign in to comment.