-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Breaking] Switch from rabit to the collective communicator #8257
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the work on swapping out rabit. When it's ready, could you please split up the PR into smaller ones and start with internal C++ changes?
@trivialfis the issue is this is a breaking change. Once we change the c++ portion, we have to change the python and java apis too, otherwise the communicator would be uninitialized. |
@wbo4958 any ideas about the JVM test failures? They seem to pass on my local desktop. |
I will check it today. |
I can repro it locally, which will make life easy. |
Seems the After replacing it with below code, it worked for me. test("test rabit timeout fail handle") {
val training = buildDataFrame(Classification.train)
try {
// mock rank 0 failure during 8th allreduce synchronization
Communicator.mockList = Array("0,8,0,0").toList.asJava
intercept[SparkException] {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"rabit_timeout" -> 0))
.fit(training)
}
} finally {
Communicator.mockList = Array.empty.toList.asJava
}
} |
@wbo4958 Thanks for the help with the debugging. Resetting the mocklist seems to have fixed it. @trivialfis I think this PR is ready for review. It touches a lot of files, but mostly it's a one-to-one swap from rabit to communicator. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the great work on swapping out rabit. Could you please layout the plan for future PRs for 1.7 release? I'm trying to estimate an ETA.
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
Show resolved
Hide resolved
@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat | |||
|
|||
void Communicator::Finalize() { | |||
communicator_->Shutdown(); | |||
communicator_.reset(nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share under which case the communicator can still be called after being shut down? I think a nullptr can be a guard for unintentional calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have some tests that mix distributed training and local training.
/** | ||
* A no-op communicator, used for non-distributed training. | ||
*/ | ||
class NoOpCommunicator : public Communicator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you have already added checks for non-distributed env in various communicator implementations. Is this still necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, this is needed to replicate the existing rabit behavior.
src/metric/auc.cu
Outdated
@@ -46,7 +45,7 @@ struct DeviceAUCCache { | |||
dh::device_vector<size_t> unique_idx; | |||
// p^T: transposed prediction matrix, used by MultiClassAUC | |||
dh::device_vector<float> predts_t; | |||
std::unique_ptr<dh::AllReducer> reducer; | |||
collective::DeviceCommunicator* communicator; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is now a global instance, we don't have to maintain a pointer to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need the device id to get the communicator, which in this class is only passed in during Init
and not saved. We can probably clean this up once we have better device id management.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ended up removing these pointers.
@@ -158,7 +158,7 @@ def _try_start_tracker( | |||
if isinstance(addrs[0], tuple): | |||
host_ip = addrs[0][0] | |||
port = addrs[0][1] | |||
rabit_context = RabitTracker( | |||
rabit_tracker = RabitTracker( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm .. so we still need rabit tracker for downstream projects. Could you please share how federated learning communicates the worker addresses across all workers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if we use the rabit communicator, which is the default, we still need to start a tracker. I imagine if we switch to something like gloo, then we can get rid of rabit completely.
For federated learning, since we have to start a gRPC server first, we just pass the server address (host:port) to each client.
@@ -12,14 +13,10 @@ | |||
namespace xgboost { | |||
namespace collective { | |||
|
|||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{}; | |||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comments on the no-op, I have some concerns that we will use this accidentally, we have issues where XGBoost failed to establish a working communicator group but proceed with distributed training without explicit error. I haven't been able to tackle those issues due to the complicated network setup others use. But still, it's something that we should have in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this line https://github.com/dmlc/xgboost/blob/master/rabit/src/engine.cc#L73, rabit provides a default engine if it's not initialized. We have some code that depends on this behavior. The NoOpCommunicator
is one way to replicate this behavior.
I think the user has to tell us if they are doing distributed training, either by entering the CommunicatorContext
, or calling Init
directly. Otherwise we'd have no way of knowing, and can't just return a nullptr
in case they are not in the distributed mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, I think the only code change needed is #8279, and then we just need to weak the CI to build with federated learning enabled.
@@ -158,7 +158,7 @@ def _try_start_tracker( | |||
if isinstance(addrs[0], tuple): | |||
host_ip = addrs[0][0] | |||
port = addrs[0][1] | |||
rabit_context = RabitTracker( | |||
rabit_tracker = RabitTracker( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if we use the rabit communicator, which is the default, we still need to start a tracker. I imagine if we switch to something like gloo, then we can get rid of rabit completely.
For federated learning, since we have to start a gRPC server first, we just pass the server address (host:port) to each client.
@@ -12,14 +13,10 @@ | |||
namespace xgboost { | |||
namespace collective { | |||
|
|||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{}; | |||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this line https://github.com/dmlc/xgboost/blob/master/rabit/src/engine.cc#L73, rabit provides a default engine if it's not initialized. We have some code that depends on this behavior. The NoOpCommunicator
is one way to replicate this behavior.
I think the user has to tell us if they are doing distributed training, either by entering the CommunicatorContext
, or calling Init
directly. Otherwise we'd have no way of knowing, and can't just return a nullptr
in case they are not in the distributed mode.
@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat | |||
|
|||
void Communicator::Finalize() { | |||
communicator_->Shutdown(); | |||
communicator_.reset(nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have some tests that mix distributed training and local training.
/** | ||
* A no-op communicator, used for non-distributed training. | ||
*/ | ||
class NoOpCommunicator : public Communicator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, this is needed to replicate the existing rabit behavior.
src/metric/auc.cu
Outdated
@@ -46,7 +45,7 @@ struct DeviceAUCCache { | |||
dh::device_vector<size_t> unique_idx; | |||
// p^T: transposed prediction matrix, used by MultiClassAUC | |||
dh::device_vector<float> predts_t; | |||
std::unique_ptr<dh::AllReducer> reducer; | |||
collective::DeviceCommunicator* communicator; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need the device id to get the communicator, which in this class is only passed in during Init
and not saved. We can probably clean this up once we have better device id management.
Could you please merge master again? |
It's up to date. |
@trivialfis can this be merged? Thanks! |
@trivialfis @hcho3 can this be merged now? Thanks! |
I noticed that the 1.7.0 release note is still indicating that users "can choose between |
@gnaggnoyil You can choose between Rabit and Federated by passing As for |
…ning (#384) Addressing NCCL issue with binary classification for distributed training. dmlc/xgboost#7982 (comment) dmlc/xgboost#8257 Co-authored-by: Nikhil Raverkar <nraverka@amazon.com>
This PR switches the
Rabit
api toCommunicator
, which gives us more flexibility in the collective communication implementation. It's a breaking change, but to most uses, it's a straightforward swap.