Support multiple alphas for segmented quantile. (#8758)
This commit is contained in:
parent
c4802bfcd0
commit
48cefa012e
@ -1,29 +1,179 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_COMMON_STATS_CUH_
|
#ifndef XGBOOST_COMMON_STATS_CUH_
|
||||||
#define XGBOOST_COMMON_STATS_CUH_
|
#define XGBOOST_COMMON_STATS_CUH_
|
||||||
|
|
||||||
#include <thrust/iterator/counting_iterator.h>
|
#include <thrust/binary_search.h> // lower_bound
|
||||||
#include <thrust/iterator/permutation_iterator.h>
|
#include <thrust/for_each.h> // for_each_n
|
||||||
|
#include <thrust/iterator/constant_iterator.h> // make_constant_iterator
|
||||||
|
#include <thrust/iterator/counting_iterator.h> // make_counting_iterator
|
||||||
|
#include <thrust/iterator/permutation_iterator.h> // make_permutation_iterator
|
||||||
|
#include <thrust/scan.h> // inclusive_scan_by_key
|
||||||
|
|
||||||
|
#include <algorithm> // std::min
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
#include <iterator> // std::distance
|
#include <iterator> // std::distance
|
||||||
|
#include <limits> // std::numeric_limits
|
||||||
|
#include <type_traits> // std::is_floating_point,std::iterator_traits
|
||||||
|
|
||||||
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "linalg_op.cuh"
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/span.h" // Span
|
||||||
#include "xgboost/linalg.h"
|
|
||||||
#include "xgboost/tree_model.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
namespace detail {
|
||||||
|
// This should be a lambda function, but for some reason gcc-11 + nvcc-11.8 failed to
|
||||||
|
// compile it. As a result, a functor is extracted instead.
|
||||||
|
//
|
||||||
|
// error: ‘__T288’ was not declared in this scope
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt>
|
||||||
|
struct QuantileSegmentOp {
|
||||||
|
SegIt seg_begin;
|
||||||
|
ValIt val;
|
||||||
|
AlphaIt alpha_it;
|
||||||
|
Span<float> d_results;
|
||||||
|
|
||||||
|
static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
|
||||||
|
"Invalid value for quantile.");
|
||||||
|
static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
|
||||||
|
"Invalid alpha.");
|
||||||
|
|
||||||
|
XGBOOST_DEVICE void operator()(std::size_t seg_idx) {
|
||||||
|
std::size_t begin = seg_begin[seg_idx];
|
||||||
|
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
|
||||||
|
double a = alpha_it[seg_idx];
|
||||||
|
|
||||||
|
if (n == 0) {
|
||||||
|
d_results[seg_idx] = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a <= (1 / (n + 1))) {
|
||||||
|
d_results[seg_idx] = val[begin];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (a >= (n / (n + 1))) {
|
||||||
|
d_results[seg_idx] = val[common::LastOf(seg_idx, seg_begin)];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
double x = a * static_cast<double>(n + 1);
|
||||||
|
double k = std::floor(x) - 1;
|
||||||
|
double d = (x - 1) - k;
|
||||||
|
|
||||||
|
auto v0 = val[begin + static_cast<std::size_t>(k)];
|
||||||
|
auto v1 = val[begin + static_cast<std::size_t>(k) + 1];
|
||||||
|
|
||||||
|
d_results[seg_idx] = v0 + d * (v1 - v0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt>
|
||||||
|
auto MakeQSegOp(SegIt seg_it, ValIt val_it, AlphaIt alpha_it, Span<float> d_results) {
|
||||||
|
return QuantileSegmentOp<SegIt, ValIt, AlphaIt>{seg_it, val_it, alpha_it, d_results};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SegIt>
|
||||||
|
struct SegOp {
|
||||||
|
SegIt seg_beg;
|
||||||
|
SegIt seg_end;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE std::size_t operator()(std::size_t i) {
|
||||||
|
return dh::SegmentId(seg_beg, seg_end, i);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename WIter>
|
||||||
|
struct WeightOp {
|
||||||
|
WIter w_begin;
|
||||||
|
Span<std::size_t const> d_sorted_idx;
|
||||||
|
XGBOOST_DEVICE float operator()(std::size_t i) { return w_begin[d_sorted_idx[i]]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt>
|
||||||
|
struct WeightedQuantileSegOp {
|
||||||
|
AlphaIt alpha_it;
|
||||||
|
SegIt seg_beg;
|
||||||
|
ValIt val_begin;
|
||||||
|
Span<float const> d_weight_cdf;
|
||||||
|
Span<std::size_t const> d_sorted_idx;
|
||||||
|
Span<float> d_results;
|
||||||
|
static_assert(std::is_floating_point<typename std::iterator_traits<AlphaIt>::value_type>::value,
|
||||||
|
"Invalid alpha.");
|
||||||
|
static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
|
||||||
|
"Invalid value for quantile.");
|
||||||
|
|
||||||
|
XGBOOST_DEVICE void operator()(std::size_t seg_idx) {
|
||||||
|
std::size_t begin = seg_beg[seg_idx];
|
||||||
|
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
|
||||||
|
if (n == 0) {
|
||||||
|
d_results[seg_idx] = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto seg_cdf = d_weight_cdf.subspan(begin, static_cast<std::size_t>(n));
|
||||||
|
auto seg_sorted_idx = d_sorted_idx.subspan(begin, static_cast<std::size_t>(n));
|
||||||
|
double a = alpha_it[seg_idx];
|
||||||
|
double thresh = seg_cdf.back() * a;
|
||||||
|
|
||||||
|
std::size_t idx =
|
||||||
|
thrust::lower_bound(thrust::seq, seg_cdf.data(), seg_cdf.data() + seg_cdf.size(), thresh) -
|
||||||
|
seg_cdf.data();
|
||||||
|
idx = std::min(idx, static_cast<std::size_t>(n - 1));
|
||||||
|
d_results[seg_idx] = val_begin[seg_sorted_idx[idx]];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt>
|
||||||
|
auto MakeWQSegOp(SegIt seg_it, ValIt val_it, AlphaIt alpha_it, Span<float const> d_weight_cdf,
|
||||||
|
Span<std::size_t const> d_sorted_idx, Span<float> d_results) {
|
||||||
|
return WeightedQuantileSegOp<SegIt, ValIt, AlphaIt>{alpha_it, seg_it, val_it,
|
||||||
|
d_weight_cdf, d_sorted_idx, d_results};
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
/**
|
/**
|
||||||
* \brief Compute segmented quantile on GPU.
|
* @brief Compute segmented quantile on GPU.
|
||||||
*
|
*
|
||||||
* \tparam SegIt Iterator for CSR style segments indptr
|
* @tparam SegIt Iterator for CSR style segments indptr
|
||||||
* \tparam ValIt Iterator for values
|
* @tparam ValIt Iterator for values
|
||||||
|
* @tparam AlphaIt Iterator to alphas
|
||||||
*
|
*
|
||||||
* \param alpha The p^th quantile we want to compute
|
* @param alpha The p^th quantile we want to compute, one for each segment.
|
||||||
|
*
|
||||||
|
* std::distance(seg_begin, seg_end) should be equal to n_segments + 1
|
||||||
|
*/
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt,
|
||||||
|
std::enable_if_t<!std::is_floating_point<AlphaIt>::value>* = nullptr>
|
||||||
|
void SegmentedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_begin, SegIt seg_end,
|
||||||
|
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
|
||||||
|
dh::device_vector<std::size_t> sorted_idx;
|
||||||
|
using Tup = thrust::tuple<std::size_t, float>;
|
||||||
|
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
|
||||||
|
auto n_segments = std::distance(seg_begin, seg_end) - 1;
|
||||||
|
if (n_segments <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
||||||
|
auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx));
|
||||||
|
|
||||||
|
quantiles->SetDevice(ctx->gpu_id);
|
||||||
|
quantiles->Resize(n_segments);
|
||||||
|
auto d_results = quantiles->DeviceSpan();
|
||||||
|
|
||||||
|
dh::LaunchN(n_segments, ctx->CUDACtx()->Stream(),
|
||||||
|
detail::MakeQSegOp(seg_begin, val, alpha_it, d_results));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute segmented quantile on GPU.
|
||||||
|
*
|
||||||
|
* @tparam SegIt Iterator for CSR style segments indptr
|
||||||
|
* @tparam ValIt Iterator for values
|
||||||
|
*
|
||||||
|
* @param alpha The p^th quantile we want to compute
|
||||||
*
|
*
|
||||||
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
|
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
|
||||||
*/
|
*/
|
||||||
@ -31,69 +181,40 @@ template <typename SegIt, typename ValIt>
|
|||||||
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
|
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
|
||||||
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
|
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
|
||||||
CHECK(alpha >= 0 && alpha <= 1);
|
CHECK(alpha >= 0 && alpha <= 1);
|
||||||
|
auto alpha_it = thrust::make_constant_iterator(alpha);
|
||||||
dh::device_vector<size_t> sorted_idx;
|
return SegmentedQuantile(ctx, alpha_it, seg_begin, seg_end, val_begin, val_end, quantiles);
|
||||||
using Tup = thrust::tuple<size_t, float>;
|
|
||||||
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
|
|
||||||
auto n_segments = std::distance(seg_begin, seg_end) - 1;
|
|
||||||
if (n_segments <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
quantiles->SetDevice(ctx->gpu_id);
|
|
||||||
quantiles->Resize(n_segments);
|
|
||||||
auto d_results = quantiles->DeviceSpan();
|
|
||||||
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
|
||||||
|
|
||||||
auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx));
|
|
||||||
|
|
||||||
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
|
|
||||||
// each segment is the index of a leaf.
|
|
||||||
size_t seg_idx = i;
|
|
||||||
size_t begin = seg_begin[seg_idx];
|
|
||||||
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
|
|
||||||
if (n == 0) {
|
|
||||||
d_results[i] = std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (alpha <= (1 / (n + 1))) {
|
|
||||||
d_results[i] = val[begin];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (alpha >= (n / (n + 1))) {
|
|
||||||
d_results[i] = val[common::LastOf(seg_idx, seg_begin)];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
double x = alpha * static_cast<double>(n + 1);
|
|
||||||
double k = std::floor(x) - 1;
|
|
||||||
double d = (x - 1) - k;
|
|
||||||
|
|
||||||
auto v0 = val[begin + static_cast<size_t>(k)];
|
|
||||||
auto v1 = val[begin + static_cast<size_t>(k) + 1];
|
|
||||||
d_results[seg_idx] = v0 + d * (v1 - v0);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SegIt, typename ValIt, typename WIter>
|
/**
|
||||||
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
|
* @brief Compute segmented quantile on GPU with weighted inputs.
|
||||||
|
*
|
||||||
|
* @tparam SegIt Iterator for CSR style segments indptr
|
||||||
|
* @tparam ValIt Iterator for values
|
||||||
|
* @tparam WIter Iterator for weights
|
||||||
|
*
|
||||||
|
* @param alpha_it Iterator for the p^th quantile we want to compute, one per-segment
|
||||||
|
* @param w_begin Iterator for weight for each input element
|
||||||
|
*/
|
||||||
|
template <typename SegIt, typename ValIt, typename AlphaIt, typename WIter,
|
||||||
|
typename std::enable_if_t<!std::is_same<
|
||||||
|
typename std::iterator_traits<AlphaIt>::value_type, void>::value>* = nullptr>
|
||||||
|
void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_beg, SegIt seg_end,
|
||||||
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
|
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
|
||||||
HostDeviceVector<float>* quantiles) {
|
HostDeviceVector<float>* quantiles) {
|
||||||
CHECK(alpha >= 0 && alpha <= 1);
|
auto cuctx = ctx->CUDACtx();
|
||||||
dh::device_vector<size_t> sorted_idx;
|
dh::device_vector<std::size_t> sorted_idx;
|
||||||
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
|
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
|
||||||
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
auto d_sorted_idx = dh::ToSpan(sorted_idx);
|
||||||
size_t n_weights = std::distance(w_begin, w_end);
|
std::size_t n_weights = std::distance(w_begin, w_end);
|
||||||
dh::device_vector<float> weights_cdf(n_weights);
|
dh::device_vector<float> weights_cdf(n_weights);
|
||||||
|
std::size_t n_elems = std::distance(val_begin, val_end);
|
||||||
|
CHECK_EQ(n_weights, n_elems);
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> caching;
|
dh::XGBCachingDeviceAllocator<char> caching;
|
||||||
auto scan_key = dh::MakeTransformIterator<size_t>(
|
auto scan_key = dh::MakeTransformIterator<std::size_t>(thrust::make_counting_iterator(0ul),
|
||||||
thrust::make_counting_iterator(0ul),
|
detail::SegOp<SegIt>{seg_beg, seg_end});
|
||||||
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(seg_beg, seg_end, i); });
|
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
auto scan_val = dh::MakeTransformIterator<float>(
|
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
|
||||||
thrust::make_counting_iterator(0ul),
|
|
||||||
[=] XGBOOST_DEVICE(size_t i) { return w_begin[d_sorted_idx[i]]; });
|
|
||||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
||||||
scan_val, weights_cdf.begin());
|
scan_val, weights_cdf.begin());
|
||||||
|
|
||||||
@ -103,24 +224,18 @@ void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg,
|
|||||||
auto d_results = quantiles->DeviceSpan();
|
auto d_results = quantiles->DeviceSpan();
|
||||||
auto d_weight_cdf = dh::ToSpan(weights_cdf);
|
auto d_weight_cdf = dh::ToSpan(weights_cdf);
|
||||||
|
|
||||||
dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
|
thrust::for_each_n(
|
||||||
size_t seg_idx = i;
|
cuctx->CTP(), thrust::make_counting_iterator(0ul), n_segments,
|
||||||
size_t begin = seg_beg[seg_idx];
|
detail::MakeWQSegOp(seg_beg, val_begin, alpha_it, d_weight_cdf, d_sorted_idx, d_results));
|
||||||
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
|
}
|
||||||
if (n == 0) {
|
|
||||||
d_results[i] = std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto leaf_cdf = d_weight_cdf.subspan(begin, static_cast<size_t>(n));
|
|
||||||
auto leaf_sorted_idx = d_sorted_idx.subspan(begin, static_cast<size_t>(n));
|
|
||||||
float thresh = leaf_cdf.back() * alpha;
|
|
||||||
|
|
||||||
size_t idx = thrust::lower_bound(thrust::seq, leaf_cdf.data(),
|
template <typename SegIt, typename ValIt, typename WIter>
|
||||||
leaf_cdf.data() + leaf_cdf.size(), thresh) -
|
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
|
||||||
leaf_cdf.data();
|
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
|
||||||
idx = std::min(idx, static_cast<size_t>(n - 1));
|
HostDeviceVector<float>* quantiles) {
|
||||||
d_results[i] = val_begin[leaf_sorted_idx[idx]];
|
CHECK(alpha >= 0 && alpha <= 1);
|
||||||
});
|
return SegmentedWeightedQuantile(ctx, thrust::make_constant_iterator(alpha), seg_beg, seg_end,
|
||||||
|
val_begin, val_end, w_begin, w_end, quantiles);
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,79 +1,155 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <cstddef> // std::size_t
|
||||||
#include <vector>
|
#include <utility> // std::pair
|
||||||
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
#include "../../../src/common/linalg_op.cuh" // ElementWiseTransformDevice
|
||||||
#include "../../../src/common/stats.cuh"
|
#include "../../../src/common/stats.cuh"
|
||||||
#include "../../../src/common/stats.h"
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/linalg.h" // Tensor
|
||||||
#include "xgboost/linalg.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
namespace {
|
namespace {
|
||||||
class StatsGPU : public ::testing::Test {
|
class StatsGPU : public ::testing::Test {
|
||||||
private:
|
private:
|
||||||
linalg::Tensor<float, 1> arr_{
|
linalg::Tensor<float, 1> arr_{{1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 5.f, 3.f, 1.f}, {10}, 0};
|
||||||
{1.f, 2.f, 3.f, 4.f, 5.f,
|
linalg::Tensor<std::size_t, 1> indptr_{{0, 5, 10}, {3}, 0};
|
||||||
2.f, 4.f, 5.f, 3.f, 1.f},
|
HostDeviceVector<float> results_;
|
||||||
{10}, 0};
|
|
||||||
linalg::Tensor<size_t, 1> indptr_{{0, 5, 10}, {3}, 0};
|
|
||||||
HostDeviceVector<float> resutls_;
|
|
||||||
using TestSet = std::vector<std::pair<float, float>>;
|
using TestSet = std::vector<std::pair<float, float>>;
|
||||||
Context ctx_;
|
Context ctx_;
|
||||||
|
|
||||||
void Check(float expected) {
|
void Check(float expected) {
|
||||||
auto const& h_results = resutls_.HostVector();
|
auto const& h_results = results_.HostVector();
|
||||||
ASSERT_EQ(h_results.size(), indptr_.Size() - 1);
|
ASSERT_EQ(h_results.size(), indptr_.Size() - 1);
|
||||||
ASSERT_EQ(h_results.front(), expected);
|
ASSERT_EQ(h_results.front(), expected);
|
||||||
EXPECT_EQ(h_results.back(), expected);
|
ASSERT_EQ(h_results.back(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void SetUp() override { ctx_.gpu_id = 0; }
|
void SetUp() override { ctx_.gpu_id = 0; }
|
||||||
|
|
||||||
|
void WeightedMulti() {
|
||||||
|
// data for one segment
|
||||||
|
std::vector<float> seg{1.f, 2.f, 3.f, 4.f, 5.f};
|
||||||
|
auto seg_size = seg.size();
|
||||||
|
|
||||||
|
// 3 segments
|
||||||
|
std::vector<float> data;
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
||||||
|
auto d_arr = arr.View(0);
|
||||||
|
|
||||||
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
|
thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return i * seg_size; });
|
||||||
|
auto val_it =
|
||||||
|
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
||||||
|
|
||||||
|
// one alpha for each segment
|
||||||
|
HostDeviceVector<float> alphas{0.0f, 0.5f, 1.0f};
|
||||||
|
alphas.SetDevice(0);
|
||||||
|
auto d_alphas = alphas.ConstDeviceSpan();
|
||||||
|
auto w_it = thrust::make_constant_iterator(0.1f);
|
||||||
|
SegmentedWeightedQuantile(&ctx_, d_alphas.data(), key_it, key_it + d_alphas.size() + 1, val_it,
|
||||||
|
val_it + d_arr.Size(), w_it, w_it + d_arr.Size(), &results_);
|
||||||
|
|
||||||
|
auto const& h_results = results_.HostVector();
|
||||||
|
ASSERT_EQ(1.0f, h_results[0]);
|
||||||
|
ASSERT_EQ(3.0f, h_results[1]);
|
||||||
|
ASSERT_EQ(5.0f, h_results[2]);
|
||||||
|
}
|
||||||
|
|
||||||
void Weighted() {
|
void Weighted() {
|
||||||
auto d_arr = arr_.View(0);
|
auto d_arr = arr_.View(0);
|
||||||
auto d_key = indptr_.View(0);
|
auto d_key = indptr_.View(0);
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
[=] __device__(size_t i) { return d_key(i); });
|
thrust::make_counting_iterator(0ul),
|
||||||
auto val_it = dh::MakeTransformIterator<float>(
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_key(i); });
|
||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); });
|
auto val_it =
|
||||||
|
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
||||||
linalg::Tensor<float, 1> weights{{10}, 0};
|
linalg::Tensor<float, 1> weights{{10}, 0};
|
||||||
linalg::ElementWiseTransformDevice(weights.View(0),
|
linalg::ElementWiseTransformDevice(weights.View(0),
|
||||||
[=] XGBOOST_DEVICE(size_t, float) { return 1.0; });
|
[=] XGBOOST_DEVICE(std::size_t, float) { return 1.0; });
|
||||||
auto w_it = weights.Data()->ConstDevicePointer();
|
auto w_it = weights.Data()->ConstDevicePointer();
|
||||||
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
||||||
SegmentedWeightedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
|
SegmentedWeightedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
|
||||||
val_it + arr_.Size(), w_it, w_it + weights.Size(), &resutls_);
|
val_it + arr_.Size(), w_it, w_it + weights.Size(), &results_);
|
||||||
this->Check(pair.second);
|
this->Check(pair.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NonWeightedMulti() {
|
||||||
|
// data for one segment
|
||||||
|
std::vector<float> seg{20.f, 15.f, 50.f, 40.f, 35.f};
|
||||||
|
auto seg_size = seg.size();
|
||||||
|
|
||||||
|
// 3 segments
|
||||||
|
std::vector<float> data;
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
|
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
||||||
|
auto d_arr = arr.View(0);
|
||||||
|
|
||||||
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
|
thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return i * seg_size; });
|
||||||
|
auto val_it =
|
||||||
|
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
||||||
|
|
||||||
|
// one alpha for each segment
|
||||||
|
HostDeviceVector<float> alphas{0.1f, 0.2f, 0.4f};
|
||||||
|
alphas.SetDevice(0);
|
||||||
|
auto d_alphas = alphas.ConstDeviceSpan();
|
||||||
|
SegmentedQuantile(&ctx_, d_alphas.data(), key_it, key_it + d_alphas.size() + 1, val_it,
|
||||||
|
val_it + d_arr.Size(), &results_);
|
||||||
|
|
||||||
|
auto const& h_results = results_.HostVector();
|
||||||
|
EXPECT_EQ(15.0f, h_results[0]);
|
||||||
|
EXPECT_EQ(16.0f, h_results[1]);
|
||||||
|
ASSERT_EQ(26.0f, h_results[2]);
|
||||||
|
}
|
||||||
|
|
||||||
void NonWeighted() {
|
void NonWeighted() {
|
||||||
auto d_arr = arr_.View(0);
|
auto d_arr = arr_.View(0);
|
||||||
auto d_key = indptr_.View(0);
|
auto d_key = indptr_.View(0);
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
[=] __device__(size_t i) { return d_key(i); });
|
thrust::make_counting_iterator(0ul), [=] __device__(std::size_t i) { return d_key(i); });
|
||||||
auto val_it = dh::MakeTransformIterator<float>(
|
auto val_it =
|
||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); });
|
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
||||||
|
|
||||||
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
||||||
SegmentedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
|
SegmentedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it,
|
||||||
val_it + arr_.Size(), &resutls_);
|
val_it + arr_.Size(), &results_);
|
||||||
this->Check(pair.second);
|
this->Check(pair.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST_F(StatsGPU, Quantile) { this->NonWeighted(); }
|
TEST_F(StatsGPU, Quantile) {
|
||||||
TEST_F(StatsGPU, WeightedQuantile) { this->Weighted(); }
|
this->NonWeighted();
|
||||||
|
this->NonWeightedMulti();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StatsGPU, WeightedQuantile) {
|
||||||
|
this->Weighted();
|
||||||
|
this->WeightedMulti();
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user