[EM] Enable prediction cache for GPU. (#10707)
- Use `UpdatePosition` for all nodes and skip `FinalizePosition` when external memory is used. - Create `encode/decode` for node position, this is just as a refactor. - Reuse code between update position and finalization.
This commit is contained in:
parent
0def8e0bae
commit
582ea104b5
@ -1,20 +1,17 @@
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost Contributors
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
* \file categorical.h
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||
#define XGBOOST_COMMON_CATEGORICAL_H_
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "bitfield.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
namespace xgboost::common {
|
||||
using CatBitField = LBitField32;
|
||||
using KCatBitField = CLBitField32;
|
||||
|
||||
@ -94,7 +91,12 @@ XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
inline auto GetNodeCats(common::Span<CatBitField::value_type const> categories,
|
||||
RegTree::CategoricalSplitMatrix::Segment seg) {
|
||||
KCatBitField node_cats{categories.subspan(seg.beg, seg.size)};
|
||||
return node_cats;
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
#endif // XGBOOST_COMMON_CATEGORICAL_H_
|
||||
|
||||
@ -16,12 +16,9 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh> // for UnitWord
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "common.h"
|
||||
#include "device_vector.cuh"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
@ -375,19 +372,24 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template <class HContainer, class DContainer>
|
||||
void CopyToD(HContainer const &h, DContainer *d) {
|
||||
if (h.empty()) {
|
||||
d->clear();
|
||||
template <class Src, class Dst>
|
||||
void CopyTo(Src const &src, Dst *dst) {
|
||||
if (src.empty()) {
|
||||
dst->clear();
|
||||
return;
|
||||
}
|
||||
d->resize(h.size());
|
||||
using HVT = std::remove_cv_t<typename HContainer::value_type>;
|
||||
using DVT = std::remove_cv_t<typename DContainer::value_type>;
|
||||
static_assert(std::is_same<HVT, DVT>::value,
|
||||
dst->resize(src.size());
|
||||
using SVT = std::remove_cv_t<typename Src::value_type>;
|
||||
using DVT = std::remove_cv_t<typename Dst::value_type>;
|
||||
static_assert(std::is_same<SVT, DVT>::value,
|
||||
"Host and device containers must have same value type.");
|
||||
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
|
||||
cudaMemcpyHostToDevice));
|
||||
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
|
||||
src.size() * sizeof(SVT), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
template <class HContainer, class DContainer>
|
||||
void CopyToD(HContainer const &h, DContainer *d) {
|
||||
CopyTo(h, d);
|
||||
}
|
||||
|
||||
// Keep track of pinned memory allocation
|
||||
|
||||
@ -307,6 +307,7 @@ class DeviceUVector {
|
||||
|
||||
public:
|
||||
DeviceUVector() = default;
|
||||
explicit DeviceUVector(std::size_t n) { this->resize(n); }
|
||||
DeviceUVector(DeviceUVector const &that) = delete;
|
||||
DeviceUVector &operator=(DeviceUVector const &that) = delete;
|
||||
DeviceUVector(DeviceUVector &&that) = default;
|
||||
@ -330,7 +331,17 @@ class DeviceUVector {
|
||||
data_.resize(n, v);
|
||||
#endif
|
||||
}
|
||||
|
||||
void clear() { // NOLINT
|
||||
#if defined(XGBOOST_USE_RMM)
|
||||
this->data_.resize(0, rmm::cuda_stream_per_thread);
|
||||
#else
|
||||
this->data_.clear();
|
||||
#endif // defined(XGBOOST_USE_RMM)
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT
|
||||
[[nodiscard]] bool empty() const { return this->size() == 0; } // NOLINT
|
||||
|
||||
[[nodiscard]] auto begin() { return data_.begin(); } // NOLINT
|
||||
[[nodiscard]] auto end() { return data_.end(); } // NOLINT
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "column_matrix.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../tree/sample_position.h" // for SamplePosition
|
||||
|
||||
namespace xgboost::common {
|
||||
// The builder is required for samples partition to left and rights children for set of nodes
|
||||
@ -364,13 +365,14 @@ class PartitionBuilder {
|
||||
}
|
||||
|
||||
// Copy row partitions into global cache for reuse in objective
|
||||
template <typename Sampledp>
|
||||
template <typename Invalidp>
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
|
||||
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
|
||||
std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
|
||||
auto& h_pos = *p_position;
|
||||
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
|
||||
|
||||
auto p_begin = row_set.Data()->data();
|
||||
// For each node, walk through all the samples that fall in this node.
|
||||
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
|
||||
auto const& node = row_set[i];
|
||||
if (node.node_id < 0) {
|
||||
@ -381,7 +383,7 @@ class PartitionBuilder {
|
||||
size_t ptr_offset = node.end() - p_begin;
|
||||
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
|
||||
for (auto idx = node.begin(); idx != node.end(); ++idx) {
|
||||
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
|
||||
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
#include "../collective/allgather.h"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
#include <limits> // for numeric_limits
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
@ -21,22 +23,26 @@ namespace xgboost {
|
||||
* Does not own underlying memory and may be trivially copied into kernels.
|
||||
*/
|
||||
struct EllpackDeviceAccessor {
|
||||
/*! \brief Whether or not if the matrix is dense. */
|
||||
/** @brief Whether or not if the matrix is dense. */
|
||||
bool is_dense;
|
||||
/*! \brief Row length for ELLPACK, equal to number of features. */
|
||||
/** @brief Row length for ELLPACK, equal to number of features when the data is dense. */
|
||||
bst_idx_t row_stride;
|
||||
bst_idx_t base_rowid{0};
|
||||
bst_idx_t n_rows{0};
|
||||
/** @brief Starting index of the rows. Used for external memory. */
|
||||
bst_idx_t base_rowid;
|
||||
/** @brief Number of rows in this batch. */
|
||||
bst_idx_t n_rows;
|
||||
/** @brief Acessor for the gradient index. */
|
||||
common::CompressedIterator<std::uint32_t> gidx_iter;
|
||||
/*! \brief Minimum value for each feature. Size equals to number of features. */
|
||||
/** @brief Minimum value for each feature. Size equals to number of features. */
|
||||
common::Span<const float> min_fvalue;
|
||||
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */
|
||||
/** @brief Histogram cut pointers. Size equals to (number of features + 1). */
|
||||
common::Span<const std::uint32_t> feature_segments;
|
||||
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
||||
/** @brief Histogram cut values. Size equals to (bins per feature * number of features). */
|
||||
common::Span<const float> gidx_fvalue_map;
|
||||
|
||||
/** @brief Type of each feature, categorical or numerical. */
|
||||
common::Span<const FeatureType> feature_types;
|
||||
|
||||
EllpackDeviceAccessor() = delete;
|
||||
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts,
|
||||
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
|
||||
common::CompressedIterator<uint32_t> gidx_iter,
|
||||
@ -108,10 +114,10 @@ struct EllpackDeviceAccessor {
|
||||
return idx;
|
||||
}
|
||||
|
||||
[[nodiscard]] __device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return gidx_fvalue_map[gidx];
|
||||
}
|
||||
|
||||
@ -3,18 +3,18 @@
|
||||
*/
|
||||
#include "adaptive.h"
|
||||
|
||||
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
|
||||
#include <cmath> // std::isnan
|
||||
#include <cstddef> // std::size_t
|
||||
#include <iterator> // std::distance
|
||||
#include <vector> // std::vector
|
||||
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
|
||||
#include <cmath> // std::isnan
|
||||
#include <cstddef> // std::size_t
|
||||
#include <iterator> // std::distance
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "../common/algorithm.h" // ArgSort
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#include "../common/numeric.h" // RunLengthEncode
|
||||
#include "../common/stats.h" // Quantile,WeightedQuantile
|
||||
#include "../common/threading_utils.h" // ParallelFor
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "../tree/sample_position.h" // for SamplePosition
|
||||
#include "xgboost/base.h" // bst_node_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
@ -23,6 +23,10 @@
|
||||
#include "xgboost/span.h" // Span
|
||||
#include "xgboost/tree_model.h" // RegTree
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
namespace xgboost::obj::detail {
|
||||
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
|
||||
@ -37,9 +41,10 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
sorted_pos[i] = position[ridx[i]];
|
||||
}
|
||||
// find the first non-sampled row
|
||||
size_t begin_pos =
|
||||
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
|
||||
[](bst_node_t nidx) { return nidx >= 0; }));
|
||||
size_t begin_pos = std::distance(
|
||||
sorted_pos.cbegin(),
|
||||
std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
|
||||
[](bst_node_t nidx) { return tree::SamplePosition::IsValid(nidx); }));
|
||||
CHECK_LE(begin_pos, sorted_pos.size());
|
||||
|
||||
std::vector<bst_node_t> leaf;
|
||||
|
||||
@ -3,13 +3,14 @@
|
||||
*/
|
||||
#include <thrust/sort.h>
|
||||
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cub/cub.cuh> // NOLINT
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cub/cub.cuh> // NOLINT
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/stats.cuh"
|
||||
#include "../tree/sample_position.h" // for SamplePosition
|
||||
#include "adaptive.h"
|
||||
#include "xgboost/context.h"
|
||||
|
||||
@ -30,10 +31,12 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
// sort row index according to node index
|
||||
thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
|
||||
sorted_position.begin() + n_samples, p_ridx->begin());
|
||||
size_t beg_pos =
|
||||
thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
|
||||
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) -
|
||||
sorted_position.cbegin();
|
||||
// Find the first one that's not sampled (nidx not been negated).
|
||||
size_t beg_pos = thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
|
||||
[] XGBOOST_DEVICE(bst_node_t nidx) {
|
||||
return tree::SamplePosition::IsValid(nidx);
|
||||
}) -
|
||||
sorted_position.cbegin();
|
||||
if (beg_pos == sorted_position.size()) {
|
||||
auto& leaf = p_nidx->HostVector();
|
||||
tree.WalkTree([&](bst_node_t nidx) {
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
/**
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <algorithm> // std::max
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <algorithm> // for :max
|
||||
#include <limits> // for numeric_limits
|
||||
|
||||
#include "../../collective/allgather.h"
|
||||
#include "../../collective/communicator-inl.h" // for GetWorldSize, GetRank
|
||||
#include "../../common/categorical.h"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "expand_entry.cuh"
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
|
||||
ridx_.resize(n_samples);
|
||||
ridx_tmp_.resize(n_samples);
|
||||
tmp_.clear();
|
||||
n_nodes_ = 1; // Root
|
||||
|
||||
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
|
||||
ridx_segments_.emplace_back(
|
||||
|
||||
@ -19,7 +19,9 @@
|
||||
namespace xgboost::tree {
|
||||
namespace cuda_impl {
|
||||
using RowIndexT = std::uint32_t;
|
||||
}
|
||||
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
|
||||
static const std::int32_t kMaxUpdatePositionBatchSize = 32;
|
||||
} // namespace cuda_impl
|
||||
|
||||
/**
|
||||
* @brief Used to demarcate a contiguous set of row indices associated with some tree
|
||||
@ -37,8 +39,6 @@ struct Segment {
|
||||
__host__ __device__ bst_idx_t Size() const { return end - begin; }
|
||||
};
|
||||
|
||||
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
|
||||
static const int kMaxUpdatePositionBatchSize = 32;
|
||||
template <typename OpDataT>
|
||||
struct PerNodeData {
|
||||
Segment segment;
|
||||
@ -46,10 +46,10 @@ struct PerNodeData {
|
||||
};
|
||||
|
||||
template <typename BatchIterT>
|
||||
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
|
||||
int* batch_idx, std::size_t* item_idx) {
|
||||
XGBOOST_DEV_INLINE void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
|
||||
int* batch_idx, std::size_t* item_idx) {
|
||||
cuda_impl::RowIndexT sum = 0;
|
||||
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
|
||||
for (int i = 0; i < cuda_impl::kMaxUpdatePositionBatchSize; i++) {
|
||||
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
|
||||
*batch_idx = i;
|
||||
*item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin;
|
||||
@ -59,10 +59,10 @@ __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t g
|
||||
}
|
||||
}
|
||||
|
||||
template <int kBlockSize, typename RowIndexT, typename OpDataT>
|
||||
template <int kBlockSize, typename OpDataT>
|
||||
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
|
||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<RowIndexT> d_ridx,
|
||||
const common::Span<const RowIndexT> ridx_tmp, std::size_t total_rows) {
|
||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<cuda_impl::RowIndexT> d_ridx,
|
||||
const common::Span<const cuda_impl::RowIndexT> ridx_tmp, bst_idx_t total_rows) {
|
||||
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
|
||||
int batch_idx;
|
||||
std::size_t item_idx;
|
||||
@ -92,6 +92,7 @@ struct IndexFlagOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Scatter from `ridx_in` to `ridx_out`.
|
||||
template <typename OpDataT>
|
||||
struct WriteResultsFunctor {
|
||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
|
||||
@ -99,10 +100,12 @@ struct WriteResultsFunctor {
|
||||
cuda_impl::RowIndexT* ridx_out;
|
||||
cuda_impl::RowIndexT* counts;
|
||||
|
||||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
|
||||
std::size_t scatter_address;
|
||||
__device__ IndexFlagTuple operator()(IndexFlagTuple const& x) {
|
||||
cuda_impl::RowIndexT scatter_address;
|
||||
// Get the segment that this row belongs to.
|
||||
const Segment& segment = batch_info[x.batch_idx].segment;
|
||||
if (x.flag) {
|
||||
// Go left.
|
||||
cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
|
||||
scatter_address = segment.begin + num_previous_flagged;
|
||||
} else {
|
||||
@ -121,10 +124,14 @@ struct WriteResultsFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RowIndexT, typename OpT, typename OpDataT>
|
||||
/**
|
||||
* @param d_batch_info Node data, with the size of the input number of nodes.
|
||||
*/
|
||||
template <typename OpT, typename OpDataT>
|
||||
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
||||
common::Span<cuda_impl::RowIndexT> d_counts, std::size_t total_rows, OpT op,
|
||||
common::Span<cuda_impl::RowIndexT> ridx,
|
||||
common::Span<cuda_impl::RowIndexT> ridx_tmp,
|
||||
common::Span<cuda_impl::RowIndexT> d_counts, bst_idx_t total_rows, OpT op,
|
||||
dh::device_vector<int8_t>* tmp) {
|
||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
||||
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
||||
@ -134,22 +141,23 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
auto input_iterator =
|
||||
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(size_t idx) {
|
||||
int batch_idx;
|
||||
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(std::size_t idx) {
|
||||
int nidx_in_batch;
|
||||
std::size_t item_idx;
|
||||
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
|
||||
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
|
||||
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), op_res, batch_idx, op_res};
|
||||
AssignBatch(batch_info_itr, idx, &nidx_in_batch, &item_idx);
|
||||
auto go_left = op(ridx[item_idx], nidx_in_batch, batch_info_itr[nidx_in_batch].data);
|
||||
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
|
||||
go_left};
|
||||
});
|
||||
size_t temp_bytes = 0;
|
||||
std::size_t temp_bytes = 0;
|
||||
if (tmp->empty()) {
|
||||
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
|
||||
IndexFlagOp(), total_rows);
|
||||
IndexFlagOp{}, total_rows);
|
||||
tmp->resize(temp_bytes);
|
||||
}
|
||||
temp_bytes = tmp->size();
|
||||
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
|
||||
discard_write_iterator, IndexFlagOp(), total_rows);
|
||||
discard_write_iterator, IndexFlagOp{}, total_rows);
|
||||
|
||||
constexpr int kBlockSize = 256;
|
||||
|
||||
@ -157,7 +165,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
const int kItemsThread = 12;
|
||||
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
|
||||
|
||||
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
|
||||
SortPositionCopyKernel<kBlockSize, OpDataT>
|
||||
<<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
|
||||
}
|
||||
|
||||
@ -168,8 +176,8 @@ struct NodePositionInfo {
|
||||
__device__ bool IsLeaf() { return left_child == -1; }
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int GetPositionFromSegments(std::size_t idx,
|
||||
const NodePositionInfo* d_node_info) {
|
||||
XGBOOST_DEV_INLINE int GetPositionFromSegments(std::size_t idx,
|
||||
const NodePositionInfo* d_node_info) {
|
||||
int position = 0;
|
||||
NodePositionInfo node = d_node_info[position];
|
||||
while (!node.IsLeaf()) {
|
||||
@ -205,7 +213,6 @@ __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
|
||||
class RowPartitioner {
|
||||
public:
|
||||
using RowIndexT = cuda_impl::RowIndexT;
|
||||
static constexpr bst_node_t kIgnoredTreePosition = -1;
|
||||
|
||||
private:
|
||||
/**
|
||||
@ -232,6 +239,7 @@ class RowPartitioner {
|
||||
dh::device_vector<int8_t> tmp_;
|
||||
dh::PinnedMemory pinned_;
|
||||
dh::PinnedMemory pinned2_;
|
||||
bst_node_t n_nodes_{0}; // Counter for internal checks.
|
||||
|
||||
public:
|
||||
/**
|
||||
@ -255,6 +263,7 @@ class RowPartitioner {
|
||||
* \brief Gets all training rows in the set.
|
||||
*/
|
||||
common::Span<const RowIndexT> GetRows();
|
||||
[[nodiscard]] bst_node_t GetNumNodes() const { return n_nodes_; }
|
||||
|
||||
/**
|
||||
* \brief Convenience method for testing
|
||||
@ -280,10 +289,14 @@ class RowPartitioner {
|
||||
const std::vector<bst_node_t>& left_nidx,
|
||||
const std::vector<bst_node_t>& right_nidx,
|
||||
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
|
||||
if (nidx.empty()) return;
|
||||
if (nidx.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK_EQ(nidx.size(), left_nidx.size());
|
||||
CHECK_EQ(nidx.size(), right_nidx.size());
|
||||
CHECK_EQ(nidx.size(), op_data.size());
|
||||
this->n_nodes_ += (left_nidx.size() + right_nidx.size());
|
||||
|
||||
auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
|
||||
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());
|
||||
@ -302,9 +315,9 @@ class RowPartitioner {
|
||||
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
|
||||
|
||||
// Partition the rows according to the operator
|
||||
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
||||
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
||||
total_rows, op, &tmp_);
|
||||
SortPositionBatch<UpdatePositionOpT, OpDataT>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx_),
|
||||
dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
||||
total_rows, op, &tmp_);
|
||||
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
||||
cudaMemcpyDefault));
|
||||
// TODO(Rory): this synchronisation hurts performance a lot
|
||||
@ -327,20 +340,16 @@ class RowPartitioner {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Finalise the position of all training instances after tree construction is
|
||||
* @brief Finalise the position of all training instances after tree construction is
|
||||
* complete. Does not update any other meta information in this data structure, so
|
||||
* should only be used at the end of training.
|
||||
*
|
||||
* When the task requires update leaf, this function will copy the node index into
|
||||
* p_out_position. The index is negated if it's being sampled in current iteration.
|
||||
*
|
||||
* \param p_out_position Node index for each row.
|
||||
* \param op Device lambda. Should provide the row index and current position as an
|
||||
* @param p_out_position Node index for each row.
|
||||
* @param op Device lambda. Should provide the row index and current position as an
|
||||
* argument and return the new position for this training instance.
|
||||
* \param sampled A device lambda to inform the partitioner whether a row is sampled.
|
||||
*/
|
||||
template <typename FinalisePositionOpT>
|
||||
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) {
|
||||
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) const {
|
||||
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
||||
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
||||
|
||||
@ -10,14 +10,11 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
|
||||
21
src/tree/sample_position.h
Normal file
21
src/tree/sample_position.h
Normal file
@ -0,0 +1,21 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "xgboost/base.h" // for bst_node_t
|
||||
|
||||
namespace xgboost::tree {
|
||||
// Utility for maniputing the node index. This is used by the tree methods and the
|
||||
// adaptive objectives to share the node index. A row is invalid if it's not used in the
|
||||
// last iteration (due to sampling). For these rows, the corresponding tree node index is
|
||||
// negated.
|
||||
struct SamplePosition {
|
||||
[[nodiscard]] bst_node_t static XGBOOST_HOST_DEV_INLINE Encode(bst_node_t nidx, bool is_valid) {
|
||||
return is_valid ? nidx : ~nidx;
|
||||
}
|
||||
[[nodiscard]] bst_node_t static XGBOOST_HOST_DEV_INLINE Decode(bst_node_t nidx) {
|
||||
return IsValid(nidx) ? nidx : ~nidx;
|
||||
}
|
||||
[[nodiscard]] bool static XGBOOST_HOST_DEV_INLINE IsValid(bst_node_t nidx) { return nidx >= 0; }
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
@ -8,15 +8,14 @@
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <array> // for array
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h" // for EscapeU8
|
||||
#include "../common/categorical.h" // for GetNodeCats
|
||||
#include "../common/common.h" // for EscapeU8
|
||||
#include "../predictor/predict_fn.h"
|
||||
#include "io_utils.h" // for GetElem
|
||||
#include "param.h"
|
||||
@ -1038,9 +1037,8 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
|
||||
categories_nodes.GetArray().emplace_back(i);
|
||||
auto begin = categories.Size();
|
||||
categories_segments.GetArray().emplace_back(begin);
|
||||
auto segment = split_categories_segments_[i];
|
||||
auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size);
|
||||
common::KCatBitField const cat_bits(node_categories);
|
||||
auto segment = this->split_categories_segments_[i];
|
||||
auto cat_bits = common::GetNodeCats(this->GetSplitCategories(), segment);
|
||||
for (size_t i = 0; i < cat_bits.Capacity(); ++i) {
|
||||
if (cat_bits.Check(i)) {
|
||||
categories.GetArray().emplace_back(i);
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
|
||||
#include "../common/error_msg.h" // for NoCategorical
|
||||
#include "../common/random.h"
|
||||
#include "sample_position.h" // for SamplePosition
|
||||
#include "constraints.h"
|
||||
#include "param.h"
|
||||
#include "split_evaluator.h"
|
||||
@ -515,7 +516,7 @@ class ColMaker: public TreeUpdater {
|
||||
common::ParallelFor(p_fmat->Info().num_row_, this->ctx_->Threads(), [&](auto ridx) {
|
||||
CHECK_LT(ridx, position_.size()) << "ridx exceed bound "
|
||||
<< "ridx=" << ridx << " pos=" << position_.size();
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
const int nid = SamplePosition::Decode(position_[ridx]);
|
||||
if (tree[nid].IsLeaf()) {
|
||||
// mark finish when it is not a fresh leaf
|
||||
if (tree[nid].RightChild() == -1) {
|
||||
@ -560,14 +561,14 @@ class ColMaker: public TreeUpdater {
|
||||
auto col = page[fid];
|
||||
common::ParallelFor(col.size(), this->ctx_->Threads(), [&](auto j) {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
bst_node_t nidx = SamplePosition::Decode(position_[ridx]);
|
||||
const bst_float fvalue = col[j].fvalue;
|
||||
// go back to parent, correct those who are not default
|
||||
if (!tree[nid].IsLeaf() && tree[nid].SplitIndex() == fid) {
|
||||
if (fvalue < tree[nid].SplitCond()) {
|
||||
this->SetEncodePosition(ridx, tree[nid].LeftChild());
|
||||
if (!tree[nidx].IsLeaf() && tree[nidx].SplitIndex() == fid) {
|
||||
if (fvalue < tree[nidx].SplitCond()) {
|
||||
this->SetEncodePosition(ridx, tree[nidx].LeftChild());
|
||||
} else {
|
||||
this->SetEncodePosition(ridx, tree[nid].RightChild());
|
||||
this->SetEncodePosition(ridx, tree[nidx].RightChild());
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -576,17 +577,10 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
// utils to get/set position, with encoded format
|
||||
// return decoded position
|
||||
inline int DecodePosition(bst_uint ridx) const {
|
||||
const int pid = position_[ridx];
|
||||
return pid < 0 ? ~pid : pid;
|
||||
}
|
||||
// encode the encoded position value for ridx
|
||||
inline void SetEncodePosition(bst_uint ridx, int nid) {
|
||||
if (position_[ridx] < 0) {
|
||||
position_[ridx] = ~nid;
|
||||
} else {
|
||||
position_[ridx] = nid;
|
||||
}
|
||||
void SetEncodePosition(bst_idx_t ridx, bst_node_t nidx) {
|
||||
bool is_invalid = position_[ridx] < 0;
|
||||
position_[ridx] = SamplePosition::Encode(nidx, !is_invalid);
|
||||
}
|
||||
// --data fields--
|
||||
const TrainParam& param_;
|
||||
|
||||
@ -6,8 +6,9 @@
|
||||
#include <ostream> // for ostream
|
||||
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
#include "param.h"
|
||||
#include "param.h" // for TrainParam
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
|
||||
namespace xgboost::tree {
|
||||
struct GPUTrainingParam {
|
||||
@ -117,6 +118,21 @@ struct DeviceSplitCandidate {
|
||||
}
|
||||
};
|
||||
|
||||
namespace cuda_impl {
|
||||
inline BatchParam HistBatch(TrainParam const& param) {
|
||||
return {param.max_bin, TrainParam::DftSparseThreshold()};
|
||||
}
|
||||
|
||||
inline BatchParam HistBatch(bst_bin_t max_bin) {
|
||||
return {max_bin, TrainParam::DftSparseThreshold()};
|
||||
}
|
||||
|
||||
inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess,
|
||||
ObjInfo const& task) {
|
||||
return BatchParam{p.max_bin, hess, !task.const_hess};
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
template <typename T>
|
||||
struct SumCallbackOp {
|
||||
// Running prefix
|
||||
|
||||
@ -34,7 +34,8 @@
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
#include "hist/param.h"
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "sample_position.h" // for SamplePosition
|
||||
#include "updater_gpu_common.cuh" // for HistBatch
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
@ -43,11 +44,15 @@
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
|
||||
// Manage memory for a single GPU
|
||||
using cuda_impl::ApproxBatch;
|
||||
using cuda_impl::HistBatch;
|
||||
|
||||
// GPU tree updater implementation.
|
||||
struct GPUHistMakerDevice {
|
||||
private:
|
||||
GPUHistEvaluator evaluator_;
|
||||
@ -56,20 +61,29 @@ struct GPUHistMakerDevice {
|
||||
MetaInfo const& info_;
|
||||
|
||||
DeviceHistogramBuilder histogram_;
|
||||
// node idx for each sample
|
||||
dh::device_vector<bst_node_t> positions_;
|
||||
std::unique_ptr<RowPartitioner> row_partitioner_;
|
||||
|
||||
public:
|
||||
// Extra data for each node that is passed to the update position function
|
||||
struct NodeSplitData {
|
||||
RegTree::Node split_node;
|
||||
FeatureType split_type;
|
||||
common::KCatBitField node_cats;
|
||||
};
|
||||
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
|
||||
|
||||
public:
|
||||
EllpackPageImpl const* page{nullptr};
|
||||
common::Span<FeatureType const> feature_types;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogramStorage<> hist{};
|
||||
|
||||
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::device_vector<int> monotone_constraints;
|
||||
// node idx for each sample
|
||||
dh::device_vector<bst_node_t> positions;
|
||||
|
||||
TrainParam param;
|
||||
|
||||
@ -143,10 +157,10 @@ struct GPUHistMakerDevice {
|
||||
|
||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
||||
|
||||
if (!row_partitioner) {
|
||||
row_partitioner = std::make_unique<RowPartitioner>();
|
||||
if (!row_partitioner_) {
|
||||
row_partitioner_ = std::make_unique<RowPartitioner>();
|
||||
}
|
||||
row_partitioner->Reset(ctx_, sample.sample_rows, page->base_rowid);
|
||||
row_partitioner_->Reset(ctx_, sample.sample_rows, page->base_rowid);
|
||||
CHECK_EQ(page->base_rowid, 0);
|
||||
|
||||
// Init histogram
|
||||
@ -182,7 +196,10 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void EvaluateSplits(const std::vector<GPUExpandEntry>& candidates, const RegTree& tree,
|
||||
common::Span<GPUExpandEntry> pinned_candidates_out) {
|
||||
if (candidates.empty()) return;
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
}
|
||||
this->monitor.Start(__func__);
|
||||
dh::TemporaryArray<EvaluateSplitInputs> d_node_inputs(2 * candidates.size());
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size());
|
||||
std::vector<bst_node_t> nidx(2 * candidates.size());
|
||||
@ -234,12 +251,12 @@ struct GPUHistMakerDevice {
|
||||
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
|
||||
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
dh::DefaultStream().Sync();
|
||||
this->monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
auto d_ridx = row_partitioner_->GetRows(nidx);
|
||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||
d_node_hist, *quantiser);
|
||||
@ -262,14 +279,6 @@ struct GPUHistMakerDevice {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Extra data for each node that is passed
|
||||
// to the update position function
|
||||
struct NodeSplitData {
|
||||
RegTree::Node split_node;
|
||||
FeatureType split_type;
|
||||
common::KCatBitField node_cats;
|
||||
};
|
||||
|
||||
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
|
||||
std::vector<NodeSplitData> const& split_data,
|
||||
std::vector<bst_node_t> const& nidx,
|
||||
@ -321,10 +330,10 @@ struct GPUHistMakerDevice {
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
|
||||
row_partitioner->UpdatePositionBatch(
|
||||
row_partitioner_->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(bst_uint ridx, int split_index, NodeSplitData const& data) {
|
||||
auto const index = ridx * num_candidates + split_index;
|
||||
[=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
|
||||
auto const index = ridx * num_candidates + nidx_in_batch;
|
||||
bool go_left;
|
||||
if (missing_bits.Check(index)) {
|
||||
go_left = data.split_node.DefaultLeft();
|
||||
@ -335,11 +344,35 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
struct GoLeftOp {
|
||||
EllpackDeviceAccessor d_matrix;
|
||||
|
||||
__device__ bool operator()(cuda_impl::RowIndexT ridx, NodeSplitData const& data) const {
|
||||
RegTree::Node const& node = data.split_node;
|
||||
// given a row index, returns the node id it belongs to
|
||||
float cut_value = d_matrix.GetFvalue(ridx, node.SplitIndex());
|
||||
// Missing value
|
||||
bool go_left = true;
|
||||
if (isnan(cut_value)) {
|
||||
go_left = node.DefaultLeft();
|
||||
} else {
|
||||
if (data.split_type == FeatureType::kCategorical) {
|
||||
go_left = common::Decision(data.node_cats.Bits(), cut_value);
|
||||
} else {
|
||||
go_left = cut_value <= node.SplitCond();
|
||||
}
|
||||
}
|
||||
return go_left;
|
||||
}
|
||||
};
|
||||
|
||||
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
monitor.Start(__func__);
|
||||
|
||||
std::vector<bst_node_t> nidx(candidates.size());
|
||||
std::vector<bst_node_t> left_nidx(candidates.size());
|
||||
std::vector<bst_node_t> right_nidx(candidates.size());
|
||||
@ -347,12 +380,12 @@ struct GPUHistMakerDevice {
|
||||
|
||||
for (size_t i = 0; i < candidates.size(); i++) {
|
||||
auto const& e = candidates[i];
|
||||
RegTree::Node split_node = (*p_tree)[e.nid];
|
||||
RegTree::Node const& split_node = (*p_tree)[e.nid];
|
||||
auto split_type = p_tree->NodeSplitType(e.nid);
|
||||
nidx.at(i) = e.nid;
|
||||
left_nidx.at(i) = split_node.LeftChild();
|
||||
right_nidx.at(i) = split_node.RightChild();
|
||||
split_data.at(i) = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
|
||||
nidx[i] = e.nid;
|
||||
left_nidx[i] = split_node.LeftChild();
|
||||
right_nidx[i] = split_node.RightChild();
|
||||
split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
|
||||
|
||||
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
||||
}
|
||||
@ -361,27 +394,15 @@ struct GPUHistMakerDevice {
|
||||
|
||||
if (info_.IsColumnSplit()) {
|
||||
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
||||
monitor.Stop(__func__);
|
||||
return;
|
||||
}
|
||||
|
||||
row_partitioner->UpdatePositionBatch(
|
||||
auto go_left = GoLeftOp{d_matrix};
|
||||
row_partitioner_->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(bst_uint ridx, int split_index, const NodeSplitData& data) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
|
||||
// Missing value
|
||||
bool go_left = true;
|
||||
if (isnan(cut_value)) {
|
||||
go_left = data.split_node.DefaultLeft();
|
||||
} else {
|
||||
if (data.split_type == FeatureType::kCategorical) {
|
||||
go_left = common::Decision(data.node_cats.Bits(), cut_value);
|
||||
} else {
|
||||
go_left = cut_value <= data.split_node.SplitCond();
|
||||
}
|
||||
}
|
||||
return go_left;
|
||||
});
|
||||
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
||||
const NodeSplitData& data) { return go_left(ridx, data); });
|
||||
monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
// After tree update is finished, update the position of all training
|
||||
@ -389,101 +410,70 @@ struct GPUHistMakerDevice {
|
||||
// prediction cache
|
||||
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
// Prediction cache will not be used with external memory
|
||||
if (!p_fmat->SingleColBlock()) {
|
||||
if (task.UpdateTreeLeaf()) {
|
||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||
}
|
||||
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
|
||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||
}
|
||||
if (p_fmat->Info().num_row_ != row_partitioner_->GetRows().size()) {
|
||||
// Subsampling with external memory. Not supported.
|
||||
p_out_position->Resize(0);
|
||||
positions.clear();
|
||||
positions_.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto const& h_split_types = p_tree->GetSplitTypes();
|
||||
auto const& categories = p_tree->GetSplitCategories();
|
||||
auto const& categories_segments = p_tree->GetSplitCategoriesPtr();
|
||||
|
||||
dh::caching_device_vector<FeatureType> d_split_types;
|
||||
dh::caching_device_vector<uint32_t> d_categories;
|
||||
dh::caching_device_vector<RegTree::CategoricalSplitMatrix::Segment> d_categories_segments;
|
||||
|
||||
if (!categories.empty()) {
|
||||
dh::CopyToD(h_split_types, &d_split_types);
|
||||
dh::CopyToD(categories, &d_categories);
|
||||
dh::CopyToD(categories_segments, &d_categories_segments);
|
||||
}
|
||||
|
||||
FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
|
||||
dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments),
|
||||
p_out_position);
|
||||
}
|
||||
|
||||
void FinalisePositionInPage(
|
||||
EllpackPageImpl const* page, const common::Span<RegTree::Node> d_nodes,
|
||||
common::Span<FeatureType const> d_feature_types, common::Span<uint32_t const> categories,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment> categories_segments,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
|
||||
auto d_gpair = this->gpair;
|
||||
p_out_position->SetDevice(ctx_->Device());
|
||||
p_out_position->Resize(row_partitioner->GetRows().size());
|
||||
p_out_position->Resize(row_partitioner_->GetRows().size());
|
||||
auto d_out_position = p_out_position->DeviceSpan();
|
||||
|
||||
auto new_position_op = [=] __device__(size_t row_id, int position) {
|
||||
// What happens if user prune the tree?
|
||||
if (!d_matrix.IsInRange(row_id)) {
|
||||
return RowPartitioner::kIgnoredTreePosition;
|
||||
}
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
bool go_left = true;
|
||||
if (common::IsCat(d_feature_types, position)) {
|
||||
auto node_cats = categories.subspan(categories_segments[position].beg,
|
||||
categories_segments[position].size);
|
||||
go_left = common::Decision(node_cats, element);
|
||||
} else {
|
||||
go_left = element <= node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
|
||||
node = d_nodes[position];
|
||||
}
|
||||
|
||||
return position;
|
||||
auto d_gpair = this->gpair;
|
||||
auto encode_op = [=] __device__(bst_idx_t row_id, bst_node_t nidx) {
|
||||
bool is_invalid = d_gpair[row_id].GetHess() - .0f == 0.f;
|
||||
return SamplePosition::Encode(nidx, !is_invalid);
|
||||
}; // NOLINT
|
||||
|
||||
auto d_out_position = p_out_position->DeviceSpan();
|
||||
row_partitioner->FinalisePosition(d_out_position, new_position_op);
|
||||
if (!p_fmat->SingleColBlock()) {
|
||||
CHECK_EQ(row_partitioner_->GetNumNodes(), p_tree->NumNodes());
|
||||
row_partitioner_->FinalisePosition(d_out_position, encode_op);
|
||||
dh::CopyTo(d_out_position, &positions_);
|
||||
return;
|
||||
}
|
||||
|
||||
auto s_position = p_out_position->ConstDeviceSpan();
|
||||
positions.resize(s_position.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(),
|
||||
s_position.size_bytes(), cudaMemcpyDeviceToDevice,
|
||||
ctx_->CUDACtx()->Stream()));
|
||||
dh::caching_device_vector<uint32_t> categories;
|
||||
dh::CopyToD(p_tree->GetSplitCategories(), &categories);
|
||||
auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
|
||||
auto d_categories = dh::ToSpan(categories);
|
||||
|
||||
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
|
||||
bst_node_t position = d_out_position[idx];
|
||||
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
|
||||
d_out_position[idx] = is_row_sampled ? ~position : position;
|
||||
});
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
|
||||
|
||||
std::vector<NodeSplitData> split_data(p_tree->NumNodes());
|
||||
auto const& tree = *p_tree;
|
||||
for (std::size_t i = 0, n = split_data.size(); i < n; ++i) {
|
||||
RegTree::Node split_node = tree[i];
|
||||
auto split_type = p_tree->NodeSplitType(i);
|
||||
auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]);
|
||||
split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats};
|
||||
}
|
||||
|
||||
auto go_left_op = GoLeftOp{d_matrix};
|
||||
dh::caching_device_vector<NodeSplitData> d_split_data;
|
||||
dh::CopyToD(split_data, &d_split_data);
|
||||
auto s_split_data = dh::ToSpan(d_split_data);
|
||||
|
||||
row_partitioner_->FinalisePosition(d_out_position,
|
||||
[=] __device__(bst_idx_t row_id, bst_node_t nidx) {
|
||||
auto split_data = s_split_data[nidx];
|
||||
auto node = split_data.split_node;
|
||||
while (!node.IsLeaf()) {
|
||||
auto go_left = go_left_op(row_id, split_data);
|
||||
nidx = go_left ? node.LeftChild() : node.RightChild();
|
||||
node = s_split_data[nidx].split_node;
|
||||
}
|
||||
return encode_op(row_id, nidx);
|
||||
});
|
||||
dh::CopyTo(d_out_position, &positions_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
|
||||
if (positions.empty()) {
|
||||
if (positions_.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -491,20 +481,19 @@ struct GPUHistMakerDevice {
|
||||
CHECK(out_preds_d.Device().IsCUDA());
|
||||
CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal());
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||
auto d_position = dh::ToSpan(positions);
|
||||
auto d_position = dh::ToSpan(positions_);
|
||||
CHECK_EQ(out_preds_d.Size(), d_position.size());
|
||||
|
||||
auto const& h_nodes = p_tree->GetNodes();
|
||||
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
|
||||
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice,
|
||||
ctx_->CUDACtx()->Stream()));
|
||||
auto d_nodes = dh::ToSpan(nodes);
|
||||
// Use the nodes from tree, the leaf value might be changed by the objective since the
|
||||
// last update tree call.
|
||||
dh::caching_device_vector<RegTree::Node> nodes;
|
||||
dh::CopyTo(p_tree->GetNodes(), &nodes);
|
||||
common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes);
|
||||
CHECK_EQ(out_preds_d.Shape(1), 1);
|
||||
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),
|
||||
[=] XGBOOST_DEVICE(std::size_t idx) mutable {
|
||||
bst_node_t nidx = d_position[idx];
|
||||
nidx = SamplePosition::Decode(nidx);
|
||||
auto weight = d_nodes[nidx].LeafValue();
|
||||
out_preds_d(idx, 0) += weight;
|
||||
});
|
||||
@ -512,7 +501,7 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||
void AllReduceHist(int nidx, int num_histograms) {
|
||||
void AllReduceHist(bst_node_t nidx, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
@ -529,7 +518,10 @@ struct GPUHistMakerDevice {
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
|
||||
if (candidates.empty()) return;
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
}
|
||||
this->monitor.Start(__func__);
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
std::vector<int> hist_nidx;
|
||||
@ -572,14 +564,15 @@ struct GPUHistMakerDevice {
|
||||
this->AllReduceHist(subtraction_trick_nidx, 1);
|
||||
}
|
||||
}
|
||||
this->monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
// Sanity check - have we created a leaf with no training instances?
|
||||
if (!collective::IsDistributed() && row_partitioner) {
|
||||
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
||||
if (!collective::IsDistributed() && row_partitioner_) {
|
||||
CHECK(row_partitioner_->GetRows(candidate.nid).size() > 0)
|
||||
<< "No training instances in this leaf!";
|
||||
}
|
||||
|
||||
@ -659,6 +652,8 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
|
||||
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
bool const is_single_block = p_fmat->SingleColBlock();
|
||||
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
Driver<GPUExpandEntry> driver(param, 32);
|
||||
@ -684,30 +679,29 @@ struct GPUHistMakerDevice {
|
||||
[&](const auto& e) { return driver.IsChildValid(e); });
|
||||
|
||||
auto new_candidates =
|
||||
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
|
||||
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry{});
|
||||
// Update all the nodes if working with external memory, this saves us from working
|
||||
// with the finalize position call, which adds an additional iteration and requires
|
||||
// special handling for row index.
|
||||
this->UpdatePosition(is_single_block ? filtered_expand_set : expand_set, p_tree);
|
||||
|
||||
monitor.Start("UpdatePosition");
|
||||
// Update position is only run when child is valid, instead of right after apply
|
||||
// split (as in approx tree method). Hense we have the finalise position call
|
||||
// in GPU Hist.
|
||||
this->UpdatePosition(filtered_expand_set, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(filtered_expand_set, tree);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates);
|
||||
monitor.Stop("EvaluateSplits");
|
||||
dh::DefaultStream().Sync();
|
||||
|
||||
driver.Push(new_candidates.begin(), new_candidates.end());
|
||||
expand_set = driver.Pop();
|
||||
}
|
||||
|
||||
monitor.Start("FinalisePosition");
|
||||
// Row partitioner can have lesser nodes than the tree since we skip some leaf
|
||||
// nodes. These nodes are handled in the `FinalisePosition` call. However, a leaf can
|
||||
// be spliable before evaluation but invalid after evaluation as we have more
|
||||
// restrictions like min loss change after evalaution. Therefore, the check condition
|
||||
// is greater than or equal to.
|
||||
if (is_single_block) {
|
||||
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
|
||||
}
|
||||
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
|
||||
monitor.Stop("FinalisePosition");
|
||||
}
|
||||
};
|
||||
|
||||
@ -767,12 +761,11 @@ class GPUHistMaker : public TreeUpdater {
|
||||
SafeColl(rc);
|
||||
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
||||
|
||||
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||
info_->feature_types.SetDevice(ctx_->Device());
|
||||
maker = std::make_unique<GPUHistMakerDevice>(
|
||||
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
|
||||
*param, column_sampler_, info_->num_col_, batch_param, dmat->Info());
|
||||
*param, column_sampler_, info_->num_col_, HistBatch(*param), dmat->Info());
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
@ -798,14 +791,13 @@ class GPUHistMaker : public TreeUpdater {
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<bst_float> p_out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix* data, linalg::MatrixView<float> p_out_preds) override {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
monitor_.Start(__func__);
|
||||
bool result = maker->UpdatePredictionCache(p_out_preds, p_last_tree_);
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
monitor_.Stop(__func__);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -881,10 +873,9 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
|
||||
auto const& info = p_fmat->Info();
|
||||
info.feature_types.SetDevice(ctx_->Device());
|
||||
auto batch = BatchParam{param->max_bin, hess, !task_->const_hess};
|
||||
maker_ = std::make_unique<GPUHistMakerDevice>(
|
||||
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
|
||||
*param, column_sampler_, info.num_col_, batch, p_fmat->Info());
|
||||
*param, column_sampler_, info.num_col_, ApproxBatch(*param, hess, *task_), p_fmat->Info());
|
||||
|
||||
std::size_t t_idx{0};
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
@ -927,14 +918,13 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<bst_float> p_out_preds) override {
|
||||
bool UpdatePredictionCache(const DMatrix* data, linalg::MatrixView<float> p_out_preds) override {
|
||||
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.Start("UpdatePredictionCache");
|
||||
monitor_.Start(__func__);
|
||||
bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_);
|
||||
monitor_.Stop("UpdatePredictionCache");
|
||||
monitor_.Stop(__func__);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@ -67,9 +67,9 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
|
||||
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
|
||||
nullptr));
|
||||
dh::device_vector<int8_t> tmp;
|
||||
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
|
||||
dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
|
||||
total_rows, op, &tmp);
|
||||
SortPositionBatch<decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
|
||||
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op,
|
||||
&tmp);
|
||||
|
||||
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
|
||||
for (size_t i = 0; i < segments.size(); i++) {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
#include "../../../../src/common/categorical.h" // for CatBitField
|
||||
#include "../../../../src/tree/hist/expand_entry.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <xgboost/base.h> // for Args
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/json.h> // for Jons
|
||||
#include <xgboost/json.h> // for Json
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
#include <xgboost/tree_updater.h> // for TreeUpdater
|
||||
@ -14,32 +14,17 @@
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/common/random.h" // for GlobalRandom
|
||||
#include "../../../src/data/ellpack_page.h" // for EllpackPage
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../collective/test_worker.h" // for BaseMGPUTest
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../../../src/common/random.h" // for GlobalRandom
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../collective/test_worker.h" // for BaseMGPUTest
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
||||
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
|
||||
float subsample = 1.0f, const std::string& sampling_method = "uniform",
|
||||
int max_bin = 2) {
|
||||
if (gpu_page_size > 0) {
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>(
|
||||
ctx, BatchParam{max_bin, TrainParam::DftSparseThreshold()})) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
}
|
||||
EXPECT_GE(batch_count, 2);
|
||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, bool is_ext,
|
||||
RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
|
||||
const std::string& sampling_method, bst_bin_t max_bin) {
|
||||
Args args{
|
||||
{"max_depth", "2"},
|
||||
{"max_bin", std::to_string(max_bin)},
|
||||
@ -60,8 +45,13 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
|
||||
hist_maker->Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
{tree});
|
||||
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||
hist_maker->UpdatePredictionCache(dmat, cache);
|
||||
if (subsample < 1.0 && is_ext) {
|
||||
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||
} else {
|
||||
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GpuHist, UniformSampling) {
|
||||
constexpr size_t kRows = 4096;
|
||||
@ -79,11 +69,11 @@ TEST(GpuHist, UniformSampling) {
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||
kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
@ -110,12 +100,12 @@ TEST(GpuHist, GradientBasedSampling) {
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
||||
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample,
|
||||
"gradient_based", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
@ -147,11 +137,11 @@ TEST(GpuHist, ExternalMemory) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), true, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@ -162,23 +152,26 @@ TEST(GpuHist, ExternalMemory) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
constexpr size_t kRows = 4096;
|
||||
constexpr size_t kCols = 2;
|
||||
constexpr size_t kPageSize = 1024;
|
||||
constexpr size_t kRows = 4096, kCols = 2;
|
||||
constexpr float kSubsample = 0.5;
|
||||
const std::string kSamplingMethod = "gradient_based";
|
||||
common::GlobalRandom().seed(0);
|
||||
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
|
||||
// Create a single batch DMatrix.
|
||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache"));
|
||||
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}
|
||||
.Device(ctx.Device())
|
||||
.Batches(1)
|
||||
.GenerateSparsePageDMatrix("temp", true);
|
||||
|
||||
// Create a DMatrix with multiple batches.
|
||||
std::unique_ptr<DMatrix> dmat_ext(
|
||||
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
|
||||
auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f}
|
||||
.Device(ctx.Device())
|
||||
.Batches(4)
|
||||
.GenerateSparsePageDMatrix("temp", true);
|
||||
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||
|
||||
@ -187,13 +180,13 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), true, &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
common::GlobalRandom() = rng;
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
|
||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample,
|
||||
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), true, &tree_ext, &preds_ext, kSubsample,
|
||||
kSamplingMethod, kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user