diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index d62eb6890f12b..9e4589124a72a 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -849,6 +849,99 @@ class BoxWrapper { protected: std::string mask_varname_; }; + + class CmatchRankMaskMetricMsg : public MetricMsg { + public: + CmatchRankMaskMetricMsg(const std::string& label_varname, + const std::string& pred_varname, int metric_phase, + const std::string& cmatch_rank_group, + const std::string& cmatch_rank_varname, + bool ignore_rank = false, + const std::string& mask_varname = "", + int bucket_size = 1000000) { + label_varname_ = label_varname; + pred_varname_ = pred_varname; + cmatch_rank_varname_ = cmatch_rank_varname; + metric_phase_ = metric_phase; + ignore_rank_ = ignore_rank; + mask_varname_ = mask_varname; + calculator = new BasicAucCalculator(); + calculator->init(bucket_size); + for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { + if (ignore_rank) { // CmatchAUC + cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0); + continue; + } + const std::vector& cur_cmatch_rank = + string::split_string(cmatch_rank, "_"); + PADDLE_ENFORCE_EQ( + cur_cmatch_rank.size(), 2, + platform::errors::PreconditionNotMet( + "illegal cmatch_rank auc spec: %s", cmatch_rank.c_str())); + cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()), + atoi(cur_cmatch_rank[1].c_str())); + } + } + virtual ~CmatchRankMaskMetricMsg() {} + void add_data(const Scope* exe_scope, + const paddle::platform::Place& place) override { + std::vector cmatch_rank_data; + get_data(exe_scope, cmatch_rank_varname_, &cmatch_rank_data); + std::vector label_data; + get_data(exe_scope, label_varname_, &label_data); + std::vector pred_data; + get_data(exe_scope, pred_varname_, &pred_data); + size_t batch_size = cmatch_rank_data.size(); + PADDLE_ENFORCE_EQ( + batch_size, label_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: cmatch_rank[%lu] and label_data[%lu]", + batch_size, label_data.size())); + PADDLE_ENFORCE_EQ( + batch_size, pred_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: cmatch_rank[%lu] and pred_data[%lu]", + batch_size, pred_data.size())); + + std::vector mask_data; + if (!mask_varname_.empty()) { + get_data(exe_scope, mask_varname_, &mask_data); + PADDLE_ENFORCE_EQ( + batch_size, mask_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: cmatch_rank[%lu] and mask_data[%lu]", + batch_size, mask_data.size())); + } + + auto cal = GetCalculator(); + std::lock_guard lock(cal->table_mutex()); + for (size_t i = 0; i < batch_size; ++i) { + const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]); + for (size_t j = 0; j < cmatch_rank_v.size(); ++j) { + if (!mask_data.empty() && !mask_data[i]) { + continue; + } + bool is_matched = false; + if (ignore_rank_) { + is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first; + } else { + is_matched = cmatch_rank_v[j] == cur_cmatch_rank; + } + if (is_matched) { + cal->add_unlock_data(pred_data[i], label_data[i]); + break; + } + } + } + } + + protected: + std::vector> cmatch_rank_v; + std::string cmatch_rank_varname_; + bool ignore_rank_; + std::string mask_varname_; + }; + const std::vector GetMetricNameList( int metric_phase = -1) const { VLOG(0) << "Want to Get metric phase: " << metric_phase; @@ -909,10 +1002,16 @@ class BoxWrapper { name, new MaskMetricMsg(label_varname, pred_varname, metric_phase, mask_varname, bucket_size, mode_collect_in_gpu, max_batch_size)); + } else if (method == "CmatchRankMaskAucCalculator") { + metric_lists_.emplace(name, new CmatchRankMaskMetricMsg( + label_varname, pred_varname, metric_phase, + cmatch_rank_group, cmatch_rank_varname, + ignore_rank, mask_varname, bucket_size)); } else { PADDLE_THROW(platform::errors::Unimplemented( - "PaddleBox only support AucCalculator, MultiTaskAucCalculator " - "CmatchRankAucCalculator and MaskAucCalculator")); + "PaddleBox only support AucCalculator, MultiTaskAucCalculator, " + "CmatchRankAucCalculator, MaskAucCalculator and " + "CmatchRankMaskAucCalculator")); } metric_name_list_.emplace_back(name); }