[EM] Enable prediction cache for GPU. (#10707)

- Use `UpdatePosition` for all nodes and skip `FinalizePosition` when external memory is used.
- Create `encode/decode` for node position, this is just as a refactor.
- Reuse code between update position and finalization.
This commit is contained in:
Jiaming Yuan 2024-08-15 21:41:59 +08:00 committed by GitHub
parent 0def8e0bae
commit 582ea104b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 378 additions and 327 deletions

View File

@ -1,20 +1,17 @@
/** /**
* Copyright 2020-2023, XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
* \file categorical.h * \file categorical.h
*/ */
#ifndef XGBOOST_COMMON_CATEGORICAL_H_ #ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_ #define XGBOOST_COMMON_CATEGORICAL_H_
#include <limits>
#include "bitfield.h" #include "bitfield.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost::common {
namespace common {
using CatBitField = LBitField32; using CatBitField = LBitField32;
using KCatBitField = CLBitField32; using KCatBitField = CLBitField32;
@ -94,7 +91,12 @@ XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot
struct IsCatOp { struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
}; };
} // namespace common
} // namespace xgboost inline auto GetNodeCats(common::Span<CatBitField::value_type const> categories,
RegTree::CategoricalSplitMatrix::Segment seg) {
KCatBitField node_cats{categories.subspan(seg.beg, seg.size)};
return node_cats;
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_CATEGORICAL_H_ #endif // XGBOOST_COMMON_CATEGORICAL_H_

View File

@ -16,12 +16,9 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/util_type.cuh> // for UnitWord #include <cub/util_type.cuh> // for UnitWord
#include <sstream>
#include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "../collective/communicator-inl.h"
#include "common.h" #include "common.h"
#include "device_vector.cuh" #include "device_vector.cuh"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
@ -375,19 +372,24 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
} }
template <class HContainer, class DContainer> template <class Src, class Dst>
void CopyToD(HContainer const &h, DContainer *d) { void CopyTo(Src const &src, Dst *dst) {
if (h.empty()) { if (src.empty()) {
d->clear(); dst->clear();
return; return;
} }
d->resize(h.size()); dst->resize(src.size());
using HVT = std::remove_cv_t<typename HContainer::value_type>; using SVT = std::remove_cv_t<typename Src::value_type>;
using DVT = std::remove_cv_t<typename DContainer::value_type>; using DVT = std::remove_cv_t<typename Dst::value_type>;
static_assert(std::is_same<HVT, DVT>::value, static_assert(std::is_same<SVT, DVT>::value,
"Host and device containers must have same value type."); "Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT), dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
cudaMemcpyHostToDevice)); src.size() * sizeof(SVT), cudaMemcpyDefault));
}
template <class HContainer, class DContainer>
void CopyToD(HContainer const &h, DContainer *d) {
CopyTo(h, d);
} }
// Keep track of pinned memory allocation // Keep track of pinned memory allocation

View File

@ -307,6 +307,7 @@ class DeviceUVector {
public: public:
DeviceUVector() = default; DeviceUVector() = default;
explicit DeviceUVector(std::size_t n) { this->resize(n); }
DeviceUVector(DeviceUVector const &that) = delete; DeviceUVector(DeviceUVector const &that) = delete;
DeviceUVector &operator=(DeviceUVector const &that) = delete; DeviceUVector &operator=(DeviceUVector const &that) = delete;
DeviceUVector(DeviceUVector &&that) = default; DeviceUVector(DeviceUVector &&that) = default;
@ -330,7 +331,17 @@ class DeviceUVector {
data_.resize(n, v); data_.resize(n, v);
#endif #endif
} }
void clear() { // NOLINT
#if defined(XGBOOST_USE_RMM)
this->data_.resize(0, rmm::cuda_stream_per_thread);
#else
this->data_.clear();
#endif // defined(XGBOOST_USE_RMM)
}
[[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT [[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT
[[nodiscard]] bool empty() const { return this->size() == 0; } // NOLINT
[[nodiscard]] auto begin() { return data_.begin(); } // NOLINT [[nodiscard]] auto begin() { return data_.begin(); } // NOLINT
[[nodiscard]] auto end() { return data_.end(); } // NOLINT [[nodiscard]] auto end() { return data_.end(); } // NOLINT

View File

@ -20,6 +20,7 @@
#include "column_matrix.h" #include "column_matrix.h"
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "../tree/sample_position.h" // for SamplePosition
namespace xgboost::common { namespace xgboost::common {
// The builder is required for samples partition to left and rights children for set of nodes // The builder is required for samples partition to left and rights children for set of nodes
@ -364,13 +365,14 @@ class PartitionBuilder {
} }
// Copy row partitions into global cache for reuse in objective // Copy row partitions into global cache for reuse in objective
template <typename Sampledp> template <typename Invalidp>
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set, void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
std::vector<bst_node_t>* p_position, Sampledp sampledp) const { std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
auto& h_pos = *p_position; auto& h_pos = *p_position;
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max()); h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
auto p_begin = row_set.Data()->data(); auto p_begin = row_set.Data()->data();
// For each node, walk through all the samples that fall in this node.
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) { ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
auto const& node = row_set[i]; auto const& node = row_set[i];
if (node.node_id < 0) { if (node.node_id < 0) {
@ -381,7 +383,7 @@ class PartitionBuilder {
size_t ptr_offset = node.end() - p_begin; size_t ptr_offset = node.end() - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin(); idx != node.end(); ++idx) { for (auto idx = node.begin(); idx != node.end(); ++idx) {
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id; h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
} }
} }
}); });

View File

@ -14,6 +14,7 @@
#include "../collective/allgather.h" #include "../collective/allgather.h"
#include "../collective/allreduce.h" #include "../collective/allreduce.h"
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
#include "categorical.h" #include "categorical.h"
#include "common.h" #include "common.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"

View File

@ -6,6 +6,8 @@
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <limits> // for numeric_limits
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/compressed_iterator.h" #include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
@ -21,22 +23,26 @@ namespace xgboost {
* Does not own underlying memory and may be trivially copied into kernels. * Does not own underlying memory and may be trivially copied into kernels.
*/ */
struct EllpackDeviceAccessor { struct EllpackDeviceAccessor {
/*! \brief Whether or not if the matrix is dense. */ /** @brief Whether or not if the matrix is dense. */
bool is_dense; bool is_dense;
/*! \brief Row length for ELLPACK, equal to number of features. */ /** @brief Row length for ELLPACK, equal to number of features when the data is dense. */
bst_idx_t row_stride; bst_idx_t row_stride;
bst_idx_t base_rowid{0}; /** @brief Starting index of the rows. Used for external memory. */
bst_idx_t n_rows{0}; bst_idx_t base_rowid;
/** @brief Number of rows in this batch. */
bst_idx_t n_rows;
/** @brief Acessor for the gradient index. */
common::CompressedIterator<std::uint32_t> gidx_iter; common::CompressedIterator<std::uint32_t> gidx_iter;
/*! \brief Minimum value for each feature. Size equals to number of features. */ /** @brief Minimum value for each feature. Size equals to number of features. */
common::Span<const float> min_fvalue; common::Span<const float> min_fvalue;
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ /** @brief Histogram cut pointers. Size equals to (number of features + 1). */
common::Span<const std::uint32_t> feature_segments; common::Span<const std::uint32_t> feature_segments;
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ /** @brief Histogram cut values. Size equals to (bins per feature * number of features). */
common::Span<const float> gidx_fvalue_map; common::Span<const float> gidx_fvalue_map;
/** @brief Type of each feature, categorical or numerical. */
common::Span<const FeatureType> feature_types; common::Span<const FeatureType> feature_types;
EllpackDeviceAccessor() = delete;
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts, EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts,
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows, bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
common::CompressedIterator<uint32_t> gidx_iter, common::CompressedIterator<uint32_t> gidx_iter,
@ -108,10 +114,10 @@ struct EllpackDeviceAccessor {
return idx; return idx;
} }
[[nodiscard]] __device__ bst_float GetFvalue(size_t ridx, size_t fidx) const { [[nodiscard]] __device__ float GetFvalue(size_t ridx, size_t fidx) const {
auto gidx = GetBinIndex(ridx, fidx); auto gidx = GetBinIndex(ridx, fidx);
if (gidx == -1) { if (gidx == -1) {
return nan(""); return std::numeric_limits<float>::quiet_NaN();
} }
return gidx_fvalue_map[gidx]; return gidx_fvalue_map[gidx];
} }

View File

@ -10,11 +10,11 @@
#include <vector> // std::vector #include <vector> // std::vector
#include "../common/algorithm.h" // ArgSort #include "../common/algorithm.h" // ArgSort
#include "../common/common.h" // AssertGPUSupport
#include "../common/numeric.h" // RunLengthEncode #include "../common/numeric.h" // RunLengthEncode
#include "../common/stats.h" // Quantile,WeightedQuantile #include "../common/stats.h" // Quantile,WeightedQuantile
#include "../common/threading_utils.h" // ParallelFor #include "../common/threading_utils.h" // ParallelFor
#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "../tree/sample_position.h" // for SamplePosition
#include "xgboost/base.h" // bst_node_t #include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo #include "xgboost/data.h" // MetaInfo
@ -23,6 +23,10 @@
#include "xgboost/span.h" // Span #include "xgboost/span.h" // Span
#include "xgboost/tree_model.h" // RegTree #include "xgboost/tree_model.h" // RegTree
#if !defined(XGBOOST_USE_CUDA)
#include "../common/common.h" // AssertGPUSupport
#endif // !defined(XGBOOST_USE_CUDA)
namespace xgboost::obj::detail { namespace xgboost::obj::detail {
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree, void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr, std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
@ -37,9 +41,10 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
sorted_pos[i] = position[ridx[i]]; sorted_pos[i] = position[ridx[i]];
} }
// find the first non-sampled row // find the first non-sampled row
size_t begin_pos = size_t begin_pos = std::distance(
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(), sorted_pos.cbegin(),
[](bst_node_t nidx) { return nidx >= 0; })); std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
[](bst_node_t nidx) { return tree::SamplePosition::IsValid(nidx); }));
CHECK_LE(begin_pos, sorted_pos.size()); CHECK_LE(begin_pos, sorted_pos.size());
std::vector<bst_node_t> leaf; std::vector<bst_node_t> leaf;

View File

@ -10,6 +10,7 @@
#include "../common/cuda_context.cuh" // CUDAContext #include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/stats.cuh" #include "../common/stats.cuh"
#include "../tree/sample_position.h" // for SamplePosition
#include "adaptive.h" #include "adaptive.h"
#include "xgboost/context.h" #include "xgboost/context.h"
@ -30,9 +31,11 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
// sort row index according to node index // sort row index according to node index
thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(), thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
sorted_position.begin() + n_samples, p_ridx->begin()); sorted_position.begin() + n_samples, p_ridx->begin());
size_t beg_pos = // Find the first one that's not sampled (nidx not been negated).
thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(), size_t beg_pos = thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) - [] XGBOOST_DEVICE(bst_node_t nidx) {
return tree::SamplePosition::IsValid(nidx);
}) -
sorted_position.cbegin(); sorted_position.cbegin();
if (beg_pos == sorted_position.size()) { if (beg_pos == sorted_position.size()) {
auto& leaf = p_nidx->HostVector(); auto& leaf = p_nidx->HostVector();

View File

@ -1,13 +1,12 @@
/** /**
* Copyright 2020-2024, XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
*/ */
#include <algorithm> // std::max #include <algorithm> // for :max
#include <vector> #include <limits> // for numeric_limits
#include <limits>
#include "../../collective/allgather.h" #include "../../collective/allgather.h"
#include "../../collective/communicator-inl.h" // for GetWorldSize, GetRank
#include "../../common/categorical.h" #include "../../common/categorical.h"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh" #include "evaluate_splits.cuh"
#include "expand_entry.cuh" #include "expand_entry.cuh"

View File

@ -15,6 +15,7 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
ridx_.resize(n_samples); ridx_.resize(n_samples);
ridx_tmp_.resize(n_samples); ridx_tmp_.resize(n_samples);
tmp_.clear(); tmp_.clear();
n_nodes_ = 1; // Root
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max()); CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
ridx_segments_.emplace_back( ridx_segments_.emplace_back(

View File

@ -19,7 +19,9 @@
namespace xgboost::tree { namespace xgboost::tree {
namespace cuda_impl { namespace cuda_impl {
using RowIndexT = std::uint32_t; using RowIndexT = std::uint32_t;
} // TODO(Rory): Can be larger. To be tuned alongside other batch operations.
static const std::int32_t kMaxUpdatePositionBatchSize = 32;
} // namespace cuda_impl
/** /**
* @brief Used to demarcate a contiguous set of row indices associated with some tree * @brief Used to demarcate a contiguous set of row indices associated with some tree
@ -37,8 +39,6 @@ struct Segment {
__host__ __device__ bst_idx_t Size() const { return end - begin; } __host__ __device__ bst_idx_t Size() const { return end - begin; }
}; };
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
static const int kMaxUpdatePositionBatchSize = 32;
template <typename OpDataT> template <typename OpDataT>
struct PerNodeData { struct PerNodeData {
Segment segment; Segment segment;
@ -46,10 +46,10 @@ struct PerNodeData {
}; };
template <typename BatchIterT> template <typename BatchIterT>
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx, XGBOOST_DEV_INLINE void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
int* batch_idx, std::size_t* item_idx) { int* batch_idx, std::size_t* item_idx) {
cuda_impl::RowIndexT sum = 0; cuda_impl::RowIndexT sum = 0;
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) { for (int i = 0; i < cuda_impl::kMaxUpdatePositionBatchSize; i++) {
if (sum + batch_info[i].segment.Size() > global_thread_idx) { if (sum + batch_info[i].segment.Size() > global_thread_idx) {
*batch_idx = i; *batch_idx = i;
*item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin; *item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin;
@ -59,10 +59,10 @@ __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t g
} }
} }
template <int kBlockSize, typename RowIndexT, typename OpDataT> template <int kBlockSize, typename OpDataT>
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel( __global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<RowIndexT> d_ridx, dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<cuda_impl::RowIndexT> d_ridx,
const common::Span<const RowIndexT> ridx_tmp, std::size_t total_rows) { const common::Span<const cuda_impl::RowIndexT> ridx_tmp, bst_idx_t total_rows) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) { for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
int batch_idx; int batch_idx;
std::size_t item_idx; std::size_t item_idx;
@ -92,6 +92,7 @@ struct IndexFlagOp {
} }
}; };
// Scatter from `ridx_in` to `ridx_out`.
template <typename OpDataT> template <typename OpDataT>
struct WriteResultsFunctor { struct WriteResultsFunctor {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info; dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
@ -99,10 +100,12 @@ struct WriteResultsFunctor {
cuda_impl::RowIndexT* ridx_out; cuda_impl::RowIndexT* ridx_out;
cuda_impl::RowIndexT* counts; cuda_impl::RowIndexT* counts;
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { __device__ IndexFlagTuple operator()(IndexFlagTuple const& x) {
std::size_t scatter_address; cuda_impl::RowIndexT scatter_address;
// Get the segment that this row belongs to.
const Segment& segment = batch_info[x.batch_idx].segment; const Segment& segment = batch_info[x.batch_idx].segment;
if (x.flag) { if (x.flag) {
// Go left.
cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
scatter_address = segment.begin + num_previous_flagged; scatter_address = segment.begin + num_previous_flagged;
} else { } else {
@ -121,10 +124,14 @@ struct WriteResultsFunctor {
} }
}; };
template <typename RowIndexT, typename OpT, typename OpDataT> /**
* @param d_batch_info Node data, with the size of the input number of nodes.
*/
template <typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info, void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp, common::Span<cuda_impl::RowIndexT> ridx,
common::Span<cuda_impl::RowIndexT> d_counts, std::size_t total_rows, OpT op, common::Span<cuda_impl::RowIndexT> ridx_tmp,
common::Span<cuda_impl::RowIndexT> d_counts, bst_idx_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp) { dh::device_vector<int8_t>* tmp) {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data()); dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
@ -134,22 +141,23 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results); thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu); auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = auto input_iterator =
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(size_t idx) { dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(std::size_t idx) {
int batch_idx; int nidx_in_batch;
std::size_t item_idx; std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); AssignBatch(batch_info_itr, idx, &nidx_in_batch, &item_idx);
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data); auto go_left = op(ridx[item_idx], nidx_in_batch, batch_info_itr[nidx_in_batch].data);
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), op_res, batch_idx, op_res}; return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
go_left};
}); });
size_t temp_bytes = 0; std::size_t temp_bytes = 0;
if (tmp->empty()) { if (tmp->empty()) {
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator, cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows); IndexFlagOp{}, total_rows);
tmp->resize(temp_bytes); tmp->resize(temp_bytes);
} }
temp_bytes = tmp->size(); temp_bytes = tmp->size();
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator, cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows); discard_write_iterator, IndexFlagOp{}, total_rows);
constexpr int kBlockSize = 256; constexpr int kBlockSize = 256;
@ -157,7 +165,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
const int kItemsThread = 12; const int kItemsThread = 12;
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT> SortPositionCopyKernel<kBlockSize, OpDataT>
<<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows); <<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
} }
@ -168,7 +176,7 @@ struct NodePositionInfo {
__device__ bool IsLeaf() { return left_child == -1; } __device__ bool IsLeaf() { return left_child == -1; }
}; };
__device__ __forceinline__ int GetPositionFromSegments(std::size_t idx, XGBOOST_DEV_INLINE int GetPositionFromSegments(std::size_t idx,
const NodePositionInfo* d_node_info) { const NodePositionInfo* d_node_info) {
int position = 0; int position = 0;
NodePositionInfo node = d_node_info[position]; NodePositionInfo node = d_node_info[position];
@ -205,7 +213,6 @@ __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
class RowPartitioner { class RowPartitioner {
public: public:
using RowIndexT = cuda_impl::RowIndexT; using RowIndexT = cuda_impl::RowIndexT;
static constexpr bst_node_t kIgnoredTreePosition = -1;
private: private:
/** /**
@ -232,6 +239,7 @@ class RowPartitioner {
dh::device_vector<int8_t> tmp_; dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_; dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_; dh::PinnedMemory pinned2_;
bst_node_t n_nodes_{0}; // Counter for internal checks.
public: public:
/** /**
@ -255,6 +263,7 @@ class RowPartitioner {
* \brief Gets all training rows in the set. * \brief Gets all training rows in the set.
*/ */
common::Span<const RowIndexT> GetRows(); common::Span<const RowIndexT> GetRows();
[[nodiscard]] bst_node_t GetNumNodes() const { return n_nodes_; }
/** /**
* \brief Convenience method for testing * \brief Convenience method for testing
@ -280,10 +289,14 @@ class RowPartitioner {
const std::vector<bst_node_t>& left_nidx, const std::vector<bst_node_t>& left_nidx,
const std::vector<bst_node_t>& right_nidx, const std::vector<bst_node_t>& right_nidx,
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) { const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
if (nidx.empty()) return; if (nidx.empty()) {
return;
}
CHECK_EQ(nidx.size(), left_nidx.size()); CHECK_EQ(nidx.size(), left_nidx.size());
CHECK_EQ(nidx.size(), right_nidx.size()); CHECK_EQ(nidx.size(), right_nidx.size());
CHECK_EQ(nidx.size(), op_data.size()); CHECK_EQ(nidx.size(), op_data.size());
this->n_nodes_ += (left_nidx.size() + right_nidx.size());
auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size()); auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size()); dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());
@ -302,8 +315,8 @@ class RowPartitioner {
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0); dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
// Partition the rows according to the operator // Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>( SortPositionBatch<UpdatePositionOpT, OpDataT>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx_),
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_); total_rows, op, &tmp_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault)); cudaMemcpyDefault));
@ -327,20 +340,16 @@ class RowPartitioner {
} }
/** /**
* \brief Finalise the position of all training instances after tree construction is * @brief Finalise the position of all training instances after tree construction is
* complete. Does not update any other meta information in this data structure, so * complete. Does not update any other meta information in this data structure, so
* should only be used at the end of training. * should only be used at the end of training.
* *
* When the task requires update leaf, this function will copy the node index into * @param p_out_position Node index for each row.
* p_out_position. The index is negated if it's being sampled in current iteration. * @param op Device lambda. Should provide the row index and current position as an
*
* \param p_out_position Node index for each row.
* \param op Device lambda. Should provide the row index and current position as an
* argument and return the new position for this training instance. * argument and return the new position for this training instance.
* \param sampled A device lambda to inform the partitioner whether a row is sampled.
*/ */
template <typename FinalisePositionOpT> template <typename FinalisePositionOpT>
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) { void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) const {
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size()); dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
sizeof(NodePositionInfo) * ridx_segments_.size(), sizeof(NodePositionInfo) * ridx_segments_.size(),

View File

@ -10,14 +10,11 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../common/categorical.h"
#include "../common/linalg_op.h" #include "../common/linalg_op.h"
#include "../common/math.h" #include "../common/math.h"
#include "xgboost/data.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"

View File

@ -0,0 +1,21 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#pragma once
#include "xgboost/base.h" // for bst_node_t
namespace xgboost::tree {
// Utility for maniputing the node index. This is used by the tree methods and the
// adaptive objectives to share the node index. A row is invalid if it's not used in the
// last iteration (due to sampling). For these rows, the corresponding tree node index is
// negated.
struct SamplePosition {
[[nodiscard]] bst_node_t static XGBOOST_HOST_DEV_INLINE Encode(bst_node_t nidx, bool is_valid) {
return is_valid ? nidx : ~nidx;
}
[[nodiscard]] bst_node_t static XGBOOST_HOST_DEV_INLINE Decode(bst_node_t nidx) {
return IsValid(nidx) ? nidx : ~nidx;
}
[[nodiscard]] bool static XGBOOST_HOST_DEV_INLINE IsValid(bst_node_t nidx) { return nidx >= 0; }
};
} // namespace xgboost::tree

View File

@ -8,14 +8,13 @@
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/tree_model.h> #include <xgboost/tree_model.h>
#include <array> // for array
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include "../common/categorical.h" #include "../common/categorical.h" // for GetNodeCats
#include "../common/common.h" // for EscapeU8 #include "../common/common.h" // for EscapeU8
#include "../predictor/predict_fn.h" #include "../predictor/predict_fn.h"
#include "io_utils.h" // for GetElem #include "io_utils.h" // for GetElem
@ -1038,9 +1037,8 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
categories_nodes.GetArray().emplace_back(i); categories_nodes.GetArray().emplace_back(i);
auto begin = categories.Size(); auto begin = categories.Size();
categories_segments.GetArray().emplace_back(begin); categories_segments.GetArray().emplace_back(begin);
auto segment = split_categories_segments_[i]; auto segment = this->split_categories_segments_[i];
auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size); auto cat_bits = common::GetNodeCats(this->GetSplitCategories(), segment);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Capacity(); ++i) { for (size_t i = 0; i < cat_bits.Capacity(); ++i) {
if (cat_bits.Check(i)) { if (cat_bits.Check(i)) {
categories.GetArray().emplace_back(i); categories.GetArray().emplace_back(i);

View File

@ -10,6 +10,7 @@
#include "../common/error_msg.h" // for NoCategorical #include "../common/error_msg.h" // for NoCategorical
#include "../common/random.h" #include "../common/random.h"
#include "sample_position.h" // for SamplePosition
#include "constraints.h" #include "constraints.h"
#include "param.h" #include "param.h"
#include "split_evaluator.h" #include "split_evaluator.h"
@ -515,7 +516,7 @@ class ColMaker: public TreeUpdater {
common::ParallelFor(p_fmat->Info().num_row_, this->ctx_->Threads(), [&](auto ridx) { common::ParallelFor(p_fmat->Info().num_row_, this->ctx_->Threads(), [&](auto ridx) {
CHECK_LT(ridx, position_.size()) << "ridx exceed bound " CHECK_LT(ridx, position_.size()) << "ridx exceed bound "
<< "ridx=" << ridx << " pos=" << position_.size(); << "ridx=" << ridx << " pos=" << position_.size();
const int nid = this->DecodePosition(ridx); const int nid = SamplePosition::Decode(position_[ridx]);
if (tree[nid].IsLeaf()) { if (tree[nid].IsLeaf()) {
// mark finish when it is not a fresh leaf // mark finish when it is not a fresh leaf
if (tree[nid].RightChild() == -1) { if (tree[nid].RightChild() == -1) {
@ -560,14 +561,14 @@ class ColMaker: public TreeUpdater {
auto col = page[fid]; auto col = page[fid];
common::ParallelFor(col.size(), this->ctx_->Threads(), [&](auto j) { common::ParallelFor(col.size(), this->ctx_->Threads(), [&](auto j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const int nid = this->DecodePosition(ridx); bst_node_t nidx = SamplePosition::Decode(position_[ridx]);
const bst_float fvalue = col[j].fvalue; const bst_float fvalue = col[j].fvalue;
// go back to parent, correct those who are not default // go back to parent, correct those who are not default
if (!tree[nid].IsLeaf() && tree[nid].SplitIndex() == fid) { if (!tree[nidx].IsLeaf() && tree[nidx].SplitIndex() == fid) {
if (fvalue < tree[nid].SplitCond()) { if (fvalue < tree[nidx].SplitCond()) {
this->SetEncodePosition(ridx, tree[nid].LeftChild()); this->SetEncodePosition(ridx, tree[nidx].LeftChild());
} else { } else {
this->SetEncodePosition(ridx, tree[nid].RightChild()); this->SetEncodePosition(ridx, tree[nidx].RightChild());
} }
} }
}); });
@ -576,17 +577,10 @@ class ColMaker: public TreeUpdater {
} }
// utils to get/set position, with encoded format // utils to get/set position, with encoded format
// return decoded position // return decoded position
inline int DecodePosition(bst_uint ridx) const {
const int pid = position_[ridx];
return pid < 0 ? ~pid : pid;
}
// encode the encoded position value for ridx // encode the encoded position value for ridx
inline void SetEncodePosition(bst_uint ridx, int nid) { void SetEncodePosition(bst_idx_t ridx, bst_node_t nidx) {
if (position_[ridx] < 0) { bool is_invalid = position_[ridx] < 0;
position_[ridx] = ~nid; position_[ridx] = SamplePosition::Encode(nidx, !is_invalid);
} else {
position_[ridx] = nid;
}
} }
// --data fields-- // --data fields--
const TrainParam& param_; const TrainParam& param_;

View File

@ -6,8 +6,9 @@
#include <ostream> // for ostream #include <ostream> // for ostream
#include "gpu_hist/histogram.cuh" #include "gpu_hist/histogram.cuh"
#include "param.h" #include "param.h" // for TrainParam
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/task.h" // for ObjInfo
namespace xgboost::tree { namespace xgboost::tree {
struct GPUTrainingParam { struct GPUTrainingParam {
@ -117,6 +118,21 @@ struct DeviceSplitCandidate {
} }
}; };
namespace cuda_impl {
inline BatchParam HistBatch(TrainParam const& param) {
return {param.max_bin, TrainParam::DftSparseThreshold()};
}
inline BatchParam HistBatch(bst_bin_t max_bin) {
return {max_bin, TrainParam::DftSparseThreshold()};
}
inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess,
ObjInfo const& task) {
return BatchParam{p.max_bin, hess, !task.const_hess};
}
} // namespace cuda_impl
template <typename T> template <typename T>
struct SumCallbackOp { struct SumCallbackOp {
// Running prefix // Running prefix

View File

@ -34,7 +34,8 @@
#include "gpu_hist/row_partitioner.cuh" #include "gpu_hist/row_partitioner.cuh"
#include "hist/param.h" #include "hist/param.h"
#include "param.h" #include "param.h"
#include "updater_gpu_common.cuh" #include "sample_position.h" // for SamplePosition
#include "updater_gpu_common.cuh" // for HistBatch
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/data.h" #include "xgboost/data.h"
@ -43,11 +44,15 @@
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/task.h" // for ObjInfo #include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"
namespace xgboost::tree { namespace xgboost::tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
// Manage memory for a single GPU using cuda_impl::ApproxBatch;
using cuda_impl::HistBatch;
// GPU tree updater implementation.
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
private: private:
GPUHistEvaluator evaluator_; GPUHistEvaluator evaluator_;
@ -56,20 +61,29 @@ struct GPUHistMakerDevice {
MetaInfo const& info_; MetaInfo const& info_;
DeviceHistogramBuilder histogram_; DeviceHistogramBuilder histogram_;
// node idx for each sample
dh::device_vector<bst_node_t> positions_;
std::unique_ptr<RowPartitioner> row_partitioner_;
public:
// Extra data for each node that is passed to the update position function
struct NodeSplitData {
RegTree::Node split_node;
FeatureType split_type;
common::KCatBitField node_cats;
};
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
public: public:
EllpackPageImpl const* page{nullptr}; EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogramStorage<> hist{}; DeviceHistogramStorage<> hist{};
dh::device_vector<GradientPair> d_gpair; // storage for gpair; dh::device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair; common::Span<GradientPair> gpair;
dh::device_vector<int> monotone_constraints; dh::device_vector<int> monotone_constraints;
// node idx for each sample
dh::device_vector<bst_node_t> positions;
TrainParam param; TrainParam param;
@ -143,10 +157,10 @@ struct GPUHistMakerDevice {
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info()); quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
if (!row_partitioner) { if (!row_partitioner_) {
row_partitioner = std::make_unique<RowPartitioner>(); row_partitioner_ = std::make_unique<RowPartitioner>();
} }
row_partitioner->Reset(ctx_, sample.sample_rows, page->base_rowid); row_partitioner_->Reset(ctx_, sample.sample_rows, page->base_rowid);
CHECK_EQ(page->base_rowid, 0); CHECK_EQ(page->base_rowid, 0);
// Init histogram // Init histogram
@ -182,7 +196,10 @@ struct GPUHistMakerDevice {
void EvaluateSplits(const std::vector<GPUExpandEntry>& candidates, const RegTree& tree, void EvaluateSplits(const std::vector<GPUExpandEntry>& candidates, const RegTree& tree,
common::Span<GPUExpandEntry> pinned_candidates_out) { common::Span<GPUExpandEntry> pinned_candidates_out) {
if (candidates.empty()) return; if (candidates.empty()) {
return;
}
this->monitor.Start(__func__);
dh::TemporaryArray<EvaluateSplitInputs> d_node_inputs(2 * candidates.size()); dh::TemporaryArray<EvaluateSplitInputs> d_node_inputs(2 * candidates.size());
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size()); dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size());
std::vector<bst_node_t> nidx(2 * candidates.size()); std::vector<bst_node_t> nidx(2 * candidates.size());
@ -234,12 +251,12 @@ struct GPUHistMakerDevice {
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(), entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
dh::DefaultStream().Sync(); this->monitor.Stop(__func__);
} }
void BuildHist(int nidx) { void BuildHist(int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx); auto d_ridx = row_partitioner_->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
d_node_hist, *quantiser); d_node_hist, *quantiser);
@ -262,14 +279,6 @@ struct GPUHistMakerDevice {
return true; return true;
} }
// Extra data for each node that is passed
// to the update position function
struct NodeSplitData {
RegTree::Node split_node;
FeatureType split_type;
common::KCatBitField node_cats;
};
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix, void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
std::vector<NodeSplitData> const& split_data, std::vector<NodeSplitData> const& split_data,
std::vector<bst_node_t> const& nidx, std::vector<bst_node_t> const& nidx,
@ -321,10 +330,10 @@ struct GPUHistMakerDevice {
}; };
collective::SafeColl(rc); collective::SafeColl(rc);
row_partitioner->UpdatePositionBatch( row_partitioner_->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data, nidx, left_nidx, right_nidx, split_data,
[=] __device__(bst_uint ridx, int split_index, NodeSplitData const& data) { [=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
auto const index = ridx * num_candidates + split_index; auto const index = ridx * num_candidates + nidx_in_batch;
bool go_left; bool go_left;
if (missing_bits.Check(index)) { if (missing_bits.Check(index)) {
go_left = data.split_node.DefaultLeft(); go_left = data.split_node.DefaultLeft();
@ -335,11 +344,35 @@ struct GPUHistMakerDevice {
}); });
} }
struct GoLeftOp {
EllpackDeviceAccessor d_matrix;
__device__ bool operator()(cuda_impl::RowIndexT ridx, NodeSplitData const& data) const {
RegTree::Node const& node = data.split_node;
// given a row index, returns the node id it belongs to
float cut_value = d_matrix.GetFvalue(ridx, node.SplitIndex());
// Missing value
bool go_left = true;
if (isnan(cut_value)) {
go_left = node.DefaultLeft();
} else {
if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision(data.node_cats.Bits(), cut_value);
} else {
go_left = cut_value <= node.SplitCond();
}
}
return go_left;
}
};
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) { void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
if (candidates.empty()) { if (candidates.empty()) {
return; return;
} }
monitor.Start(__func__);
std::vector<bst_node_t> nidx(candidates.size()); std::vector<bst_node_t> nidx(candidates.size());
std::vector<bst_node_t> left_nidx(candidates.size()); std::vector<bst_node_t> left_nidx(candidates.size());
std::vector<bst_node_t> right_nidx(candidates.size()); std::vector<bst_node_t> right_nidx(candidates.size());
@ -347,12 +380,12 @@ struct GPUHistMakerDevice {
for (size_t i = 0; i < candidates.size(); i++) { for (size_t i = 0; i < candidates.size(); i++) {
auto const& e = candidates[i]; auto const& e = candidates[i];
RegTree::Node split_node = (*p_tree)[e.nid]; RegTree::Node const& split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(e.nid); auto split_type = p_tree->NodeSplitType(e.nid);
nidx.at(i) = e.nid; nidx[i] = e.nid;
left_nidx.at(i) = split_node.LeftChild(); left_nidx[i] = split_node.LeftChild();
right_nidx.at(i) = split_node.RightChild(); right_nidx[i] = split_node.RightChild();
split_data.at(i) = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)}; split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
} }
@ -361,27 +394,15 @@ struct GPUHistMakerDevice {
if (info_.IsColumnSplit()) { if (info_.IsColumnSplit()) {
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
monitor.Stop(__func__);
return; return;
} }
auto go_left = GoLeftOp{d_matrix};
row_partitioner->UpdatePositionBatch( row_partitioner_->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data, nidx, left_nidx, right_nidx, split_data,
[=] __device__(bst_uint ridx, int split_index, const NodeSplitData& data) { [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
// given a row index, returns the node id it belongs to const NodeSplitData& data) { return go_left(ridx, data); });
float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); monitor.Stop(__func__);
// Missing value
bool go_left = true;
if (isnan(cut_value)) {
go_left = data.split_node.DefaultLeft();
} else {
if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision(data.node_cats.Bits(), cut_value);
} else {
go_left = cut_value <= data.split_node.SplitCond();
}
}
return go_left;
});
} }
// After tree update is finished, update the position of all training // After tree update is finished, update the position of all training
@ -389,101 +410,70 @@ struct GPUHistMakerDevice {
// prediction cache // prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task, void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
// Prediction cache will not be used with external memory if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
if (!p_fmat->SingleColBlock()) {
if (task.UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective function can not be used with external memory."; LOG(FATAL) << "Current objective function can not be used with external memory.";
} }
if (p_fmat->Info().num_row_ != row_partitioner_->GetRows().size()) {
// Subsampling with external memory. Not supported.
p_out_position->Resize(0); p_out_position->Resize(0);
positions.clear(); positions_.clear();
return; return;
} }
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
auto const& h_split_types = p_tree->GetSplitTypes();
auto const& categories = p_tree->GetSplitCategories();
auto const& categories_segments = p_tree->GetSplitCategoriesPtr();
dh::caching_device_vector<FeatureType> d_split_types;
dh::caching_device_vector<uint32_t> d_categories;
dh::caching_device_vector<RegTree::CategoricalSplitMatrix::Segment> d_categories_segments;
if (!categories.empty()) {
dh::CopyToD(h_split_types, &d_split_types);
dh::CopyToD(categories, &d_categories);
dh::CopyToD(categories_segments, &d_categories_segments);
}
FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments),
p_out_position);
}
void FinalisePositionInPage(
EllpackPageImpl const* page, const common::Span<RegTree::Node> d_nodes,
common::Span<FeatureType const> d_feature_types, common::Span<uint32_t const> categories,
common::Span<RegTree::CategoricalSplitMatrix::Segment> categories_segments,
HostDeviceVector<bst_node_t>* p_out_position) {
auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
auto d_gpair = this->gpair;
p_out_position->SetDevice(ctx_->Device()); p_out_position->SetDevice(ctx_->Device());
p_out_position->Resize(row_partitioner->GetRows().size()); p_out_position->Resize(row_partitioner_->GetRows().size());
auto d_out_position = p_out_position->DeviceSpan();
auto new_position_op = [=] __device__(size_t row_id, int position) { auto d_gpair = this->gpair;
// What happens if user prune the tree? auto encode_op = [=] __device__(bst_idx_t row_id, bst_node_t nidx) {
if (!d_matrix.IsInRange(row_id)) { bool is_invalid = d_gpair[row_id].GetHess() - .0f == 0.f;
return RowPartitioner::kIgnoredTreePosition; return SamplePosition::Encode(nidx, !is_invalid);
}
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
bool go_left = true;
if (common::IsCat(d_feature_types, position)) {
auto node_cats = categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, element);
} else {
go_left = element <= node.SplitCond();
}
if (go_left) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
return position;
}; // NOLINT }; // NOLINT
auto d_out_position = p_out_position->DeviceSpan(); if (!p_fmat->SingleColBlock()) {
row_partitioner->FinalisePosition(d_out_position, new_position_op); CHECK_EQ(row_partitioner_->GetNumNodes(), p_tree->NumNodes());
row_partitioner_->FinalisePosition(d_out_position, encode_op);
dh::CopyTo(d_out_position, &positions_);
return;
}
auto s_position = p_out_position->ConstDeviceSpan(); dh::caching_device_vector<uint32_t> categories;
positions.resize(s_position.size()); dh::CopyToD(p_tree->GetSplitCategories(), &categories);
dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(), auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
s_position.size_bytes(), cudaMemcpyDeviceToDevice, auto d_categories = dh::ToSpan(categories);
ctx_->CUDACtx()->Stream()));
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) { auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
bst_node_t position = d_out_position[idx];
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f; std::vector<NodeSplitData> split_data(p_tree->NumNodes());
d_out_position[idx] = is_row_sampled ? ~position : position; auto const& tree = *p_tree;
for (std::size_t i = 0, n = split_data.size(); i < n; ++i) {
RegTree::Node split_node = tree[i];
auto split_type = p_tree->NodeSplitType(i);
auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]);
split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats};
}
auto go_left_op = GoLeftOp{d_matrix};
dh::caching_device_vector<NodeSplitData> d_split_data;
dh::CopyToD(split_data, &d_split_data);
auto s_split_data = dh::ToSpan(d_split_data);
row_partitioner_->FinalisePosition(d_out_position,
[=] __device__(bst_idx_t row_id, bst_node_t nidx) {
auto split_data = s_split_data[nidx];
auto node = split_data.split_node;
while (!node.IsLeaf()) {
auto go_left = go_left_op(row_id, split_data);
nidx = go_left ? node.LeftChild() : node.RightChild();
node = s_split_data[nidx].split_node;
}
return encode_op(row_id, nidx);
}); });
dh::CopyTo(d_out_position, &positions_);
} }
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) { bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
if (positions.empty()) { if (positions_.empty()) {
return false; return false;
} }
@ -491,20 +481,19 @@ struct GPUHistMakerDevice {
CHECK(out_preds_d.Device().IsCUDA()); CHECK(out_preds_d.Device().IsCUDA());
CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal()); CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal());
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); auto d_position = dh::ToSpan(positions_);
auto d_position = dh::ToSpan(positions);
CHECK_EQ(out_preds_d.Size(), d_position.size()); CHECK_EQ(out_preds_d.Size(), d_position.size());
auto const& h_nodes = p_tree->GetNodes(); // Use the nodes from tree, the leaf value might be changed by the objective since the
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size()); // last update tree call.
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), dh::caching_device_vector<RegTree::Node> nodes;
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice, dh::CopyTo(p_tree->GetNodes(), &nodes);
ctx_->CUDACtx()->Stream())); common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes);
auto d_nodes = dh::ToSpan(nodes);
CHECK_EQ(out_preds_d.Shape(1), 1); CHECK_EQ(out_preds_d.Shape(1), 1);
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(), dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),
[=] XGBOOST_DEVICE(std::size_t idx) mutable { [=] XGBOOST_DEVICE(std::size_t idx) mutable {
bst_node_t nidx = d_position[idx]; bst_node_t nidx = d_position[idx];
nidx = SamplePosition::Decode(nidx);
auto weight = d_nodes[nidx].LeafValue(); auto weight = d_nodes[nidx].LeafValue();
out_preds_d(idx, 0) += weight; out_preds_d(idx, 0) += weight;
}); });
@ -512,7 +501,7 @@ struct GPUHistMakerDevice {
} }
// num histograms is the number of contiguous histograms in memory to reduce over // num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, int num_histograms) { void AllReduceHist(bst_node_t nidx, int num_histograms) {
monitor.Start("AllReduce"); monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data(); auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT; using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
@ -529,7 +518,10 @@ struct GPUHistMakerDevice {
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) { void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
if (candidates.empty()) return; if (candidates.empty()) {
return;
}
this->monitor.Start(__func__);
// Some nodes we will manually compute histograms // Some nodes we will manually compute histograms
// others we will do by subtraction // others we will do by subtraction
std::vector<int> hist_nidx; std::vector<int> hist_nidx;
@ -572,14 +564,15 @@ struct GPUHistMakerDevice {
this->AllReduceHist(subtraction_trick_nidx, 1); this->AllReduceHist(subtraction_trick_nidx, 1);
} }
} }
this->monitor.Stop(__func__);
} }
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree; RegTree& tree = *p_tree;
// Sanity check - have we created a leaf with no training instances? // Sanity check - have we created a leaf with no training instances?
if (!collective::IsDistributed() && row_partitioner) { if (!collective::IsDistributed() && row_partitioner_) {
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) CHECK(row_partitioner_->GetRows(candidate.nid).size() > 0)
<< "No training instances in this leaf!"; << "No training instances in this leaf!";
} }
@ -659,6 +652,8 @@ struct GPUHistMakerDevice {
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) { RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
bool const is_single_block = p_fmat->SingleColBlock();
auto& tree = *p_tree; auto& tree = *p_tree;
// Process maximum 32 nodes at a time // Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32); Driver<GPUExpandEntry> driver(param, 32);
@ -684,30 +679,29 @@ struct GPUHistMakerDevice {
[&](const auto& e) { return driver.IsChildValid(e); }); [&](const auto& e) { return driver.IsChildValid(e); });
auto new_candidates = auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry()); pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry{});
// Update all the nodes if working with external memory, this saves us from working
// with the finalize position call, which adds an additional iteration and requires
// special handling for row index.
this->UpdatePosition(is_single_block ? filtered_expand_set : expand_set, p_tree);
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(filtered_expand_set, p_tree);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, tree); this->BuildHistLeftRight(filtered_expand_set, tree);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates); this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates);
monitor.Stop("EvaluateSplits");
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
driver.Push(new_candidates.begin(), new_candidates.end()); driver.Push(new_candidates.begin(), new_candidates.end());
expand_set = driver.Pop(); expand_set = driver.Pop();
} }
// Row partitioner can have lesser nodes than the tree since we skip some leaf
monitor.Start("FinalisePosition"); // nodes. These nodes are handled in the `FinalisePosition` call. However, a leaf can
// be spliable before evaluation but invalid after evaluation as we have more
// restrictions like min loss change after evalaution. Therefore, the check condition
// is greater than or equal to.
if (is_single_block) {
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
}
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position); this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
monitor.Stop("FinalisePosition");
} }
}; };
@ -767,12 +761,11 @@ class GPUHistMaker : public TreeUpdater {
SafeColl(rc); SafeColl(rc);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed); this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
info_->feature_types.SetDevice(ctx_->Device()); info_->feature_types.SetDevice(ctx_->Device());
maker = std::make_unique<GPUHistMakerDevice>( maker = std::make_unique<GPUHistMakerDevice>(
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
*param, column_sampler_, info_->num_col_, batch_param, dmat->Info()); *param, column_sampler_, info_->num_col_, HistBatch(*param), dmat->Info());
p_last_fmat_ = dmat; p_last_fmat_ = dmat;
initialised_ = true; initialised_ = true;
@ -798,14 +791,13 @@ class GPUHistMaker : public TreeUpdater {
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
} }
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data, linalg::MatrixView<float> p_out_preds) override {
linalg::MatrixView<bst_float> p_out_preds) override {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false; return false;
} }
monitor_.Start("UpdatePredictionCache"); monitor_.Start(__func__);
bool result = maker->UpdatePredictionCache(p_out_preds, p_last_tree_); bool result = maker->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache"); monitor_.Stop(__func__);
return result; return result;
} }
@ -881,10 +873,9 @@ class GPUGlobalApproxMaker : public TreeUpdater {
auto const& info = p_fmat->Info(); auto const& info = p_fmat->Info();
info.feature_types.SetDevice(ctx_->Device()); info.feature_types.SetDevice(ctx_->Device());
auto batch = BatchParam{param->max_bin, hess, !task_->const_hess};
maker_ = std::make_unique<GPUHistMakerDevice>( maker_ = std::make_unique<GPUHistMakerDevice>(
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_, ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
*param, column_sampler_, info.num_col_, batch, p_fmat->Info()); *param, column_sampler_, info.num_col_, ApproxBatch(*param, hess, *task_), p_fmat->Info());
std::size_t t_idx{0}; std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
@ -927,14 +918,13 @@ class GPUGlobalApproxMaker : public TreeUpdater {
maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
} }
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data, linalg::MatrixView<float> p_out_preds) override {
linalg::MatrixView<bst_float> p_out_preds) override {
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false; return false;
} }
monitor_.Start("UpdatePredictionCache"); monitor_.Start(__func__);
bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_); bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache"); monitor_.Stop(__func__);
return result; return result;
} }

View File

@ -67,9 +67,9 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault, h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
nullptr)); nullptr));
dh::device_vector<int8_t> tmp; dh::device_vector<int8_t> tmp;
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), SortPositionBatch<decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op,
total_rows, op, &tmp); &tmp);
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; }; auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
for (size_t i = 0; i < segments.size(); i++) { for (size_t i = 0; i < segments.size(); i++) {

View File

@ -1,10 +1,11 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/json.h> // for Json #include <xgboost/json.h> // for Json
#include <xgboost/tree_model.h> // for RegTree #include <xgboost/tree_model.h> // for RegTree
#include "../../../../src/common/categorical.h" // for CatBitField
#include "../../../../src/tree/hist/expand_entry.h" #include "../../../../src/tree/hist/expand_entry.h"
namespace xgboost::tree { namespace xgboost::tree {

View File

@ -5,7 +5,7 @@
#include <xgboost/base.h> // for Args #include <xgboost/base.h> // for Args
#include <xgboost/context.h> // for Context #include <xgboost/context.h> // for Context
#include <xgboost/host_device_vector.h> // for HostDeviceVector #include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/json.h> // for Jons #include <xgboost/json.h> // for Json
#include <xgboost/task.h> // for ObjInfo #include <xgboost/task.h> // for ObjInfo
#include <xgboost/tree_model.h> // for RegTree #include <xgboost/tree_model.h> // for RegTree
#include <xgboost/tree_updater.h> // for TreeUpdater #include <xgboost/tree_updater.h> // for TreeUpdater
@ -15,31 +15,16 @@
#include <vector> // for vector #include <vector> // for vector
#include "../../../src/common/random.h" // for GlobalRandom #include "../../../src/common/random.h" // for GlobalRandom
#include "../../../src/data/ellpack_page.h" // for EllpackPage
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../collective/test_worker.h" // for BaseMGPUTest #include "../collective/test_worker.h" // for BaseMGPUTest
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" #include "../helpers.h"
namespace xgboost::tree { namespace xgboost::tree {
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, namespace {
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds, void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, bool is_ext,
float subsample = 1.0f, const std::string& sampling_method = "uniform", RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
int max_bin = 2) { const std::string& sampling_method, bst_bin_t max_bin) {
if (gpu_page_size > 0) {
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto& batch : dmat->GetBatches<EllpackPage>(
ctx, BatchParam{max_bin, TrainParam::DftSparseThreshold()})) {
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
batch_count++;
row_count += batch.Size();
}
EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
}
Args args{ Args args{
{"max_depth", "2"}, {"max_depth", "2"},
{"max_bin", std::to_string(max_bin)}, {"max_bin", std::to_string(max_bin)},
@ -60,8 +45,13 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
hist_maker->Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, hist_maker->Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{tree}); {tree});
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1); auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
hist_maker->UpdatePredictionCache(dmat, cache); if (subsample < 1.0 && is_ext) {
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
} else {
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
}
} }
} // anonymous namespace
TEST(GpuHist, UniformSampling) { TEST(GpuHist, UniformSampling) {
constexpr size_t kRows = 4096; constexpr size_t kRows = 4096;
@ -79,11 +69,11 @@ TEST(GpuHist, UniformSampling) {
RegTree tree; RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
Context ctx(MakeCUDACtx(0)); Context ctx(MakeCUDACtx(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling. // Build another tree using sampling.
RegTree tree_sampling; RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "uniform", UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, "uniform",
kRows); kRows);
// Make sure the predictions are the same. // Make sure the predictions are the same.
@ -110,12 +100,12 @@ TEST(GpuHist, GradientBasedSampling) {
RegTree tree; RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
Context ctx(MakeCUDACtx(0)); Context ctx(MakeCUDACtx(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling. // Build another tree using sampling.
RegTree tree_sampling; RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample,
"gradient_based", kRows); "gradient_based", kRows);
// Make sure the predictions are the same. // Make sure the predictions are the same.
@ -147,11 +137,11 @@ TEST(GpuHist, ExternalMemory) {
// Build a tree using the in-memory DMatrix. // Build a tree using the in-memory DMatrix.
RegTree tree; RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using multiple ELLPACK pages. // Build another tree using multiple ELLPACK pages.
RegTree tree_ext; RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows); UpdateTree(&ctx, &gpair, dmat_ext.get(), true, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
// Make sure the predictions are the same. // Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector(); auto preds_h = preds.ConstHostVector();
@ -162,23 +152,26 @@ TEST(GpuHist, ExternalMemory) {
} }
TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ExternalMemoryWithSampling) {
constexpr size_t kRows = 4096; constexpr size_t kRows = 4096, kCols = 2;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1024;
constexpr float kSubsample = 0.5; constexpr float kSubsample = 0.5;
const std::string kSamplingMethod = "gradient_based"; const std::string kSamplingMethod = "gradient_based";
common::GlobalRandom().seed(0); common::GlobalRandom().seed(0);
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
Context ctx(MakeCUDACtx(0));
// Create a single batch DMatrix. // Create a single batch DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache")); auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}
.Device(ctx.Device())
.Batches(1)
.GenerateSparsePageDMatrix("temp", true);
// Create a DMatrix with multiple batches. // Create a DMatrix with multiple batches.
std::unique_ptr<DMatrix> dmat_ext( auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f}
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache")); .Device(ctx.Device())
.Batches(4)
.GenerateSparsePageDMatrix("temp", true);
Context ctx(MakeCUDACtx(0));
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device()); linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows)); gpair.Data()->Copy(GenerateRandomGradients(kRows));
@ -187,13 +180,13 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
RegTree tree; RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows); UpdateTree(&ctx, &gpair, p_fmat.get(), true, &tree, &preds, kSubsample, kSamplingMethod, kRows);
// Build another tree using multiple ELLPACK pages. // Build another tree using multiple ELLPACK pages.
common::GlobalRandom() = rng; common::GlobalRandom() = rng;
RegTree tree_ext; RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0)); HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, UpdateTree(&ctx, &gpair, p_fmat_ext.get(), true, &tree_ext, &preds_ext, kSubsample,
kSamplingMethod, kRows); kSamplingMethod, kRows);
// Make sure the predictions are the same. // Make sure the predictions are the same.