Skip to content

Commit

Permalink
[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38230)
Browse files Browse the repository at this point in the history
* 12.3 first add metrics module

* add Mask/MultiTask

* add WuAUC

* [PSLIB] Update WuAUC Compute

* [PSLIB] Change WuAUC Compute Mehod

* [PSLIB] Clean WuAUC Compute

* [PSLIB] Clean Metric Module Unused Code

* mv metric instance

* [PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)

* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
  • Loading branch information
WorgenZhang authored Jan 28, 2022
1 parent d3011c7 commit 7460a89
Show file tree
Hide file tree
Showing 27 changed files with 1,719 additions and 10 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
Expand All @@ -199,7 +199,7 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
Expand Down Expand Up @@ -362,6 +363,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}

template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}

template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
Expand Down Expand Up @@ -838,6 +844,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape);
}
}
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
Expand Down Expand Up @@ -929,6 +936,17 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));

if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str);

char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ struct Record {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string uid_;
};

struct PvInstanceObject {
Expand Down Expand Up @@ -157,6 +158,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
Expand Down Expand Up @@ -232,6 +234,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
std::string uid_slot_;
};

// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
Expand Down Expand Up @@ -293,6 +296,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
Expand All @@ -307,6 +311,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_uid_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
Expand Down Expand Up @@ -471,7 +476,7 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
x = static_cast<size_t>(t);
}
#endif
ar >> ins.MutableFloatData();
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor
}

message MultiSlotDesc { repeated Slot slots = 1; }
message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}

message DataFeedDesc {
optional string name = 1;
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
}

// set filelist, file_idx_ will reset to zero.
Expand Down Expand Up @@ -147,6 +149,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}

template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}

template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
Expand Down Expand Up @@ -386,11 +394,14 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
<< input_channel_->Size();

auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
if (this->merge_by_insid_) {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
}
};

Expand Down Expand Up @@ -618,6 +629,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down Expand Up @@ -686,6 +698,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Dataset {
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
Expand Down Expand Up @@ -175,6 +176,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);

virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
Expand Down Expand Up @@ -263,6 +265,8 @@ class DatasetImpl : public Dataset {
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
Expand Down
22 changes: 21 additions & 1 deletion paddle/fluid/framework/downpour_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"

Expand All @@ -25,7 +26,6 @@ limitations under the License. */

namespace paddle {
namespace framework {

void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) {
Expand Down Expand Up @@ -780,6 +780,21 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}

/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}

void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
Expand Down Expand Up @@ -877,6 +892,11 @@ void DownpourWorker::TrainFiles() {
}
}

// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}

// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ endif(WITH_BOX_PS)

if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO)

cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
7 changes: 7 additions & 0 deletions paddle/fluid/framework/fleet/gloo_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ enum GlooStoreType { HDFS, HTTP };

class GlooWrapper {
public:
static std::shared_ptr<GlooWrapper> GetInstance() {
static auto s_instance = std::make_shared<GlooWrapper>();
return s_instance;
}

GlooWrapper() {}

virtual ~GlooWrapper() {}
Expand Down Expand Up @@ -153,6 +158,8 @@ class GlooWrapper {
#endif
}

bool IsInitialized() { return is_initialized_; }

template <typename T>
std::vector<T> AllReduce(std::vector<T>& sendbuf, // NOLINT
const std::string& mode = "sum") { // NOLINT
Expand Down
Loading

0 comments on commit 7460a89

Please sign in to comment.