Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Switch from rabit to the collective communicator #8257

Merged
merged 29 commits into from
Oct 5, 2022

Conversation

rongou
Copy link
Contributor

@rongou rongou commented Sep 21, 2022

This PR switches the Rabit api to Communicator, which gives us more flexibility in the collective communication implementation. It's a breaking change, but to most uses, it's a straightforward swap.

Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the work on swapping out rabit. When it's ready, could you please split up the PR into smaller ones and start with internal C++ changes?

@rongou
Copy link
Contributor Author

rongou commented Sep 22, 2022

@trivialfis the issue is this is a breaking change. Once we change the c++ portion, we have to change the python and java apis too, otherwise the communicator would be uninitialized.

@rongou
Copy link
Contributor Author

rongou commented Sep 23, 2022

@wbo4958 any ideas about the JVM test failures? They seem to pass on my local desktop.

@wbo4958
Copy link
Contributor

wbo4958 commented Sep 26, 2022

I will check it today.

@wbo4958
Copy link
Contributor

wbo4958 commented Sep 26, 2022

I can repro it locally, which will make life easy.

@wbo4958
Copy link
Contributor

wbo4958 commented Sep 26, 2022

Seems the test rabit timeout fail handle has affected others.

After replacing it with below code, it worked for me.

  test("test rabit timeout fail handle") {
    val training = buildDataFrame(Classification.train)

    try {
      // mock rank 0 failure during 8th allreduce synchronization
      Communicator.mockList = Array("0,8,0,0").toList.asJava
      intercept[SparkException] {
        new XGBoostClassifier(Map(
          "eta" -> "0.1",
          "max_depth" -> "10",
          "verbosity" -> "1",
          "objective" -> "binary:logistic",
          "num_round" -> 5,
          "num_workers" -> numWorkers,
          "rabit_timeout" -> 0))
          .fit(training)
      }
    } finally {
      Communicator.mockList = Array.empty.toList.asJava
    }
  }

@rongou rongou changed the title [WIP] Switch from rabit to the collective communicator [Breaking] Switch from rabit to the collective communicator Sep 26, 2022
@rongou
Copy link
Contributor Author

rongou commented Sep 26, 2022

@wbo4958 Thanks for the help with the debugging. Resetting the mocklist seems to have fixed it.

@trivialfis I think this PR is ready for review. It touches a lot of files, but mostly it's a one-to-one swap from rabit to communicator. Thanks!

@rongou rongou requested a review from trivialfis September 26, 2022 19:36
Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great work on swapping out rabit. Could you please layout the plan for future PRs for 1.7 release? I'm trying to estimate an ETA.

@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat

void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share under which case the communicator can still be called after being shut down? I think a nullptr can be a guard for unintentional calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have some tests that mix distributed training and local training.

/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have already added checks for non-distributed env in various communicator implementations. Is this still necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this is needed to replicate the existing rabit behavior.

@@ -46,7 +45,7 @@ struct DeviceAUCCache {
dh::device_vector<size_t> unique_idx;
// p^T: transposed prediction matrix, used by MultiClassAUC
dh::device_vector<float> predts_t;
std::unique_ptr<dh::AllReducer> reducer;
collective::DeviceCommunicator* communicator;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is now a global instance, we don't have to maintain a pointer to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need the device id to get the communicator, which in this class is only passed in during Init and not saved. We can probably clean this up once we have better device id management.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up removing these pointers.

@@ -158,7 +158,7 @@ def _try_start_tracker(
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
rabit_tracker = RabitTracker(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm .. so we still need rabit tracker for downstream projects. Could you please share how federated learning communicates the worker addresses across all workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we use the rabit communicator, which is the default, we still need to start a tracker. I imagine if we switch to something like gloo, then we can get rid of rabit completely.

For federated learning, since we have to start a gRPC server first, we just pass the server address (host:port) to each client.

@@ -12,14 +13,10 @@
namespace xgboost {
namespace collective {

thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other comments on the no-op, I have some concerns that we will use this accidentally, we have issues where XGBoost failed to establish a working communicator group but proceed with distributed training without explicit error. I haven't been able to tackle those issues due to the complicated network setup others use. But still, it's something that we should have in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this line https://github.com/dmlc/xgboost/blob/master/rabit/src/engine.cc#L73, rabit provides a default engine if it's not initialized. We have some code that depends on this behavior. The NoOpCommunicator is one way to replicate this behavior.

I think the user has to tell us if they are doing distributed training, either by entering the CommunicatorContext, or calling Init directly. Otherwise we'd have no way of knowing, and can't just return a nullptr in case they are not in the distributed mode.

Copy link
Contributor Author

@rongou rongou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR, I think the only code change needed is #8279, and then we just need to weak the CI to build with federated learning enabled.

@@ -158,7 +158,7 @@ def _try_start_tracker(
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
rabit_tracker = RabitTracker(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we use the rabit communicator, which is the default, we still need to start a tracker. I imagine if we switch to something like gloo, then we can get rid of rabit completely.

For federated learning, since we have to start a gRPC server first, we just pass the server address (host:port) to each client.

@@ -12,14 +13,10 @@
namespace xgboost {
namespace collective {

thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this line https://github.com/dmlc/xgboost/blob/master/rabit/src/engine.cc#L73, rabit provides a default engine if it's not initialized. We have some code that depends on this behavior. The NoOpCommunicator is one way to replicate this behavior.

I think the user has to tell us if they are doing distributed training, either by entering the CommunicatorContext, or calling Init directly. Otherwise we'd have no way of knowing, and can't just return a nullptr in case they are not in the distributed mode.

@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat

void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have some tests that mix distributed training and local training.

/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this is needed to replicate the existing rabit behavior.

@@ -46,7 +45,7 @@ struct DeviceAUCCache {
dh::device_vector<size_t> unique_idx;
// p^T: transposed prediction matrix, used by MultiClassAUC
dh::device_vector<float> predts_t;
std::unique_ptr<dh::AllReducer> reducer;
collective::DeviceCommunicator* communicator;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need the device id to get the communicator, which in this class is only passed in during Init and not saved. We can probably clean this up once we have better device id management.

@trivialfis
Copy link
Member

Could you please merge master again?

@rongou
Copy link
Contributor Author

rongou commented Sep 30, 2022

It's up to date.

@rongou
Copy link
Contributor Author

rongou commented Oct 3, 2022

@trivialfis can this be merged? Thanks!

@hcho3
Copy link
Collaborator

hcho3 commented Oct 3, 2022

@rongou We'll have to first merge #8298 to fix the CI.

@rongou
Copy link
Contributor Author

rongou commented Oct 5, 2022

@trivialfis @hcho3 can this be merged now? Thanks!

@gnaggnoyil
Copy link

I noticed that the 1.7.0 release note is still indicating that users "can choose between rabit and federated". Did someone forgot to change the wording in the release note, or just I'm misunderstanding something?

@hcho3
Copy link
Collaborator

hcho3 commented Nov 3, 2022

@gnaggnoyil You can choose between Rabit and Federated by passing {"xgboost_communicator": "[tracker type]"} to xgboost.collective.init().

As for xgboost.rabit, we got rid of the API in 1.7.0, but it broke some downstream projects. So we plan to release patch release 1.7.1 to restore xgboost.rabit.

@rongou rongou deleted the switch-to-communicator branch November 18, 2022 19:01
NikhilRaverkar pushed a commit to NikhilRaverkar/sagemaker-xgboost-container that referenced this pull request Mar 17, 2023
NikhilRaverkar added a commit to aws/sagemaker-xgboost-container that referenced this pull request Mar 17, 2023
…ning (#384)

Addressing NCCL issue with binary classification for distributed training.
dmlc/xgboost#7982 (comment) dmlc/xgboost#8257

Co-authored-by: Nikhil Raverkar <nraverka@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants