[EM] Avoid synchronous calls and unnecessary ATS access. (#10811)

- Pass context into various functions.
- Factor out some CUDA algorithms.
- Use ATS only for update position.
This commit is contained in:
Jiaming Yuan 2024-09-10 14:33:14 +08:00 committed by GitHub
parent ed5f33df16
commit d94f6679fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 161 additions and 201 deletions

View File

@ -190,8 +190,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
} }
template <bool accending, typename IdxT, typename U> template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys, void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
xgboost::common::Span<IdxT> sorted_idx) {
std::size_t bytes = 0; std::size_t bytes = 0;
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
dh::Iota(sorted_idx, cuctx->Stream()); dh::Iota(sorted_idx, cuctx->Stream());
@ -272,5 +271,40 @@ void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_f
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred); out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
} }
} }
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit So we don't crash
// on n > 2^31.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename OffsetT>
void InclusiveScan(xgboost::Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
ScanOpT scan_op, OffsetT num_items) {
auto cuctx = ctx->CUDACtx();
std::size_t bytes = 0;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
nullptr, bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
nullptr, bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
storage.data().get(), bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
storage.data().get(), bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr,
false)));
#endif
}
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
OffsetT num_items) {
InclusiveScan(ctx, d_in, d_out, cub::Sum{}, num_items);
}
} // namespace xgboost::common } // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_ #endif // XGBOOST_COMMON_ALGORITHM_CUH_

View File

@ -372,21 +372,6 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
} }
template <class Src, class Dst>
void CopyTo(Src const &src, Dst *dst) {
if (src.empty()) {
dst->clear();
return;
}
dst->resize(src.size());
using SVT = std::remove_cv_t<typename Src::value_type>;
using DVT = std::remove_cv_t<typename Dst::value_type>;
static_assert(std::is_same_v<SVT, DVT>,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
src.size() * sizeof(SVT), cudaMemcpyDefault));
}
// Keep track of pinned memory allocation // Keep track of pinned memory allocation
struct PinnedMemory { struct PinnedMemory {
void *temp_storage{nullptr}; void *temp_storage{nullptr};
@ -748,45 +733,6 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
return aggregate; return aggregate;
} }
// wrapper to avoid integer `num_items`.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
typename OffsetT>
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
OffsetT num_items) {
size_t bytes = 0;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr,
false)));
#endif
TemporaryArray<char> storage(bytes);
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr, false)));
#endif
}
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}
class CUDAStreamView; class CUDAStreamView;
class CUDAEvent { class CUDAEvent {
@ -857,8 +803,23 @@ class CUDAStream {
[[nodiscard]] cudaStream_t Handle() const { return stream_; } [[nodiscard]] cudaStream_t Handle() const { return stream_; }
void Sync() { this->View().Sync(); } void Sync() { this->View().Sync(); }
void Wait(CUDAEvent const &e) { this->View().Wait(e); }
}; };
template <class Src, class Dst>
void CopyTo(Src const &src, Dst *dst, CUDAStreamView stream = DefaultStream()) {
if (src.empty()) {
dst->clear();
return;
}
dst->resize(src.size());
using SVT = std::remove_cv_t<typename Src::value_type>;
using DVT = std::remove_cv_t<typename Dst::value_type>;
static_assert(std::is_same_v<SVT, DVT>, "Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
src.size() * sizeof(SVT), cudaMemcpyDefault, stream));
}
inline auto CachingThrustPolicy() { inline auto CachingThrustPolicy() {
XGBCachingDeviceAllocator<char> alloc; XGBCachingDeviceAllocator<char> alloc;
#if THRUST_MAJOR_VERSION >= 2 || defined(XGBOOST_USE_RMM) #if THRUST_MAJOR_VERSION >= 2 || defined(XGBOOST_USE_RMM)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023 by XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include <thrust/functional.h> // for maximum #include <thrust/functional.h> // for maximum
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator #include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
@ -158,7 +158,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan(); auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
if (param_.HasTruncation()) { if (param_.HasTruncation()) {
n_cuda_threads_ = n_cuda_threads_ =
common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair()); common::SegmentedTrapezoidThreads(ctx, d_group_ptr, d_threads_group_ptr, Param().NumPair());
} else { } else {
auto n_pairs = Param().NumPair(); auto n_pairs = Param().NumPair();
dh::LaunchN(n_groups, cuctx->Stream(), dh::LaunchN(n_groups, cuctx->Stream(),

View File

@ -1,20 +1,20 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_THREADING_UTILS_CUH_ #ifndef XGBOOST_COMMON_THREADING_UTILS_CUH_
#define XGBOOST_COMMON_THREADING_UTILS_CUH_ #define XGBOOST_COMMON_THREADING_UTILS_CUH_
#include <algorithm> // std::min #include <algorithm> // std::min
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "./math.h" // Sqr #include "./math.h" // Sqr
#include "common.h" #include "algorithm.cuh" // for InclusiveSum
#include "common.h" // for safe_cuda
#include "device_helpers.cuh" // LaunchN #include "device_helpers.cuh" // LaunchN
#include "xgboost/base.h" // XGBOOST_DEVICE #include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/span.h" // Span #include "xgboost/span.h" // Span
namespace xgboost { namespace xgboost::common {
namespace common {
/** /**
* \param n Number of items (length of the base) * \param n Number of items (length of the base)
* \param h hight * \param h hight
@ -43,9 +43,8 @@ XGBOOST_DEVICE inline std::size_t DiscreteTrapezoidArea(std::size_t n, std::size
* with h <= n * with h <= n
*/ */
template <typename U> template <typename U>
std::size_t SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr, std::size_t SegmentedTrapezoidThreads(Context const *ctx, Span<U> group_ptr,
xgboost::common::Span<std::size_t> out_group_threads_ptr, Span<std::size_t> out_group_threads_ptr, std::size_t h) {
std::size_t h) {
CHECK_GE(group_ptr.size(), 1); CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size()); CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
dh::LaunchN(group_ptr.size(), [=] XGBOOST_DEVICE(std::size_t idx) { dh::LaunchN(group_ptr.size(), [=] XGBOOST_DEVICE(std::size_t idx) {
@ -57,8 +56,8 @@ std::size_t SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
std::size_t cnt = static_cast<std::size_t>(group_ptr[idx] - group_ptr[idx - 1]); std::size_t cnt = static_cast<std::size_t>(group_ptr[idx] - group_ptr[idx - 1]);
out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h); out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
}); });
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(), InclusiveSum(ctx, out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size()); out_group_threads_ptr.size());
std::size_t total = 0; std::size_t total = 0;
dh::safe_cuda(cudaMemcpy(&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, dh::safe_cuda(cudaMemcpy(&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost)); sizeof(total), cudaMemcpyDeviceToHost));
@ -82,6 +81,5 @@ XGBOOST_DEVICE inline void UnravelTrapeziodIdx(std::size_t i_idx, std::size_t n,
j = idx - n_elems + i + 1; j = idx - n_elems + i + 1;
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // XGBOOST_COMMON_THREADING_UTILS_CUH_ #endif // XGBOOST_COMMON_THREADING_UTILS_CUH_

View File

@ -254,30 +254,7 @@ void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch,
d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid}; d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid};
thrust::transform_output_iterator<decltype(functor), decltype(discard)> out(discard, functor); thrust::transform_output_iterator<decltype(functor), decltype(discard)> out(discard, functor);
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit common::InclusiveScan(ctx, key_value_index_iter, out, TupleScanOp<Tuple>{}, batch.Size());
// So we don't crash on n > 2^31
size_t temp_storage_bytes = 0;
using DispatchScan = cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, std::int64_t>;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
ctx->CUDACtx()->Stream()));
#else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
#endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), ctx->CUDACtx()->Stream()));
#else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
#endif
} }
void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) { void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) {

View File

@ -13,7 +13,7 @@
#include <utility> #include <utility>
#include "../collective/allreduce.h" #include "../collective/allreduce.h"
#include "../common/algorithm.cuh" // SegmentedArgSort #include "../common/algorithm.cuh" // SegmentedArgSort, InclusiveScan
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
#include "auc.h" #include "auc.h"
@ -128,8 +128,8 @@ std::tuple<double, double, double> GPUBinaryAUC(Context const *ctx,
dh::tbegin(d_unique_idx)); dh::tbegin(d_unique_idx));
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx)); d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp), common::InclusiveScan(ctx, dh::tbegin(d_fptp), dh::tbegin(d_fptp), PairPlus<double, double>{},
PairPlus<double, double>{}, d_fptp.size()); d_fptp.size());
auto d_neg_pos = dh::ToSpan(cache->neg_pos); auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values // scatter unique negaive/positive values
@ -239,7 +239,7 @@ double ScaleClasses(Context const *ctx, bool is_column_split, common::Span<doubl
* getting class id or group id given scan index. * getting class id or group id given scan index.
*/ */
template <typename Fn> template <typename Fn>
void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) { void SegmentedFPTP(Context const *ctx, common::Span<Pair> d_fptp, Fn segment_id) {
using Triple = thrust::tuple<uint32_t, double, double>; using Triple = thrust::tuple<uint32_t, double, double>;
// expand to tuple to include idx // expand to tuple to include idx
auto fptp_it_in = dh::MakeTransformIterator<Triple>( auto fptp_it_in = dh::MakeTransformIterator<Triple>(
@ -253,8 +253,8 @@ void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) {
thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t)); thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
return t; return t;
}); });
dh::InclusiveScan( common::InclusiveScan(
fptp_it_in, fptp_it_out, ctx, fptp_it_in, fptp_it_out,
[=] XGBOOST_DEVICE(Triple const &l, Triple const &r) { [=] XGBOOST_DEVICE(Triple const &l, Triple const &r) {
uint32_t l_gid = segment_id(thrust::get<0>(l)); uint32_t l_gid = segment_id(thrust::get<0>(l));
uint32_t r_gid = segment_id(thrust::get<0>(r)); uint32_t r_gid = segment_id(thrust::get<0>(r));
@ -391,7 +391,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
d_unique_idx = d_unique_idx.subspan(0, n_uniques); d_unique_idx = d_unique_idx.subspan(0, n_uniques);
auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; }; auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; };
SegmentedFPTP(d_fptp, get_class_id); SegmentedFPTP(ctx, d_fptp, get_class_id);
// scatter unique FP_PREV/TP_PREV values // scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos); auto d_neg_pos = dh::ToSpan(cache->neg_pos);
@ -528,8 +528,8 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
dh::caching_device_vector<size_t> threads_group_ptr(group_ptr.size(), 0); dh::caching_device_vector<size_t> threads_group_ptr(group_ptr.size(), 0);
auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr); auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr);
// Use max to represent triangle // Use max to represent triangle
auto n_threads = common::SegmentedTrapezoidThreads( auto n_threads = common::SegmentedTrapezoidThreads(ctx, d_group_ptr, d_threads_group_ptr,
d_group_ptr, d_threads_group_ptr, std::numeric_limits<size_t>::max()); std::numeric_limits<std::size_t>::max());
CHECK_LT(n_threads, std::numeric_limits<int32_t>::max()); CHECK_LT(n_threads, std::numeric_limits<int32_t>::max());
// get the coordinate in nested summation // get the coordinate in nested summation
auto get_i_j = [=]XGBOOST_DEVICE(size_t idx, size_t query_group_idx) { auto get_i_j = [=]XGBOOST_DEVICE(size_t idx, size_t query_group_idx) {
@ -591,8 +591,8 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
} }
return {}; // discard return {}; // discard
}); });
dh::InclusiveScan( common::InclusiveScan(
in, out, ctx, in, out,
[] XGBOOST_DEVICE(RankScanItem const &l, RankScanItem const &r) { [] XGBOOST_DEVICE(RankScanItem const &l, RankScanItem const &r) {
if (l.group_id != r.group_id) { if (l.group_id != r.group_id) {
return r; return r;
@ -774,7 +774,7 @@ std::pair<double, uint32_t> GPURankingPRAUCImpl(Context const *ctx,
auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) { auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) {
return dh::SegmentId(d_group_ptr, idx); return dh::SegmentId(d_group_ptr, idx);
}; };
SegmentedFPTP(d_fptp, get_group_id); SegmentedFPTP(ctx, d_fptp, get_group_id);
// scatter unique FP_PREV/TP_PREV values // scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos); auto d_neg_pos = dh::ToSpan(cache->neg_pos);

View File

@ -12,7 +12,6 @@
#include <cmath> #include <cmath>
#include <numeric> // for accumulate #include <numeric> // for accumulate
#include "../common/common.h" // for AssertGPUSupport
#include "../common/math.h" #include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h" #include "../common/pseudo_huber.h"
@ -28,7 +27,9 @@
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h> #include <thrust/transform_reduce.h>
#include "../common/device_helpers.cuh" #include "../common/cuda_context.cuh" // for CUDAContext
#else
#include "../common/common.h" // for AssertGPUSupport
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
namespace xgboost::metric { namespace xgboost::metric {
@ -48,11 +49,10 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
auto labels = info.labels.View(ctx->Device()); auto labels = info.labels.View(ctx->Device());
if (ctx->IsCUDA()) { if (ctx->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0); thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size(); thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce( result = thrust::transform_reduce(
thrust::cuda::par(alloc), begin, end, ctx->CUDACtx()->CTP(), begin, end,
[=] XGBOOST_DEVICE(size_t i) { [=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape()); auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx); auto sample_id = std::get<0>(idx);

View File

@ -6,14 +6,15 @@
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <string>
#include <set> #include <set>
#include <string>
#include "xgboost/logging.h" #include "../common/cuda_context.cuh" // for CUDAContext
#include "xgboost/span.h" #include "../common/device_helpers.cuh"
#include "constraints.cuh" #include "constraints.cuh"
#include "param.h" #include "param.h"
#include "../common/device_helpers.cuh" #include "xgboost/logging.h"
#include "xgboost/span.h"
namespace xgboost { namespace xgboost {
@ -130,9 +131,9 @@ FeatureInteractionConstraintDevice::FeatureInteractionConstraintDevice(
this->Configure(param, n_features); this->Configure(param, n_features);
} }
void FeatureInteractionConstraintDevice::Reset() { void FeatureInteractionConstraintDevice::Reset(Context const* ctx) {
for (auto& node : node_constraints_storage_) { for (auto& node : node_constraints_storage_) {
thrust::fill(node.begin(), node.end(), 0); thrust::fill(ctx->CUDACtx()->CTP(), node.begin(), node.end(), 0);
} }
} }

View File

@ -78,7 +78,7 @@ struct FeatureInteractionConstraintDevice {
FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice const& that) = default; FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice const& that) = default;
FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice&& that) = default; FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice&& that) = default;
/*! \brief Reset before constructing a new tree. */ /*! \brief Reset before constructing a new tree. */
void Reset(); void Reset(Context const* ctx);
/*! \brief Return a list of features given node id */ /*! \brief Return a list of features given node id */
common::Span<bst_feature_t> QueryNode(int32_t nid); common::Span<bst_feature_t> QueryNode(int32_t nid);
/*! /*!

View File

@ -138,9 +138,9 @@ class GPUHistEvaluator {
/** /**
* \brief Reset the evaluator, should be called before any use. * \brief Reset the evaluator, should be called before any use.
*/ */
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, void Reset(Context const *ctx, common::HistogramCuts const &cuts,
bst_feature_t n_features, TrainParam const &param, bool is_column_split, common::Span<FeatureType const> ft, bst_feature_t n_features, TrainParam const &param,
DeviceOrd device); bool is_column_split);
/** /**
* \brief Get host category storage for nidx. Different from the internal version, this * \brief Get host category storage for nidx. Different from the internal version, this
@ -154,8 +154,8 @@ class GPUHistEvaluator {
} }
[[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) { [[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) {
copy_stream_.View().Sync();
if (has_categoricals_) { if (has_categoricals_) {
copy_stream_.View().Sync();
CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_}; CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_};
return common::KCatBitField{accessor.GetNodeCatStorage(nidx)}; return common::KCatBitField{accessor.GetNodeCatStorage(nidx)};
} else { } else {

View File

@ -13,14 +13,13 @@
#include "xgboost/data.h" #include "xgboost/data.h"
namespace xgboost::tree { namespace xgboost::tree {
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, void GPUHistEvaluator::Reset(Context const *ctx, common::HistogramCuts const &cuts,
bst_feature_t n_features, TrainParam const &param, common::Span<FeatureType const> ft, bst_feature_t n_features,
bool is_column_split, DeviceOrd device) { TrainParam const &param, bool is_column_split) {
param_ = param; param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device}; tree_evaluator_ = TreeEvaluator{param, n_features, ctx->Device()};
has_categoricals_ = cuts.HasCategorical(); has_categoricals_ = cuts.HasCategorical();
if (cuts.HasCategorical()) { if (cuts.HasCategorical()) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
auto beg = thrust::make_counting_iterator<size_t>(1ul); auto beg = thrust::make_counting_iterator<size_t>(1ul);
auto end = thrust::make_counting_iterator<size_t>(ptrs.size()); auto end = thrust::make_counting_iterator<size_t>(ptrs.size());
@ -29,7 +28,7 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<Fea
// onehot-encoding-based splits. // onehot-encoding-based splits.
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
need_sort_histogram_ = need_sort_histogram_ =
thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { thrust::any_of(ctx->CUDACtx()->CTP(), beg, end, [=] XGBOOST_DEVICE(size_t i) {
auto idx = i - 1; auto idx = i - 1;
if (common::IsCat(ft, idx)) { if (common::IsCat(ft, idx)) {
auto n_bins = ptrs[i] - ptrs[idx]; auto n_bins = ptrs[i] - ptrs[idx];
@ -44,8 +43,8 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<Fea
CHECK_NE(node_categorical_storage_size_, 0); CHECK_NE(node_categorical_storage_size_, 0);
split_cats_.resize(node_categorical_storage_size_); split_cats_.resize(node_categorical_storage_size_);
h_split_cats_.resize(node_categorical_storage_size_); h_split_cats_.resize(node_categorical_storage_size_);
dh::safe_cuda( dh::safe_cuda(cudaMemsetAsync(split_cats_.data().get(), '\0',
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); split_cats_.size() * sizeof(CatST), ctx->CUDACtx()->Stream()));
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
sort_input_.resize(cat_sorted_idx_.size()); sort_input_.resize(cat_sorted_idx_.size());
@ -57,14 +56,14 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<Fea
auto d_fidxes = dh::ToSpan(feature_idx_); auto d_fidxes = dh::ToSpan(feature_idx_);
auto it = thrust::make_counting_iterator(0ul); auto it = thrust::make_counting_iterator(0ul);
auto values = cuts.cut_values_.ConstDeviceSpan(); auto values = cuts.cut_values_.ConstDeviceSpan();
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(), thrust::transform(ctx->CUDACtx()->CTP(), it, it + feature_idx_.size(), feature_idx_.begin(),
[=] XGBOOST_DEVICE(size_t i) { [=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(ptrs, i); auto fidx = dh::SegmentId(ptrs, i);
return fidx; return fidx;
}); });
} }
is_column_split_ = is_column_split; is_column_split_ = is_column_split;
device_ = device; device_ = ctx->Device();
} }
common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram( common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(

View File

@ -66,12 +66,10 @@ GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span<GradientPa
MetaInfo const& info) { MetaInfo const& info) {
using GradientSumT = GradientPairPrecise; using GradientSumT = GradientPairPrecise;
using T = typename GradientSumT::ValueT; using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::device_ptr<GradientPair const> gpair_beg{gpair.data()}; thrust::device_ptr<GradientPair const> gpair_beg{gpair.data()};
auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
Pair p = Pair p = dh::Reduce(ctx->CUDACtx()->CTP(), beg, beg + gpair.size(), Pair{}, thrust::plus<Pair>{});
dh::Reduce(thrust::cuda::par(alloc), beg, beg + gpair.size(), Pair{}, thrust::plus<Pair>{});
// Treat pair as array of 4 primitive types to allreduce // Treat pair as array of 4 primitive types to allreduce
using ReduceT = typename decltype(p.first)::ValueT; using ReduceT = typename decltype(p.first)::ValueT;
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements."); static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");

View File

@ -11,6 +11,7 @@
#include <cstdint> // for int32_t, uint32_t #include <cstdint> // for int32_t, uint32_t
#include <vector> // for vector #include <vector> // for vector
#include "../../common/cuda_context.cuh" // for CUDAContext
#include "../../common/device_helpers.cuh" // for MakeTransformIterator #include "../../common/device_helpers.cuh" // for MakeTransformIterator
#include "xgboost/base.h" // for bst_idx_t #include "xgboost/base.h" // for bst_idx_t
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
@ -356,18 +357,18 @@ class RowPartitioner {
* argument and return the new position for this training instance. * argument and return the new position for this training instance.
*/ */
template <typename FinalisePositionOpT> template <typename FinalisePositionOpT>
void FinalisePosition(common::Span<bst_node_t> d_out_position, bst_idx_t base_ridx, void FinalisePosition(Context const* ctx, common::Span<bst_node_t> d_out_position,
FinalisePositionOpT op) const { bst_idx_t base_ridx, FinalisePositionOpT op) const {
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size()); dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
sizeof(NodePositionInfo) * ridx_segments_.size(), sizeof(NodePositionInfo) * ridx_segments_.size(),
cudaMemcpyDefault)); cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
constexpr int kBlockSize = 512; constexpr int kBlockSize = 512;
const int kItemsThread = 8; const int kItemsThread = 8;
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()}; common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()};
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>( FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0, ctx->CUDACtx()->Stream()>>>(
dh::ToSpan(d_node_info_storage), base_ridx, d_ridx, d_out_position, op); dh::ToSpan(d_node_info_storage), base_ridx, d_ridx, d_out_position, op);
} }
}; };

View File

@ -64,14 +64,10 @@ struct NodeSplitData {
}; };
static_assert(std::is_trivially_copyable_v<NodeSplitData>); static_assert(std::is_trivially_copyable_v<NodeSplitData>);
// To be tuned.
constexpr double ExtMemPrefetchThresh() { return 4.0; }
// Some nodes we will manually compute histograms, others we will do by subtraction // Some nodes we will manually compute histograms, others we will do by subtraction
[[nodiscard]] bool AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer, void AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer,
std::vector<GPUExpandEntry> const& candidates, std::vector<GPUExpandEntry> const& candidates,
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub) {
common::Span<bst_node_t> nodes_to_sub) {
auto const& tree = *p_tree; auto const& tree = *p_tree;
std::size_t nidx_in_set{0}; std::size_t nidx_in_set{0};
double total{0.0}, smaller{0.0}; double total{0.0}, smaller{0.0};
@ -97,12 +93,6 @@ constexpr double ExtMemPrefetchThresh() { return 4.0; }
} }
++nidx_in_set; ++nidx_in_set;
} }
if (-kRtEps < smaller && smaller < kRtEps) { // Too close to 0, don't prefetch.
return false;
}
// Prefetch if these smaller nodes are not quite small.
return (total / smaller) < ExtMemPrefetchThresh();
} }
// GPU tree updater implementation. // GPU tree updater implementation.
@ -201,16 +191,19 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration // Reset values for each update iteration
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) { [[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
this->monitor.Start(__func__); this->monitor.Start(__func__);
common::SetDevice(ctx_->Ordinal());
auto const& info = p_fmat->Info(); auto const& info = p_fmat->Info();
// backup the gradient
dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair, ctx_->CUDACtx()->Stream());
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(), this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel, param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree); param.colsample_bytree);
common::SetDevice(ctx_->Ordinal()); this->interaction_constraints.Reset(ctx_);
this->evaluator_.Reset(this->ctx_, *cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(),
this->interaction_constraints.Reset(); p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit());
// Sampling // Sampling
dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair); // backup the gradient
auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat); auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
this->gpair = sample.gpair; this->gpair = sample.gpair;
p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners
@ -242,10 +235,6 @@ struct GPUHistMakerDevice {
} }
// Other initializations // Other initializations
this->evaluator_.Reset(*cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(),
p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit(),
this->ctx_->Device());
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info()); quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
this->InitFeatureGroupsOnce(info); this->InitFeatureGroupsOnce(info);
@ -488,8 +477,8 @@ struct GPUHistMakerDevice {
// Prepare for build hist // Prepare for build hist
std::vector<bst_node_t> build_nidx(candidates.size()); std::vector<bst_node_t> build_nidx(candidates.size());
std::vector<bst_node_t> subtraction_nidx(candidates.size()); std::vector<bst_node_t> subtraction_nidx(candidates.size());
auto prefetch_copy = AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx);
AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx); auto prefetch_copy = !build_nidx.empty();
this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx); this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx);
@ -534,10 +523,13 @@ struct GPUHistMakerDevice {
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) { if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective function can not be used with external memory."; LOG(FATAL) << "Current objective function can not be used with external memory.";
} }
monitor.Start(__func__);
if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) { if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) {
// External memory with concatenation. Not supported. // External memory with concatenation. Not supported.
p_out_position->Resize(0); p_out_position->Resize(0);
positions_.clear(); positions_.clear();
monitor.Stop(__func__);
return; return;
} }
@ -557,14 +549,16 @@ struct GPUHistMakerDevice {
CHECK_EQ(part->GetNumNodes(), p_tree->NumNodes()); CHECK_EQ(part->GetNumNodes(), p_tree->NumNodes());
auto base_ridx = batch_ptr_[k]; auto base_ridx = batch_ptr_[k];
auto n_samples = batch_ptr_.at(k + 1) - base_ridx; auto n_samples = batch_ptr_.at(k + 1) - base_ridx;
part->FinalisePosition(d_out_position.subspan(base_ridx, n_samples), base_ridx, encode_op); part->FinalisePosition(ctx_, d_out_position.subspan(base_ridx, n_samples), base_ridx,
encode_op);
} }
dh::CopyTo(d_out_position, &positions_); dh::CopyTo(d_out_position, &positions_, this->ctx_->CUDACtx()->Stream());
monitor.Stop(__func__);
return; return;
} }
dh::caching_device_vector<uint32_t> categories; dh::caching_device_vector<uint32_t> categories;
dh::CopyTo(p_tree->GetSplitCategories(), &categories); dh::CopyTo(p_tree->GetSplitCategories(), &categories, this->ctx_->CUDACtx()->Stream());
auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
auto d_categories = dh::ToSpan(categories); auto d_categories = dh::ToSpan(categories);
auto ft = p_fmat->Info().feature_types.ConstDeviceSpan(); auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
@ -583,22 +577,24 @@ struct GPUHistMakerDevice {
auto go_left_op = GoLeftOp{d_matrix}; auto go_left_op = GoLeftOp{d_matrix};
dh::caching_device_vector<NodeSplitData> d_split_data; dh::caching_device_vector<NodeSplitData> d_split_data;
dh::CopyTo(split_data, &d_split_data); dh::CopyTo(split_data, &d_split_data, this->ctx_->CUDACtx()->Stream());
auto s_split_data = dh::ToSpan(d_split_data); auto s_split_data = dh::ToSpan(d_split_data);
partitioners_.front()->FinalisePosition( partitioners_.front()->FinalisePosition(ctx_, d_out_position, page.BaseRowId(),
d_out_position, page.BaseRowId(), [=] __device__(bst_idx_t row_id, bst_node_t nidx) { [=] __device__(bst_idx_t row_id, bst_node_t nidx) {
auto split_data = s_split_data[nidx]; auto split_data = s_split_data[nidx];
auto node = split_data.split_node; auto node = split_data.split_node;
while (!node.IsLeaf()) { while (!node.IsLeaf()) {
auto go_left = go_left_op(row_id, split_data); auto go_left = go_left_op(row_id, split_data);
nidx = go_left ? node.LeftChild() : node.RightChild(); nidx = go_left ? node.LeftChild()
node = s_split_data[nidx].split_node; : node.RightChild();
} node = s_split_data[nidx].split_node;
return encode_op(row_id, nidx); }
}); return encode_op(row_id, nidx);
dh::CopyTo(d_out_position, &positions_); });
dh::CopyTo(d_out_position, &positions_, this->ctx_->CUDACtx()->Stream());
} }
monitor.Stop(__func__);
} }
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) { bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
@ -616,7 +612,7 @@ struct GPUHistMakerDevice {
// Use the nodes from tree, the leaf value might be changed by the objective since the // Use the nodes from tree, the leaf value might be changed by the objective since the
// last update tree call. // last update tree call.
dh::caching_device_vector<RegTree::Node> nodes; dh::caching_device_vector<RegTree::Node> nodes;
dh::CopyTo(p_tree->GetNodes(), &nodes); dh::CopyTo(p_tree->GetNodes(), &nodes, this->ctx_->CUDACtx()->Stream());
common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes); common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes);
CHECK_EQ(out_preds_d.Shape(1), 1); CHECK_EQ(out_preds_d.Shape(1), 1);
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(), dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),

View File

@ -1,16 +1,17 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/copy.h> // thrust::copy #include <thrust/copy.h> // thrust::copy
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/threading_utils.cuh" #include "../../../src/common/threading_utils.cuh"
#include "../helpers.h" // for MakeCUDACtx
namespace xgboost { namespace xgboost::common {
namespace common {
TEST(SegmentedTrapezoidThreads, Basic) { TEST(SegmentedTrapezoidThreads, Basic) {
size_t constexpr kElements = 24, kGroups = 3; size_t constexpr kElements = 24, kGroups = 3;
auto ctx = MakeCUDACtx(0);
dh::device_vector<size_t> offset_ptr(kGroups + 1, 0); dh::device_vector<size_t> offset_ptr(kGroups + 1, 0);
offset_ptr[0] = 0; offset_ptr[0] = 0;
offset_ptr[1] = 8; offset_ptr[1] = 8;
@ -19,11 +20,11 @@ TEST(SegmentedTrapezoidThreads, Basic) {
size_t h = 1; size_t h = 1;
dh::device_vector<size_t> thread_ptr(kGroups + 1, 0); dh::device_vector<size_t> thread_ptr(kGroups + 1, 0);
size_t total = SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); size_t total = SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
ASSERT_EQ(total, kElements - kGroups); ASSERT_EQ(total, kElements - kGroups);
h = 2; h = 2;
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
std::vector<size_t> h_thread_ptr(thread_ptr.size()); std::vector<size_t> h_thread_ptr(thread_ptr.size());
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin()); thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
for (size_t i = 1; i < h_thread_ptr.size(); ++i) { for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
@ -31,7 +32,7 @@ TEST(SegmentedTrapezoidThreads, Basic) {
} }
h = 7; h = 7;
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin()); thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
for (size_t i = 1; i < h_thread_ptr.size(); ++i) { for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 28); ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 28);
@ -66,5 +67,4 @@ TEST(SegmentedTrapezoidThreads, Unravel) {
ASSERT_EQ(i, 6); ASSERT_EQ(i, 6);
ASSERT_EQ(j, 7); ASSERT_EQ(j, 7);
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost

View File

@ -60,8 +60,7 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()}; GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false, evaluator.Reset(&ctx, cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false);
ctx.Device());
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split; DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
ASSERT_EQ(result.thresh, 1); ASSERT_EQ(result.thresh, 1);
@ -104,7 +103,7 @@ TEST(GpuHist, PartitionBasic) {
}; };
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()}; GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, ctx.Device()); evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
{ {
// -1.0s go right // -1.0s go right
@ -217,7 +216,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
false}; false};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()}; GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, ctx.Device()); evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
{ {
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0}); auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
@ -277,10 +276,8 @@ TEST(GpuHist, PartitionTwoNodes) {
cuts.min_vals_.ConstDeviceSpan(), cuts.min_vals_.ConstDeviceSpan(),
false}; false};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
ctx.Device()}; evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
ctx.Device());
{ {
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0}); auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
@ -336,10 +333,8 @@ void TestEvaluateSingleSplit(bool is_categorical) {
cuts.min_vals_.ConstDeviceSpan(), cuts.min_vals_.ConstDeviceSpan(),
false}; false};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
ctx.Device()}; evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
ctx.Device());
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split; DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);
@ -522,7 +517,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
cuts_.cut_values_.SetDevice(ctx.Device()); cuts_.cut_values_.SetDevice(ctx.Device());
cuts_.min_vals_.SetDevice(ctx.Device()); cuts_.min_vals_.SetDevice(ctx.Device());
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, false, ctx.Device()); evaluator.Reset(&ctx, cuts_, dh::ToSpan(ft), info_.num_col_, param_, false);
// Convert the sample histogram to fixed point // Convert the sample histogram to fixed point
auto quantiser = DummyRoundingFactor(&ctx); auto quantiser = DummyRoundingFactor(&ctx);
@ -586,7 +581,7 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
false}; false};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()}; GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device()); evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split; DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);