Skip to content

Commit

Permalink
Don't use zip iterator / tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Sep 9, 2021
1 parent 11c8bcd commit 115b649
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 130 deletions.
200 changes: 104 additions & 96 deletions src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,22 @@ void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
thrust::make_counting_iterator<std::size_t>(0),
[d_out_scan] __device__(std::size_t i) {
ScanComputedElem<GradientSumT> c = d_out_scan[i];
GradientSumT left_sum, right_sum;
if (c.is_cat) {
left_sum = c.parent_sum - c.best_partial_sum;
right_sum = c.best_partial_sum;
} else {
if (c.best_direction == DefaultDirection::kRightDir) {
left_sum = c.best_partial_sum;
right_sum = c.parent_sum - c.best_partial_sum;
} else {
left_sum = c.parent_sum - c.best_partial_sum;
right_sum = c.best_partial_sum;
}
}
return DeviceSplitCandidate{c.best_loss_chg, c.best_direction, c.best_findex,
c.best_fvalue, c.is_cat, GradientPair{c.best_left_sum},
GradientPair{c.best_right_sum}};
c.best_fvalue, c.is_cat, GradientPair{left_sum},
GradientPair{right_sum}};
});
GPUTrainingParam param = left.param;
thrust::reduce_by_key(
Expand Down Expand Up @@ -77,54 +90,67 @@ EvaluateSplitsFindOptimalSplitsViaScan(
throw std::runtime_error("Invariant violated");
}

uint64_t left_hist_size = static_cast<uint64_t>(left.gradient_histogram.size());
auto map_to_left_right = [left_hist_size] __device__(uint64_t idx) {
uint32_t left_hist_size = static_cast<uint32_t>(left.gradient_histogram.size());
auto map_to_hist_bin = [left_hist_size] __device__(uint32_t idx) {
if (idx < left_hist_size) {
// Left child node
return EvaluateSplitsHistEntry{ChildNodeIndicator::kLeftChild, idx};
return EvaluateSplitsHistEntry{0, idx};
} else {
// Right child node
return EvaluateSplitsHistEntry{ChildNodeIndicator::kRightChild, idx - left_hist_size};
return EvaluateSplitsHistEntry{1, idx - left_hist_size};
}
};

std::size_t size = left.gradient_histogram.size() + right.gradient_histogram.size();
auto for_count_iter = thrust::make_counting_iterator<uint64_t>(0);
auto for_loc_iter = dh::MakeTransformIterator<EvaluateSplitsHistEntry>(
for_count_iter, map_to_left_right);
auto rev_count_iter = thrust::make_reverse_iterator(
thrust::make_counting_iterator<uint64_t>(0) + static_cast<std::ptrdiff_t>(size));
auto rev_loc_iter = dh::MakeTransformIterator<EvaluateSplitsHistEntry>(
rev_count_iter, map_to_left_right);
auto zip_loc_iter = thrust::make_zip_iterator(thrust::make_tuple(for_loc_iter, rev_loc_iter));

auto scan_input_iter =
dh::MakeTransformIterator<thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>>>(
zip_loc_iter, ScanValueOp<GradientSumT>{left, right, evaluator});
auto forward_count_iter = thrust::make_counting_iterator<uint32_t>(0);
auto forward_bin_iter = dh::MakeTransformIterator<EvaluateSplitsHistEntry>(
forward_count_iter, map_to_hist_bin);
auto forward_scan_input_iter = dh::MakeTransformIterator<ScanElem<GradientSumT>>(
forward_bin_iter, ScanValueOp<GradientSumT>{true, left, right, evaluator});

dh::device_vector<ScanComputedElem<GradientSumT>> out_scan(l_n_features + r_n_features);
auto scan_out_iter = thrust::make_transform_output_iterator(
auto forward_scan_out_iter = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
WriteScan<GradientSumT>{left, right, evaluator, dh::ToSpan(out_scan)});
WriteScan<GradientSumT>{true, left, right, evaluator, dh::ToSpan(out_scan)});
{
auto scan_op = ScanOp<GradientSumT>{true, left, right, evaluator};
std::size_t n_temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, n_temp_bytes, forward_scan_input_iter,
forward_scan_out_iter, scan_op, size);
dh::TemporaryArray<int8_t> temp(n_temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), n_temp_bytes, forward_scan_input_iter,
forward_scan_out_iter, scan_op, size);
}

auto backward_count_iter = thrust::make_reverse_iterator(
thrust::make_counting_iterator<uint32_t>(0) + static_cast<std::ptrdiff_t>(size));
auto backward_bin_iter = dh::MakeTransformIterator<EvaluateSplitsHistEntry>(
backward_count_iter, map_to_hist_bin);
auto backward_scan_input_iter = dh::MakeTransformIterator<ScanElem<GradientSumT>>(
backward_bin_iter, ScanValueOp<GradientSumT>{false, left, right, evaluator});
auto backward_scan_out_iter = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
WriteScan<GradientSumT>{false, left, right, evaluator, dh::ToSpan(out_scan)});
{
auto scan_op = ScanOp<GradientSumT>{false, left, right, evaluator};
std::size_t n_temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, n_temp_bytes, backward_scan_input_iter,
backward_scan_out_iter, scan_op, size);
dh::TemporaryArray<int8_t> temp(n_temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), n_temp_bytes, backward_scan_input_iter,
backward_scan_out_iter, scan_op, size);
}

auto scan_op = ScanOp<GradientSumT>{left, right, evaluator};
std::size_t n_temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, n_temp_bytes, scan_input_iter, scan_out_iter,
scan_op, size);
dh::TemporaryArray<int8_t> temp(n_temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), n_temp_bytes, scan_input_iter, scan_out_iter,
scan_op, size);
return out_scan;
}

template <typename GradientSumT>
template <bool forward>
__noinline__ __device__ ScanElem<GradientSumT>
ScanValueOp<GradientSumT>::MapEvaluateSplitsHistEntryToScanElem(
EvaluateSplitsHistEntry entry,
EvaluateSplitInputs<GradientSumT> split_input) {
ScanElem<GradientSumT> ret;
ret.indicator = entry.indicator;
ret.node_idx = entry.node_idx;
ret.hist_idx = entry.hist_idx;
ret.gpair = split_input.gradient_histogram[entry.hist_idx];
ret.findex = static_cast<int32_t>(dh::SegmentId(split_input.feature_segments, entry.hist_idx));
Expand All @@ -136,68 +162,61 @@ ScanValueOp<GradientSumT>::MapEvaluateSplitsHistEntryToScanElem(
* For the element at the beginning of each segment, compute gradient sums and loss_chg
* ahead of time. These will be later used by the inclusive scan operator.
**/
GradientSumT partial_sum = ret.gpair;
GradientSumT complement_sum = split_input.parent_sum - partial_sum;
GradientSumT *left_sum, *right_sum;
if (ret.is_cat) {
ret.computed_result.left_sum = split_input.parent_sum - ret.gpair;
ret.computed_result.right_sum = ret.gpair;
left_sum = &complement_sum;
right_sum = &partial_sum;
} else {
if (forward) {
ret.computed_result.left_sum = ret.gpair;
ret.computed_result.right_sum = split_input.parent_sum - ret.gpair;
left_sum = &partial_sum;
right_sum = &complement_sum;
} else {
ret.computed_result.left_sum = split_input.parent_sum - ret.gpair;
ret.computed_result.right_sum = ret.gpair;
left_sum = &complement_sum;
right_sum = &partial_sum;
}
}
ret.computed_result.best_left_sum = ret.computed_result.left_sum;
ret.computed_result.best_right_sum = ret.computed_result.right_sum;
ret.computed_result.parent_sum = partial_sum;
ret.computed_result.best_partial_sum = partial_sum;
ret.computed_result.parent_sum = split_input.parent_sum;
float parent_gain = evaluator.CalcGain(split_input.nidx, split_input.param,
GradStats{ret.computed_result.parent_sum});
GradStats{split_input.parent_sum});
float gain = evaluator.CalcSplitGain(split_input.param, split_input.nidx, ret.findex,
GradStats{ret.computed_result.left_sum},
GradStats{ret.computed_result.right_sum});
GradStats{*left_sum}, GradStats{*right_sum});
ret.computed_result.best_loss_chg = gain - parent_gain;
ret.computed_result.best_findex = ret.findex;
ret.computed_result.best_fvalue = ret.fvalue;
ret.computed_result.best_direction =
(forward ? DefaultDirection::kRightDir : DefaultDirection::kLeftDir);
ret.computed_result.is_cat = ret.is_cat;
}

return ret;
}

template <typename GradientSumT>
__noinline__ __device__ thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>>
ScanValueOp<GradientSumT>::operator() (
thrust::tuple<EvaluateSplitsHistEntry, EvaluateSplitsHistEntry> entry_tup) {
const auto& fw = thrust::get<0>(entry_tup);
const auto& bw = thrust::get<1>(entry_tup);
ScanElem<GradientSumT> ret_fw, ret_bw;
ret_fw = MapEvaluateSplitsHistEntryToScanElem<true>(
fw,
(fw.indicator == ChildNodeIndicator::kLeftChild ? this->left : this->right));
ret_bw = MapEvaluateSplitsHistEntryToScanElem<false>(
bw,
(bw.indicator == ChildNodeIndicator::kLeftChild ? this->left : this->right));
return thrust::make_tuple(ret_fw, ret_bw);
__noinline__ __device__ ScanElem<GradientSumT>
ScanValueOp<GradientSumT>::operator() (EvaluateSplitsHistEntry entry) {
return MapEvaluateSplitsHistEntryToScanElem(
entry, (entry.node_idx == 0 ? this->left : this->right));
}

template <typename GradientSumT>
template <bool forward>
__noinline__ __device__ ScanElem<GradientSumT>
ScanOp<GradientSumT>::DoIt(ScanElem<GradientSumT> lhs, ScanElem<GradientSumT> rhs) {
ScanElem<GradientSumT> ret;
ret = rhs;
ret.computed_result = {};
if (lhs.findex != rhs.findex || lhs.indicator != rhs.indicator) {
if (lhs.findex != rhs.findex || lhs.node_idx != rhs.node_idx) {
// Segmented Scan
return rhs;
}
if (((lhs.indicator == ChildNodeIndicator::kLeftChild) &&
if (((lhs.node_idx == 0) &&
(left.feature_set.size() != left.feature_segments.size()) &&
!thrust::binary_search(thrust::seq, left.feature_set.begin(), left.feature_set.end(),
lhs.findex)) ||
((lhs.indicator == ChildNodeIndicator::kRightChild) &&
((lhs.node_idx == 1) &&
(right.feature_set.size() != right.feature_segments.size()) &&
!thrust::binary_search(thrust::seq, right.feature_set.begin(), right.feature_set.end(),
lhs.findex))) {
Expand All @@ -206,53 +225,49 @@ ScanOp<GradientSumT>::DoIt(ScanElem<GradientSumT> lhs, ScanElem<GradientSumT> rh
}

GradientSumT parent_sum = lhs.computed_result.parent_sum;
GradientSumT left_sum, right_sum;
GradientSumT partial_sum, complement_sum;
GradientSumT *left_sum, *right_sum;
if (lhs.is_cat) {
left_sum = lhs.computed_result.parent_sum - rhs.gpair;
right_sum = rhs.gpair;
partial_sum = rhs.gpair;
complement_sum = lhs.computed_result.parent_sum - rhs.gpair;
left_sum = &complement_sum;
right_sum = &partial_sum;
} else {
partial_sum = lhs.computed_result.partial_sum + rhs.gpair;
complement_sum = parent_sum - partial_sum;
if (forward) {
left_sum = lhs.computed_result.left_sum + rhs.gpair;
right_sum = lhs.computed_result.parent_sum - left_sum;
left_sum = &partial_sum;
right_sum = &complement_sum;
} else {
right_sum = lhs.computed_result.right_sum + rhs.gpair;
left_sum = lhs.computed_result.parent_sum - right_sum;
left_sum = &complement_sum;
right_sum = &partial_sum;
}
}
bst_node_t nidx = (lhs.indicator == ChildNodeIndicator::kLeftChild) ? left.nidx : right.nidx;
bst_node_t nidx = (lhs.node_idx == 0) ? left.nidx : right.nidx;
float gain = evaluator.CalcSplitGain(
left.param, nidx, lhs.findex, GradStats{left_sum}, GradStats{right_sum});
left.param, nidx, lhs.findex, GradStats{*left_sum}, GradStats{*right_sum});
float parent_gain = evaluator.CalcGain(left.nidx, left.param, GradStats{parent_sum});
float loss_chg = gain - parent_gain;
ret.computed_result = lhs.computed_result;
ret.computed_result.Update(left_sum, right_sum, parent_sum,
loss_chg, rhs.findex, rhs.is_cat, rhs.fvalue,
ret.computed_result.Update(partial_sum, parent_sum, loss_chg, rhs.findex, rhs.is_cat, rhs.fvalue,
(forward ? DefaultDirection::kRightDir : DefaultDirection::kLeftDir),
left.param);
return ret;
}

template <typename GradientSumT>
__noinline__ __device__ thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>>
ScanOp<GradientSumT>::operator() (
thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>> lhs,
thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>> rhs) {
const auto& lhs_fw = thrust::get<0>(lhs);
const auto& lhs_bw = thrust::get<1>(lhs);
const auto& rhs_fw = thrust::get<0>(rhs);
const auto& rhs_bw = thrust::get<1>(rhs);
return thrust::make_tuple(DoIt<true>(lhs_fw, rhs_fw), DoIt<false>(lhs_bw, rhs_bw));
__noinline__ __device__ ScanElem<GradientSumT>
ScanOp<GradientSumT>::operator() (ScanElem<GradientSumT> lhs, ScanElem<GradientSumT> rhs) {
return DoIt(lhs, rhs);
};

template <typename GradientSumT>
template <bool forward>
void
__noinline__ __device__ WriteScan<GradientSumT>::DoIt(ScanElem<GradientSumT> e) {
EvaluateSplitInputs<GradientSumT>& split_input =
(e.indicator == ChildNodeIndicator::kLeftChild) ? left : right;
EvaluateSplitInputs<GradientSumT>& split_input = (e.node_idx == 0) ? left : right;
std::size_t offset = 0;
std::size_t n_features = left.feature_segments.empty() ? 0 : left.feature_segments.size() - 1;
if (e.indicator == ChildNodeIndicator::kRightChild) {
if (e.node_idx == 1) {
offset = n_features;
}
if ((!forward && split_input.feature_segments[e.findex] == e.hist_idx) ||
Expand All @@ -264,41 +279,34 @@ __noinline__ __device__ WriteScan<GradientSumT>::DoIt(ScanElem<GradientSumT> e)
}

template <typename GradientSumT>
thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>>
__noinline__ __device__ WriteScan<GradientSumT>::operator() (
thrust::tuple<ScanElem<GradientSumT>, ScanElem<GradientSumT>> e) {
const auto& fw = thrust::get<0>(e);
const auto& bw = thrust::get<1>(e);
DoIt<true>(fw);
DoIt<false>(bw);
ScanElem<GradientSumT>
__noinline__ __device__ WriteScan<GradientSumT>::operator() (ScanElem<GradientSumT> e) {
DoIt(e);
return {}; // discard
}

template <typename GradientSumT>
__noinline__ __device__ bool
ScanComputedElem<GradientSumT>::Update(
GradientSumT left_sum_in,
GradientSumT right_sum_in,
GradientSumT partial_sum_in,
GradientSumT parent_sum_in,
float loss_chg_in,
int32_t findex_in,
bool is_cat_in,
float fvalue_in,
DefaultDirection dir_in,
const GPUTrainingParam& param) {
left_sum = left_sum_in;
right_sum = right_sum_in;
partial_sum = partial_sum_in;
parent_sum = parent_sum_in;
if (loss_chg_in > best_loss_chg &&
left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
partial_sum_in.GetHess() >= param.min_child_weight &&
(parent_sum_in.GetHess() - partial_sum_in.GetHess()) >= param.min_child_weight) {
best_loss_chg = loss_chg_in;
best_findex = findex_in;
is_cat = is_cat_in;
best_fvalue = fvalue_in;
best_direction = dir_in;
best_left_sum = left_sum_in;
best_right_sum = right_sum_in;
best_partial_sum = partial_sum_in;
return true;
}
return false;
Expand Down
Loading

0 comments on commit 115b649

Please sign in to comment.