Skip to content

Commit

Permalink
Merge pull request #26 from songyuwen0808/paddlebox
Browse files Browse the repository at this point in the history
添加auc monitor类型MultiMaskAucCalculator
  • Loading branch information
qingshui committed Jan 18, 2022
2 parents 495be05 + 5222f6e commit 56c50c2
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,75 @@ class MaskMetricMsg : public MetricMsg {
std::string mask_varname_;
};

class MultiMaskMetricMsg : public MetricMsg {
public:
MultiMaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname_list, const std::string& mask_varvalue_list,
int bucket_size = 1000000,
bool mode_collect_in_gpu = false, int max_batch_size = 0) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_list_ = string::split_string(mask_varname_list, " ");
const std::vector<std::string> tmp_val_lst = string::split_string(mask_varvalue_list, " ");
for (const auto& it : tmp_val_lst) {
mask_varvalue_list_.emplace_back(atoi(it.c_str()));
}
PADDLE_ENFORCE_EQ(mask_varname_list_.size(), mask_varvalue_list_.size(),
platform::errors::PreconditionNotMet("mast var num[%zu] should be equal to mask val num[%zu]",
mask_varname_list_.size(), mask_varvalue_list_.size()));

metric_phase_ = metric_phase;
calculator = new BasicAucCalculator(mode_collect_in_gpu);
calculator->init(bucket_size);
}
virtual ~MultiMaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);

std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);

PADDLE_ENFORCE_EQ(label_data.size(), pred_data.size(),
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));

std::vector<std::vector<int64_t>> mask_value_data_list(mask_varname_list_.size());
for (size_t name_idx = 0; name_idx < mask_varname_list_.size(); ++name_idx) {
get_data<int64_t>(exe_scope, mask_varname_list_[name_idx], &mask_value_data_list[name_idx]);
PADDLE_ENFORCE_EQ(label_data.size(), mask_value_data_list[name_idx].size(),
platform::errors::PreconditionNotMet(
"the label data length[%d] should be consistent with "
"the %s[%zu] length", label_data.size(), mask_value_data_list[name_idx].size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
size_t batch_size = label_data.size();
bool flag = true;
for (size_t ins_idx = 0; ins_idx < batch_size; ++ins_idx) {
flag = true;
for (size_t val_idx = 0; val_idx < mask_varvalue_list_.size(); ++val_idx) {
if (mask_value_data_list[val_idx][ins_idx] != mask_varvalue_list_[val_idx]) {
flag = false;
break;
}
}
if (flag) {
cal->add_unlock_data(pred_data[ins_idx], label_data[ins_idx]);
}
}

}

protected:
std::vector<int> mask_varvalue_list_;
std::vector<std::string> mask_varname_list_;
std::string cmatch_rank_varname_;
};

class CmatchRankMaskMetricMsg : public MetricMsg {
public:
CmatchRankMaskMetricMsg(const std::string& label_varname,
Expand Down Expand Up @@ -1137,6 +1206,11 @@ void BoxWrapper::InitMetric(const std::string& method, const std::string& name,
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size, mode_collect_in_gpu,
max_batch_size));
} else if (method == "MultiMaskAucCalculator") {
metric_lists_.emplace(
name, new MultiMaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, cmatch_rank_group, bucket_size, mode_collect_in_gpu,
max_batch_size));
} else if (method == "CmatchRankMaskAucCalculator") {
metric_lists_.emplace(
name, new CmatchRankMaskMetricMsg(
Expand Down

0 comments on commit 56c50c2

Please sign in to comment.