Skip to content

Commit

Permalink
feat: CmatchRankMaskMetricMsg (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
laipaang authored Feb 22, 2021
1 parent de7313b commit 1a6b2f2
Showing 1 changed file with 101 additions and 2 deletions.
103 changes: 101 additions & 2 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,99 @@ class BoxWrapper {
protected:
std::string mask_varname_;
};

class CmatchRankMaskMetricMsg : public MetricMsg {
public:
CmatchRankMaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
bool ignore_rank = false,
const std::string& mask_varname = "",
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
mask_varname_ = mask_varname;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal cmatch_rank auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
}
virtual ~CmatchRankMaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and label_data[%lu]",
batch_size, label_data.size()));
PADDLE_ENFORCE_EQ(
batch_size, pred_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and pred_data[%lu]",
batch_size, pred_data.size()));

std::vector<int64_t> mask_data;
if (!mask_varname_.empty()) {
get_data<int64_t>(exe_scope, mask_varname_, &mask_data);
PADDLE_ENFORCE_EQ(
batch_size, mask_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and mask_data[%lu]",
batch_size, mask_data.size()));
}

auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
if (!mask_data.empty() && !mask_data[i]) {
continue;
}
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_unlock_data(pred_data[i], label_data[i]);
break;
}
}
}
}

protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
std::string mask_varname_;
};

const std::vector<std::string> GetMetricNameList(
int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
Expand Down Expand Up @@ -909,10 +1002,16 @@ class BoxWrapper {
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size,
mode_collect_in_gpu, max_batch_size));
} else if (method == "CmatchRankMaskAucCalculator") {
metric_lists_.emplace(name, new CmatchRankMaskMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, mask_varname, bucket_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"PaddleBox only support AucCalculator, MultiTaskAucCalculator "
"CmatchRankAucCalculator and MaskAucCalculator"));
"PaddleBox only support AucCalculator, MultiTaskAucCalculator, "
"CmatchRankAucCalculator, MaskAucCalculator and "
"CmatchRankMaskAucCalculator"));
}
metric_name_list_.emplace_back(name);
}
Expand Down

0 comments on commit 1a6b2f2

Please sign in to comment.