From 30b7ed5d5221b4ea86eb3fafc1e20d07603a7319 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 12 Apr 2024 17:18:06 -0400 Subject: [PATCH] Move processor interface init from learner to communicator functional --- src/collective/communicator.cc | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index aee7c0051325..14be1b036c21 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -41,13 +41,25 @@ void Communicator::Init(Json const& config) { case CommunicatorType::kFederated: { #if defined(XGBOOST_USE_FEDERATED) communicator_.reset(FederatedCommunicator::Create(config)); - std::cout << "!!!!!!!! Communicator Initialization!!!!!!!!!!!!!!!!!!!! " << std::endl; - auto plugin_name = "dummy"; - std::map loader_params = {{"LIBRARY_PATH", "/tmp"}}; - std::map proc_params = {}; - processing::ProcessorLoader loader(loader_params); - processor_instance = loader.load(plugin_name); - processor_instance->Initialize(collective::GetRank() == 0, proc_params); + // Get processor configs + std::string plugin_name{}; + std::string loader_params_key{}; + std::string loader_params_map{}; + std::string proc_params_key{}; + std::string proc_params_map{}; + plugin_name = OptionalArg(config, "plugin_name", plugin_name); + loader_params_key = OptionalArg(config, "loader_params_key", loader_params_key); + loader_params_map = OptionalArg(config, "loader_params_map", loader_params_map); + proc_params_key = OptionalArg(config, "proc_params_key", proc_params_key); + proc_params_map = OptionalArg(config, "proc_params_map", proc_params_map); + // Initialize processor if plugin_name is provided + if (!plugin_name.empty()){ + std::map loader_params = {{loader_params_key, loader_params_map}}; + std::map proc_params = {{proc_params_key, proc_params_map}}; + processing::ProcessorLoader loader(loader_params); + processor_instance = loader.load(plugin_name); + processor_instance->Initialize(collective::GetRank() == 0, proc_params); + } #else LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; #endif