From 4212aba02712e0617b679e6117adf39650d657e0 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 15 Aug 2018 23:35:47 +0800 Subject: [PATCH] Implement transform to reduce CPU/GPU code duplication. * Implement Transform class. * Add tests for softmax. * Use Transform in regression, softmax and hinge objectives, except for Cox. * Mark old gpu objective functions deprecated. * static_assert for softmax. * Split up multi-gpu tests. --- src/common/common.cc | 4 +- src/common/common.cu | 2 +- src/common/common.h | 36 +- src/common/device_helpers.cuh | 21 +- src/common/host_device_vector.cu | 22 +- src/common/host_device_vector.h | 7 +- src/common/math.h | 36 +- src/common/span.h | 10 +- src/common/transform.h | 203 +++++++ src/objective/hinge.cc | 73 +-- src/objective/hinge.cu | 109 ++++ src/objective/multiclass_obj.cc | 139 +---- src/objective/multiclass_obj.cu | 195 ++++++ src/objective/objective.cc | 11 +- src/objective/regression_obj.cc | 424 +------------ src/objective/regression_obj.cu | 560 ++++++++++++++++++ src/objective/regression_obj_gpu.cu | 202 ------- .../common/test_gpu_compressed_iterator.cu | 4 +- tests/cpp/common/test_gpu_hist_util.cu | 16 +- tests/cpp/common/test_host_device_vector.cu | 77 ++- tests/cpp/common/test_span.h | 16 +- tests/cpp/common/test_transform_range.cc | 61 ++ tests/cpp/common/test_transform_range.cu | 43 ++ tests/cpp/helpers.cc | 14 + tests/cpp/helpers.h | 12 + tests/cpp/objective/test_hinge.cc | 8 +- tests/cpp/objective/test_hinge.cu | 1 + tests/cpp/objective/test_multiclass_obj.cc | 60 ++ .../cpp/objective/test_multiclass_obj_gpu.cu | 1 + tests/cpp/objective/test_regression_obj.cc | 65 +- .../cpp/objective/test_regression_obj_gpu.cu | 78 +-- 31 files changed, 1513 insertions(+), 997 deletions(-) create mode 100644 src/common/transform.h create mode 100644 src/objective/hinge.cu create mode 100644 src/objective/multiclass_obj.cu create mode 100644 src/objective/regression_obj.cu delete mode 100644 src/objective/regression_obj_gpu.cu create mode 100644 tests/cpp/common/test_transform_range.cc create mode 100644 tests/cpp/common/test_transform_range.cu create mode 100644 tests/cpp/objective/test_hinge.cu create mode 100644 tests/cpp/objective/test_multiclass_obj.cc create mode 100644 tests/cpp/objective/test_multiclass_obj_gpu.cu diff --git a/src/common/common.cc b/src/common/common.cc index c9899bc99f0e..e1602004b5c2 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -1,9 +1,11 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2018 by Contributors * \file common.cc * \brief Enable all kinds of global variables in common. */ #include + +#include "common.h" #include "./random.h" namespace xgboost { diff --git a/src/common/common.cu b/src/common/common.cu index bf434b272257..d51bd8c262db 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -11,7 +11,7 @@ int AllVisibleImpl::AllVisible() { // When compiled with CUDA but running on CPU only device, // cudaGetDeviceCount will fail. dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); - } catch(const std::exception& e) { + } catch(const thrust::system::system_error& err) { return 0; } return n_visgpus; diff --git a/src/common/common.h b/src/common/common.h index ead260000cfe..f521d972d417 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2018 by Contributors * \file common.h * \brief Common utilities */ @@ -19,6 +19,13 @@ #if defined(__CUDACC__) #include #include + +#define WITH_CUDA() true + +#else + +#define WITH_CUDA() false + #endif namespace dh { @@ -29,11 +36,11 @@ namespace dh { #define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, - int line) { + int line) { if (code != cudaSuccess) { - throw thrust::system_error(code, thrust::cuda_category(), - std::string{file} + "(" + // NOLINT - std::to_string(line) + ")"); + LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(), + std::string{file} + ": " + // NOLINT + std::to_string(line)).what(); } return code; } @@ -70,13 +77,13 @@ inline std::string ToString(const T& data) { */ class Range { public: + using DifferenceType = int64_t; + class Iterator { friend class Range; public: - using DifferenceType = int64_t; - - XGBOOST_DEVICE int64_t operator*() const { return i_; } + XGBOOST_DEVICE DifferenceType operator*() const { return i_; } XGBOOST_DEVICE const Iterator &operator++() { i_ += step_; return *this; @@ -97,8 +104,8 @@ class Range { XGBOOST_DEVICE void Step(DifferenceType s) { step_ = s; } protected: - XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {} - XGBOOST_DEVICE explicit Iterator(int64_t start, int step) : + XGBOOST_DEVICE explicit Iterator(DifferenceType start) : i_(start) {} + XGBOOST_DEVICE explicit Iterator(DifferenceType start, DifferenceType step) : i_{start}, step_{step} {} public: @@ -109,9 +116,10 @@ class Range { XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT - XGBOOST_DEVICE Range(int64_t begin, int64_t end) + XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end) : begin_(begin), end_(end) {} - XGBOOST_DEVICE Range(int64_t begin, int64_t end, Iterator::DifferenceType step) + XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end, + DifferenceType step) : begin_(begin, step), end_(end) {} XGBOOST_DEVICE bool operator==(const Range& other) const { @@ -121,9 +129,7 @@ class Range { return !(*this == other); } - XGBOOST_DEVICE void Step(Iterator::DifferenceType s) { begin_.Step(s); } - - XGBOOST_DEVICE Iterator::DifferenceType GetStep() const { return begin_.step_; } + XGBOOST_DEVICE void Step(DifferenceType s) { begin_.Step(s); } private: Iterator begin_; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c2d8a75d5fe2..81abed25c4d7 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -9,6 +9,7 @@ #include #include "common.h" +#include "span.h" #include #include @@ -955,7 +956,7 @@ class SaveCudaContext { // cudaGetDevice will fail. try { safe_cuda(cudaGetDevice(&saved_device_)); - } catch (thrust::system::system_error & err) { + } catch (const thrust::system::system_error & err) { saved_device_ = -1; } func(); @@ -1035,4 +1036,22 @@ ReduceT ReduceShards(std::vector *shards, FunctionT f) { }; return std::accumulate(sums.begin(), sums.end(), ReduceT()); } + +template ::index_type> +xgboost::common::Span ToSpan( + thrust::device_vector& vec, + IndexT offset = 0, + IndexT size = -1) { + size = size == -1 ? vec.size() : size; + CHECK_LE(offset + size, vec.size()); + return {vec.data().get() + offset, static_cast(size)}; +} + +template +xgboost::common::Span ToSpan(thrust::device_vector& vec, + size_t offset, size_t size) { + using IndexT = typename xgboost::common::Span::index_type; + return ToSpan(vec, static_cast(offset), static_cast(size)); +} } // namespace dh diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 9c979df9d42c..5477394b7856 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -116,6 +116,7 @@ struct HostDeviceVectorImpl { int ndevices = vec_->distribution_.devices_.Size(); start_ = vec_->distribution_.ShardStart(new_size, index_); proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_); + // The size on this device. size_t size_d = vec_->distribution_.ShardSize(new_size, index_); SetDevice(); data_.resize(size_d); @@ -230,7 +231,7 @@ struct HostDeviceVectorImpl { CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kWrite); return {shards_[devices.Index(device)].data_.data().get(), - static_cast::index_type>(DeviceSize(device))}; + static_cast::index_type>(DeviceSize(device))}; } common::Span ConstDeviceSpan(int device) { @@ -238,7 +239,7 @@ struct HostDeviceVectorImpl { CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); return {shards_[devices.Index(device)].data_.data().get(), - static_cast::index_type>(DeviceSize(device))}; + static_cast::index_type>(DeviceSize(device))}; } size_t DeviceSize(int device) { @@ -289,7 +290,6 @@ struct HostDeviceVectorImpl { data_h_.size() * sizeof(T), cudaMemcpyHostToDevice)); } else { - // dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); }); } } @@ -304,14 +304,20 @@ struct HostDeviceVectorImpl { void Copy(HostDeviceVectorImpl* other) { CHECK_EQ(Size(), other->Size()); + // Data is on host. if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) { std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin()); - } else { - CHECK(distribution_ == other->distribution_); - dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { - shard.Copy(&other->shards_[i]); - }); + return; } + // Data is on device; + if (distribution_ != other->distribution_) { + distribution_ = GPUDistribution(); + Reshard(other->Distribution()); + size_d_ = other->size_d_; + } + dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { + shard.Copy(&other->shards_[i]); + }); } void Copy(const std::vector& other) { diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index 881cf06423b3..8daa19fe436b 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -111,8 +111,11 @@ class GPUDistribution { } friend bool operator==(const GPUDistribution& a, const GPUDistribution& b) { - return a.devices_ == b.devices_ && a.granularity_ == b.granularity_ && - a.overlap_ == b.overlap_ && a.offsets_ == b.offsets_; + bool const res = a.devices_ == b.devices_ && + a.granularity_ == b.granularity_ && + a.overlap_ == b.overlap_ && + a.offsets_ == b.offsets_; + return res; } friend bool operator!=(const GPUDistribution& a, const GPUDistribution& b) { diff --git a/src/common/math.h b/src/common/math.h index be2598e3a108..d6019beadbd8 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "avx_helpers.h" namespace xgboost { @@ -29,22 +30,31 @@ inline avx::Float8 Sigmoid(avx::Float8 x) { } /*! - * \brief do inplace softmax transformaton on p_rec - * \param p_rec the input/output vector of the values. + * \brief Do inplace softmax transformaton on start to end + * + * \tparam Iterator Input iterator type + * + * \param start Start iterator of input + * \param end end iterator of input */ -inline void Softmax(std::vector* p_rec) { - std::vector &rec = *p_rec; - float wmax = rec[0]; - for (size_t i = 1; i < rec.size(); ++i) { - wmax = std::max(rec[i], wmax); +template +XGBOOST_DEVICE inline void Softmax(Iterator start, Iterator end) { + static_assert(std::is_same().operator*())>::type + >::value, + "Values should be of type bst_float"); + bst_float wmax = *start; + for (Iterator i = start+1; i != end; ++i) { + wmax = fmaxf(*i, wmax); } double wsum = 0.0f; - for (float & elem : rec) { - elem = std::exp(elem - wmax); - wsum += elem; + for (Iterator i = start; i != end; ++i) { + *i = expf(*i - wmax); + wsum += *i; } - for (float & elem : rec) { - elem /= static_cast(wsum); + for (Iterator i = start; i != end; ++i) { + *i /= static_cast(wsum); } } @@ -56,7 +66,7 @@ inline void Softmax(std::vector* p_rec) { * \tparam Iterator The type of the iterator. */ template -inline Iterator FindMaxIndex(Iterator begin, Iterator end) { +XGBOOST_DEVICE inline Iterator FindMaxIndex(Iterator begin, Iterator end) { Iterator maxit = begin; for (Iterator it = begin; it != end; ++it) { if (*it > *maxit) maxit = it; diff --git a/src/common/span.h b/src/common/span.h index 75cfd7059506..a618b682622b 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -49,7 +49,7 @@ * * https://github.com/Microsoft/GSL/pull/664 * - * FIXME: Group these MSVC workarounds into a manageable place. + * TODO(trivialfis): Group these MSVC workarounds into a manageable place. */ #if defined(_MSC_VER) && _MSC_VER < 1910 @@ -68,7 +68,7 @@ namespace xgboost { namespace common { // Usual logging facility is not available inside device code. -// FIXME: Make dmlc check more generic. +// TODO(trivialfis): Make dmlc check more generic. #define KERNEL_CHECK(cond) \ do { \ if (!(cond)) { \ @@ -104,11 +104,11 @@ constexpr detail::ptrdiff_t dynamic_extent = -1; // NOLINT enum class byte : unsigned char {}; // NOLINT -namespace detail { - -template +template class Span; +namespace detail { + template class SpanIterator { using ElementType = typename SpanType::element_type; diff --git a/src/common/transform.h b/src/common/transform.h new file mode 100644 index 000000000000..e854f4d1cbf1 --- /dev/null +++ b/src/common/transform.h @@ -0,0 +1,203 @@ +/*! + * Copyright 2018 XGBoost contributors + */ +#ifndef XGBOOST_COMMON_TRANSFORM_H_ +#define XGBOOST_COMMON_TRANSFORM_H_ + +#include +#include +#include +#include // enable_if + +#include "host_device_vector.h" +#include "common.h" +#include "span.h" + +#if defined (__CUDACC__) +#include "device_helpers.cuh" +#endif + +namespace xgboost { +namespace common { + +constexpr size_t kBlockThreads = 256; + +namespace detail { + +#if defined(__CUDACC__) +template +__global__ void LaunchCUDAKernel(Functor _func, Range _range, + SpanType... _spans) { + for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) { + _func(i, _spans...); + } +} +#endif + +} // namespace detail + +/*! \brief Do Transformation on HostDeviceVectors. + * + * \tparam CompiledWithCuda A bool parameter used to distinguish compilation + * trajectories, users do not need to use it. + * + * Note: Using Transform is a VERY tricky thing to do. Transform uses template + * argument to duplicate itself into two different types, one for CPU, + * another for CUDA. The trick is not without its flaw: + * + * If you use it in a function that can be compiled by both nvcc and host + * compiler, the behaviour is un-defined! Because your function is NOT + * duplicated by `CompiledWithCuda`. At link time, cuda compiler resolution + * will merge functions with same signature. + */ +template +class Transform { + private: + template + struct Evaluator { + public: + Evaluator(Functor func, const Range& range, const GPUSet& devices, + bool reshard) : + func_(func), range_{range}, reshard_{reshard}, + distribution_{GPUDistribution::Block(devices)} {} + Evaluator(Functor func, const Range& range, const GPUDistribution& dist, + bool reshard) : + func_(func), range_{range}, reshard_{reshard}, distribution_{dist} {} + + /*! + * \brief Evaluate the functor with input pointers to HostDeviceVector. + * + * \tparam HDV... HostDeviceVectors type. + * \param vectors Pointers to HostDeviceVector. + */ + template + void Eval(HDV... vectors) const { + bool on_device = !distribution_.IsEmpty(); + + if (on_device) { + LaunchCUDA(func_, vectors...); + } else { + LaunchCPU(func_, vectors...); + } + } + + private: + // CUDA UnpackHDV + template + Span UnpackHDV(HostDeviceVector* _vec, int _device) const { + return _vec->DeviceSpan(_device); + } + template + Span UnpackHDV(const HostDeviceVector* _vec, int _device) const { + return _vec->ConstDeviceSpan(_device); + } + // CPU UnpackHDV + template + Span UnpackHDV(HostDeviceVector* _vec) const { + return Span {_vec->HostPointer(), + static_cast::index_type>(_vec->Size())}; + } + template + Span UnpackHDV(const HostDeviceVector* _vec) const { + return Span {_vec->ConstHostPointer(), + static_cast::index_type>(_vec->Size())}; + } + // Recursive unpack for Reshard. + template + void UnpackReshard(GPUDistribution dist, const HostDeviceVector* vector) const { + vector->Reshard(dist); + } + template + void UnpackReshard(GPUDistribution dist, + const HostDeviceVector* _vector, + const HostDeviceVector*... _vectors) const { + _vector->Reshard(dist); + UnpackReshard(dist, _vectors...); + } + +#if defined(__CUDACC__) + template ::type* = nullptr, + typename... HDV> + void LaunchCUDA(Functor _func, HDV*... _vectors) const { + if (reshard_) + UnpackReshard(distribution_, _vectors...); + + GPUSet devices = distribution_.Devices(); + size_t range_size = *range_.end() - *range_.begin(); +#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) + for (omp_ulong i = 0; i < devices.Size(); ++i) { + int d = devices.Index(i); + // Ignore other attributes of GPUDistribution for spliting index. + size_t shard_size = + GPUDistribution::Block(devices).ShardSize(range_size, d); + Range shard_range {0, static_cast(shard_size)}; + dh::safe_cuda(cudaSetDevice(d)); + const int GRID_SIZE = + static_cast(dh::DivRoundUp(*(range_.end()), kBlockThreads)); + + detail::LaunchCUDAKernel<<>>( + _func, shard_range, UnpackHDV(_vectors, d)...); + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); + } + } +#else + /*! \brief Dummy funtion defined when compiling for CPU. */ + template ::type* = nullptr, + typename... HDV> + void LaunchCUDA(Functor _func, HDV*... _vectors) const { + LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA(); + } +#endif + + template + void LaunchCPU(Functor func, HDV*... vectors) const { + auto end = *(range_.end()); +#pragma omp parallel for schedule(static) + for (omp_ulong idx = 0; idx < end; ++idx) { + func(idx, UnpackHDV(vectors)...); + } + } + + private: + /*! \brief Callable object. */ + Functor func_; + /*! \brief Range object specifying parallel threads index range. */ + Range range_; + /*! \brief Whether resharding for vectors is required. */ + bool reshard_; + GPUDistribution distribution_; + }; + + public: + /*! + * \brief Initialize a Transform object. + * + * \tparam Functor A callable object type. + * \return A Evaluator having one method Eval. + * + * \param func A callable object, accepting a size_t thread index, + * followed by a set of Span classes. + * \param range Range object specifying parallel threads index range. + * \param devices GPUSet specifying GPUs to use, when compiling for CPU, + * this should be GPUSet::Empty(). + * \param reshard Whether Reshard for HostDeviceVector is needed. + */ + template + static Evaluator Init(Functor func, Range const range, + GPUSet const devices, + bool const reshard = true) { + return Evaluator {func, range, devices, reshard}; + } + template + static Evaluator Init(Functor func, Range const range, + GPUDistribution const dist, + bool const reshard = true) { + return Evaluator {func, range, dist, reshard}; + } +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_TRANSFORM_H_ diff --git a/src/objective/hinge.cc b/src/objective/hinge.cc index 503cd1e924e2..f7cdeeb86bed 100644 --- a/src/objective/hinge.cc +++ b/src/objective/hinge.cc @@ -1,73 +1,18 @@ /*! - * Copyright 2018 by Contributors - * \file hinge.cc - * \brief Provides an implementation of the hinge loss function - * \author Henry Gouk + * Copyright 2018 XGBoost contributors */ -#include -#include "../common/math.h" +// Dummy file to keep the CUDA conditional compile trick. + +#include namespace xgboost { namespace obj { -DMLC_REGISTRY_FILE_TAG(hinge); - -class HingeObj : public ObjFunction { - public: - HingeObj() = default; - - void Configure( - const std::vector > &args) override { - // This objective does not take any parameters - } - - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "labels are not correctly provided" - << "preds.size=" << preds.Size() - << ", label.size=" << info.labels_.Size(); - const auto& preds_h = preds.HostVector(); - const auto& labels_h = info.labels_.HostVector(); - const auto& weights_h = info.weights_.HostVector(); - - out_gpair->Resize(preds_h.size()); - auto& gpair = out_gpair->HostVector(); - - for (size_t i = 0; i < preds_h.size(); ++i) { - auto y = labels_h[i] * 2.0 - 1.0; - bst_float p = preds_h[i]; - bst_float w = weights_h.size() > 0 ? weights_h[i] : 1.0f; - bst_float g, h; - if (p * y < 1.0) { - g = -y * w; - h = w; - } else { - g = 0.0; - h = std::numeric_limits::min(); - } - gpair[i] = GradientPair(g, h); - } - } - - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - for (auto& p : preds) { - p = p > 0.0 ? 1.0 : 0.0; - } - } - - const char* DefaultEvalMetric() const override { - return "error"; - } -}; - -XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") -.describe("Hinge loss. Expects labels to be in [0,1f]") -.set_body([]() { return new HingeObj(); }); +DMLC_REGISTRY_FILE_TAG(hinge_obj); } // namespace obj } // namespace xgboost + +#ifndef XGBOOST_USE_CUDA +#include "hinge.cu" +#endif diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu new file mode 100644 index 000000000000..fdc5505fc6e5 --- /dev/null +++ b/src/objective/hinge.cu @@ -0,0 +1,109 @@ +/*! + * Copyright 2018 by Contributors + * \file hinge.cc + * \brief Provides an implementation of the hinge loss function + * \author Henry Gouk + */ +#include +#include "../common/math.h" +#include "../common/transform.h" +#include "../common/common.h" +#include "../common/span.h" +#include "../common/host_device_vector.h" + +namespace xgboost { +namespace obj { + +#if defined(XGBOOST_USE_CUDA) +DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu); +#endif + +struct HingeObjParam : public dmlc::Parameter { + int n_gpus; + int gpu_id; + DMLC_DECLARE_PARAMETER(HingeObjParam) { + DMLC_DECLARE_FIELD(n_gpus).set_default(0).set_lower_bound(0) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +class HingeObj : public ObjFunction { + public: + HingeObj() = default; + + void Configure( + const std::vector > &args) override { + param_.InitAllowUnknown(args); + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + + void GetGradient(const HostDeviceVector &preds, + const MetaInfo &info, + int iter, + HostDeviceVector *out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) + << "labels are not correctly provided" + << "preds.size=" << preds.Size() + << ", label.size=" << info.labels_.Size(); + + const bool is_null_weight = info.weights_.Size() == 0; + const size_t ndata = preds.Size(); + out_gpair->Resize(ndata); + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _label_correct, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + bst_float p = _preds[_idx]; + bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float y = _labels[_idx] * 2.0 - 1.0; + bst_float g, h; + if (p * y < 1.0) { + g = -y * w; + h = w; + } else { + g = 0.0; + h = std::numeric_limits::min(); + } + _out_gpair[_idx] = GradientPair(g, h); + }, + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + } + + void PredTransform(HostDeviceVector *io_preds) override { + common::Transform<>::Init( + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; + }, + common::Range{0, static_cast(io_preds->Size()), 1}, devices_) + .Eval(io_preds); + } + + const char* DefaultEvalMetric() const override { + return "error"; + } + + private: + GPUSet devices_; + HostDeviceVector label_correct_; + HingeObjParam param_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(HingeObjParam); +// register the objective functions +XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") +.describe("Hinge loss. Expects labels to be in [0,1f]") +.set_body([]() { return new HingeObj(); }); + +} // namespace obj +} // namespace xgboost diff --git a/src/objective/multiclass_obj.cc b/src/objective/multiclass_obj.cc index dc43f932764c..2041bd56e560 100644 --- a/src/objective/multiclass_obj.cc +++ b/src/objective/multiclass_obj.cc @@ -1,141 +1,18 @@ /*! - * Copyright 2015 by Contributors - * \file multi_class.cc - * \brief Definition of multi-class classification objectives. - * \author Tianqi Chen + * Copyright 2018 XGBoost contributors */ -#include -#include -#include -#include -#include -#include -#include -#include "../common/math.h" +// Dummy file to keep the CUDA conditional compile trick. + +#include namespace xgboost { namespace obj { DMLC_REGISTRY_FILE_TAG(multiclass_obj); -struct SoftmaxMultiClassParam : public dmlc::Parameter { - int num_class; - // declare parameters - DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { - DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) - .describe("Number of output class in the multi-class classification."); - } -}; - -class SoftmaxMultiClassObj : public ObjFunction { - public: - explicit SoftmaxMultiClassObj(bool output_prob) - : output_prob_(output_prob) { - } - void Configure(const std::vector >& args) override { - param_.InitAllowUnknown(args); - } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels_.Size())) - << "SoftmaxMultiClassObj: label size and pred size does not match"; - const std::vector& preds_h = preds.HostVector(); - out_gpair->Resize(preds_h.size()); - std::vector& gpair = out_gpair->HostVector(); - const int nclass = param_.num_class; - const auto ndata = static_cast(preds_h.size() / nclass); - - const auto& labels = info.labels_.HostVector(); - int label_error = 0; - #pragma omp parallel - { - std::vector rec(nclass); - #pragma omp for schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { - for (int k = 0; k < nclass; ++k) { - rec[k] = preds_h[i * nclass + k]; - } - common::Softmax(&rec); - auto label = static_cast(labels[i]); - if (label < 0 || label >= nclass) { - label_error = label; label = 0; - } - const bst_float wt = info.GetWeight(i); - for (int k = 0; k < nclass; ++k) { - bst_float p = rec[k]; - const float eps = 1e-16f; - const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps); - if (label == k) { - gpair[i * nclass + k] = GradientPair((p - 1.0f) * wt, h); - } else { - gpair[i * nclass + k] = GradientPair(p* wt, h); - } - } - } - } - CHECK(label_error >= 0 && label_error < nclass) - << "SoftmaxMultiClassObj: label must be in [0, num_class)," - << " num_class=" << nclass - << " but found " << label_error << " in label."; - } - void PredTransform(HostDeviceVector* io_preds) override { - this->Transform(io_preds, output_prob_); - } - void EvalTransform(HostDeviceVector* io_preds) override { - this->Transform(io_preds, true); - } - const char* DefaultEvalMetric() const override { - return "merror"; - } - - private: - inline void Transform(HostDeviceVector *io_preds, bool prob) { - std::vector &preds = io_preds->HostVector(); - std::vector tmp; - const int nclass = param_.num_class; - const auto ndata = static_cast(preds.size() / nclass); - if (!prob) tmp.resize(ndata); - - #pragma omp parallel - { - std::vector rec(nclass); - #pragma omp for schedule(static) - for (omp_ulong j = 0; j < ndata; ++j) { - for (int k = 0; k < nclass; ++k) { - rec[k] = preds[j * nclass + k]; - } - if (!prob) { - tmp[j] = static_cast( - common::FindMaxIndex(rec.begin(), rec.end()) - rec.begin()); - } else { - common::Softmax(&rec); - for (int k = 0; k < nclass; ++k) { - preds[j * nclass + k] = rec[k]; - } - } - } - } - if (!prob) preds = tmp; - } - // output probability - bool output_prob_; - // parameter - SoftmaxMultiClassParam param_; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam); - -XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax") -.describe("Softmax for multi-class classification, output class index.") -.set_body([]() { return new SoftmaxMultiClassObj(false); }); - -XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob") -.describe("Softmax for multi-class classification, output probability distribution.") -.set_body([]() { return new SoftmaxMultiClassObj(true); }); - } // namespace obj } // namespace xgboost + +#ifndef XGBOOST_USE_CUDA +#include "multiclass_obj.cu" +#endif diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu new file mode 100644 index 000000000000..d912615777fc --- /dev/null +++ b/src/objective/multiclass_obj.cu @@ -0,0 +1,195 @@ +/*! + * Copyright 2015-2018 by Contributors + * \file multi_class.cc + * \brief Definition of multi-class classification objectives. + * \author Tianqi Chen + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../common/math.h" +#include "../common/transform.h" + +namespace xgboost { +namespace obj { + +#if defined(XGBOOST_USE_CUDA) +DMLC_REGISTRY_FILE_TAG(multiclass_obj_gpu); +#endif + +struct SoftmaxMultiClassParam : public dmlc::Parameter { + int num_class; + int n_gpus; + int gpu_id; + // declare parameters + DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { + DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) + .describe("Number of output class in the multi-class classification."); + DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; +// TODO(trivialfis): Currently the resharding in softmax is less than ideal +// due to repeated copying data between CPU and GPUs. Maybe we just use single +// GPU? +class SoftmaxMultiClassObj : public ObjFunction { + public: + explicit SoftmaxMultiClassObj(bool output_prob) + : output_prob_(output_prob) { + } + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + void GetGradient(const HostDeviceVector& preds, + const MetaInfo& info, + int iter, + HostDeviceVector* out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels_.Size())) + << "SoftmaxMultiClassObj: label size and pred size does not match"; + + const int nclass = param_.num_class; + const auto ndata = static_cast(preds.Size() / nclass); + + // clear out device memory; + out_gpair->Reshard(GPUSet::Empty()); + preds.Reshard(GPUSet::Empty()); + + out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass)); + info.labels_.Reshard(GPUDistribution::Block(devices_)); + info.weights_.Reshard(GPUDistribution::Block(devices_)); + preds.Reshard(GPUDistribution::Granular(devices_, nclass)); + label_correct_.Reshard(GPUDistribution::Block(devices_)); + + out_gpair->Resize(preds.Size()); + label_correct_.Fill(1); + + const bool is_null_weight = info.weights_.Size() == 0; + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t idx, + common::Span gpair, + common::Span labels, + common::Span preds, + common::Span weights, + common::Span _label_correct) { + common::Span point = preds.subspan(idx * nclass, nclass); + + // Part of Softmax function + bst_float wmax = std::numeric_limits::min(); + for (auto const i : point) { wmax = fmaxf(i, wmax); } + double wsum = 0.0f; + for (auto const i : point) { wsum += expf(i - wmax); } + auto label = labels[idx]; + if (label < 0 || label >= nclass) { + _label_correct[0] = 0; + label = 0; + } + bst_float wt = is_null_weight ? 1.0f : weights[idx]; + for (int k = 0; k < nclass; ++k) { + // Computation duplicated to avoid creating a cache. + bst_float p = expf(point[k] - wmax) / static_cast(wsum); + const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, kRtEps); + p = label == k ? p - 1.0f : p; + gpair[idx * nclass + k] = GradientPair(p * wt, h); + } + }, common::Range{0, ndata}, devices_, false) + .Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_); + + out_gpair->Reshard(GPUSet::Empty()); + out_gpair->Reshard(GPUDistribution::Block(devices_)); + preds.Reshard(GPUSet::Empty()); + preds.Reshard(GPUDistribution::Block(devices_)); + + std::vector& label_correct_h = label_correct_.HostVector(); + for (auto const flag : label_correct_h) { + if (flag != 1) { + LOG(FATAL) << "SoftmaxMultiClassObj: label must be in [0, num_class)."; + } + } + } + void PredTransform(HostDeviceVector* io_preds) override { + this->Transform(io_preds, output_prob_); + } + void EvalTransform(HostDeviceVector* io_preds) override { + this->Transform(io_preds, true); + } + const char* DefaultEvalMetric() const override { + return "merror"; + } + + inline void Transform(HostDeviceVector *io_preds, bool prob) { + const int nclass = param_.num_class; + const auto ndata = static_cast(io_preds->Size() / nclass); + max_preds_.Resize(ndata); + + io_preds->Reshard(GPUSet::Empty()); // clear out device memory + if (prob) { + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + common::Span point = + _preds.subspan(_idx * nclass, nclass); + common::Softmax(point.begin(), point.end()); + }, + common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass)) + .Eval(io_preds); + } else { + io_preds->Reshard(GPUDistribution::Granular(devices_, nclass)); + max_preds_.Reshard(GPUDistribution::Block(devices_)); + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _preds, + common::Span _max_preds) { + common::Span point = + _preds.subspan(_idx * nclass, nclass); + _max_preds[_idx] = + common::FindMaxIndex(point.cbegin(), + point.cend()) - point.cbegin(); + }, + common::Range{0, ndata}, devices_, false) + .Eval(io_preds, &max_preds_); + } + if (!prob) { + io_preds->Resize(max_preds_.Size()); + io_preds->Copy(max_preds_); + } + io_preds->Reshard(GPUSet::Empty()); // clear out device memory + io_preds->Reshard(GPUDistribution::Block(devices_)); + } + + private: + // output probability + bool output_prob_; + // parameter + SoftmaxMultiClassParam param_; + GPUSet devices_; + // Cache for max_preds + HostDeviceVector max_preds_; + HostDeviceVector label_correct_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam); + +XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax") +.describe("Softmax for multi-class classification, output class index.") +.set_body([]() { return new SoftmaxMultiClassObj(false); }); + +XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob") +.describe("Softmax for multi-class classification, output probability distribution.") +.set_body([]() { return new SoftmaxMultiClassObj(true); }); + +} // namespace obj +} // namespace xgboost diff --git a/src/objective/objective.cc b/src/objective/objective.cc index ebc68ebca390..9f6f190fa9e2 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -30,12 +30,15 @@ ObjFunction* ObjFunction::Create(const std::string& name) { namespace xgboost { namespace obj { // List of files that will be force linked in static links. -DMLC_REGISTRY_LINK_TAG(regression_obj); #ifdef XGBOOST_USE_CUDA - DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); -#endif +DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); +DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu); +DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu); +#else +DMLC_REGISTRY_LINK_TAG(regression_obj); +DMLC_REGISTRY_LINK_TAG(hinge_obj); DMLC_REGISTRY_LINK_TAG(multiclass_obj); +#endif DMLC_REGISTRY_LINK_TAG(rank_obj); -DMLC_REGISTRY_LINK_TAG(hinge); } // namespace obj } // namespace xgboost diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 5a69e3825611..e74bbf688535 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -1,426 +1,18 @@ /*! - * Copyright 2015 by Contributors - * \file regression_obj.cc - * \brief Definition of single-value regression and classification objectives. - * \author Tianqi Chen, Kailong Chen + * Copyright 2018 XGBoost contributors */ -#include -#include -#include -#include -#include -#include -#include "../common/math.h" -#include "../common/avx_helpers.h" -#include "./regression_loss.h" +// Dummy file to keep the CUDA conditional compile trick. + +#include namespace xgboost { namespace obj { DMLC_REGISTRY_FILE_TAG(regression_obj); -struct RegLossParam : public dmlc::Parameter { - float scale_pos_weight; - // declare parameters - DMLC_DECLARE_PARAMETER(RegLossParam) { - DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) - .describe("Scale the weight of positive examples by this factor"); - } -}; - -// regression loss function -template -class RegLossObj : public ObjFunction { - public: - RegLossObj() = default; - - void Configure( - const std::vector > &args) override { - param_.InitAllowUnknown(args); - } - void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, - int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "labels are not correctly provided" - << "preds.size=" << preds.Size() - << ", label.size=" << info.labels_.Size(); - const auto& preds_h = preds.HostVector(); - const auto& labels = info.labels_.HostVector(); - const auto& weights = info.weights_.HostVector(); - - this->LazyCheckLabels(labels); - out_gpair->Resize(preds_h.size()); - auto& gpair = out_gpair->HostVector(); - const auto n = static_cast(preds_h.size()); - auto gpair_ptr = out_gpair->HostPointer(); - avx::Float8 scale(param_.scale_pos_weight); - - const omp_ulong remainder = n % 8; -#pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < n - remainder; i += 8) { - avx::Float8 y(&labels[i]); - avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i])); - avx::Float8 w = weights.empty() ? avx::Float8(1.0f) - : avx::Float8(&weights[i]); - // Adjust weight - w += y * (scale * w - w); - avx::Float8 grad = Loss::FirstOrderGradient(p, y); - avx::Float8 hess = Loss::SecondOrderGradient(p, y); - avx::StoreGpair(gpair_ptr + i, grad * w, hess * w); - } - for (omp_ulong i = n - remainder; i < n; ++i) { - auto y = labels[i]; - bst_float p = Loss::PredTransform(preds_h[i]); - bst_float w = info.GetWeight(i); - w += y * ((param_.scale_pos_weight * w) - w); - gpair[i] = GradientPair(Loss::FirstOrderGradient(p, y) * w, - Loss::SecondOrderGradient(p, y) * w); - } - } - const char *DefaultEvalMetric() const override { - return Loss::DefaultEvalMetric(); - } - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - const auto ndata = static_cast(preds.size()); -#pragma omp parallel for schedule(static) - for (bst_omp_uint j = 0; j < ndata; ++j) { - preds[j] = Loss::PredTransform(preds[j]); - } - } - bst_float ProbToMargin(bst_float base_score) const override { - return Loss::ProbToMargin(base_score); - } - - protected: - void LazyCheckLabels(const std::vector &labels) { - if (labels_checked_) return; - for (auto &y : labels) { - CHECK(Loss::CheckLabel(y)) << Loss::LabelErrorMsg(); - } - labels_checked_ = true; - } - RegLossParam param_; - bool labels_checked_{false}; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(RegLossParam); - -XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") -.describe("Linear regression.") -.set_body([]() { return new RegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic") -.describe("Logistic regression for probability regression task.") -.set_body([]() { return new RegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic") -.describe("Logistic regression for binary classification task.") -.set_body([]() { return new RegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw") -.describe("Logistic regression for classification, output score before logistic transformation") -.set_body([]() { return new RegLossObj(); }); - -// declare parameter -struct PoissonRegressionParam : public dmlc::Parameter { - float max_delta_step; - DMLC_DECLARE_PARAMETER(PoissonRegressionParam) { - DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f) - .describe("Maximum delta step we allow each weight estimation to be." \ - " This parameter is required for possion regression."); - } -}; - -// poisson regression for count -class PoissonRegression : public ObjFunction { - public: - // declare functions - void Configure(const std::vector >& args) override { - param_.InitAllowUnknown(args); - } - - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; - const auto& preds_h = preds.HostVector(); - out_gpair->Resize(preds.Size()); - auto& gpair = out_gpair->HostVector(); - const auto& labels = info.labels_.HostVector(); - // check if label in range - bool label_correct = true; - // start calculating gradient - const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) -#pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) - bst_float p = preds_h[i]; - bst_float w = info.GetWeight(i); - bst_float y = labels[i]; - if (y >= 0.0f) { - gpair[i] = GradientPair((std::exp(p) - y) * w, - std::exp(p + param_.max_delta_step) * w); - } else { - label_correct = false; - } - } - CHECK(label_correct) << "PoissonRegression: label must be nonnegative"; - } - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - const long ndata = static_cast(preds.size()); // NOLINT(*) -#pragma omp parallel for schedule(static) - for (long j = 0; j < ndata; ++j) { // NOLINT(*) - preds[j] = std::exp(preds[j]); - } - } - void EvalTransform(HostDeviceVector *io_preds) override { - PredTransform(io_preds); - } - bst_float ProbToMargin(bst_float base_score) const override { - return std::log(base_score); - } - const char* DefaultEvalMetric() const override { - return "poisson-nloglik"; - } - - private: - PoissonRegressionParam param_; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(PoissonRegressionParam); - -XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") -.describe("Possion regression for count data.") -.set_body([]() { return new PoissonRegression(); }); - -// cox regression for survival data (negative values mean they are censored) -class CoxRegression : public ObjFunction { - public: - // declare functions - void Configure(const std::vector >& args) override {} - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; - const auto& preds_h = preds.HostVector(); - out_gpair->Resize(preds_h.size()); - auto& gpair = out_gpair->HostVector(); - const std::vector &label_order = info.LabelAbsSort(); - - const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) - - // pre-compute a sum - double exp_p_sum = 0; // we use double because we might need the precision with large datasets - for (omp_ulong i = 0; i < ndata; ++i) { - exp_p_sum += std::exp(preds_h[label_order[i]]); - } - - // start calculating grad and hess - const auto& labels = info.labels_.HostVector(); - double r_k = 0; - double s_k = 0; - double last_exp_p = 0.0; - double last_abs_y = 0.0; - double accumulated_sum = 0; - for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) - const size_t ind = label_order[i]; - const double p = preds_h[ind]; - const double exp_p = std::exp(p); - const double w = info.GetWeight(ind); - const double y = labels[ind]; - const double abs_y = std::abs(y); - - // only update the denominator after we move forward in time (labels are sorted) - // this is Breslow's method for ties - accumulated_sum += last_exp_p; - if (last_abs_y < abs_y) { - exp_p_sum -= accumulated_sum; - accumulated_sum = 0; - } else { - CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " << - "MetaInfo::LabelArgsort failed!"; - } - - if (y > 0) { - r_k += 1.0/exp_p_sum; - s_k += 1.0/(exp_p_sum*exp_p_sum); - } - - const double grad = exp_p*r_k - static_cast(y > 0); - const double hess = exp_p*r_k - exp_p*exp_p * s_k; - gpair.at(ind) = GradientPair(grad * w, hess * w); - - last_abs_y = abs_y; - last_exp_p = exp_p; - } - } - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - const long ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) - for (long j = 0; j < ndata; ++j) { // NOLINT(*) - preds[j] = std::exp(preds[j]); - } - } - void EvalTransform(HostDeviceVector *io_preds) override { - PredTransform(io_preds); - } - bst_float ProbToMargin(bst_float base_score) const override { - return std::log(base_score); - } - const char* DefaultEvalMetric() const override { - return "cox-nloglik"; - } -}; - -// register the objective function -XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") -.describe("Cox regression for censored survival data (negative labels are considered censored).") -.set_body([]() { return new CoxRegression(); }); - -// gamma regression -class GammaRegression : public ObjFunction { - public: - // declare functions - void Configure(const std::vector >& args) override { - } - - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; - const auto& preds_h = preds.HostVector(); - out_gpair->Resize(preds_h.size()); - auto& gpair = out_gpair->HostVector(); - const auto& labels = info.labels_.HostVector(); - // check if label in range - bool label_correct = true; - // start calculating gradient - const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) - bst_float p = preds_h[i]; - bst_float w = info.GetWeight(i); - bst_float y = labels[i]; - if (y >= 0.0f) { - gpair[i] = GradientPair((1 - y / std::exp(p)) * w, y / std::exp(p) * w); - } else { - label_correct = false; - } - } - CHECK(label_correct) << "GammaRegression: label must be positive"; - } - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - const long ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) - for (long j = 0; j < ndata; ++j) { // NOLINT(*) - preds[j] = std::exp(preds[j]); - } - } - void EvalTransform(HostDeviceVector *io_preds) override { - PredTransform(io_preds); - } - bst_float ProbToMargin(bst_float base_score) const override { - return std::log(base_score); - } - const char* DefaultEvalMetric() const override { - return "gamma-nloglik"; - } -}; - -// register the objective functions -XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma") -.describe("Gamma regression for severity data.") -.set_body([]() { return new GammaRegression(); }); - -// declare parameter -struct TweedieRegressionParam : public dmlc::Parameter { - float tweedie_variance_power; - DMLC_DECLARE_PARAMETER(TweedieRegressionParam) { - DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f) - .describe("Tweedie variance power. Must be between in range [1, 2)."); - } -}; - -// tweedie regression -class TweedieRegression : public ObjFunction { - public: - // declare functions - void Configure(const std::vector >& args) override { - param_.InitAllowUnknown(args); - } - - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; - const auto& preds_h = preds.HostVector(); - out_gpair->Resize(preds.Size()); - auto& gpair = out_gpair->HostVector(); - const auto& labels = info.labels_.HostVector(); - // check if label in range - bool label_correct = true; - // start calculating gradient - const omp_ulong ndata = static_cast(preds.Size()); // NOLINT(*) - #pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) - bst_float p = preds_h[i]; - bst_float w = info.GetWeight(i); - bst_float y = labels[i]; - float rho = param_.tweedie_variance_power; - if (y >= 0.0f) { - bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p); - bst_float hess = -y * (1 - rho) * \ - std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p); - gpair[i] = GradientPair(grad * w, hess * w); - } else { - label_correct = false; - } - } - CHECK(label_correct) << "TweedieRegression: label must be nonnegative"; - } - void PredTransform(HostDeviceVector *io_preds) override { - std::vector &preds = io_preds->HostVector(); - const long ndata = static_cast(preds.size()); // NOLINT(*) -#pragma omp parallel for schedule(static) - for (long j = 0; j < ndata; ++j) { // NOLINT(*) - preds[j] = std::exp(preds[j]); - } - } - - bst_float ProbToMargin(bst_float base_score) const override { - return std::log(base_score); - } - - const char* DefaultEvalMetric() const override { - std::ostringstream os; - os << "tweedie-nloglik@" << param_.tweedie_variance_power; - std::string metric = os.str(); - return metric.c_str(); - } - - private: - TweedieRegressionParam param_; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(TweedieRegressionParam); - -XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") -.describe("Tweedie regression for insurance data.") -.set_body([]() { return new TweedieRegression(); }); } // namespace obj } // namespace xgboost + +#ifndef XGBOOST_USE_CUDA +#include "regression_obj.cu" +#endif diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu new file mode 100644 index 000000000000..e74c82af1c95 --- /dev/null +++ b/src/objective/regression_obj.cu @@ -0,0 +1,560 @@ +/*! + * Copyright 2015-2018 by Contributors + * \file regression_obj.cu + * \brief Definition of single-value regression and classification objectives. + * \author Tianqi Chen, Kailong Chen + */ + +#include +#include +#include +#include +#include +#include + +#include "../common/span.h" +#include "../common/transform.h" +#include "../common/common.h" +#include "../common/host_device_vector.h" +#include "./regression_loss.h" + + +namespace xgboost { +namespace obj { + +#if defined(XGBOOST_USE_CUDA) +DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); +#endif + +struct RegLossParam : public dmlc::Parameter { + float scale_pos_weight; + int n_gpus; + int gpu_id; + // declare parameters + DMLC_DECLARE_PARAMETER(RegLossParam) { + DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) + .describe("Scale the weight of positive examples by this factor"); + DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +template +class RegLossObj : public ObjFunction { + protected: + HostDeviceVector label_correct_; + + public: + RegLossObj() = default; + + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + HostDeviceVector* out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) + << "labels are not correctly provided" + << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); + size_t ndata = preds.Size(); + out_gpair->Resize(ndata); + label_correct_.Fill(1); + + bool is_null_weight = info.weights_.Size() == 0; + auto scale_pos_weight = param_.scale_pos_weight; + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _label_correct, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + bst_float p = Loss::PredTransform(_preds[_idx]); + bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float label = _labels[_idx]; + if (label == 1.0f) { + w *= scale_pos_weight; + } + if (!Loss::CheckLabel(label)) { + // If there is an incorrect label, the host code will know. + _label_correct[0] = 0; + } + _out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, + Loss::SecondOrderGradient(p, label) * w); + }, + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + + // copy "label correct" flags back to host + std::vector& label_correct_h = label_correct_.HostVector(); + for (auto const flag : label_correct_h) { + if (flag == 0) { + LOG(FATAL) << Loss::LabelErrorMsg(); + } + } + } + + public: + const char* DefaultEvalMetric() const override { + return Loss::DefaultEvalMetric(); + } + + void PredTransform(HostDeviceVector *io_preds) override { + common::Transform<>::Init( + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = Loss::PredTransform(_preds[_idx]); + }, common::Range{0, static_cast(io_preds->Size())}, + devices_).Eval(io_preds); + } + + float ProbToMargin(float base_score) const override { + return Loss::ProbToMargin(base_score); + } + + protected: + RegLossParam param_; + GPUSet devices_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(RegLossParam); + +XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") +.describe("Linear regression.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic") +.describe("Logistic regression for probability regression task.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic") +.describe("Logistic regression for binary classification task.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw") +.describe("Logistic regression for classification, output score " + "before logistic transformation.") +.set_body([]() { return new RegLossObj(); }); + +// Deprecated GPU functions +XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear") +.describe("Deprecated. Linear regression (computed on GPU).") +.set_body([]() { + LOG(WARNING) << "gpu:reg:linear is now deprecated, use reg:linear instead."; + return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic") +.describe("Deprecated. Logistic regression for probability regression task (computed on GPU).") +.set_body([]() { + LOG(WARNING) << "gpu:reg:logistic is now deprecated, use reg:logistic instead."; + return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic") +.describe("Deprecated. Logistic regression for binary classification task (computed on GPU).") +.set_body([]() { + LOG(WARNING) << "gpu:binary:logistic is now deprecated, use binary:logistic instead."; + return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw") +.describe("Deprecated. Logistic regression for classification, output score " + "before logistic transformation (computed on GPU)") +.set_body([]() { + LOG(WARNING) << "gpu:binary:logitraw is now deprecated, use binary:logitraw instead."; + return new RegLossObj(); }); +// End deprecated + +// declare parameter +struct PoissonRegressionParam : public dmlc::Parameter { + float max_delta_step; + int n_gpus; + int gpu_id; + DMLC_DECLARE_PARAMETER(PoissonRegressionParam) { + DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f) + .describe("Maximum delta step we allow each weight estimation to be." \ + " This parameter is required for possion regression."); + DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +// poisson regression for count +class PoissonRegression : public ObjFunction { + public: + // declare functions + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + HostDeviceVector *out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + size_t ndata = preds.Size(); + out_gpair->Resize(ndata); + label_correct_.Fill(1); + + bool is_null_weight = info.weights_.Size() == 0; + bst_float max_delta_step = param_.max_delta_step; + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _label_correct, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + bst_float p = _preds[_idx]; + bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float y = _labels[_idx]; + if (y < 0.0f) { + _label_correct[0] = 0; + } + _out_gpair[_idx] = GradientPair{(expf(p) - y) * w, + expf(p + max_delta_step) * w}; + }, + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + // copy "label correct" flags back to host + std::vector& label_correct_h = label_correct_.HostVector(); + for (auto const flag : label_correct_h) { + if (flag == 0) { + LOG(FATAL) << "PoissonRegression: label must be nonnegative"; + } + } + } + void PredTransform(HostDeviceVector *io_preds) override { + common::Transform<>::Init( + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = expf(_preds[_idx]); + }, + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); + } + void EvalTransform(HostDeviceVector *io_preds) override { + PredTransform(io_preds); + } + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + const char* DefaultEvalMetric() const override { + return "poisson-nloglik"; + } + + private: + GPUSet devices_; + PoissonRegressionParam param_; + HostDeviceVector label_correct_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(PoissonRegressionParam); + +XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") +.describe("Possion regression for count data.") +.set_body([]() { return new PoissonRegression(); }); + + +// cox regression for survival data (negative values mean they are censored) +class CoxRegression : public ObjFunction { + public: + // declare functions + void Configure(const std::vector >& args) override {} + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + HostDeviceVector *out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const auto& preds_h = preds.HostVector(); + out_gpair->Resize(preds_h.size()); + auto& gpair = out_gpair->HostVector(); + const std::vector &label_order = info.LabelAbsSort(); + + const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) + + // pre-compute a sum + double exp_p_sum = 0; // we use double because we might need the precision with large datasets + for (omp_ulong i = 0; i < ndata; ++i) { + exp_p_sum += std::exp(preds_h[label_order[i]]); + } + + // start calculating grad and hess + const auto& labels = info.labels_.HostVector(); + double r_k = 0; + double s_k = 0; + double last_exp_p = 0.0; + double last_abs_y = 0.0; + double accumulated_sum = 0; + for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) + const size_t ind = label_order[i]; + const double p = preds_h[ind]; + const double exp_p = std::exp(p); + const double w = info.GetWeight(ind); + const double y = labels[ind]; + const double abs_y = std::abs(y); + + // only update the denominator after we move forward in time (labels are sorted) + // this is Breslow's method for ties + accumulated_sum += last_exp_p; + if (last_abs_y < abs_y) { + exp_p_sum -= accumulated_sum; + accumulated_sum = 0; + } else { + CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " << + "MetaInfo::LabelArgsort failed!"; + } + + if (y > 0) { + r_k += 1.0/exp_p_sum; + s_k += 1.0/(exp_p_sum*exp_p_sum); + } + + const double grad = exp_p*r_k - static_cast(y > 0); + const double hess = exp_p*r_k - exp_p*exp_p * s_k; + gpair.at(ind) = GradientPair(grad * w, hess * w); + + last_abs_y = abs_y; + last_exp_p = exp_p; + } + } + void PredTransform(HostDeviceVector *io_preds) override { + std::vector &preds = io_preds->HostVector(); + const long ndata = static_cast(preds.size()); // NOLINT(*) +#pragma omp parallel for schedule(static) + for (long j = 0; j < ndata; ++j) { // NOLINT(*) + preds[j] = std::exp(preds[j]); + } + } + void EvalTransform(HostDeviceVector *io_preds) override { + PredTransform(io_preds); + } + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + const char* DefaultEvalMetric() const override { + return "cox-nloglik"; + } +}; + +// register the objective function +XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") +.describe("Cox regression for censored survival data (negative labels are considered censored).") +.set_body([]() { return new CoxRegression(); }); + + +struct GammaRegressionParam : public dmlc::Parameter { + int n_gpus; + int gpu_id; + DMLC_DECLARE_PARAMETER(GammaRegressionParam) { + DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +// gamma regression +class GammaRegression : public ObjFunction { + public: + // declare functions + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + + void GetGradient(const HostDeviceVector &preds, + const MetaInfo &info, + int iter, + HostDeviceVector *out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const size_t ndata = preds.Size(); + out_gpair->Resize(ndata); + label_correct_.Fill(1); + + const bool is_null_weight = info.weights_.Size() == 0; + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _label_correct, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + bst_float p = _preds[_idx]; + bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float y = _labels[_idx]; + if (y < 0.0f) { + _label_correct[0] = 0; + } + _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); + }, + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + + // copy "label correct" flags back to host + std::vector& label_correct_h = label_correct_.HostVector(); + for (auto const flag : label_correct_h) { + if (flag == 0) { + LOG(FATAL) << "GammaRegression: label must be nonnegative"; + } + } + } + void PredTransform(HostDeviceVector *io_preds) override { + common::Transform<>::Init( + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = expf(_preds[_idx]); + }, + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); + } + void EvalTransform(HostDeviceVector *io_preds) override { + PredTransform(io_preds); + } + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + const char* DefaultEvalMetric() const override { + return "gamma-nloglik"; + } + + private: + GPUSet devices_; + GammaRegressionParam param_; + HostDeviceVector label_correct_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(GammaRegressionParam); +// register the objective functions +XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma") +.describe("Gamma regression for severity data.") +.set_body([]() { return new GammaRegression(); }); + + +// declare parameter +struct TweedieRegressionParam : public dmlc::Parameter { + float tweedie_variance_power; + int n_gpus; + int gpu_id; + DMLC_DECLARE_PARAMETER(TweedieRegressionParam) { + DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f) + .describe("Tweedie variance power. Must be between in range [1, 2)."); + DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + } +}; + +// tweedie regression +class TweedieRegression : public ObjFunction { + public: + // declare functions + void Configure(const std::vector >& args) override { + param_.InitAllowUnknown(args); + CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + HostDeviceVector *out_gpair) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const size_t ndata = preds.Size(); + out_gpair->Resize(ndata); + label_correct_.Fill(1); + + const bool is_null_weight = info.weights_.Size() == 0; + const float rho = param_.tweedie_variance_power; + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _label_correct, + common::Span _out_gpair, + common::Span _preds, + common::Span _labels, + common::Span _weights) { + bst_float p = _preds[_idx]; + bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float y = _labels[_idx]; + if (y < 0.0f) { + _label_correct[0] = 0; + } + bst_float grad = -y * expf((1 - rho) * p) + expf((2 - rho) * p); + bst_float hess = + -y * (1 - rho) * \ + std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); + _out_gpair[_idx] = GradientPair(grad * w, hess * w); + }, + common::Range{0, static_cast(ndata), 1}, devices_) + .Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + + // copy "label correct" flags back to host + std::vector& label_correct_h = label_correct_.HostVector(); + for (auto const flag : label_correct_h) { + if (flag == 0) { + LOG(FATAL) << "TweedieRegression: label must be nonnegative"; + } + } + } + void PredTransform(HostDeviceVector *io_preds) override { + common::Transform<>::Init( + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = expf(_preds[_idx]); + }, + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); + } + + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + + const char* DefaultEvalMetric() const override { + std::ostringstream os; + os << "tweedie-nloglik@" << param_.tweedie_variance_power; + std::string metric = os.str(); + return metric.c_str(); + } + + private: + GPUSet devices_; + TweedieRegressionParam param_; + HostDeviceVector label_correct_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(TweedieRegressionParam); + +XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") +.describe("Tweedie regression for insurance data.") +.set_body([]() { return new TweedieRegression(); }); + +} // namespace obj +} // namespace xgboost diff --git a/src/objective/regression_obj_gpu.cu b/src/objective/regression_obj_gpu.cu deleted file mode 100644 index ab1a11a72766..000000000000 --- a/src/objective/regression_obj_gpu.cu +++ /dev/null @@ -1,202 +0,0 @@ -/*! - * Copyright 2017 XGBoost contributors - */ -// GPU implementation of objective function. -// Necessary to avoid extra copying of data to CPU. -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../common/span.h" -#include "../common/device_helpers.cuh" -#include "../common/host_device_vector.h" -#include "./regression_loss.h" - - -namespace xgboost { -namespace obj { - -using dh::DVec; - -DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); - -struct GPURegLossParam : public dmlc::Parameter { - float scale_pos_weight; - int n_gpus; - int gpu_id; - // declare parameters - DMLC_DECLARE_PARAMETER(GPURegLossParam) { - DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) - .describe("Scale the weight of positive examples by this factor"); - DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1) - .describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)"); - DMLC_DECLARE_FIELD(gpu_id) - .set_lower_bound(0) - .set_default(0) - .describe("gpu to use for objective function evaluation"); - } -}; - -// GPU kernel for gradient computation -template -__global__ void get_gradient_k -(common::Span out_gpair, common::Span label_correct, - common::Span preds, common::Span labels, - const float * __restrict__ weights, int n, float scale_pos_weight) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= n) - return; - float p = Loss::PredTransform(preds[i]); - float w = weights == nullptr ? 1.0f : weights[i]; - float label = labels[i]; - if (label == 1.0f) - w *= scale_pos_weight; - if (!Loss::CheckLabel(label)) - atomicAnd(label_correct.data(), 0); - out_gpair[i] = GradientPair - (Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w); -} - -// GPU kernel for predicate transformation -template -__global__ void pred_transform_k(common::Span preds, int n) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= n) - return; - preds[i] = Loss::PredTransform(preds[i]); -} - -// regression loss function for evaluation on GPU (eventually) -template -class GPURegLossObj : public ObjFunction { - protected: - HostDeviceVector label_correct_; - - // allocate device data for n elements, do nothing if memory is allocated already - void LazyResize() { - } - - public: - GPURegLossObj() {} - - void Configure(const std::vector >& args) override { - param_.InitAllowUnknown(args); - CHECK(param_.n_gpus != 0) << "Must have at least one device"; - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); - label_correct_.Reshard(devices_); - label_correct_.Resize(devices_.Size()); - } - - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, - int iter, - HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); - size_t ndata = preds.Size(); - preds.Reshard(devices_); - info.labels_.Reshard(devices_); - info.weights_.Reshard(devices_); - out_gpair->Reshard(devices_); - out_gpair->Resize(ndata); - GetGradientDevice(preds, info, iter, out_gpair); - } - - private: - void GetGradientDevice(const HostDeviceVector& preds, - const MetaInfo &info, - int iter, - HostDeviceVector* out_gpair) { - label_correct_.Fill(1); - - // run the kernel -#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1) - for (int i = 0; i < devices_.Size(); ++i) { - int d = devices_[i]; - dh::safe_cuda(cudaSetDevice(d)); - const int block = 256; - size_t n = preds.DeviceSize(d); - if (n > 0) { - get_gradient_k<<>> - (out_gpair->DeviceSpan(d), label_correct_.DeviceSpan(d), - preds.DeviceSpan(d), info.labels_.DeviceSpan(d), - info.weights_.Size() > 0 ? info.weights_.DevicePointer(d) : nullptr, - n, param_.scale_pos_weight); - dh::safe_cuda(cudaGetLastError()); - } - dh::safe_cuda(cudaDeviceSynchronize()); - } - - // copy "label correct" flags back to host - std::vector& label_correct_h = label_correct_.HostVector(); - for (int i = 0; i < devices_.Size(); ++i) { - if (label_correct_h[i] == 0) - LOG(FATAL) << Loss::LabelErrorMsg(); - } - } - - public: - const char* DefaultEvalMetric() const override { - return Loss::DefaultEvalMetric(); - } - - void PredTransform(HostDeviceVector *io_preds) override { - io_preds->Reshard(devices_); - size_t ndata = io_preds->Size(); - PredTransformDevice(io_preds); - } - - void PredTransformDevice(HostDeviceVector* preds) { -#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1) - for (int i = 0; i < devices_.Size(); ++i) { - int d = devices_[i]; - dh::safe_cuda(cudaSetDevice(d)); - const int block = 256; - size_t n = preds->DeviceSize(d); - if (n > 0) { - pred_transform_k<<>>( - preds->DeviceSpan(d), n); - dh::safe_cuda(cudaGetLastError()); - } - dh::safe_cuda(cudaDeviceSynchronize()); - } - } - - float ProbToMargin(float base_score) const override { - return Loss::ProbToMargin(base_score); - } - - protected: - GPURegLossParam param_; - GPUSet devices_; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(GPURegLossParam); - -XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear") -.describe("Linear regression (computed on GPU).") -.set_body([]() { return new GPURegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic") -.describe("Logistic regression for probability regression task (computed on GPU).") -.set_body([]() { return new GPURegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic") -.describe("Logistic regression for binary classification task (computed on GPU).") -.set_body([]() { return new GPURegLossObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw") -.describe("Logistic regression for classification, output score " - "before logistic transformation (computed on GPU)") -.set_body([]() { return new GPURegLossObj(); }); - -} // namespace obj -} // namespace xgboost diff --git a/tests/cpp/common/test_gpu_compressed_iterator.cu b/tests/cpp/common/test_gpu_compressed_iterator.cu index b462b78a5751..4c53527a6513 100644 --- a/tests/cpp/common/test_gpu_compressed_iterator.cu +++ b/tests/cpp/common/test_gpu_compressed_iterator.cu @@ -14,7 +14,7 @@ struct WriteSymbolFunction { WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d, int* input_data_d) : cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {} - + __device__ void operator()(size_t i) { cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i); } @@ -28,7 +28,7 @@ struct ReadSymbolFunction { __device__ void operator()(size_t i) { output_data_d[i] = ci[i]; - } + } }; TEST(CompressedIterator, TestGPU) { diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index f8f550687506..79945ff919b2 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -10,7 +10,7 @@ namespace xgboost { namespace common { -TEST(gpu_hist_util, TestDeviceSketch) { +void TestDeviceSketch(const GPUSet& devices) { // create the data int nrows = 10001; std::vector test_data(nrows); @@ -28,7 +28,7 @@ TEST(gpu_hist_util, TestDeviceSketch) { tree::TrainParam p; p.max_bin = 20; p.gpu_id = 0; - p.n_gpus = GPUSet::AllVisible().Size(); + p.n_gpus = devices.Size(); // ensure that the exact quantiles are found p.gpu_batch_nrows = nrows * 10; @@ -58,5 +58,17 @@ TEST(gpu_hist_util, TestDeviceSketch) { delete dmat; } +TEST(gpu_hist_util, DeviceSketch) { + TestDeviceSketch(GPUSet::Range(0, 1)); +} + +#if defined(XGBOOST_USE_NCCL) +TEST(gpu_hist_util, MGPU_DeviceSketch) { + auto devices = GPUSet::AllVisible(); + CHECK_GT(devices.Size(), 1); + TestDeviceSketch(devices); +} +#endif + } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index 5f6252f21ded..bac9b026a043 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -178,18 +178,57 @@ TEST(HostDeviceVector, TestCopy) { SetCudaSetDeviceHandler(nullptr); } -// The test is not really useful if n_gpus < 2 TEST(HostDeviceVector, Reshard) { std::vector h_vec (2345); for (size_t i = 0; i < h_vec.size(); ++i) { h_vec[i] = i; } HostDeviceVector vec (h_vec); + auto devices = GPUSet::Range(0, 1); + + vec.Reshard(devices); + ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); + ASSERT_EQ(vec.Size(), h_vec.size()); + auto span = vec.DeviceSpan(0); // sync to device + + vec.Reshard(GPUSet::Empty()); // pull back to cpu, empty devices. + ASSERT_EQ(vec.Size(), h_vec.size()); + ASSERT_TRUE(vec.Devices().IsEmpty()); + + auto h_vec_1 = vec.HostVector(); + ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); +} + +TEST(HostDeviceVector, Span) { + HostDeviceVector vec {1.0f, 2.0f, 3.0f, 4.0f}; + vec.Reshard(GPUSet{0, 1}); + auto span = vec.DeviceSpan(0); + ASSERT_EQ(vec.DeviceSize(0), span.size()); + ASSERT_EQ(vec.DevicePointer(0), span.data()); + auto const_span = vec.ConstDeviceSpan(0); + ASSERT_EQ(vec.DeviceSize(0), span.size()); + ASSERT_EQ(vec.ConstDevicePointer(0), span.data()); +} + +// Multi-GPUs' test +#if defined(XGBOOST_USE_NCCL) +TEST(HostDeviceVector, MGPU_Reshard) { auto devices = GPUSet::AllVisible(); + if (devices.Size() < 2) { + LOG(WARNING) << "Not testing in multi-gpu environment."; + return; + } + + std::vector h_vec (2345); + for (size_t i = 0; i < h_vec.size(); ++i) { + h_vec[i] = i; + } + HostDeviceVector vec (h_vec); + + // Data size for each device. std::vector devices_size (devices.Size()); // From CPU to GPUs. - // Assuming we have > 1 devices. vec.Reshard(devices); size_t total_size = 0; for (size_t i = 0; i < devices.Size(); ++i) { @@ -198,42 +237,26 @@ TEST(HostDeviceVector, Reshard) { } ASSERT_EQ(total_size, h_vec.size()); ASSERT_EQ(total_size, vec.Size()); - auto h_vec_1 = vec.HostVector(); - ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); - vec.Reshard(GPUSet::Empty()); // clear out devices memory + // Reshard from devices to devices with different distribution. + EXPECT_ANY_THROW( + vec.Reshard(GPUDistribution::Granular(devices, 12))); - // Shrink down the number of devices. - vec.Reshard(GPUSet::Range(0, 1)); + // All data is drawn back to CPU + vec.Reshard(GPUSet::Empty()); + ASSERT_TRUE(vec.Devices().IsEmpty()); ASSERT_EQ(vec.Size(), h_vec.size()); - ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); - h_vec_1 = vec.HostVector(); - ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); - vec.Reshard(GPUSet::Empty()); // clear out devices memory - // Grow the number of devices. - vec.Reshard(devices); + vec.Reshard(GPUDistribution::Granular(devices, 12)); total_size = 0; for (size_t i = 0; i < devices.Size(); ++i) { total_size += vec.DeviceSize(i); - ASSERT_EQ(devices_size[i], vec.DeviceSize(i)); + devices_size[i] = vec.DeviceSize(i); } ASSERT_EQ(total_size, h_vec.size()); ASSERT_EQ(total_size, vec.Size()); - h_vec_1 = vec.HostVector(); - ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); -} - -TEST(HostDeviceVector, Span) { - HostDeviceVector vec {1.0f, 2.0f, 3.0f, 4.0f}; - vec.Reshard(GPUSet{0, 1}); - auto span = vec.DeviceSpan(0); - ASSERT_EQ(vec.Size(), span.size()); - ASSERT_EQ(vec.DevicePointer(0), span.data()); - auto const_span = vec.ConstDeviceSpan(0); - ASSERT_EQ(vec.Size(), span.size()); - ASSERT_EQ(vec.ConstDevicePointer(0), span.data()); } +#endif } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_span.h b/tests/cpp/common/test_span.h index 194a356ce9ca..bfefb6a91b23 100644 --- a/tests/cpp/common/test_span.h +++ b/tests/cpp/common/test_span.h @@ -7,6 +7,14 @@ #include "../../include/xgboost/base.h" #include "../../../src/common/span.h" +template +XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) { + float j = 0; + for (Iter i = _begin; i != _end; ++i, ++j) { + *i = j; + } +} + namespace xgboost { namespace common { @@ -20,14 +28,6 @@ namespace common { *(status) = -1; \ } -template -XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) { - float j = 0; - for (Iter i = _begin; i != _end; ++i, ++j) { - *i = j; - } -} - struct TestTestStatus { int * status_; diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc new file mode 100644 index 000000000000..b09e4cd0d28e --- /dev/null +++ b/tests/cpp/common/test_transform_range.cc @@ -0,0 +1,61 @@ +#include +#include +#include + +#include "../../../src/common/host_device_vector.h" +#include "../../../src/common/transform.h" +#include "../../../src/common/span.h" +#include "../helpers.h" + +#if defined(__CUDACC__) + +#define TRANSFORM_GPU_RANGE GPUSet::Range(0, 1) +#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Range(0, 1)) + +#else + +#define TRANSFORM_GPU_RANGE GPUSet::Empty() +#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Empty()) + +#endif + +template +void InitializeRange(Iter _begin, Iter _end) { + float j = 0; + for (Iter i = _begin; i != _end; ++i, ++j) { + *i = j; + } +} + +namespace xgboost { +namespace common { + +template +struct TestTransformRange { + void XGBOOST_DEVICE operator()(size_t _idx, + Span _out, Span _in) { + _out[_idx] = _in[_idx]; + } +}; + +TEST(Transform, DeclareUnifiedTest(Basic)) { + const size_t size {256}; + std::vector h_in(size); + std::vector h_out(size); + InitializeRange(h_in.begin(), h_in.end()); + std::vector h_sol(size); + InitializeRange(h_sol.begin(), h_sol.end()); + + const HostDeviceVector in_vec{h_in, TRANSFORM_GPU_DIST}; + HostDeviceVector out_vec{h_out, TRANSFORM_GPU_DIST}; + out_vec.Fill(0); + + Transform<>::Init(TestTransformRange{}, Range{0, size}, TRANSFORM_GPU_RANGE) + .Eval(&out_vec, &in_vec); + std::vector res = out_vec.HostVector(); + + ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_transform_range.cu b/tests/cpp/common/test_transform_range.cu new file mode 100644 index 000000000000..39517cedcda2 --- /dev/null +++ b/tests/cpp/common/test_transform_range.cu @@ -0,0 +1,43 @@ +// This converts all tests from CPU to GPU. +#include "test_transform_range.cc" + +#if defined(XGBOOST_USE_NCCL) +namespace xgboost { +namespace common { + +// Test here is multi gpu specific +TEST(Transform, MGPU_Basic) { + auto devices = GPUSet::AllVisible(); + CHECK_GT(devices.Size(), 1); + const size_t size {256}; + std::vector h_in(size); + std::vector h_out(size); + InitializeRange(h_in.begin(), h_in.end()); + std::vector h_sol(size); + InitializeRange(h_sol.begin(), h_sol.end()); + + const HostDeviceVector in_vec {h_in, + GPUDistribution::Block(GPUSet::Empty())}; + HostDeviceVector out_vec {h_out, + GPUDistribution::Block(GPUSet::Empty())}; + out_vec.Fill(0); + + in_vec.Reshard(GPUDistribution::Granular(devices, 8)); + out_vec.Reshard(GPUDistribution::Block(devices)); + + // Granularity is different, resharding will throw. + EXPECT_ANY_THROW( + Transform<>::Init(TestTransformRange{}, Range{0, size}, devices) + .Eval(&out_vec, &in_vec)); + + + Transform<>::Init(TestTransformRange{}, Range{0, size}, + devices, false).Eval(&out_vec, &in_vec); + std::vector res = out_vec.HostVector(); + + ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); +} + +} // namespace xgboost +} // namespace common +#endif \ No newline at end of file diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 9f13b3868432..b8c897eac2f7 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -125,3 +125,17 @@ std::shared_ptr* CreateDMatrix(int rows, int columns, &handle); return static_cast *>(handle); } + +namespace xgboost { +bool IsNear(std::vector::const_iterator _beg1, + std::vector::const_iterator _end1, + std::vector::const_iterator _beg2) { + for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) { + if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){ + return false; + } + } + return true; +} +} + diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 998249a53735..bdc0e81f97e4 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -15,6 +15,12 @@ #include #include +#if defined(__CUDACC__) +#define DeclareUnifiedTest(name) GPU ## name +#else +#define DeclareUnifiedTest(name) name +#endif + std::string TempFileName(); bool FileExists(const std::string name); @@ -46,6 +52,12 @@ xgboost::bst_float GetMetricEval( std::vector labels, std::vector weights = std::vector ()); +namespace xgboost { +bool IsNear(std::vector::const_iterator _beg1, + std::vector::const_iterator _end1, + std::vector::const_iterator _beg2); +} + /** * \fn std::shared_ptr CreateDMatrix(int rows, int columns, float sparsity, int seed); * diff --git a/tests/cpp/objective/test_hinge.cc b/tests/cpp/objective/test_hinge.cc index 5986ca620035..cbfd70e80d68 100644 --- a/tests/cpp/objective/test_hinge.cc +++ b/tests/cpp/objective/test_hinge.cc @@ -4,7 +4,7 @@ #include "../helpers.h" -TEST(Objective, HingeObj) { +TEST(Objective, DeclareUnifiedTest(HingeObj)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge"); std::vector > args; obj->Configure(args); @@ -15,6 +15,12 @@ TEST(Objective, HingeObj) { { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, { 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}, { eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps }); + CheckObjFunction(obj, + {-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}, + { 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {}, // Empty weight. + { 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}, + { eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps }); ASSERT_NO_THROW(obj->DefaultEvalMetric()); diff --git a/tests/cpp/objective/test_hinge.cu b/tests/cpp/objective/test_hinge.cu new file mode 100644 index 000000000000..9decd79a43e3 --- /dev/null +++ b/tests/cpp/objective/test_hinge.cu @@ -0,0 +1 @@ +#include "test_hinge.cc" diff --git a/tests/cpp/objective/test_multiclass_obj.cc b/tests/cpp/objective/test_multiclass_obj.cc new file mode 100644 index 000000000000..3ff229bdd5a1 --- /dev/null +++ b/tests/cpp/objective/test_multiclass_obj.cc @@ -0,0 +1,60 @@ +/*! + * Copyright 2018 XGBoost contributors + */ +#include + +#include "../helpers.h" + +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax"); + std::vector> args {{"num_class", "3"}}; + obj->Configure(args); + CheckObjFunction(obj, + {1, 0, 2, 2, 0, 1}, // preds + {1.0, 0.0}, // labels + {1.0, 1.0}, // weights + {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad + {0.36, 0.16, 0.44, 0.45, 0.16, 0.37}); // hess + + ASSERT_NO_THROW(obj->DefaultEvalMetric()); + + delete obj; +} + +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax"); + std::vector> args + {std::pair("num_class", "3")}; + obj->Configure(args); + + xgboost::HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 2.0f}; + std::vector out_preds = {0.0f, 2.0f}; + obj->PredTransform(&io_preds); + + auto& preds = io_preds.HostVector(); + + for (int i = 0; i < static_cast(io_preds.Size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } + + delete obj; +} + +TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softprob"); + std::vector> args + {std::pair("num_class", "3")}; + obj->Configure(args); + + xgboost::HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f}; + std::vector out_preds = {0.66524096f, 0.09003057f, 0.24472847f}; + + obj->PredTransform(&io_preds); + auto& preds = io_preds.HostVector(); + + for (int i = 0; i < static_cast(io_preds.Size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } + delete obj; +} diff --git a/tests/cpp/objective/test_multiclass_obj_gpu.cu b/tests/cpp/objective/test_multiclass_obj_gpu.cu new file mode 100644 index 000000000000..7567d3242296 --- /dev/null +++ b/tests/cpp/objective/test_multiclass_obj_gpu.cu @@ -0,0 +1 @@ +#include "test_multiclass_obj.cc" diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index fc31d4f968a1..7843cc12fa28 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -1,9 +1,11 @@ -// Copyright by Contributors +/*! + * Copyright 2017-2018 XGBoost contributors + */ #include #include "../helpers.h" -TEST(Objective, LinearRegressionGPair) { +TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:linear"); std::vector > args; obj->Configure(args); @@ -13,27 +15,32 @@ TEST(Objective, LinearRegressionGPair) { {1, 1, 1, 1, 1, 1, 1, 1}, {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, {1, 1, 1, 1, 1, 1, 1, 1}); - + CheckObjFunction(obj, + {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + {0, 0, 0, 0, 1, 1, 1, 1}, + {}, // empty weight + {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, + {1, 1, 1, 1, 1, 1, 1, 1}); ASSERT_NO_THROW(obj->DefaultEvalMetric()); delete obj; } -TEST(Objective, LogisticRegressionGPair) { +TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic"); std::vector > args; obj->Configure(args); CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, - {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds + { 0, 0, 0, 0, 1, 1, 1, 1}, // labels + { 1, 1, 1, 1, 1, 1, 1, 1}, // weights + { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, // out_grad + {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess delete obj; } -TEST(Objective, LogisticRegressionBasic) { +TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic"); std::vector > args; obj->Configure(args); @@ -61,7 +68,7 @@ TEST(Objective, LogisticRegressionBasic) { delete obj; } -TEST(Objective, LogisticRawGPair) { +TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:logitraw"); std::vector > args; obj->Configure(args); @@ -75,7 +82,7 @@ TEST(Objective, LogisticRawGPair) { delete obj; } -TEST(Objective, PoissonRegressionGPair) { +TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson"); std::vector > args; args.push_back(std::make_pair("max_delta_step", "0.1f")); @@ -86,11 +93,16 @@ TEST(Objective, PoissonRegressionGPair) { { 1, 1, 1, 1, 1, 1, 1, 1}, { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); - + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + {}, // Empty weight + { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, + {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); delete obj; } -TEST(Objective, PoissonRegressionBasic) { +TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson"); std::vector > args; obj->Configure(args); @@ -116,7 +128,7 @@ TEST(Objective, PoissonRegressionBasic) { delete obj; } -TEST(Objective, GammaRegressionGPair) { +TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma"); std::vector > args; obj->Configure(args); @@ -126,11 +138,16 @@ TEST(Objective, GammaRegressionGPair) { {1, 1, 1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f}, {0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f}); - + CheckObjFunction(obj, + {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + {0, 0, 0, 0, 1, 1, 1, 1}, + {}, // Empty weight + {1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f}, + {0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f}); delete obj; } -TEST(Objective, GammaRegressionBasic) { +TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma"); std::vector > args; obj->Configure(args); @@ -156,7 +173,7 @@ TEST(Objective, GammaRegressionBasic) { delete obj; } -TEST(Objective, TweedieRegressionGPair) { +TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie"); std::vector > args; args.push_back(std::make_pair("tweedie_variance_power", "1.1f")); @@ -167,11 +184,17 @@ TEST(Objective, TweedieRegressionGPair) { { 1, 1, 1, 1, 1, 1, 1, 1}, { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + {}, // Empty weight. + { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, + {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); delete obj; } -TEST(Objective, TweedieRegressionBasic) { +TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie"); std::vector > args; obj->Configure(args); @@ -197,6 +220,9 @@ TEST(Objective, TweedieRegressionBasic) { delete obj; } + +// CoxRegression not implemented in GPU code, no need for testing. +#if !defined(__CUDACC__) TEST(Objective, CoxRegressionGPair) { xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("survival:cox"); std::vector > args; @@ -210,3 +236,4 @@ TEST(Objective, CoxRegressionGPair) { delete obj; } +#endif diff --git a/tests/cpp/objective/test_regression_obj_gpu.cu b/tests/cpp/objective/test_regression_obj_gpu.cu index 6bcede3614d4..38f29b8a8800 100644 --- a/tests/cpp/objective/test_regression_obj_gpu.cu +++ b/tests/cpp/objective/test_regression_obj_gpu.cu @@ -1,78 +1,6 @@ /*! - * Copyright 2017 XGBoost contributors + * Copyright 2018 XGBoost contributors */ -#include +// Dummy file to keep the CUDA tests. -#include "../helpers.h" - -TEST(Objective, GPULinearRegressionGPair) { - xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear"); - std::vector > args; - obj->Configure(args); - CheckObjFunction(obj, - {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {0, 0, 0, 0, 1, 1, 1, 1}, - {1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}); - - ASSERT_NO_THROW(obj->DefaultEvalMetric()); - - delete obj; -} - -TEST(Objective, GPULogisticRegressionGPair) { - xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic"); - std::vector > args; - obj->Configure(args); - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, - {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); - - delete obj; -} - -TEST(Objective, GPULogisticRegressionBasic) { - xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic"); - std::vector > args; - obj->Configure(args); - - // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0})) - << "Expected error when label not in range [0,1f] for LogisticRegression"; - - // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f); - EXPECT_ANY_THROW(obj->ProbToMargin(10)) - << "Expected error when base_score not in range [0,1f] for LogisticRegression"; - - // test PredTransform - xgboost::HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; - std::vector out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f}; - obj->PredTransform(&io_preds); - auto& preds = io_preds.HostVector(); - for (int i = 0; i < static_cast(io_preds.Size()); ++i) { - EXPECT_NEAR(preds[i], out_preds[i], 0.01f); - } - - delete obj; -} - -TEST(Objective, GPULogisticRawGPair) { - xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw"); - std::vector > args; - obj->Configure(args); - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, - {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); - - delete obj; -} +#include "test_regression_obj.cc"