[POC] Experimental support for l1 error. (#7812)
Support adaptive tree, a feature supported by both sklearn and lightgbm. The tree leaf is recomputed based on residue of labels and predictions after construction. For l1 error, the optimal value is the median (50 percentile). This is marked as experimental support for the following reasons: - The value is not well defined for distributed training, where we might have empty leaves for local workers. Right now I just use the original leaf value for computing the average with other workers, which might cause significant errors. - Some follow-ups are required, for exact, pruner, and optimization for quantile function. Also, we need to calculate the initial estimation.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file common.h
|
||||
* \brief Common utilities
|
||||
*/
|
||||
@@ -14,12 +14,12 @@
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/system/cuda/error.h>
|
||||
@@ -164,6 +164,67 @@ class Range {
|
||||
Iterator end_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Transform iterator that takes an index and calls transform operator.
|
||||
*
|
||||
* This is CPU-only right now as taking host device function as operator complicates the
|
||||
* code. For device side one can use `thrust::transform_iterator` instead.
|
||||
*/
|
||||
template <typename Fn>
|
||||
class IndexTransformIter {
|
||||
size_t iter_{0};
|
||||
Fn fn_;
|
||||
|
||||
public:
|
||||
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
|
||||
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
||||
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
||||
|
||||
public:
|
||||
/**
|
||||
* \param op Transform operator, takes a size_t index as input.
|
||||
*/
|
||||
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
||||
IndexTransformIter(IndexTransformIter const &) = default;
|
||||
|
||||
value_type operator*() const { return fn_(iter_); }
|
||||
|
||||
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
||||
|
||||
IndexTransformIter &operator++() {
|
||||
iter_++;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator++(int) {
|
||||
auto ret = *this;
|
||||
++(*this);
|
||||
return ret;
|
||||
}
|
||||
IndexTransformIter &operator+=(difference_type n) {
|
||||
iter_ += n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter &operator-=(difference_type n) {
|
||||
(*this) += -n;
|
||||
return *this;
|
||||
}
|
||||
IndexTransformIter operator+(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret += n;
|
||||
}
|
||||
IndexTransformIter operator-(difference_type n) const {
|
||||
auto ret = *this;
|
||||
return ret -= n;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
auto MakeIndexTransformIter(Fn&& fn) {
|
||||
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
||||
}
|
||||
|
||||
int AllVisibleGPUs();
|
||||
|
||||
inline void AssertGPUSupport() {
|
||||
@@ -191,13 +252,39 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
|
||||
|
||||
struct OptionalWeights {
|
||||
Span<float const> weights;
|
||||
float dft{1.0f};
|
||||
float dft{1.0f}; // fixme: make this compile time constant
|
||||
|
||||
explicit OptionalWeights(Span<float const> w) : weights{w} {}
|
||||
explicit OptionalWeights(float w) : dft{w} {}
|
||||
|
||||
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
|
||||
};
|
||||
|
||||
/**
|
||||
* Last index of a group in a CSR style of index pointer.
|
||||
*/
|
||||
template <typename Indexable>
|
||||
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
||||
return indptr[group + 1] - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Run length encode on CPU, input must be sorted.
|
||||
*/
|
||||
template <typename Iter, typename Idx>
|
||||
void RunLengthEncode(Iter begin, Iter end, std::vector<Idx> *p_out) {
|
||||
auto &out = *p_out;
|
||||
out = std::vector<Idx>{0};
|
||||
size_t n = std::distance(begin, end);
|
||||
for (size_t i = 1; i < n; ++i) {
|
||||
if (begin[i] != begin[i - 1]) {
|
||||
out.push_back(i);
|
||||
}
|
||||
}
|
||||
if (out.back() != n) {
|
||||
out.push_back(n);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 XGBoost contributors
|
||||
* Copyright 2017-2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <thrust/device_ptr.h>
|
||||
@@ -1537,6 +1537,43 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Different from the above one, this one can handle cases where segment doesn't
|
||||
* start from 0, but as a result it uses comparison sort.
|
||||
*/
|
||||
template <typename SegIt, typename ValIt>
|
||||
void SegmentedArgSort(SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end,
|
||||
dh::device_vector<size_t> *p_sorted_idx) {
|
||||
using Tup = thrust::tuple<int32_t, float>;
|
||||
auto &sorted_idx = *p_sorted_idx;
|
||||
size_t n = std::distance(val_begin, val_end);
|
||||
sorted_idx.resize(n);
|
||||
dh::Iota(dh::ToSpan(sorted_idx));
|
||||
dh::device_vector<Tup> keys(sorted_idx.size());
|
||||
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) -> Tup {
|
||||
int32_t leaf_idx;
|
||||
if (i < *seg_begin) {
|
||||
leaf_idx = -1;
|
||||
} else {
|
||||
leaf_idx = dh::SegmentId(seg_begin, seg_end, i);
|
||||
}
|
||||
auto residue = val_begin[i];
|
||||
return thrust::make_tuple(leaf_idx, residue);
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> caching;
|
||||
thrust::copy(thrust::cuda::par(caching), key_it, key_it + keys.size(), keys.begin());
|
||||
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc), keys.begin(), keys.end(), sorted_idx.begin(),
|
||||
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
|
||||
if (thrust::get<0>(l) != thrust::get<0>(r)) {
|
||||
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
|
||||
}
|
||||
return thrust::get<1>(l) < thrust::get<1>(r); // residue
|
||||
});
|
||||
}
|
||||
|
||||
class CUDAStreamView;
|
||||
|
||||
class CUDAEvent {
|
||||
@@ -1600,5 +1637,6 @@ class CUDAStream {
|
||||
}
|
||||
|
||||
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||
void Sync() { this->View().Sync(); }
|
||||
};
|
||||
} // namespace dh
|
||||
|
||||
@@ -13,6 +13,7 @@ namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||
"For function with return, use transform instead.");
|
||||
if (t.Contiguous()) {
|
||||
@@ -40,7 +41,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
|
||||
}
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
||||
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
||||
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
|
||||
}
|
||||
} // namespace linalg
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "categorical.h"
|
||||
#include "column_matrix.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -254,7 +256,7 @@ class PartitionBuilder {
|
||||
n_left += mem_blocks_[j]->n_left;
|
||||
}
|
||||
size_t n_right = 0;
|
||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
|
||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) {
|
||||
mem_blocks_[j]->n_offset_right = n_left + n_right;
|
||||
n_right += mem_blocks_[j]->n_right;
|
||||
}
|
||||
@@ -279,6 +281,30 @@ class PartitionBuilder {
|
||||
return blocks_offsets_[nid] + begin / BlockSize;
|
||||
}
|
||||
|
||||
// Copy row partitions into global cache for reuse in objective
|
||||
template <typename Sampledp>
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
|
||||
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
|
||||
auto& h_pos = *p_position;
|
||||
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
|
||||
|
||||
auto p_begin = row_set.Data()->data();
|
||||
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
|
||||
auto const& node = row_set[i];
|
||||
if (node.node_id < 0) {
|
||||
return;
|
||||
}
|
||||
CHECK(tree[node.node_id].IsLeaf());
|
||||
if (node.begin) { // guard for empty node.
|
||||
size_t ptr_offset = node.end - p_begin;
|
||||
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
|
||||
for (auto idx = node.begin; idx != node.end; ++idx) {
|
||||
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
struct BlockInfo{
|
||||
size_t n_left;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017 by Contributors
|
||||
* Copyright 2017-2022 by Contributors
|
||||
* \file row_set.h
|
||||
* \brief Quick Utility to compute subset of rows
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@@ -15,10 +15,15 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
/*! \brief collection of rowset */
|
||||
class RowSetCollection {
|
||||
public:
|
||||
RowSetCollection() = default;
|
||||
RowSetCollection(RowSetCollection const&) = delete;
|
||||
RowSetCollection(RowSetCollection&&) = default;
|
||||
RowSetCollection& operator=(RowSetCollection const&) = delete;
|
||||
RowSetCollection& operator=(RowSetCollection&&) = default;
|
||||
|
||||
/*! \brief data structure to store an instance set, a subset of
|
||||
* rows (instances) associated with a particular node in a decision
|
||||
* tree. */
|
||||
@@ -38,20 +43,17 @@ class RowSetCollection {
|
||||
return end - begin;
|
||||
}
|
||||
};
|
||||
/* \brief specifies how to split a rowset into two */
|
||||
struct Split {
|
||||
std::vector<size_t> left;
|
||||
std::vector<size_t> right;
|
||||
};
|
||||
|
||||
inline std::vector<Elem>::const_iterator begin() const { // NOLINT
|
||||
std::vector<Elem>::const_iterator begin() const { // NOLINT
|
||||
return elem_of_each_node_.begin();
|
||||
}
|
||||
|
||||
inline std::vector<Elem>::const_iterator end() const { // NOLINT
|
||||
std::vector<Elem>::const_iterator end() const { // NOLINT
|
||||
return elem_of_each_node_.end();
|
||||
}
|
||||
|
||||
size_t Size() const { return std::distance(begin(), end()); }
|
||||
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline const Elem& operator[](unsigned node_id) const {
|
||||
const Elem& e = elem_of_each_node_[node_id];
|
||||
@@ -86,6 +88,8 @@ class RowSetCollection {
|
||||
}
|
||||
|
||||
std::vector<size_t>* Data() { return &row_indices_; }
|
||||
std::vector<size_t> const* Data() const { return &row_indices_; }
|
||||
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
|
||||
size_t n_left, size_t n_right) {
|
||||
@@ -123,7 +127,6 @@ class RowSetCollection {
|
||||
// vector: node_id -> elements
|
||||
std::vector<Elem> elem_of_each_node_;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
127
src/common/stats.cuh
Normal file
127
src/common/stats.cuh
Normal file
@@ -0,0 +1,127 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_STATS_CUH_
|
||||
#define XGBOOST_COMMON_STATS_CUH_
|
||||
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/permutation_iterator.h>
|
||||
|
||||
#include <iterator> // std::distance
|
||||
|
||||
#include "device_helpers.cuh"
|
||||
#include "linalg_op.cuh"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/**
|
||||
* \brief Compute segmented quantile on GPU.
|
||||
*
|
||||
* \tparam SegIt Iterator for CSR style segments indptr
|
||||
* \tparam ValIt Iterator for values
|
||||
*
|
||||
* \param alpha The p^th quantile we want to compute
|
||||
*
|
||||
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
|
||||
*/
|
||||
template <typename SegIt, typename ValIt>
|
||||
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
|
||||
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
|
||||
CHECK(alpha >= 0 && alpha <= 1);
|
||||
|
||||
dh::device_vector<size_t> sorted_idx;
|
||||
using Tup = thrust::tuple<size_t, float>;
|
||||
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
|
||||
auto n_segments = std::distance(seg_begin, seg_end) - 1;
|
||||
if (n_segments <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
quantiles->SetDevice(ctx->gpu_id);
|
||||
quantiles->Resize(n_segments);
|
||||
auto d_results = quantiles->DeviceSpan();
|
||||
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
||||
|
||||
auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx));
|
||||
|
||||
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
|
||||
// each segment is the index of a leaf.
|
||||
size_t seg_idx = i;
|
||||
size_t begin = seg_begin[seg_idx];
|
||||
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
|
||||
if (n == 0) {
|
||||
d_results[i] = std::numeric_limits<float>::quiet_NaN();
|
||||
return;
|
||||
}
|
||||
|
||||
if (alpha <= (1 / (n + 1))) {
|
||||
d_results[i] = val[begin];
|
||||
return;
|
||||
}
|
||||
if (alpha >= (n / (n + 1))) {
|
||||
d_results[i] = val[common::LastOf(seg_idx, seg_begin)];
|
||||
return;
|
||||
}
|
||||
|
||||
double x = alpha * static_cast<double>(n + 1);
|
||||
double k = std::floor(x) - 1;
|
||||
double d = (x - 1) - k;
|
||||
|
||||
auto v0 = val[begin + static_cast<size_t>(k)];
|
||||
auto v1 = val[begin + static_cast<size_t>(k) + 1];
|
||||
d_results[seg_idx] = v0 + d * (v1 - v0);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SegIt, typename ValIt, typename WIter>
|
||||
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
|
||||
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
|
||||
HostDeviceVector<float>* quantiles) {
|
||||
CHECK(alpha >= 0 && alpha <= 1);
|
||||
dh::device_vector<size_t> sorted_idx;
|
||||
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
|
||||
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
||||
size_t n_weights = std::distance(w_begin, w_end);
|
||||
dh::device_vector<float> weights_cdf(n_weights);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> caching;
|
||||
auto scan_key = dh::MakeTransformIterator<size_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(seg_beg, seg_end, i); });
|
||||
auto scan_val = dh::MakeTransformIterator<float>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) { return w_begin[d_sorted_idx[i]]; });
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
||||
scan_val, weights_cdf.begin());
|
||||
|
||||
auto n_segments = std::distance(seg_beg, seg_end) - 1;
|
||||
quantiles->SetDevice(ctx->gpu_id);
|
||||
quantiles->Resize(n_segments);
|
||||
auto d_results = quantiles->DeviceSpan();
|
||||
auto d_weight_cdf = dh::ToSpan(weights_cdf);
|
||||
|
||||
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
|
||||
size_t seg_idx = i;
|
||||
size_t begin = seg_beg[seg_idx];
|
||||
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
|
||||
if (n == 0) {
|
||||
d_results[i] = std::numeric_limits<float>::quiet_NaN();
|
||||
return;
|
||||
}
|
||||
auto leaf_cdf = d_weight_cdf.subspan(begin, static_cast<size_t>(n));
|
||||
auto leaf_sorted_idx = d_sorted_idx.subspan(begin, static_cast<size_t>(n));
|
||||
float thresh = leaf_cdf.back() * alpha;
|
||||
|
||||
size_t idx = thrust::lower_bound(thrust::seq, leaf_cdf.data(),
|
||||
leaf_cdf.data() + leaf_cdf.size(), thresh) -
|
||||
leaf_cdf.data();
|
||||
idx = std::min(idx, static_cast<size_t>(n - 1));
|
||||
d_results[i] = val_begin[leaf_sorted_idx[idx]];
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_STATS_CUH_
|
||||
95
src/common/stats.h
Normal file
95
src/common/stats.h
Normal file
@@ -0,0 +1,95 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_STATS_H_
|
||||
#define XGBOOST_COMMON_STATS_H_
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
/**
|
||||
* \brief Percentile with masked array using linear interpolation.
|
||||
*
|
||||
* https://www.itl.nist.gov/div898/handbook/prc/section2/prc262.htm
|
||||
*
|
||||
* \param alpha Percentile, must be in range [0, 1].
|
||||
* \param begin Iterator begin for input array.
|
||||
* \param end Iterator end for input array.
|
||||
*
|
||||
* \return The result of interpolation.
|
||||
*/
|
||||
template <typename Iter>
|
||||
float Quantile(double alpha, Iter const& begin, Iter const& end) {
|
||||
CHECK(alpha >= 0 && alpha <= 1);
|
||||
auto n = static_cast<double>(std::distance(begin, end));
|
||||
if (n == 0) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
std::vector<size_t> sorted_idx(n);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
|
||||
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
|
||||
|
||||
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
|
||||
static_assert(std::is_same<decltype(val(0)), float>::value, "");
|
||||
|
||||
if (alpha <= (1 / (n + 1))) {
|
||||
return val(0);
|
||||
}
|
||||
if (alpha >= (n / (n + 1))) {
|
||||
return val(sorted_idx.size() - 1);
|
||||
}
|
||||
assert(n != 0 && "The number of rows in a leaf can not be zero.");
|
||||
double x = alpha * static_cast<double>((n + 1));
|
||||
double k = std::floor(x) - 1;
|
||||
CHECK_GE(k, 0);
|
||||
double d = (x - 1) - k;
|
||||
|
||||
auto v0 = val(static_cast<size_t>(k));
|
||||
auto v1 = val(static_cast<size_t>(k) + 1);
|
||||
return v0 + d * (v1 - v0);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate the weighted quantile with step function. Unlike the unweighted
|
||||
* version, no interpolation is used.
|
||||
*
|
||||
* See https://aakinshin.net/posts/weighted-quantiles/ for some discussion on computing
|
||||
* weighted quantile with interpolation.
|
||||
*/
|
||||
template <typename Iter, typename WeightIter>
|
||||
float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
|
||||
auto n = static_cast<double>(std::distance(begin, end));
|
||||
if (n == 0) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
std::vector<size_t> sorted_idx(n);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
|
||||
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
|
||||
|
||||
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
|
||||
|
||||
std::vector<float> weight_cdf(n); // S_n
|
||||
// weighted cdf is sorted during construction
|
||||
weight_cdf[0] = *(weights + sorted_idx[0]);
|
||||
for (size_t i = 1; i < n; ++i) {
|
||||
weight_cdf[i] = weight_cdf[i - 1] + *(weights + sorted_idx[i]);
|
||||
}
|
||||
float thresh = weight_cdf.back() * alpha;
|
||||
size_t idx =
|
||||
std::lower_bound(weight_cdf.cbegin(), weight_cdf.cend(), thresh) - weight_cdf.cbegin();
|
||||
idx = std::min(idx, static_cast<size_t>(n - 1));
|
||||
return val(idx);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_STATS_H_
|
||||
Reference in New Issue
Block a user