Skip to content

Commit

Permalink
Fix compiler warnings. (#7974)
Browse files Browse the repository at this point in the history
- Remove unused parameters. There are still many warnings that are not yet
addressed. Currently, the warnings in dmlc-core dominate the error log.
- Remove `distributed` parameter from metric.
- Fixes some warnings about signed comparison.
  • Loading branch information
trivialfis authored Jun 6, 2022
1 parent d48123d commit 1a33b50
Show file tree
Hide file tree
Showing 46 changed files with 149 additions and 189 deletions.
9 changes: 9 additions & 0 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ function(xgboost_set_cuda_flags target)
set_property(TARGET ${target} PROPERTY CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")

if (FORCE_COLORED_OUTPUT)
if (FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND
((CMAKE_CXX_COMPILER_ID STREQUAL "GNU") OR
(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")))
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fdiagnostics-color=always>)
endif()
endif (FORCE_COLORED_OUTPUT)

if (USE_DEVICE_DEBUG)
target_compile_options(${target} PRIVATE
$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-G;-src-in-ptx>)
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class GradientBooster : public Model, public Configurable {
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
virtual void Slice(int32_t /*layer_begin*/, int32_t /*layer_end*/, int32_t /*step*/,
GradientBooster* /*out*/, bool* /*out_of_bound*/) const {
LOG(FATAL) << "Slice is not supported by current booster.";
}
/*!
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/json_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class JsonReader {
} else if (got == 0) {
msg += "\\0\"";
} else {
msg += (got <= 127 ? std::string{got} : std::to_string(got)) + " \""; // NOLINT
msg += (got <= static_cast<char>(127) ? std::string{got} : std::to_string(got)) + " \"";
}
Error(msg);
}
Expand Down
3 changes: 2 additions & 1 deletion include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ class TensorView {
}

template <size_t old_dim, size_t new_dim, int32_t D, typename Index>
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], Index i) const {
LINALG_HD size_t MakeSliceDim(DMLC_ATTRIBUTE_UNUSED size_t new_shape[D],
DMLC_ATTRIBUTE_UNUSED size_t new_stride[D], Index i) const {
static_assert(old_dim < kDim, "");
return stride_[old_dim] * i;
}
Expand Down
6 changes: 1 addition & 5 deletions include/xgboost/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,8 @@ class Metric : public Configurable {
* \brief evaluate a specific metric
* \param preds prediction
* \param info information, including label etc.
* \param distributed whether a call to Allreduce is needed to gather
* the average statistics across all the node,
* this is only supported by some metrics
*/
virtual double Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, bool distributed) = 0;
virtual double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) = 0;
/*! \return name of metric */
virtual const char* Name() const = 0;
/*! \brief virtual destructor */
Expand Down
6 changes: 4 additions & 2 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ class ObjFunction : public Configurable {
* \param prediction Model prediction after transformation.
* \param p_tree Tree that needs to be updated.
*/
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const {}
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
MetaInfo const& /*info*/,
HostDeviceVector<float> const& /*prediction*/,
RegTree* /*p_tree*/) const {}

/*!
* \brief Create an objective function according to name.
Expand Down
6 changes: 3 additions & 3 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_thr

if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
hessian, n_threads);
n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, n_threads};
SortedSketchContainer container{max_bins, m->Info(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
Expand Down
15 changes: 7 additions & 8 deletions src/common/partition_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ class PartitionBuilder {
const size_t n_left = child_nodes_sizes.first;
const size_t n_right = child_nodes_sizes.second;

SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
SetNLeftElems(node_in_set, range.begin(), n_left);
SetNRightElems(node_in_set, range.begin(), n_right);
}

/**
Expand All @@ -188,8 +188,7 @@ class PartitionBuilder {
*/
template <typename Pred>
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
bst_feature_t fidx, common::RowSetCollection* p_row_set_collection,
Pred pred) {
common::RowSetCollection* p_row_set_collection, Pred pred) {
auto& row_set_collection = *p_row_set_collection;
const size_t* p_ridx = row_set_collection[nid].begin;
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
Expand All @@ -200,8 +199,8 @@ class PartitionBuilder {
const size_t n_left = child_nodes_sizes.first;
const size_t n_right = child_nodes_sizes.second;

this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
this->SetNLeftElems(node_in_set, range.begin(), n_left);
this->SetNRightElems(node_in_set, range.begin(), n_right);
}

// allocate thread local memory, should be called for each specific task
Expand All @@ -223,12 +222,12 @@ class PartitionBuilder {
return { mem_blocks_.at(task_idx)->Right(), end - begin };
}

void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) {
void SetNLeftElems(int nid, size_t begin, size_t n_left) {
size_t task_idx = GetTaskIdx(nid, begin);
mem_blocks_.at(task_idx)->n_left = n_left;
}

void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) {
void SetNRightElems(int nid, size_t begin, size_t n_right) {
size_t task_idx = GetTaskIdx(nid, begin);
mem_blocks_.at(task_idx)->n_right = n_right;
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ template class SketchContainerImpl<WXQuantileSketch<float, float>>;

HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info,
std::vector<size_t> columns_size, bool use_group,
Span<float const> hessian, int32_t n_threads)
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
monitor_.Init(__func__);
Expand Down
4 changes: 2 additions & 2 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl

public:
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
bool use_group, Span<float const> hessian, int32_t n_threads);
bool use_group, int32_t n_threads);
};

/**
Expand Down Expand Up @@ -868,7 +868,7 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
public:
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
std::vector<size_t> columns_size, bool use_group,
Span<float const> hessian, int32_t n_threads)
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
monitor_.Init(__func__);
Expand Down
3 changes: 1 addition & 2 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin

BatchSet<EllpackPage> IterativeDeviceDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK(page_);
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(page_));
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(page_));
return BatchSet<EllpackPage>(begin_iter);
}
} // namespace data
Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class IterativeDeviceDMatrix : public DMatrix {

bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
Expand Down
4 changes: 2 additions & 2 deletions src/data/proxy_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DMatrixProxy : public DMatrix {
bool SingleColBlock() const override { return true; }
bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
DMatrix* Slice(common::Span<int32_t const> /*ridxs*/) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
return nullptr;
}
Expand All @@ -100,7 +100,7 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Not implemented.";
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
}
Expand Down
42 changes: 18 additions & 24 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_thre
}

void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
ObjFunction const* obj, size_t gidx,
ObjFunction const* obj,
std::vector<std::unique_ptr<RegTree>>* p_trees) {
CHECK(!updaters_.empty());
if (!updaters_.back()->HasNodePosition()) {
Expand Down Expand Up @@ -257,7 +257,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
auto v_predt = out.Slice(linalg::All(), 0);
Expand All @@ -274,7 +274,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp);
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(&tmp, p_fmat, gid, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
auto v_predt = out.Slice(linalg::All(), gid);
Expand All @@ -289,7 +289,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
}

monitor_.Stop("BoostNewTrees");
this->CommitModel(std::move(new_trees), p_fmat, predt);
this->CommitModel(std::move(new_trees));
}

void GBTree::InitUpdater(Args const& cfg) {
Expand Down Expand Up @@ -378,9 +378,7 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
}
}

void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix* m,
PredictionCacheEntry* predts) {
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
monitor_.Start("CommitModel");
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
model_.CommitModel(std::move(new_trees[gid]), gid);
Expand Down Expand Up @@ -490,15 +488,14 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
"want to update a portion of trees.";
}

*out_of_bound = detail::SliceTrees(
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
[&](auto const &in_it, auto const &out_it) {
auto new_tree =
std::make_unique<RegTree>(*this->model_.trees.at(in_it));
bst_group_t group = this->model_.tree_info[in_it];
out_trees.at(out_it) = std::move(new_tree);
out_trees_info.at(out_it) = group;
});
*out_of_bound = detail::SliceTrees(layer_begin, layer_end, step, this->model_, layer_trees,
[&](auto const& in_it, auto const& out_it) {
auto new_tree =
std::make_unique<RegTree>(*this->model_.trees.at(in_it));
bst_group_t group = this->model_.tree_info[in_it];
out_trees.at(out_it) = std::move(new_tree);
out_trees_info.at(out_it) = group;
});
}

void GBTree::PredictBatch(DMatrix* p_fmat,
Expand Down Expand Up @@ -674,11 +671,10 @@ class Dart : public GBTree {
auto p_dart = dynamic_cast<Dart*>(out);
CHECK(p_dart);
CHECK(p_dart->weight_drop_.empty());
detail::SliceTrees(
layer_begin, layer_end, step, model_, tparam_, this->LayerTrees(),
[&](auto const& in_it, auto const&) {
p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it));
});
detail::SliceTrees(layer_begin, layer_end, step, model_, this->LayerTrees(),
[&](auto const& in_it, auto const&) {
p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it));
});
}

void SaveModel(Json *p_out) const override {
Expand Down Expand Up @@ -901,9 +897,7 @@ class Dart : public GBTree {

protected:
// commit new trees all at once
void
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix*, PredictionCacheEntry*) override {
void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) override {
int num_new_trees = 0;
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
num_new_trees += new_trees[gid].size();
Expand Down
14 changes: 5 additions & 9 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,

// Call fn for each pair of input output tree. Return true if index is out of bound.
template <typename Func>
inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step,
GBTreeModel const &model, GBTreeTrainParam const &tparam,
uint32_t layer_trees, Func fn) {
bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step, GBTreeModel const& model,
uint32_t layer_trees, Func fn) {
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model, layer_begin, layer_end);
if (tree_end > model.trees.size()) {
Expand Down Expand Up @@ -206,8 +205,7 @@ class GBTree : public GradientBooster {
* \brief Optionally update the leaf value.
*/
void UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
ObjFunction const* obj, size_t gidx,
std::vector<std::unique_ptr<RegTree>>* p_trees);
ObjFunction const* obj, std::vector<std::unique_ptr<RegTree>>* p_trees);

/*! \brief Carry out one iteration of boosting */
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
Expand Down Expand Up @@ -325,7 +323,7 @@ class GBTree : public GradientBooster {
};

if (importance_type == "weight") {
add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) {
add_score([&](auto const&, bst_node_t, bst_feature_t split) {
gain_map[split] = split_counts[split];
});
} else if (importance_type == "gain" || importance_type == "total_gain") {
Expand Down Expand Up @@ -423,9 +421,7 @@ class GBTree : public GradientBooster {
DMatrix* f_dmat = nullptr) const;

// commit new trees all at once
virtual void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix* m,
PredictionCacheEntry* predts);
virtual void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees);

// --- data structure ---
GBTreeModel model_;
Expand Down
3 changes: 1 addition & 2 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,7 @@ class LearnerImpl : public LearnerIO {

obj_->EvalTransform(&out);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
<< ev->Eval(out, m->Info(), tparam_.dsplit == DataSplitMode::kRow);
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Eval(out, m->Info());
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/metric/auc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,

template <typename Curve>
class EvalAUC : public Metric {
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
double auc {0};
if (tparam_->gpu_id != GenericParameter::kCpuId) {
preds.SetDevice(tparam_->gpu_id);
Expand Down
12 changes: 4 additions & 8 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,8 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
* up each class in all kernels.
*/
template <bool scale, typename Fn>
double GPUMultiClassAUCOVR(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<uint32_t> d_class_ptr,
size_t n_classes, std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device));
/**
* Sorted idx
Expand Down Expand Up @@ -478,8 +476,7 @@ double GPUMultiClassROCAUC(common::Span<float const> predts,
double tp, size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(predts, info, device, dh::ToSpan(class_ptr),
n_classes, cache, fn);
return GPUMultiClassAUCOVR<true>(info, device, dh::ToSpan(class_ptr), n_classes, cache, fn);
}

namespace {
Expand Down Expand Up @@ -704,8 +701,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(predts, info, device, d_class_ptr,
n_classes, cache, fn);
return GPUMultiClassAUCOVR<false>(info, device, d_class_ptr, n_classes, cache, fn);
}

template <typename Fn>
Expand Down
Loading

0 comments on commit 1a33b50

Please sign in to comment.