Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] External memory support for GPU #4357

Closed
rongou opened this issue Apr 10, 2019 · 16 comments
Closed

[RFC] External memory support for GPU #4357

rongou opened this issue Apr 10, 2019 · 16 comments

Comments

@rongou
Copy link
Contributor

rongou commented Apr 10, 2019

Motivation

XGBoost has experimental support for external memory that allows for training on datasets that don’t fit in main (a.k.a. host) memory. However, it’s not available with GPU algorithms. To train on large datasets with GPUs, we have to either downsample the data, which defeats the purpose, or scale out to multi-gpu and multi-node settings, which have their own complexities and limitations. It’s desirable to make the GPU algorithms more flexible by adding external memory support.

Goals

  • Support external memory for the gpu_hist tree method.
  • Provide reasonable performance.

Non-Goals

This proposal is for XGBoost only, and doesn’t apply to other libraries in RAPIDS such as cuDF/cuML.

Assumptions

  • We’ll build on top of the existing external memory version in XGBoost, which only supports importing from the libsvm format.
  • We’ll rely on existing cache systems to leverage host memory, for example, Linux page cache for local files, FS-Cache for network filesystems.
  • Although the approach should work if data is spilled to disk, for reasonable performance, we assume the whole training dataset fits in host memory, thus is cached after the first round.

Risks

  • A naive implementation may be unacceptably slow.
  • A fully tuned implementation that is asynchronous and fully pipelined may introduce a lot of complexity to the code base.

Design

Existing Code

The current XGBoost code allows the user to specify the input data as

filename#cacheprefix

where filename is the normal path to a libsvm file, and cacheprefix is a path to a cache file that XGBoost will use for external memory cache. During training, the data is read in, parsed, and written out to the cache in 32MB pages. For each round, the pages are read back in and fed to the learner as batches.

Currently the gpu_hist tree method only accepts the first batch and errors out if there are more than one batch. In the training loop, the data is initialized once by going through feature quantile generation and data compression. The compressed feature quantiles are kept in GPU memory during training.

The GPU algorithm is described in detail in the paper.

Adding external memory support can be split into the following phases.

Phase 1: Basic Correctness

Here we aim to support external memory in the simplest way:

  • Feature quantile generation and data compression are done for each batch. This means we can’t cache the compressed features in GPU memory. Feature quantiles should be consistent across batches if multiple batches are used to build a single tree. It will not be possible to combine histograms that use different cut points. The current quantile algorithm already works in batches (but does not support external memory). Therefore, it should be possible to find the quantiles for the entire dataset at the price of reading all the data once from the external memory. The quantile values can then be used for binning and compression during each tree construction.
  • During tree construction, histograms are accumulated locally with each batch. After the final batch, the histograms are combined across all GPUs using allreduce.
  • Prediction also needs to be modified accordingly to operate on batches. This should be much easier than with training. Instead of assuming a single SparsePage within a DMatrix, the algorithm should iterate over all SparsePages. One existing limitation is that only training/prediction data is read from external memory, while predictions are kept fully in memory.

Phase 2: Cache Compressed Features

Based on current timing information, feature quantile generation and data compression are relatively expensive (equivalent to building around 40 trees), and will add up quickly if we perform them on the fly for every batch. We can write the compressed features back to disk, similar to how data caching is currently done for external memory. Compressed and binned features take less space (~12 bits after compression and binning vs 64 bits before), and are therefore cheaper to read. This is tied to the DMatrix refacotring (#4354).

In the multi-gpu setup, right now each data batch is split between different GPUs. When we properly support multiple batches, we can instead split at the batch level and feed whole batches to each GPU.

Phase 3: Overlap Data Transfers and Tree Construction

Regardless of the type of storage used to cache the compressed features, it’s likely to be very slow to feed each batch to the GPU and then build the histograms serially. We need to put the data transfer and compute for each batch into a separate CUDA stream, and launch multiple streams asynchronously.

Phase 4: (Optionally) Use CUDA Graphs

We can potentially define the data transfers and tree construction in a CUDA graph and launch the graph repeatedly for each batch. This is only available for CUDA 10 and above.

Alternatives Considered

When training on a very large dataset (e.g. 10-20 TB), one approach is to scale out to many nodes. However, it may require hundreds or even thousands of GPUs, which may be prohibitively expensive and hard to schedule in a shared cluster. With that many nodes, network communication may also become a bottleneck, especially in a cloud/enterprise environment relying on relatively slow Ethernet connections. The goal of external memory support is not to replace distributed training, but to provide an alternative that may be better suited to some situations.

Another approach is stochastic gradient boosting: at each iteration a subsample of the training data is drawn at random (without replacement) from the full training data set. This randomly selected subsample is then used in place of the full sample to fit the base learner and compute the model update for the current iteration (see paper). However, the sampling ratio is a hyperparameter that needs to be tuned, so the subsample may still not fit in memory. There are other sampling techniques (for example, see paper), which are algorithmic improvements largely orthogonal to external memory support.

@RAMitchell @canonizer @sriramch

@rongou
Copy link
Contributor Author

rongou commented Jul 3, 2019

Adding a task list here to keep track of the progress:

  • Get gpu predictor to accept batches (still keeping the full prediction output in gpu memory)
  • Get the gpu tree updater to accept batches (still keeping the full compressed features in gpu memory)
  • Add a preprocessing step to do quantile generation
  • Save/Load compressed features
  • Gradient-based sampling (e.g. https://arxiv.org/abs/1803.00841)
  • User guide

@trivialfis
Copy link
Member

@rongou Sounds awesome. Could you add some details around the preprocess step? I have been looking in sketch recently.

@hcho3 hcho3 pinned this issue Jul 4, 2019
@rongou
Copy link
Contributor Author

rongou commented Jul 6, 2019

What I'm thinking is more or less a pure refactoring. Right now we hand off the sparse page dmatrix to the tree updater, which loops through it once to build the quantiles, then loop through it again to compress each batch. We probably need to do these at a higher level so that we can write out the compressed features. Hand-wavy pseudo-code:

learner->preprocess(dmat); // generate quantiles
foreach batch in dmat:
  auto compressed = learner->process(batch);
  save(compressed);

Will try to get a draft PR out next week.

@trivialfis
Copy link
Member

trivialfis commented Jul 7, 2019

@rongou Preferably this is done in DMatrix, you can add a new method to DMatrix called get compressed hist index ...

@sriramch
Copy link
Contributor

sriramch commented Jul 8, 2019

@rongou Preferably this is done in DMatrix, you can add a new method to DMatrix called get compressed hist index ...

@trivialfis isn't dmatrix just a data representation type and isn't it agnostic to tree methods such as histograms? if so, is it right to pollute this interface with such methods?

@trivialfis
Copy link
Member

@sriramch We think building histogram indices inside DMatrix can provide us some opportunities around saving memory. Like it's possible to avoid copying original data set. To me it makes sense for us to consider histogram indices as data, since histogram is used to replace the actual input data. Also this is inspired by LGB.

I think @hcho3 and @RAMitchell can provide some more input for other issues of current indices building method.

We can talk about it in details in the DMatrix RFC, since I'm not sure about how to integrate it with external memory yet.

@trivialfis
Copy link
Member

trivialfis commented Aug 9, 2019

@rongou @sriramch Would you agree that the initial support is completed? As other things on the list is more about optimization.

@rongou
Copy link
Contributor Author

rongou commented Aug 9, 2019

I guess it depends on your definition of "support". :) It now supports the # cache file syntax, but the whole dataset is still kept in GPU memory.

@trivialfis
Copy link
Member

trivialfis commented Aug 9, 2019 via email

@SemanticBeeng
Copy link

"XGBoost only, and doesn’t apply to other libraries in RAPIDS such as cuDF/cuML"

Would the Apache Arrow support for off-heap / zero copy memory & memory mapped files + CUDA integration be a more general solution?

https://arrow.apache.org/docs/python/cuda.html

Has it been considered?
Any known reasons why would not work?
Interested to (have someone) explore?

@trivialfis
Copy link
Member

@SemanticBeeng

By off-heap do you mean reading from disk? Maybe in the future, but currently not a priority.

Zero copy memory:

Did you mean running XGBoost on dataset with arrow memory layout without making a copy? If so then no in near future. Internally XGBoost uses CSR (for most of the parts). They are quite different. Having said that, it's still possible I think.

Memory mapped files

I'm interested in this feature. But again, so many to-dos ...

CUDA integration

I have a PR for initial support of cuDF. I think the arrow specification still has some room to grow, let's see what can be done after merging that PR. Hope that helps.

@trivialfis
Copy link
Member

@SemanticBeeng If you have some ideas, please do share.

@SemanticBeeng
Copy link

SemanticBeeng commented Aug 17, 2019

Thanks.
Some context for my thinking (maybe my question is about XGBoost4j ?).

  1. XGBoost4j as a unifying #ProgrammingModel over Apache Spark and Apache Flink
    This suggests a share form of data.

  2. Apache Arrow is meant as a shared runtime
    Is used by cuDF. This suggests (again) use of its data format for the runtime.

Am looking to determine if the goal of "standardizing in-memory data" can be adopted here.

"XGBoost uses CSR (for most of the parts)."

Ah, yes, this is important to know.
Would you be kind to share some resources for my education?

Do you know if XGBoost4j is different in this regard?

@zwqjoy
Copy link

zwqjoy commented Nov 1, 2019

I use XGBoost 0.81 with cpu, The external memory that allows for training on datasets that don’t fit in main (a.k.a. host) memory. But I train ,the memory is same 17G, use the external memory or not.

@trivialfis
Copy link
Member

Document should be the last item now

@trivialfis
Copy link
Member

External memory support for GPU and gradient based sampling is available! Thanks to @rongou

@trivialfis trivialfis unpinned this issue Feb 22, 2020
@lock lock bot locked as resolved and limited conversation to collaborators Jun 24, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

6 participants