diff --git a/Jenkinsfile-win64 b/Jenkinsfile-win64 index b156833426c9..b919cb1acac0 100644 --- a/Jenkinsfile-win64 +++ b/Jenkinsfile-win64 @@ -40,7 +40,8 @@ pipeline { steps { script { parallel ([ - 'build-win64-cuda10.1': { BuildWin64() } + 'build-win64-cuda10.1': { BuildWin64() }, + 'build-rpkg-win64-cuda10.1': { BuildRPackageWithCUDAWin64() } ]) } } @@ -75,6 +76,7 @@ def checkoutSrcs() { def BuildWin64() { node('win64 && cuda10_unified') { + deleteDir() unstash name: 'srcs' echo "Building XGBoost for Windows AMD64 target..." bat "nvcc --version" @@ -115,8 +117,26 @@ def BuildWin64() { } } +def BuildRPackageWithCUDAWin64() { + node('win64 && cuda10_unified') { + deleteDir() + unstash name: 'srcs' + bat "nvcc --version" + if (env.BRANCH_NAME == 'master' || env.BRANCH_NAME.startsWith('release')) { + bat """ + bash tests/ci_build/build_r_pkg_with_cuda_win64.sh ${commit_id} + """ + echo 'Uploading R tarball...' + path = ("${BRANCH_NAME}" == 'master') ? '' : "${BRANCH_NAME}/" + s3Upload bucket: 'xgboost-nightly-builds', path: path, acl: 'PublicRead', includePathPattern:'xgboost_r_gpu_win64_*.tar.gz' + } + deleteDir() + } +} + def TestWin64() { node('win64 && cuda10_unified') { + deleteDir() unstash name: 'srcs' unstash name: 'xgboost_whl' unstash name: 'xgboost_cli' @@ -127,7 +147,7 @@ def TestWin64() { bat "build\\testxgboost.exe" echo "Installing Python dependencies..." def env_name = 'win64_' + UUID.randomUUID().toString().replaceAll('-', '') - bat "conda env create -n ${env_name} --file=tests/ci_build/conda_env/win64_test.yml" + bat "conda activate && mamba env create -n ${env_name} --file=tests/ci_build/conda_env/win64_test.yml" echo "Installing Python wheel..." bat """ conda activate ${env_name} && for /R %%i in (python-package\\dist\\*.whl) DO python -m pip install "%%i" diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 269641ea617d..3684c250c89e 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -90,7 +90,9 @@ function(format_gencode_flags flags out) endif() # Set up architecture flags if(NOT flags) - if (CUDA_VERSION VERSION_GREATER_EQUAL "11.0") + if (CUDA_VERSION VERSION_GREATER_EQUAL "11.1") + set(flags "50;52;60;61;70;75;80;86") + elseif (CUDA_VERSION VERSION_GREATER_EQUAL "11.0") set(flags "35;50;52;60;61;70;75;80") elseif(CUDA_VERSION VERSION_GREATER_EQUAL "10.0") set(flags "35;50;52;60;61;70;75") diff --git a/doc/install.rst b/doc/install.rst index 31635d3a2202..31224b6ed44c 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -61,9 +61,12 @@ R and then run ``install.packages("xgboost")``. Without OpenMP, XGBoost will only use a single CPU core, leading to suboptimal training speed. -* We also provide **experimental** pre-built binary on Linux x86_64 with GPU support. +* We also provide **experimental** pre-built binary with GPU support. With this binary, + you will be able to use the GPU algorithm without building XGBoost from the source. Download the binary package from the Releases page. The file name will be of the form - ``xgboost_r_gpu_linux_[version].tar.gz``. Then install XGBoost by running: + ``xgboost_r_gpu_[os]_[version].tar.gz``, where ``[os]`` is either ``linux`` or ``win64``. + (We build the binaries for 64-bit Linux and Windows.) + Then install XGBoost by running: .. code-block:: bash @@ -142,9 +145,11 @@ R - Other than standard CRAN installation, we also provide *experimental* pre-built binary on -Linux x86_64 with GPU support. You can go to `this page +with GPU support. You can go to `this page `_, Find the commit -ID you want to install: ``xgboost_r_gpu_linux_[commit].tar.gz``, download it then run: +ID you want to install and then locate the file ``xgboost_r_gpu_[os]_[commit].tar.gz``, +where ``[os]`` is either ``linux`` or ``win64``. (We build the binaries for 64-bit Linux +and Windows.) Download it and run the following commands: .. code-block:: bash diff --git a/doc/parameter.rst b/doc/parameter.rst index df8bfcc086cf..4e2ee2c3958b 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -243,16 +243,6 @@ Additional parameters for ``hist`` and ``gpu_hist`` tree method - Use single precision to build histograms instead of double precision. -Additional parameters for ``gpu_hist`` tree method -================================================== - -* ``deterministic_histogram``, [default=``true``] - - - Build histogram on GPU deterministically. Histogram building is not deterministic due - to the non-associative aspect of floating point summation. We employ a pre-rounding - routine to mitigate the issue, which may lead to slightly lower accuracy. Set to - ``false`` to disable it. - Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 0e3fbd981280..4ea2de31a436 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -255,9 +255,12 @@ class GradientPairInternal { /*! \brief gradient statistics pair usually needed in gradient boosting */ using GradientPair = detail::GradientPairInternal; - /*! \brief High precision gradient statistics pair */ using GradientPairPrecise = detail::GradientPairInternal; +/*! \brief Fixed point representation for gradient pair. */ +using GradientPairInt32 = detail::GradientPairInternal; +/*! \brief Fixed point representation for high precision gradient pair. */ +using GradientPairInt64 = detail::GradientPairInternal; using Args = std::vector >; diff --git a/python-package/setup.py b/python-package/setup.py index 8882e1c0372a..3947b53b39ea 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -302,7 +302,7 @@ def run(self): with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd: description = fd.read() - with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION')) as fd: + with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd: version = fd.read().strip() setup(name='xgboost', diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 663b5a5a294f..7c1078c13669 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -22,7 +22,7 @@ pass VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') -with open(VERSION_FILE) as f: +with open(VERSION_FILE, encoding="ascii") as f: __version__ = f.read().strip() __all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter', diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index cea85c8c6c6c..77480f79c8a0 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -70,7 +70,7 @@ class XGBoostLabelEncoder(LabelEncoder): '''Label encoder with JSON serialization methods.''' def to_json(self): '''Returns a JSON compatible dictionary''' - meta = dict() + meta = {} for k, v in self.__dict__.items(): if isinstance(v, np.ndarray): meta[k] = v.tolist() @@ -81,7 +81,7 @@ def to_json(self): def from_json(self, doc): # pylint: disable=attribute-defined-outside-init '''Load the encoder back from a JSON compatible dict.''' - meta = dict() + meta = {} for k, v in doc.items(): if k == 'classes_': self.classes_ = np.array(v) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index e25a15e4371a..0bbfbca62ba2 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2197,7 +2197,8 @@ def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): """ if isinstance(fout, (STRING_TYPES, os.PathLike)): fout = os.fspath(os.path.expanduser(fout)) - fout = open(fout, 'w') # pylint: disable=consider-using-with + # pylint: disable=consider-using-with + fout = open(fout, 'w', encoding="utf-8") need_close = True else: need_close = False diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 757778f65782..999caae45a6a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -538,7 +538,7 @@ def get_xgb_params(self) -> Dict[str, Any]: 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder', "enable_categorical" } - filtered = dict() + filtered = {} for k, v in params.items(): if k not in wrapper_specific and not callable(v): filtered[k] = v @@ -557,7 +557,7 @@ def _get_type(self) -> str: return self._estimator_type # pylint: disable=no-member def save_model(self, fname: Union[str, os.PathLike]) -> None: - meta = dict() + meta = {} for k, v in self.__dict__.items(): if k == '_le': meta['_le'] = self._le.to_json() @@ -596,7 +596,7 @@ def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: ) return meta = json.loads(meta_str) - states = dict() + states = {} for k, v in meta.items(): if k == '_le': self._le = XGBoostLabelEncoder() diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index f2e6ac629cec..699a3f6277ec 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ #pragma once #include @@ -98,24 +98,28 @@ template ::value && !std::is_same::value> * = // NOLINT nullptr> -T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT +XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT using Type = typename dh::detail::AtomicDispatcher::Type; Type ret = ::atomicAdd(reinterpret_cast(addr), static_cast(v)); return static_cast(ret); } - namespace dh { #ifdef XGBOOST_USE_NCCL #define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__) inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file, - int line) { + int line) { if (code != ncclSuccess) { std::stringstream ss; - ss << "NCCL failure :" << ncclGetErrorString(code) << " "; - ss << file << "(" << line << ")"; - throw std::runtime_error(ss.str()); + ss << "NCCL failure :" << ncclGetErrorString(code); + if (code == ncclUnhandledCudaError) { + // nccl usually preserves the last error so we can get more details. + auto err = cudaPeekAtLastError(); + ss << " " << thrust::system_error(err, thrust::cuda_category()).what(); + } + ss << " " << file << "(" << line << ")"; + LOG(FATAL) << ss.str(); } return code; @@ -1104,6 +1108,44 @@ XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, static_cast(gpair.GetHess())); } +/** + * \brief An atomicAdd designed for gradient pair with better performance. For general + * int64_t atomicAdd, one can simply cast it to unsigned long long. + */ +XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t *dst, int64_t src) { + uint32_t* y_low = reinterpret_cast(dst); + uint32_t *y_high = y_low + 1; + + auto cast_src = reinterpret_cast(&src); + + uint32_t const x_low = static_cast(src); + uint32_t const x_high = (*cast_src) >> 32; + + auto const old = atomicAdd(y_low, x_low); + uint32_t const carry = old > (std::numeric_limits::max() - x_low) ? 1 : 0; + uint32_t const sig = x_high + carry; + atomicAdd(y_high, sig); +} + +XGBOOST_DEV_INLINE void +AtomicAddGpair(xgboost::GradientPairInt64 *dest, + xgboost::GradientPairInt64 const &gpair) { + auto dst_ptr = reinterpret_cast(dest); + auto g = gpair.GetGrad(); + auto h = gpair.GetHess(); + + AtomicAdd64As32(dst_ptr, g); + AtomicAdd64As32(dst_ptr + 1, h); +} + +XGBOOST_DEV_INLINE void +AtomicAddGpair(xgboost::GradientPairInt32 *dest, + xgboost::GradientPairInt32 const &gpair) { + auto dst_ptr = reinterpret_cast(dest); + + ::atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); + ::atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); +} // Thrust version of this function causes error on Windows template diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 52ad0d7cd92d..933e6ea37927 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -142,7 +142,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { LOG(INFO) << "Tree method is automatically set to 'approx' " "since external-memory data matrix is used."; tparam_.tree_method = TreeMethod::kApprox; - } else if (fmat->Info().num_row_ >= (4UL << 20UL)) { + } else if (fmat->Info().num_row_ >= (1UL << 18UL)) { /* Choose tree_method='hist' automatically for large data matrix */ LOG(INFO) << "Tree method is automatically selected to be " "'hist' for faster work. To use the old behavior " diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index aae2fbc04da7..791363a05cdd 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2021 by XGBoost Contributors */ #include #include @@ -34,7 +34,7 @@ namespace tree { * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. */ template -XGBOOST_DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) { +T CreateRoundingFactor(T max_abs, int n) { T delta = max_abs / (static_cast(1.0) - 2 * n * std::numeric_limits::epsilon()); // Calculate ceil(log_2(delta)). @@ -78,7 +78,7 @@ struct Clip : public thrust::unary_function { }; template -GradientSumT CreateRoundingFactor(common::Span gpair) { +HistRounding CreateRoundingFactor(common::Span gpair) { using T = typename GradientSumT::ValueT; dh::XGBCachingDeviceAllocator alloc; @@ -94,26 +94,51 @@ GradientSumT CreateRoundingFactor(common::Span gpair) { gpair.size()), CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), gpair.size()) }; - return histogram_rounding; + + using IntT = typename HistRounding::SharedSumT::ValueT; + + /** + * Factor for converting gradients from fixed-point to floating-point. + */ + GradientSumT to_floating_point = + histogram_rounding / + T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - + 2)); // keep 1 for sign bit + /** + * Factor for converting gradients from floating-point to fixed-point. For + * f64: + * + * Precision = 64 - 1 - log2(rounding) + * + * rounding is calcuated as exp(m), see the rounding factor calcuation for + * details. + */ + GradientSumT to_fixed_point = GradientSumT( + T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess()); + + return {histogram_rounding, to_fixed_point, to_floating_point}; } -template GradientPairPrecise CreateRoundingFactor(common::Span gpair); -template GradientPair CreateRoundingFactor(common::Span gpair); +template HistRounding +CreateRoundingFactor(common::Span gpair); +template HistRounding +CreateRoundingFactor(common::Span gpair); -template +template __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, FeatureGroupsAccessor feature_groups, common::Span d_ridx, GradientSumT* __restrict__ d_node_hist, const GradientPair* __restrict__ d_gpair, - GradientSumT const rounding, - bool use_shared_memory_histograms) { + HistRounding const rounding) { + using SharedSumT = typename HistRounding::SharedSumT; using T = typename GradientSumT::ValueT; + extern __shared__ char smem[]; FeatureGroup group = feature_groups[blockIdx.y]; - GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT + SharedSumT *smem_arr = reinterpret_cast(smem); if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, group.num_bins, GradientSumT()); + dh::BlockFill(smem_arr, group.num_bins, SharedSumT()); __syncthreads(); } int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride; @@ -123,16 +148,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; if (gidx != matrix.NumBins()) { - GradientSumT truncated { - TruncateWithRoundingFactor(rounding.GetGrad(), d_gpair[ridx].GetGrad()), - TruncateWithRoundingFactor(rounding.GetHess(), d_gpair[ridx].GetHess()), - }; // If we are not using shared memory, accumulate the values directly into // global memory - GradientSumT* atomic_add_ptr = - use_shared_memory_histograms ? smem_arr : d_node_hist; gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx; - dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated); + if (use_shared_memory_histograms) { + auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); + dh::AtomicAddGpair(smem_arr + gidx, adjusted); + } else { + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.rounding.GetGrad(), + d_gpair[ridx].GetGrad()), + TruncateWithRoundingFactor(rounding.rounding.GetHess(), + d_gpair[ridx].GetHess()), + }; + dh::AtomicAddGpair(d_node_hist + gidx, truncated); + } } } @@ -140,12 +170,7 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group.num_bins)) { - GradientSumT truncated{ - TruncateWithRoundingFactor(rounding.GetGrad(), - smem_arr[i].GetGrad()), - TruncateWithRoundingFactor(rounding.GetHess(), - smem_arr[i].GetHess()), - }; + auto truncated = rounding.ToFloatingPoint(smem_arr[i]); dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } @@ -157,57 +182,68 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span d_ridx, common::Span histogram, - GradientSumT rounding) { + HistRounding rounding, + bool force_global_memory) { // decide whether to use shared memory int device = 0; dh::safe_cuda(cudaGetDevice(&device)); + // opt into maximum shared memory for the kernel if necessary int max_shared_memory = dh::MaxSharedMemoryOptin(device); - size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins; - bool shared = smem_size <= max_shared_memory; + + size_t smem_size = sizeof(typename HistRounding::SharedSumT) * + feature_groups.max_group_bins; + bool shared = !force_global_memory && smem_size <= max_shared_memory; smem_size = shared ? smem_size : 0; - // opt into maximum shared memory for the kernel if necessary - auto kernel = SharedMemHistKernel; + auto runit = [&](auto kernel) { + if (shared) { + dh::safe_cuda(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); + } + + // determine the launch configuration + int min_grid_size; + int block_threads = 1024; + dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize( + &min_grid_size, &block_threads, kernel, smem_size, 0)); + + int num_groups = feature_groups.NumGroups(); + int n_mps = 0; + dh::safe_cuda( + cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); + int n_blocks_per_mp = 0; + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &n_blocks_per_mp, kernel, block_threads, smem_size)); + unsigned grid_size = n_blocks_per_mp * n_mps; + + // TODO(canonizer): This is really a hack, find a better way to distribute + // the data among thread blocks. The intention is to generate enough thread + // blocks to fill the GPU, but avoid having too many thread blocks, as this + // is less efficient when the number of rows is low. At least one thread + // block per feature group is required. The number of thread blocks: + // - for num_groups <= num_groups_threshold, around grid_size * num_groups + // - for num_groups_threshold <= num_groups <= num_groups_threshold * + // grid_size, + // around grid_size * num_groups_threshold + // - for num_groups_threshold * grid_size <= num_groups, around num_groups + int num_groups_threshold = 4; + grid_size = common::DivRoundUp( + grid_size, common::DivRoundUp(num_groups, num_groups_threshold)); + + using T = typename GradientSumT::ValueT; + dh::LaunchKernel {dim3(grid_size, num_groups), + static_cast(block_threads), + smem_size} (kernel, matrix, feature_groups, d_ridx, + histogram.data(), gpair.data(), rounding); + }; + if (shared) { - dh::safe_cuda(cudaFuncSetAttribute - (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_memory)); + runit(SharedMemHistKernel); + } else { + runit(SharedMemHistKernel); } - // determine the launch configuration - int min_grid_size; - int block_threads = 1024; - dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize( - &min_grid_size, &block_threads, kernel, smem_size, 0)); - - int num_groups = feature_groups.NumGroups(); - int n_mps = 0; - dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); - int n_blocks_per_mp = 0; - dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor - (&n_blocks_per_mp, kernel, block_threads, smem_size)); - unsigned grid_size = n_blocks_per_mp * n_mps; - - // TODO(canonizer): This is really a hack, find a better way to distribute the - // data among thread blocks. - // The intention is to generate enough thread blocks to fill the GPU, but - // avoid having too many thread blocks, as this is less efficient when the - // number of rows is low. At least one thread block per feature group is - // required. - // The number of thread blocks: - // - for num_groups <= num_groups_threshold, around grid_size * num_groups - // - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size, - // around grid_size * num_groups_threshold - // - for num_groups_threshold * grid_size <= num_groups, around num_groups - int num_groups_threshold = 4; - grid_size = common::DivRoundUp(grid_size, - common::DivRoundUp(num_groups, num_groups_threshold)); - - dh::LaunchKernel { - dim3(grid_size, num_groups), static_cast(block_threads), smem_size} ( - kernel, - matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding, - shared); dh::safe_cuda(cudaGetLastError()); } @@ -217,7 +253,8 @@ template void BuildGradientHistogram( common::Span gpair, common::Span ridx, common::Span histogram, - GradientPair rounding); + HistRounding rounding, + bool force_global_memory); template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, @@ -225,7 +262,8 @@ template void BuildGradientHistogram( common::Span gpair, common::Span ridx, common::Span histogram, - GradientPairPrecise rounding); + HistRounding rounding, + bool force_global_memory); } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 84c79568fe23..a45083f76875 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2021 by XGBoost Contributors */ #ifndef HISTOGRAM_CUH_ #define HISTOGRAM_CUH_ @@ -12,21 +12,57 @@ namespace xgboost { namespace tree { -template -GradientSumT CreateRoundingFactor(common::Span gpair); - -template -XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) { +template +XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) { + static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision."); return (rounding_factor + static_cast(x)) - rounding_factor; } +/** + * Truncation factor for gradient, see comments in `CreateRoundingFactor()` for details. + */ +template +struct HistRounding { + /* Factor to truncate the gradient before building histogram for deterministic result. */ + GradientSumT rounding; + /* Convert gradient to fixed point representation. */ + GradientSumT to_fixed_point; + /* Convert fixed point representation back to floating point. */ + GradientSumT to_floating_point; + + /* Type used in shared memory. */ + using SharedSumT = std::conditional_t< + std::is_same::value, + GradientPairInt32, GradientPairInt64>; + using T = typename GradientSumT::ValueT; + + XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const { + auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), + T(gpair.GetHess() * to_fixed_point.GetHess())); + return adjusted; + } + XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const { + auto g = gpair.GetGrad() * to_floating_point.GetGrad(); + auto h = gpair.GetHess() * to_floating_point.GetHess(); + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.GetGrad(), g), + TruncateWithRoundingFactor(rounding.GetHess(), h), + }; + return truncated; + } +}; + +template +HistRounding CreateRoundingFactor(common::Span gpair); + template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, - GradientSumT rounding); + HistRounding rounding, + bool force_global_memory = false); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 1e2673f055b1..7499293e7860 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -46,14 +46,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); struct GPUHistMakerTrainParam : public XGBoostParameter { bool single_precision_histogram; - bool deterministic_histogram; bool debug_synchronize; // declare parameters DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( "Use single precision to build histograms."); - DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe( - "Pre-round the gradient for obtaining deterministic gradient histogram."); DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( "Check if all distributed tree are identical after tree construction."); } @@ -153,7 +150,7 @@ class DeviceHistogram { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - auto ptr = data_.data().get() + nidx_map_[nidx]; + auto ptr = data_.data().get() + nidx_map_.at(nidx); return common::Span( reinterpret_cast(ptr), n_bins_); } @@ -179,9 +176,8 @@ struct GPUHistMakerDevice { std::vector node_sum_gradients; TrainParam param; - bool deterministic_histogram; - GradientSumT histogram_rounding; + HistRounding histogram_rounding; dh::PinnedMemory pinned; @@ -205,7 +201,6 @@ struct GPUHistMakerDevice { TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, - bool deterministic_histogram, BatchParam _batch_param) : device_id(_device_id), page(_page), @@ -214,7 +209,6 @@ struct GPUHistMakerDevice { tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), - deterministic_histogram{deterministic_histogram}, batch_param(_batch_param) { sampler.reset(new GradientBasedSampler( page, _n_rows, batch_param, param.subsample, param.sampling_method)); @@ -227,9 +221,9 @@ struct GPUHistMakerDevice { // Init histogram hist.Init(device_id, page->Cuts().TotalBins()); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); - feature_groups.reset(new FeatureGroups( - page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), - sizeof(GradientSumT))); + feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, + dh::MaxSharedMemoryOptin(device_id), + sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT @@ -263,11 +257,7 @@ struct GPUHistMakerDevice { page = sample.page; gpair = sample.gpair; - if (deterministic_histogram) { - histogram_rounding = CreateRoundingFactor(this->gpair); - } else { - histogram_rounding = GradientSumT{0.0, 0.0}; - } + histogram_rounding = CreateRoundingFactor(this->gpair); row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows)); @@ -805,7 +795,6 @@ class GPUHistMakerSpecialised { param_, column_sampling_seed, info_->num_col_, - hist_maker_param_.deterministic_histogram, batch_param)); p_last_fmat_ = dmat; diff --git a/tests/ci_build/build_r_pkg_with_cuda_win64.sh b/tests/ci_build/build_r_pkg_with_cuda_win64.sh new file mode 100644 index 000000000000..f83795775b2c --- /dev/null +++ b/tests/ci_build/build_r_pkg_with_cuda_win64.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e +set -x + +if [ "$#" -ne 1 ] +then + echo "Build the R package tarball with CUDA code. Usage: $0 [commit hash]" + exit 1 +fi + +commit_hash="$1" + +MAKE="/c/Rtools/bin/make" /c/Rtools/bin/make Rpack +mv xgboost/ xgboost_rpack/ + +mkdir build +cd build +cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DR_LIB=ON -DLIBR_HOME="c:\\Program Files\\R\\R-3.6.3" +cmake --build . --config Release --parallel +cd .. + +rm xgboost +# This super wacky hack is found in cmake/RPackageInstall.cmake.in and +# cmake/RPackageInstallTargetSetup.cmake. This hack lets us bypass the normal build process of R +# and have R use xgboost.dll that we've already built. +rm -v xgboost_rpack/configure +rm -rfv xgboost_rpack/src +mkdir -p xgboost_rpack/src +cp -v lib/xgboost.dll xgboost_rpack/src/ +echo 'all:' > xgboost_rpack/src/Makefile +echo 'all:' > xgboost_rpack/src/Makefile.win +mv xgboost_rpack/ xgboost/ +/c/Rtools/bin/tar -cvf xgboost_r_gpu_win64_${commit_hash}.tar xgboost/ +/c/Rtools/bin/gzip -9c xgboost_r_gpu_win64_${commit_hash}.tar > xgboost_r_gpu_win64_${commit_hash}.tar.gz \ No newline at end of file diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index dab8e2d2d987..25378d52f552 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -2,7 +2,7 @@ name: win64_env channels: - conda-forge dependencies: -- python +- python=3.8 - wheel - numpy - scipy @@ -16,6 +16,5 @@ dependencies: - jsonschema - hypothesis - jsonschema +- python-graphviz - pip -- pip: - - graphviz diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index f353c8af7358..f8eb30f6a354 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -2,7 +2,7 @@ name: win64_env channels: - conda-forge dependencies: -- python=3.7 +- python=3.8 - numpy - scipy - matplotlib @@ -12,8 +12,7 @@ dependencies: - boto3 - hypothesis - jsonschema +- cupy +- python-graphviz +- modin-ray - pip -- pip: - - cupy-cuda101 - - modin[all] - - graphviz diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index cb7176c00758..6e8668bd2581 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,7 +1,10 @@ /*! - * Copyright 2017 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ +#include +#include #include +#include #include #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" @@ -101,8 +104,6 @@ struct IsSorted { } // namespace namespace xgboost { -namespace common { - void TestSegmentedUniqueRegression(std::vector values, size_t n_duplicated) { std::vector segments{0, static_cast(values.size())}; @@ -194,5 +195,73 @@ TEST(DeviceHelpers, ArgSort) { ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(), thrust::greater{})); } -} // namespace common + +namespace { +// Atomic add as type cast for test. +XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT + uint64_t* u_dst = reinterpret_cast(dst); + uint64_t u_src = *reinterpret_cast(&src); + uint64_t ret = ::atomicAdd(u_dst, u_src); + return *reinterpret_cast(&ret); +} +} + +void TestAtomicAdd() { + size_t n_elements = 1024; + dh::device_vector result_a(1, 0); + auto d_result_a = result_a.data().get(); + + dh::device_vector result_b(1, 0); + auto d_result_b = result_b.data().get(); + + /** + * Test for simple inputs + */ + std::vector h_inputs(n_elements); + for (size_t i = 0; i < h_inputs.size(); ++i) { + h_inputs[i] = (i % 2 == 0) ? i : -i; + } + dh::device_vector inputs(h_inputs); + auto d_inputs = inputs.data().get(); + + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + + /** + * Test for positive values that don't fit into 32 bit integer. + */ + thrust::fill(inputs.begin(), inputs.end(), + (std::numeric_limits::max() / 2)); + thrust::fill(result_a.begin(), result_a.end(), 0); + thrust::fill(result_b.begin(), result_b.end(), 0); + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + ASSERT_GT(result_a[0], std::numeric_limits::max()); + CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]); + + /** + * Test for negative values that don't fit into 32 bit integer. + */ + thrust::fill(inputs.begin(), inputs.end(), + (std::numeric_limits::min() / 2)); + thrust::fill(result_a.begin(), result_a.end(), 0); + thrust::fill(result_b.begin(), result_b.end(), 0); + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + ASSERT_LT(result_a[0], std::numeric_limits::min()); + CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]); +} + +TEST(AtomicAdd, Int64) { + TestAtomicAdd(); +} } // namespace xgboost diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index dc8fd267d059..4cf736bf6f9b 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -1,5 +1,9 @@ -#include "test_ranking_obj.cc" +/*! + * Copyright 2019-2021 by XGBoost Contributors + */ +#include +#include "test_ranking_obj.cc" #include "../../../src/objective/rank_obj.cu" namespace xgboost { diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 4879ca080937..9b16cca5362d 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -1,8 +1,13 @@ +/*! + * Copyright 2019-2021 by XGBoost Contributors + */ #include #include #include +#include #include + #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../helpers.h" diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 591dc43d27a0..72c22539679f 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,8 +1,9 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ #include #include +#include #include #include #include @@ -80,8 +81,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, - true, batch_param); + GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, + kNCols, kNCols, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); @@ -93,14 +94,18 @@ void TestBuildHist(bool use_shared_memory_histograms) { gpair.SetDevice(0); thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); - - maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.hist.AllocateHistogram(0); maker.gpair = gpair.DeviceSpan(); + maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; + + BuildGradientHistogram( + page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), + gpair.DeviceSpan(), maker.row_partitioner->GetRows(0), + maker.hist.GetNodeHistogram(0), maker.histogram_rounding, + !use_shared_memory_histograms); - maker.BuildHist(0); - DeviceHistogram d_hist = maker.hist; + DeviceHistogram& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -115,6 +120,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { std::vector solution = GetHostHistGpair(); std::cout << std::fixed; for (size_t i = 0; i < h_result.size(); ++i) { + ASSERT_FALSE(std::isnan(h_result[i].GetGrad())); EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f); EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f); } @@ -158,7 +164,8 @@ TEST(GpuHist, ApplySplit) { HostDeviceVector feature_types(10, FeatureType::kCategorical); feature_types.SetDevice(bparam.gpu_id); tree::GPUHistMakerDevice updater( - 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam); + 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, + bparam); updater.ApplySplit(candidate, &tree); ASSERT_EQ(tree.GetSplitTypes().size(), 3); @@ -217,8 +224,8 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice - maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); + GPUHistMakerDevice maker( + 0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index dc556fdaefaf..3f40999861b9 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -55,9 +55,6 @@ def test_deterministic_gpu_hist(self): model_0, model_1 = self.run_cls(X, y, True) assert model_0 == model_1 - model_0, model_1 = self.run_cls(X, y, False) - assert model_0 != model_1 - def test_invalid_gpu_id(self): X = np.random.randn(10, 5) * 1e4 y = np.random.randint(0, 2, size=10) * 1e4