[EM] Return a full DMatrix instead of a Ellpack from the GPU sampler. (#10753)
This commit is contained in:
parent
d6ebcfb032
commit
bde1265caf
4
.gitignore
vendored
4
.gitignore
vendored
@ -63,6 +63,7 @@ java/xgboost4j-demo/data/
|
|||||||
java/xgboost4j-demo/tmp/
|
java/xgboost4j-demo/tmp/
|
||||||
java/xgboost4j-demo/model/
|
java/xgboost4j-demo/model/
|
||||||
nb-configuration*
|
nb-configuration*
|
||||||
|
|
||||||
# Eclipse
|
# Eclipse
|
||||||
.project
|
.project
|
||||||
.cproject
|
.cproject
|
||||||
@ -154,3 +155,6 @@ model*.json
|
|||||||
*.rds
|
*.rds
|
||||||
Rplots.pdf
|
Rplots.pdf
|
||||||
*.zip
|
*.zip
|
||||||
|
|
||||||
|
# nsys
|
||||||
|
*.nsys-rep
|
||||||
@ -110,8 +110,15 @@ class MetaInfo {
|
|||||||
* @brief Validate all metainfo.
|
* @brief Validate all metainfo.
|
||||||
*/
|
*/
|
||||||
void Validate(DeviceOrd device) const;
|
void Validate(DeviceOrd device) const;
|
||||||
|
/**
|
||||||
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
|
* @brief Slice the meta info.
|
||||||
|
*
|
||||||
|
* The device of ridxs is specified by the ctx object.
|
||||||
|
*
|
||||||
|
* @param ridxs Index of selected rows.
|
||||||
|
* @param nnz The number of non-missing values.
|
||||||
|
*/
|
||||||
|
MetaInfo Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs, bst_idx_t nnz) const;
|
||||||
|
|
||||||
MetaInfo Copy() const;
|
MetaInfo Copy() const;
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -508,6 +508,11 @@ xgboost::common::Span<T> ToSpan(DeviceUVector<T> &vec) {
|
|||||||
return {vec.data(), vec.size()};
|
return {vec.data(), vec.size()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
xgboost::common::Span<std::add_const_t<T>> ToSpan(DeviceUVector<T> const &vec) {
|
||||||
|
return {vec.data(), vec.size()};
|
||||||
|
}
|
||||||
|
|
||||||
// thrust begin, similiar to std::begin
|
// thrust begin, similiar to std::begin
|
||||||
template <typename T>
|
template <typename T>
|
||||||
thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT
|
thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT
|
||||||
|
|||||||
@ -76,7 +76,7 @@ struct IterOp {
|
|||||||
// returns a thrust iterator for a tensor view.
|
// returns a thrust iterator for a tensor view.
|
||||||
template <typename T, std::int32_t kDim>
|
template <typename T, std::int32_t kDim>
|
||||||
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
|
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
|
||||||
return dh::MakeTransformIterator<T>(
|
return thrust::make_transform_iterator(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
|
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
|
||||||
}
|
}
|
||||||
@ -85,5 +85,16 @@ template <typename T, std::int32_t kDim>
|
|||||||
auto tcend(TensorView<T, kDim> v) { // NOLINT
|
auto tcend(TensorView<T, kDim> v) { // NOLINT
|
||||||
return tcbegin(v) + v.Size();
|
return tcbegin(v) + v.Size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
auto tbegin(TensorView<T, kDim> v) { // NOLINT
|
||||||
|
return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul),
|
||||||
|
detail::IterOp<std::remove_const_t<T>, kDim>{v});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
auto tend(TensorView<T, kDim> v) { // NOLINT
|
||||||
|
return tbegin(v) + v.Size();
|
||||||
|
}
|
||||||
} // namespace xgboost::linalg
|
} // namespace xgboost::linalg
|
||||||
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
||||||
|
|||||||
@ -351,8 +351,10 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
|||||||
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
|
std::vector<T> Gather(const std::vector<T>& in, common::Span<bst_idx_t const> ridxs,
|
||||||
|
size_t stride = 1) {
|
||||||
if (in.empty()) {
|
if (in.empty()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
@ -366,11 +368,51 @@ std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, s
|
|||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
namespace cuda_impl {
|
||||||
|
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
|
||||||
|
MetaInfo* p_out);
|
||||||
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
void SliceMetaInfo(Context const*, MetaInfo const&, common::Span<bst_idx_t const>, MetaInfo*) {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace cuda_impl
|
||||||
|
|
||||||
|
MetaInfo MetaInfo::Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs,
|
||||||
|
bst_idx_t nnz) const {
|
||||||
|
/**
|
||||||
|
* Shape
|
||||||
|
*/
|
||||||
MetaInfo out;
|
MetaInfo out;
|
||||||
out.num_row_ = ridxs.size();
|
out.num_row_ = ridxs.size();
|
||||||
out.num_col_ = this->num_col_;
|
out.num_col_ = this->num_col_;
|
||||||
|
out.num_nonzero_ = nnz;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Feature Info
|
||||||
|
*/
|
||||||
|
out.feature_weights.SetDevice(ctx->Device());
|
||||||
|
out.feature_weights.Resize(this->feature_weights.Size());
|
||||||
|
out.feature_weights.Copy(this->feature_weights);
|
||||||
|
|
||||||
|
out.feature_names = this->feature_names;
|
||||||
|
|
||||||
|
out.feature_types.SetDevice(ctx->Device());
|
||||||
|
out.feature_types.Resize(this->feature_types.Size());
|
||||||
|
out.feature_types.Copy(this->feature_types);
|
||||||
|
|
||||||
|
out.feature_type_names = this->feature_type_names;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sample Info
|
||||||
|
*/
|
||||||
|
if (ctx->IsCUDA()) {
|
||||||
|
cuda_impl::SliceMetaInfo(ctx, *this, ridxs, &out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||||
// the slice function.
|
// the slice function.
|
||||||
if (this->labels.Size() != this->num_row_) {
|
if (this->labels.Size() != this->num_row_) {
|
||||||
@ -386,10 +428,8 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
out.labels_upper_bound_.HostVector() =
|
out.labels_upper_bound_.HostVector() = Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
||||||
Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
out.labels_lower_bound_.HostVector() = Gather(this->labels_lower_bound_.HostVector(), ridxs);
|
||||||
out.labels_lower_bound_.HostVector() =
|
|
||||||
Gather(this->labels_lower_bound_.HostVector(), ridxs);
|
|
||||||
// weights
|
// weights
|
||||||
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
|
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
|
||||||
auto& h_weights = out.weights_.HostVector();
|
auto& h_weights = out.weights_.HostVector();
|
||||||
@ -414,14 +454,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
out.feature_weights.Resize(this->feature_weights.Size());
|
|
||||||
out.feature_weights.Copy(this->feature_weights);
|
|
||||||
|
|
||||||
out.feature_names = this->feature_names;
|
|
||||||
out.feature_types.Resize(this->feature_types.Size());
|
|
||||||
out.feature_types.Copy(this->feature_types);
|
|
||||||
out.feature_type_names = this->feature_type_names;
|
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2022 by XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*
|
*
|
||||||
* \file data.cu
|
* \file data.cu
|
||||||
* \brief Handles setting metainfo from array interface.
|
* \brief Handles setting metainfo from array interface.
|
||||||
*/
|
*/
|
||||||
|
#include <thrust/gather.h> // for gather
|
||||||
|
|
||||||
#include "../common/cuda_context.cuh"
|
#include "../common/cuda_context.cuh"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/linalg_op.cuh"
|
#include "../common/linalg_op.cuh"
|
||||||
@ -169,6 +171,62 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void Gather(Context const* ctx, linalg::MatrixView<float const> in,
|
||||||
|
common::Span<bst_idx_t const> ridx, linalg::Matrix<float>* p_out) {
|
||||||
|
if (in.Empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto& out = *p_out;
|
||||||
|
out.Reshape(ridx.size(), in.Shape(1));
|
||||||
|
auto d_out = out.View(ctx->Device());
|
||||||
|
|
||||||
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
auto map_it = thrust::make_transform_iterator(thrust::make_counting_iterator(0ull),
|
||||||
|
[=] XGBOOST_DEVICE(bst_idx_t i) {
|
||||||
|
auto [r, c] = linalg::UnravelIndex(i, in.Shape());
|
||||||
|
return (ridx[r] * in.Shape(1)) + c;
|
||||||
|
});
|
||||||
|
CHECK_NE(in.Shape(1), 0);
|
||||||
|
thrust::gather(cuctx->TP(), map_it, map_it + out.Size(), linalg::tcbegin(in),
|
||||||
|
linalg::tbegin(d_out));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Gather(Context const* ctx, HostDeviceVector<T> const& in, common::Span<bst_idx_t const> ridx,
|
||||||
|
HostDeviceVector<T>* p_out) {
|
||||||
|
if (in.Empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
in.SetDevice(ctx->Device());
|
||||||
|
|
||||||
|
auto& out = *p_out;
|
||||||
|
out.SetDevice(ctx->Device());
|
||||||
|
out.Resize(ridx.size());
|
||||||
|
auto d_out = out.DeviceSpan();
|
||||||
|
|
||||||
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
auto d_in = in.ConstDeviceSpan();
|
||||||
|
thrust::gather(cuctx->TP(), dh::tcbegin(ridx), dh::tcend(ridx), dh::tcbegin(d_in),
|
||||||
|
dh::tbegin(d_out));
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
namespace cuda_impl {
|
||||||
|
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
|
||||||
|
MetaInfo* p_out) {
|
||||||
|
auto& out = *p_out;
|
||||||
|
|
||||||
|
Gather(ctx, info.labels.View(ctx->Device()), ridx, &p_out->labels);
|
||||||
|
Gather(ctx, info.base_margin_.View(ctx->Device()), ridx, &p_out->base_margin_);
|
||||||
|
|
||||||
|
Gather(ctx, info.labels_lower_bound_, ridx, &out.labels_lower_bound_);
|
||||||
|
Gather(ctx, info.labels_upper_bound_, ridx, &out.labels_upper_bound_);
|
||||||
|
|
||||||
|
Gather(ctx, info.weights_, ridx, &out.weights_);
|
||||||
|
}
|
||||||
|
} // namespace cuda_impl
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix, DataSplitMode data_split_mode) {
|
const std::string& cache_prefix, DataSplitMode data_split_mode) {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2024, XGBoost contributors
|
* Copyright 2019-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include <cuda/functional> // for proclaim_return_type
|
||||||
#include <thrust/iterator/discard_iterator.h>
|
#include <thrust/iterator/discard_iterator.h>
|
||||||
#include <thrust/iterator/transform_output_iterator.h>
|
#include <thrust/iterator/transform_output_iterator.h>
|
||||||
|
|
||||||
@ -576,4 +577,17 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
|
|||||||
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()),
|
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()),
|
||||||
feature_types};
|
feature_types};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing(
|
||||||
|
Context const* ctx, common::Span<FeatureType const> feature_types) const {
|
||||||
|
auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types);
|
||||||
|
using T = typename decltype(d_acc.gidx_iter)::value_type;
|
||||||
|
auto it = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0ull),
|
||||||
|
cuda::proclaim_return_type<T>([=] __device__(std::size_t i) { return d_acc.gidx_iter[i]; }));
|
||||||
|
auto nnz = thrust::count_if(ctx->CUDACtx()->CTP(), it, it + d_acc.row_stride * d_acc.n_rows,
|
||||||
|
cuda::proclaim_return_type<bool>(
|
||||||
|
[=] __device__(T gidx) { return gidx != d_acc.NullValue(); }));
|
||||||
|
return nnz;
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -236,6 +236,11 @@ class EllpackPageImpl {
|
|||||||
[[nodiscard]] EllpackDeviceAccessor GetHostAccessor(
|
[[nodiscard]] EllpackDeviceAccessor GetHostAccessor(
|
||||||
Context const* ctx, std::vector<common::CompressedByteT>* h_gidx_buffer,
|
Context const* ctx, std::vector<common::CompressedByteT>* h_gidx_buffer,
|
||||||
common::Span<FeatureType const> feature_types = {}) const;
|
common::Span<FeatureType const> feature_types = {}) const;
|
||||||
|
/**
|
||||||
|
* @brief Calculate the number of non-missing values.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bst_idx_t NumNonMissing(Context const* ctx,
|
||||||
|
common::Span<FeatureType const> feature_types) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -101,6 +101,17 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IterativeDMatrix::IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const& info,
|
||||||
|
BatchParam batch) {
|
||||||
|
this->ellpack_ = ellpack;
|
||||||
|
CHECK_EQ(this->Info().num_row_, 0);
|
||||||
|
CHECK_EQ(this->Info().num_col_, 0);
|
||||||
|
this->Info().Extend(info, true, true);
|
||||||
|
this->Info().num_nonzero_ = info.num_nonzero_;
|
||||||
|
CHECK_EQ(this->Info().num_row_, info.num_row_);
|
||||||
|
this->batch_ = batch;
|
||||||
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
BatchParam const& param) {
|
BatchParam const& param) {
|
||||||
if (param.Initialized()) {
|
if (param.Initialized()) {
|
||||||
|
|||||||
@ -48,6 +48,11 @@ class IterativeDMatrix : public QuantileDMatrix {
|
|||||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
|
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, float missing, int nthread,
|
XGDMatrixCallbackNext *next, float missing, int nthread,
|
||||||
bst_bin_t max_bin);
|
bst_bin_t max_bin);
|
||||||
|
/**
|
||||||
|
* @param Directly construct a QDM from an existing one.
|
||||||
|
*/
|
||||||
|
IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const &info, BatchParam batch);
|
||||||
|
|
||||||
~IterativeDMatrix() override = default;
|
~IterativeDMatrix() override = default;
|
||||||
|
|
||||||
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
|
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
|
||||||
|
|||||||
@ -31,6 +31,9 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
|||||||
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||||
auto out = new SimpleDMatrix;
|
auto out = new SimpleDMatrix;
|
||||||
SparsePage& out_page = *out->sparse_page_;
|
SparsePage& out_page = *out->sparse_page_;
|
||||||
|
// Convert to uint64 to avoid a breaking change in the C API. The performance impact is
|
||||||
|
// small since we have to iteratve through the sparse page.
|
||||||
|
std::vector<bst_idx_t> h_ridx(ridxs.data(), ridxs.data() + ridxs.size());
|
||||||
for (auto const& page : this->GetBatches<SparsePage>()) {
|
for (auto const& page : this->GetBatches<SparsePage>()) {
|
||||||
auto batch = page.GetView();
|
auto batch = page.GetView();
|
||||||
auto& h_data = out_page.data.HostVector();
|
auto& h_data = out_page.data.HostVector();
|
||||||
@ -42,8 +45,8 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
|||||||
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
||||||
h_offset.emplace_back(rptr);
|
h_offset.emplace_back(rptr);
|
||||||
}
|
}
|
||||||
out->Info() = this->Info().Slice(ridxs);
|
auto ctx = this->fmat_ctx_.MakeCPU();
|
||||||
out->Info().num_nonzero_ = h_offset.back();
|
out->Info() = this->Info().Slice(&ctx, h_ridx, h_offset.back());
|
||||||
}
|
}
|
||||||
out->fmat_ctx_ = this->fmat_ctx_;
|
out->fmat_ctx_ = this->fmat_ctx_;
|
||||||
return out;
|
return out;
|
||||||
|
|||||||
@ -14,12 +14,12 @@
|
|||||||
|
|
||||||
#include "../../common/cuda_context.cuh" // for CUDAContext
|
#include "../../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../../common/random.h"
|
#include "../../common/random.h"
|
||||||
|
#include "../../data/ellpack_page.cuh" // for EllpackPageImpl
|
||||||
|
#include "../../data/iterative_dmatrix.h" // for IterativeDMatrix
|
||||||
#include "../param.h"
|
#include "../param.h"
|
||||||
#include "gradient_based_sampler.cuh"
|
#include "gradient_based_sampler.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
/*! \brief A functor that returns random weights. */
|
/*! \brief A functor that returns random weights. */
|
||||||
class RandomWeight : public thrust::unary_function<size_t, float> {
|
class RandomWeight : public thrust::unary_function<size_t, float> {
|
||||||
public:
|
public:
|
||||||
@ -58,12 +58,14 @@ struct IsNonZero : public thrust::unary_function<GradientPair, bool> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief A functor that clears the row indexes with empty gradient. */
|
/*! \brief A functor that clears the row indexes with empty gradient. */
|
||||||
struct ClearEmptyRows : public thrust::binary_function<GradientPair, size_t, size_t> {
|
struct ClearEmptyRows : public thrust::binary_function<GradientPair, bst_idx_t, bst_idx_t> {
|
||||||
|
static constexpr bst_idx_t InvalidRow() { return std::numeric_limits<std::size_t>::max(); }
|
||||||
|
|
||||||
XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const {
|
XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const {
|
||||||
if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) {
|
if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) {
|
||||||
return row_index;
|
return row_index;
|
||||||
} else {
|
} else {
|
||||||
return std::numeric_limits<std::size_t>::max();
|
return InvalidRow();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -148,10 +150,9 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
|
|||||||
|
|
||||||
NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
|
NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
|
||||||
|
|
||||||
GradientBasedSample NoSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
return {dmat->Info().num_row_, dmat, gpair};
|
||||||
return {dmat->Info().num_row_, page, gpair};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
||||||
@ -159,37 +160,39 @@ ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
|||||||
|
|
||||||
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
||||||
common::Span<GradientPair> gpair,
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* p_fmat) {
|
||||||
|
std::shared_ptr<EllpackPage> new_page;
|
||||||
if (!page_concatenated_) {
|
if (!page_concatenated_) {
|
||||||
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
||||||
page_.reset(nullptr);
|
|
||||||
bst_idx_t offset = 0;
|
bst_idx_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
||||||
auto page = batch.Impl();
|
auto page = batch.Impl();
|
||||||
if (!page_) {
|
if (!new_page) {
|
||||||
page_ = std::make_unique<EllpackPageImpl>(ctx, page->CutsShared(), page->is_dense,
|
new_page = std::make_shared<EllpackPage>();
|
||||||
page->row_stride, dmat->Info().num_row_);
|
*new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense,
|
||||||
|
page->row_stride, p_fmat->Info().num_row_);
|
||||||
}
|
}
|
||||||
bst_idx_t num_elements = page_->Copy(ctx, page, offset);
|
bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
}
|
}
|
||||||
page_concatenated_ = true;
|
page_concatenated_ = true;
|
||||||
|
this->p_fmat_new_ =
|
||||||
|
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
|
||||||
}
|
}
|
||||||
return {dmat->Info().num_row_, page_.get(), gpair};
|
return {p_fmat->Info().num_row_, this->p_fmat_new_.get(), gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
||||||
: batch_param_{std::move(batch_param)}, subsample_(subsample) {}
|
: batch_param_{std::move(batch_param)}, subsample_{subsample} {}
|
||||||
|
|
||||||
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* p_fmat) {
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<std::size_t>(0),
|
thrust::counting_iterator<std::size_t>(0),
|
||||||
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
return {p_fmat->Info().num_row_, p_fmat, gpair};
|
||||||
return {dmat->Info().num_row_, page, gpair};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
||||||
@ -203,13 +206,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
|||||||
common::Span<GradientPair> gpair,
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
|
||||||
|
std::shared_ptr<EllpackPage> new_page = std::make_shared<EllpackPage>();
|
||||||
|
auto page = new_page->Impl();
|
||||||
|
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<std::size_t>(0),
|
thrust::counting_iterator<std::size_t>(0),
|
||||||
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{});
|
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{});
|
||||||
|
|
||||||
// Count the sampled rows.
|
// Count the sampled rows.
|
||||||
size_t sample_rows =
|
bst_idx_t sample_rows =
|
||||||
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{});
|
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{});
|
||||||
|
|
||||||
// Compact gradient pairs.
|
// Compact gradient pairs.
|
||||||
@ -227,17 +234,25 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
|||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
*page = EllpackPageImpl{ctx, first_page->CutsShared(), first_page->is_dense,
|
||||||
page_.reset(new EllpackPageImpl(ctx, first_page->CutsShared(), first_page->is_dense,
|
first_page->row_stride, sample_rows};
|
||||||
first_page->row_stride, sample_rows));
|
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0);
|
thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0);
|
||||||
for (auto& batch : batch_iterator) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
// Select the metainfo
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
dmat->Info().feature_types.SetDevice(ctx->Device());
|
||||||
|
auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan());
|
||||||
|
compact_row_index_.resize(sample_rows);
|
||||||
|
thrust::copy_if(
|
||||||
|
cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(),
|
||||||
|
[] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); });
|
||||||
|
// Create the new DMatrix
|
||||||
|
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
||||||
|
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
||||||
|
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
|
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
|
||||||
@ -254,14 +269,12 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
|||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
|
||||||
|
|
||||||
// Perform Poisson sampling in place.
|
// Perform Poisson sampling in place.
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
return {n_rows, page, gpair};
|
return {n_rows, dmat, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
||||||
@ -277,6 +290,8 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
common::Span<GradientPair> gpair,
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
std::shared_ptr<EllpackPage> new_page = std::make_shared<EllpackPage>();
|
||||||
|
auto page = new_page->Impl();
|
||||||
bst_idx_t n_rows = dmat->Info().num_row_;
|
bst_idx_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
@ -293,24 +308,33 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||||
// Index the sample rows.
|
// Index the sample rows.
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
IsNonZero());
|
IsNonZero{});
|
||||||
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
|
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
|
||||||
sample_row_index_.begin());
|
sample_row_index_.begin());
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
sample_row_index_.begin(), ClearEmptyRows());
|
sample_row_index_.begin(), ClearEmptyRows{});
|
||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
|
||||||
page_.reset(new EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(),
|
|
||||||
first_page->row_stride, sample_rows});
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
|
||||||
thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0);
|
|
||||||
for (auto& batch : batch_iterator) {
|
|
||||||
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
|
||||||
}
|
|
||||||
|
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
*page = EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(), first_page->row_stride,
|
||||||
|
sample_rows};
|
||||||
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
|
thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0);
|
||||||
|
for (auto& batch : batch_iterator) {
|
||||||
|
page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
|
}
|
||||||
|
// Select the metainfo
|
||||||
|
dmat->Info().feature_types.SetDevice(ctx->Device());
|
||||||
|
auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan());
|
||||||
|
compact_row_index_.resize(sample_rows);
|
||||||
|
thrust::copy_if(
|
||||||
|
cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(),
|
||||||
|
[] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); });
|
||||||
|
// Create the new DMatrix
|
||||||
|
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
||||||
|
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
||||||
|
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
||||||
@ -378,5 +402,4 @@ size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx,
|
|||||||
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
|
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||||
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
||||||
}
|
}
|
||||||
}; // namespace tree
|
}; // namespace xgboost::tree
|
||||||
}; // namespace xgboost
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
#include "../../common/device_vector.cuh" // for device_vector, caching_device_vector
|
#include "../../common/device_vector.cuh" // for device_vector, caching_device_vector
|
||||||
#include "../../data/ellpack_page.cuh" // for EllpackPageImpl
|
#include "../../common/timer.h" // for Monitor
|
||||||
#include "xgboost/base.h" // for GradientPair
|
#include "xgboost/base.h" // for GradientPair
|
||||||
#include "xgboost/data.h" // for BatchParam
|
#include "xgboost/data.h" // for BatchParam
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
@ -13,11 +13,11 @@
|
|||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
struct GradientBasedSample {
|
struct GradientBasedSample {
|
||||||
/*!\brief Number of sampled rows. */
|
/*!\brief Number of sampled rows. */
|
||||||
std::size_t sample_rows;
|
bst_idx_t sample_rows;
|
||||||
/*!\brief Sampled rows in ELLPACK format. */
|
/*!\brief Sampled rows in ELLPACK format. */
|
||||||
EllpackPageImpl const* page;
|
DMatrix* p_fmat;
|
||||||
/*!\brief Gradient pairs for the sampled rows. */
|
/*!\brief Gradient pairs for the sampled rows. */
|
||||||
common::Span<GradientPair> gpair;
|
common::Span<GradientPair const> gpair;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SamplingStrategy {
|
class SamplingStrategy {
|
||||||
@ -48,7 +48,7 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_{nullptr};
|
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
|
||||||
bool page_concatenated_{false};
|
bool page_concatenated_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -74,9 +74,10 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
|
|||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_;
|
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
|
||||||
dh::device_vector<GradientPair> gpair_{};
|
dh::device_vector<GradientPair> gpair_{};
|
||||||
dh::caching_device_vector<size_t> sample_row_index_;
|
dh::caching_device_vector<bst_idx_t> sample_row_index_;
|
||||||
|
dh::device_vector<bst_idx_t> compact_row_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Gradient-based sampling in in-memory mode.. */
|
/*! \brief Gradient-based sampling in in-memory mode.. */
|
||||||
@ -105,9 +106,10 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
|||||||
float subsample_;
|
float subsample_;
|
||||||
dh::device_vector<float> threshold_;
|
dh::device_vector<float> threshold_;
|
||||||
dh::device_vector<float> grad_sum_;
|
dh::device_vector<float> grad_sum_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_;
|
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
|
||||||
dh::device_vector<GradientPair> gpair_;
|
dh::device_vector<GradientPair> gpair_;
|
||||||
dh::device_vector<size_t> sample_row_index_;
|
dh::device_vector<bst_idx_t> sample_row_index_;
|
||||||
|
dh::device_vector<bst_idx_t> compact_row_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Draw a sample of rows from a DMatrix.
|
/*! \brief Draw a sample of rows from a DMatrix.
|
||||||
|
|||||||
@ -119,9 +119,9 @@ struct DeviceSplitCandidate {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) {
|
inline BatchParam HistBatch(TrainParam const& param) {
|
||||||
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
|
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
|
||||||
p.prefetch_copy = prefetch_copy;
|
p.prefetch_copy = true;
|
||||||
p.n_prefetch_batches = 1;
|
p.n_prefetch_batches = 1;
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
@ -134,6 +134,14 @@ inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hes
|
|||||||
ObjInfo const& task) {
|
ObjInfo const& task) {
|
||||||
return BatchParam{p.max_bin, hess, !task.const_hess};
|
return BatchParam{p.max_bin, hess, !task.const_hess};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty parameter to prevent regen, only used to control external memory prefetching.
|
||||||
|
inline BatchParam StaticBatch(bool prefetch_copy) {
|
||||||
|
BatchParam p;
|
||||||
|
p.prefetch_copy = prefetch_copy;
|
||||||
|
p.n_prefetch_batches = 1;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -52,6 +52,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
|||||||
using cuda_impl::ApproxBatch;
|
using cuda_impl::ApproxBatch;
|
||||||
using cuda_impl::HistBatch;
|
using cuda_impl::HistBatch;
|
||||||
|
|
||||||
|
// Both the approx and hist initializes the DMatrix before creating the actual
|
||||||
|
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
|
||||||
|
// parameter to avoid any regen.
|
||||||
|
using cuda_impl::StaticBatch;
|
||||||
|
|
||||||
// GPU tree updater implementation.
|
// GPU tree updater implementation.
|
||||||
struct GPUHistMakerDevice {
|
struct GPUHistMakerDevice {
|
||||||
private:
|
private:
|
||||||
@ -64,6 +69,7 @@ struct GPUHistMakerDevice {
|
|||||||
// node idx for each sample
|
// node idx for each sample
|
||||||
dh::device_vector<bst_node_t> positions_;
|
dh::device_vector<bst_node_t> positions_;
|
||||||
std::unique_ptr<RowPartitioner> row_partitioner_;
|
std::unique_ptr<RowPartitioner> row_partitioner_;
|
||||||
|
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Extra data for each node that is passed to the update position function
|
// Extra data for each node that is passed to the update position function
|
||||||
@ -75,13 +81,12 @@ struct GPUHistMakerDevice {
|
|||||||
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
|
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EllpackPageImpl const* page{nullptr};
|
|
||||||
common::Span<FeatureType const> feature_types;
|
common::Span<FeatureType const> feature_types;
|
||||||
|
|
||||||
DeviceHistogramStorage<> hist{};
|
DeviceHistogramStorage<> hist{};
|
||||||
|
|
||||||
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||||
common::Span<GradientPair> gpair;
|
common::Span<GradientPair const> gpair;
|
||||||
|
|
||||||
dh::device_vector<int> monotone_constraints;
|
dh::device_vector<int> monotone_constraints;
|
||||||
|
|
||||||
@ -99,18 +104,20 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
std::unique_ptr<FeatureGroups> feature_groups;
|
std::unique_ptr<FeatureGroups> feature_groups;
|
||||||
|
|
||||||
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
|
GPUHistMakerDevice(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
|
||||||
common::Span<FeatureType const> _feature_types, bst_idx_t _n_rows,
|
bool is_external_memory, common::Span<FeatureType const> _feature_types,
|
||||||
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
|
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||||
uint32_t n_features, BatchParam batch_param, MetaInfo const& info)
|
BatchParam batch_param, MetaInfo const& info)
|
||||||
: evaluator_{_param, n_features, ctx->Device()},
|
: evaluator_{_param, static_cast<bst_feature_t>(info.num_col_), ctx->Device()},
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
feature_types{_feature_types},
|
feature_types{_feature_types},
|
||||||
param(std::move(_param)),
|
param(std::move(_param)),
|
||||||
column_sampler_(std::move(column_sampler)),
|
column_sampler_(std::move(column_sampler)),
|
||||||
interaction_constraints(param, n_features),
|
interaction_constraints(param, info.num_col_),
|
||||||
info_{info} {
|
info_{info},
|
||||||
sampler = std::make_unique<GradientBasedSampler>(ctx, _n_rows, batch_param, param.subsample,
|
cuts_{std::move(cuts)} {
|
||||||
|
sampler =
|
||||||
|
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
|
||||||
param.sampling_method, is_external_memory);
|
param.sampling_method, is_external_memory);
|
||||||
if (!param.monotone_constraints.empty()) {
|
if (!param.monotone_constraints.empty()) {
|
||||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||||
@ -123,19 +130,19 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
~GPUHistMakerDevice() = default;
|
~GPUHistMakerDevice() = default;
|
||||||
|
|
||||||
void InitFeatureGroupsOnce() {
|
void InitFeatureGroupsOnce(MetaInfo const& info) {
|
||||||
if (!feature_groups) {
|
if (!feature_groups) {
|
||||||
CHECK(page);
|
CHECK(cuts_);
|
||||||
feature_groups = std::make_unique<FeatureGroups>(page->Cuts(), page->is_dense,
|
feature_groups = std::make_unique<FeatureGroups>(*cuts_, info.IsDense(),
|
||||||
dh::MaxSharedMemoryOptin(ctx_->Ordinal()),
|
dh::MaxSharedMemoryOptin(ctx_->Ordinal()),
|
||||||
sizeof(GradientPairPrecise));
|
sizeof(GradientPairPrecise));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset values for each update iteration
|
// Reset values for each update iteration
|
||||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
|
||||||
auto const& info = dmat->Info();
|
auto const& info = p_fmat->Info();
|
||||||
this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(),
|
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
|
||||||
param.colsample_bynode, param.colsample_bylevel,
|
param.colsample_bynode, param.colsample_bylevel,
|
||||||
param.colsample_bytree);
|
param.colsample_bytree);
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||||
@ -148,54 +155,54 @@ struct GPUHistMakerDevice {
|
|||||||
dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
||||||
dh_gpair->Size() * sizeof(GradientPair),
|
dh_gpair->Size() * sizeof(GradientPair),
|
||||||
cudaMemcpyDeviceToDevice));
|
cudaMemcpyDeviceToDevice));
|
||||||
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat);
|
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
|
||||||
page = sample.page;
|
this->gpair = sample.gpair;
|
||||||
gpair = sample.gpair;
|
p_fmat = sample.p_fmat;
|
||||||
|
CHECK(p_fmat->SingleColBlock());
|
||||||
|
|
||||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
this->evaluator_.Reset(*cuts_, feature_types, p_fmat->Info().num_col_, param,
|
||||||
dmat->Info().IsColumnSplit(), ctx_->Device());
|
p_fmat->Info().IsColumnSplit(), ctx_->Device());
|
||||||
|
|
||||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
|
||||||
|
|
||||||
if (!row_partitioner_) {
|
if (!row_partitioner_) {
|
||||||
row_partitioner_ = std::make_unique<RowPartitioner>();
|
row_partitioner_ = std::make_unique<RowPartitioner>();
|
||||||
}
|
}
|
||||||
row_partitioner_->Reset(ctx_, sample.sample_rows, page->base_rowid);
|
row_partitioner_->Reset(ctx_, sample.sample_rows, 0);
|
||||||
CHECK_EQ(page->base_rowid, 0);
|
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(ctx_->Device(), page->Cuts().TotalBins());
|
hist.Init(ctx_->Device(), this->cuts_->TotalBins());
|
||||||
hist.Reset(ctx_);
|
hist.Reset(ctx_);
|
||||||
|
|
||||||
this->InitFeatureGroupsOnce();
|
this->InitFeatureGroupsOnce(info);
|
||||||
|
|
||||||
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
|
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
|
||||||
|
return p_fmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
GPUExpandEntry EvaluateRootSplit(DMatrix const * p_fmat, GradientPairInt64 root_sum) {
|
||||||
int nidx = RegTree::kRoot;
|
int nidx = RegTree::kRoot;
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
auto sampled_features = column_sampler_->GetFeatureSet(0);
|
auto sampled_features = column_sampler_->GetFeatureSet(0);
|
||||||
sampled_features->SetDevice(ctx_->Device());
|
sampled_features->SetDevice(ctx_->Device());
|
||||||
common::Span<bst_feature_t> feature_set =
|
common::Span<bst_feature_t> feature_set =
|
||||||
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
||||||
auto matrix = page->GetDeviceAccessor(ctx_->Device());
|
|
||||||
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{
|
||||||
gpu_param,
|
gpu_param,
|
||||||
*quantiser,
|
*quantiser,
|
||||||
feature_types,
|
feature_types,
|
||||||
matrix.feature_segments,
|
cuts_->cut_ptrs_.ConstDeviceSpan(),
|
||||||
matrix.gidx_fvalue_map,
|
cuts_->cut_values_.ConstDeviceSpan(),
|
||||||
matrix.min_fvalue,
|
cuts_->min_vals_.ConstDeviceSpan(),
|
||||||
matrix.is_dense && !collective::IsDistributed()
|
p_fmat->IsDense() && !collective::IsDistributed()
|
||||||
};
|
};
|
||||||
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
|
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
|
||||||
return split;
|
return split;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EvaluateSplits(const std::vector<GPUExpandEntry>& candidates, const RegTree& tree,
|
void EvaluateSplits(DMatrix const* p_fmat, const std::vector<GPUExpandEntry>& candidates,
|
||||||
common::Span<GPUExpandEntry> pinned_candidates_out) {
|
const RegTree& tree, common::Span<GPUExpandEntry> pinned_candidates_out) {
|
||||||
if (candidates.empty()) {
|
if (candidates.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -204,12 +211,11 @@ struct GPUHistMakerDevice {
|
|||||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size());
|
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size());
|
||||||
std::vector<bst_node_t> nidx(2 * candidates.size());
|
std::vector<bst_node_t> nidx(2 * candidates.size());
|
||||||
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
||||||
auto matrix = page->GetDeviceAccessor(ctx_->Device());
|
EvaluateSplitSharedInputs shared_inputs{
|
||||||
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types,
|
GPUTrainingParam{param}, *quantiser, feature_types, cuts_->cut_ptrs_.ConstDeviceSpan(),
|
||||||
matrix.feature_segments, matrix.gidx_fvalue_map,
|
cuts_->cut_values_.ConstDeviceSpan(), cuts_->min_vals_.ConstDeviceSpan(),
|
||||||
matrix.min_fvalue,
|
|
||||||
// is_dense represents the local data
|
// is_dense represents the local data
|
||||||
matrix.is_dense && !collective::IsDistributed()};
|
p_fmat->IsDense() && !collective::IsDistributed()};
|
||||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||||
// Store the feature set ptrs so they dont go out of scope before the kernel is called
|
// Store the feature set ptrs so they dont go out of scope before the kernel is called
|
||||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
|
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
|
||||||
@ -254,7 +260,7 @@ struct GPUHistMakerDevice {
|
|||||||
this->monitor.Stop(__func__);
|
this->monitor.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildHist(int nidx) {
|
void BuildHist(EllpackPageImpl const* page, int nidx) {
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
auto d_ridx = row_partitioner_->GetRows(nidx);
|
auto d_ridx = row_partitioner_->GetRows(nidx);
|
||||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||||
@ -272,9 +278,8 @@ struct GPUHistMakerDevice {
|
|||||||
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
||||||
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
||||||
|
|
||||||
dh::LaunchN(page->Cuts().TotalBins(), [=] __device__(size_t idx) {
|
dh::LaunchN(cuts_->TotalBins(), [=] __device__(size_t idx) {
|
||||||
d_node_hist_subtraction[idx] =
|
d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
|
||||||
});
|
});
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -366,7 +371,8 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
|
void UpdatePosition(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
|
||||||
|
RegTree* p_tree) {
|
||||||
if (candidates.empty()) {
|
if (candidates.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -390,7 +396,8 @@ struct GPUHistMakerDevice {
|
|||||||
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
|
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||||
|
|
||||||
if (info_.IsColumnSplit()) {
|
if (info_.IsColumnSplit()) {
|
||||||
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
||||||
@ -402,18 +409,20 @@ struct GPUHistMakerDevice {
|
|||||||
nidx, left_nidx, right_nidx, split_data,
|
nidx, left_nidx, right_nidx, split_data,
|
||||||
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
||||||
const NodeSplitData& data) { return go_left(ridx, data); });
|
const NodeSplitData& data) { return go_left(ridx, data); });
|
||||||
|
}
|
||||||
|
|
||||||
monitor.Stop(__func__);
|
monitor.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// After tree update is finished, update the position of all training
|
// After tree update is finished, update the position of all training
|
||||||
// instances to their final leaf. This information is used later to update the
|
// instances to their final leaf. This information is used later to update the
|
||||||
// prediction cache
|
// prediction cache
|
||||||
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task,
|
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples,
|
||||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||||
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
|
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
|
||||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||||
}
|
}
|
||||||
if (p_fmat->Info().num_row_ != row_partitioner_->GetRows().size()) {
|
if (p_fmat->Info().num_row_ != n_samples) {
|
||||||
// Subsampling with external memory. Not supported.
|
// Subsampling with external memory. Not supported.
|
||||||
p_out_position->Resize(0);
|
p_out_position->Resize(0);
|
||||||
positions_.clear();
|
positions_.clear();
|
||||||
@ -438,11 +447,12 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dh::caching_device_vector<uint32_t> categories;
|
dh::caching_device_vector<uint32_t> categories;
|
||||||
dh::CopyToD(p_tree->GetSplitCategories(), &categories);
|
dh::CopyTo(p_tree->GetSplitCategories(), &categories);
|
||||||
auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
|
auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
|
||||||
auto d_categories = dh::ToSpan(categories);
|
auto d_categories = dh::ToSpan(categories);
|
||||||
|
|
||||||
auto d_matrix = page->GetDeviceAccessor(ctx_->Device());
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
|
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||||
|
|
||||||
std::vector<NodeSplitData> split_data(p_tree->NumNodes());
|
std::vector<NodeSplitData> split_data(p_tree->NumNodes());
|
||||||
auto const& tree = *p_tree;
|
auto const& tree = *p_tree;
|
||||||
@ -469,6 +479,8 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
return encode_op(row_id, nidx);
|
return encode_op(row_id, nidx);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
dh::CopyTo(d_out_position, &positions_);
|
dh::CopyTo(d_out_position, &positions_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -508,7 +520,7 @@ struct GPUHistMakerDevice {
|
|||||||
auto rc = collective::GlobalSum(
|
auto rc = collective::GlobalSum(
|
||||||
ctx_, info_,
|
ctx_, info_,
|
||||||
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
|
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||||
page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device()));
|
cuts_->TotalBins() * 2 * num_histograms, ctx_->Device()));
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
|
|
||||||
monitor.Stop("AllReduce");
|
monitor.Stop("AllReduce");
|
||||||
@ -517,7 +529,8 @@ struct GPUHistMakerDevice {
|
|||||||
/**
|
/**
|
||||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||||
*/
|
*/
|
||||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
|
void BuildHistLeftRight(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
|
||||||
|
const RegTree& tree) {
|
||||||
if (candidates.empty()) {
|
if (candidates.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -544,8 +557,10 @@ struct GPUHistMakerDevice {
|
|||||||
// Guaranteed contiguous memory
|
// Guaranteed contiguous memory
|
||||||
hist.AllocateHistograms(ctx_, all_new);
|
hist.AllocateHistograms(ctx_, all_new);
|
||||||
|
|
||||||
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
for (auto nidx : hist_nidx) {
|
for (auto nidx : hist_nidx) {
|
||||||
this->BuildHist(nidx);
|
this->BuildHist(page.Impl(), nidx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduce all in one go
|
// Reduce all in one go
|
||||||
@ -560,7 +575,9 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||||
// Calculate other histogram manually
|
// Calculate other histogram manually
|
||||||
this->BuildHist(subtraction_trick_nidx);
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
|
this->BuildHist(page.Impl(), subtraction_trick_nidx);
|
||||||
|
}
|
||||||
this->AllReduceHist(subtraction_trick_nidx, 1);
|
this->AllReduceHist(subtraction_trick_nidx, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -595,7 +612,7 @@ struct GPUHistMakerDevice {
|
|||||||
std::vector<common::CatBitField::value_type> split_cats;
|
std::vector<common::CatBitField::value_type> split_cats;
|
||||||
|
|
||||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||||
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
|
auto n_bins_feature = cuts_->FeatureBins(candidate.split.findex);
|
||||||
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
|
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
|
||||||
CHECK_LE(split_cats.size(), h_cats.size());
|
CHECK_LE(split_cats.size(), h_cats.size());
|
||||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||||
@ -618,7 +635,7 @@ struct GPUHistMakerDevice {
|
|||||||
parent.RightChild());
|
parent.RightChild());
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry InitRoot(RegTree* p_tree) {
|
GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) {
|
||||||
constexpr bst_node_t kRootNIdx = 0;
|
constexpr bst_node_t kRootNIdx = 0;
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
auto quantiser = *this->quantiser;
|
auto quantiser = *this->quantiser;
|
||||||
@ -635,7 +652,9 @@ struct GPUHistMakerDevice {
|
|||||||
collective::SafeColl(rc);
|
collective::SafeColl(rc);
|
||||||
|
|
||||||
hist.AllocateHistograms(ctx_, {kRootNIdx});
|
hist.AllocateHistograms(ctx_, {kRootNIdx});
|
||||||
this->BuildHist(kRootNIdx);
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
|
this->BuildHist(page.Impl(), kRootNIdx);
|
||||||
|
}
|
||||||
this->AllReduceHist(kRootNIdx, 1);
|
this->AllReduceHist(kRootNIdx, 1);
|
||||||
|
|
||||||
// Remember root stats
|
// Remember root stats
|
||||||
@ -646,24 +665,25 @@ struct GPUHistMakerDevice {
|
|||||||
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
||||||
|
|
||||||
// Generate first split
|
// Generate first split
|
||||||
auto root_entry = this->EvaluateRootSplit(root_sum_quantised);
|
auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised);
|
||||||
return root_entry;
|
return root_entry;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
|
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
|
||||||
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
||||||
bool const is_single_block = p_fmat->SingleColBlock();
|
bool const is_single_block = p_fmat->SingleColBlock();
|
||||||
|
bst_idx_t const n_samples = p_fmat->Info().num_row_;
|
||||||
|
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
// Process maximum 32 nodes at a time
|
// Process maximum 32 nodes at a time
|
||||||
Driver<GPUExpandEntry> driver(param, 32);
|
Driver<GPUExpandEntry> driver(param, 32);
|
||||||
|
|
||||||
monitor.Start("Reset");
|
monitor.Start("Reset");
|
||||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
p_fmat = this->Reset(gpair_all, p_fmat);
|
||||||
monitor.Stop("Reset");
|
monitor.Stop("Reset");
|
||||||
|
|
||||||
monitor.Start("InitRoot");
|
monitor.Start("InitRoot");
|
||||||
driver.Push({this->InitRoot(p_tree)});
|
driver.Push({this->InitRoot(p_fmat, p_tree)});
|
||||||
monitor.Stop("InitRoot");
|
monitor.Stop("InitRoot");
|
||||||
|
|
||||||
// The set of leaves that can be expanded asynchronously
|
// The set of leaves that can be expanded asynchronously
|
||||||
@ -683,11 +703,11 @@ struct GPUHistMakerDevice {
|
|||||||
// Update all the nodes if working with external memory, this saves us from working
|
// Update all the nodes if working with external memory, this saves us from working
|
||||||
// with the finalize position call, which adds an additional iteration and requires
|
// with the finalize position call, which adds an additional iteration and requires
|
||||||
// special handling for row index.
|
// special handling for row index.
|
||||||
this->UpdatePosition(is_single_block ? filtered_expand_set : expand_set, p_tree);
|
this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree);
|
||||||
|
|
||||||
this->BuildHistLeftRight(filtered_expand_set, tree);
|
this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree);
|
||||||
|
|
||||||
this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates);
|
this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates);
|
||||||
dh::DefaultStream().Sync();
|
dh::DefaultStream().Sync();
|
||||||
|
|
||||||
driver.Push(new_candidates.begin(), new_candidates.end());
|
driver.Push(new_candidates.begin(), new_candidates.end());
|
||||||
@ -701,7 +721,7 @@ struct GPUHistMakerDevice {
|
|||||||
if (is_single_block) {
|
if (is_single_block) {
|
||||||
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
|
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
|
||||||
}
|
}
|
||||||
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
|
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -750,9 +770,8 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDataOnce(TrainParam const* param, DMatrix* dmat) {
|
void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) {
|
||||||
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
|
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
|
||||||
info_ = &dmat->Info();
|
|
||||||
|
|
||||||
// Synchronise the column sampling seed
|
// Synchronise the column sampling seed
|
||||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||||
@ -761,13 +780,19 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
std::shared_ptr<common::HistogramCuts const> cuts;
|
||||||
info_->feature_types.SetDevice(ctx_->Device());
|
auto batch = HistBatch(*param);
|
||||||
maker = std::make_unique<GPUHistMakerDevice>(
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, HistBatch(*param))) {
|
||||||
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
|
cuts = page.Impl()->CutsShared();
|
||||||
*param, column_sampler_, info_->num_col_, HistBatch(*param), dmat->Info());
|
}
|
||||||
|
|
||||||
p_last_fmat_ = dmat;
|
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||||
|
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||||
|
maker = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
|
||||||
|
p_fmat->Info().feature_types.ConstDeviceSpan(),
|
||||||
|
*param, column_sampler_, batch, p_fmat->Info());
|
||||||
|
|
||||||
|
p_last_fmat_ = p_fmat;
|
||||||
initialised_ = true;
|
initialised_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -801,8 +826,6 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
MetaInfo* info_{}; // NOLINT
|
|
||||||
|
|
||||||
std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT
|
std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT
|
||||||
|
|
||||||
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
|
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
|
||||||
@ -873,9 +896,15 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
|||||||
|
|
||||||
auto const& info = p_fmat->Info();
|
auto const& info = p_fmat->Info();
|
||||||
info.feature_types.SetDevice(ctx_->Device());
|
info.feature_types.SetDevice(ctx_->Device());
|
||||||
maker_ = std::make_unique<GPUHistMakerDevice>(
|
std::shared_ptr<common::HistogramCuts const> cuts;
|
||||||
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
|
auto batch = ApproxBatch(*param, hess, *task_);
|
||||||
*param, column_sampler_, info.num_col_, ApproxBatch(*param, hess, *task_), p_fmat->Info());
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, batch)) {
|
||||||
|
cuts = page.Impl()->CutsShared();
|
||||||
|
}
|
||||||
|
batch.regen = false; // Regen only at the beginning of the iteration.
|
||||||
|
maker_ = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
|
||||||
|
info.feature_types.ConstDeviceSpan(), *param,
|
||||||
|
column_sampler_, batch, p_fmat->Info());
|
||||||
|
|
||||||
std::size_t t_idx{0};
|
std::size_t t_idx{0};
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/equal.h> // for equal
|
#include <thrust/equal.h> // for equal
|
||||||
|
#include <thrust/iterator/constant_iterator.h> // for make_constant_iterator
|
||||||
#include <thrust/sequence.h> // for sequence
|
#include <thrust/sequence.h> // for sequence
|
||||||
|
|
||||||
#include "../../../src/common/cuda_context.cuh"
|
#include "../../../src/common/cuda_context.cuh"
|
||||||
@ -83,6 +84,14 @@ void TestSlice() {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestWriteAccess(CUDAContext const* cuctx, linalg::TensorView<double, 3> t) {
|
||||||
|
thrust::for_each(cuctx->CTP(), linalg::tbegin(t), linalg::tend(t),
|
||||||
|
[=] XGBOOST_DEVICE(double& v) { v = 0; });
|
||||||
|
auto eq = thrust::equal(cuctx->CTP(), linalg::tcbegin(t), linalg::tcend(t),
|
||||||
|
thrust::make_constant_iterator<double>(0.0), thrust::equal_to<>{});
|
||||||
|
ASSERT_TRUE(eq);
|
||||||
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
||||||
@ -106,5 +115,7 @@ TEST(Linalg, GPUIter) {
|
|||||||
|
|
||||||
bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t));
|
bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t));
|
||||||
ASSERT_TRUE(eq);
|
ASSERT_TRUE(eq);
|
||||||
|
|
||||||
|
TestWriteAccess(cuctx, t);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::linalg
|
} // namespace xgboost::linalg
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023, XGBoost contributors
|
* Copyright 2019-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
|
||||||
@ -15,7 +15,6 @@
|
|||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
TEST(EllpackPage, EmptyDMatrix) {
|
TEST(EllpackPage, EmptyDMatrix) {
|
||||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
||||||
constexpr float kSparsity = 0;
|
constexpr float kSparsity = 0;
|
||||||
@ -242,7 +241,7 @@ TEST(EllpackPage, Compact) {
|
|||||||
namespace {
|
namespace {
|
||||||
class EllpackPageTest : public testing::TestWithParam<float> {
|
class EllpackPageTest : public testing::TestWithParam<float> {
|
||||||
protected:
|
protected:
|
||||||
void Run(float sparsity) {
|
void TestFromGHistIndex(float sparsity) const {
|
||||||
// Only testing with small sample size as the cuts might be different between host and
|
// Only testing with small sample size as the cuts might be different between host and
|
||||||
// device.
|
// device.
|
||||||
size_t n_samples{128}, n_features{13};
|
size_t n_samples{128}, n_features{13};
|
||||||
@ -273,9 +272,25 @@ class EllpackPageTest : public testing::TestWithParam<float> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestNumNonMissing(float sparsity) const {
|
||||||
|
size_t n_samples{1024}, n_features{13};
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
|
||||||
|
auto nnz = p_fmat->Info().num_nonzero_;
|
||||||
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(
|
||||||
|
&ctx, BatchParam{17, tree::TrainParam::DftSparseThreshold()})) {
|
||||||
|
auto ellpack_nnz =
|
||||||
|
page.Impl()->NumNonMissing(&ctx, p_fmat->Info().feature_types.ConstDeviceSpan());
|
||||||
|
ASSERT_EQ(nnz, ellpack_nnz);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); }
|
TEST_P(EllpackPageTest, FromGHistIndex) { this->TestFromGHistIndex(GetParam()); }
|
||||||
|
|
||||||
|
TEST_P(EllpackPageTest, NumNonMissing) { this->TestNumNonMissing(this->GetParam()); }
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f));
|
INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f));
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -355,4 +355,70 @@ TEST(MetaInfo, HostExtend) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); }
|
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); }
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class TestMetaInfo : public ::testing::TestWithParam<std::tuple<bst_target_t, bool>> {
|
||||||
|
public:
|
||||||
|
void Run(Context const *ctx, bst_target_t n_targets) {
|
||||||
|
MetaInfo info;
|
||||||
|
info.num_row_ = 128;
|
||||||
|
info.num_col_ = 3;
|
||||||
|
info.feature_names.resize(info.num_col_, "a");
|
||||||
|
info.labels.Reshape(info.num_row_, n_targets);
|
||||||
|
|
||||||
|
HostDeviceVector<bst_idx_t> ridx(info.num_row_ / 2, 0);
|
||||||
|
ridx.SetDevice(ctx->Device());
|
||||||
|
auto h_ridx = ridx.HostSpan();
|
||||||
|
for (std::size_t i = 0, j = 0; i < ridx.Size(); i++, j += 2) {
|
||||||
|
h_ridx[i] = j;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
info.weights_.Resize(info.num_row_);
|
||||||
|
auto h_w = info.weights_.HostSpan();
|
||||||
|
std::iota(h_w.begin(), h_w.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out = info.Slice(ctx, ctx->IsCPU() ? h_ridx : ridx.ConstDeviceSpan(), /*nnz=*/256);
|
||||||
|
|
||||||
|
ASSERT_EQ(info.labels.Device(), ctx->Device());
|
||||||
|
auto h_y = info.labels.HostView();
|
||||||
|
auto h_y_out = out.labels.HostView();
|
||||||
|
ASSERT_EQ(h_y_out.Shape(0), ridx.Size());
|
||||||
|
ASSERT_EQ(h_y_out.Shape(1), n_targets);
|
||||||
|
|
||||||
|
auto h_w = info.weights_.ConstHostSpan();
|
||||||
|
auto h_w_out = out.weights_.ConstHostSpan();
|
||||||
|
ASSERT_EQ(h_w_out.size(), ridx.Size());
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < ridx.Size(); ++i) {
|
||||||
|
for (bst_target_t t = 0; t < n_targets; ++t) {
|
||||||
|
ASSERT_EQ(h_y_out(i, t), h_y(h_ridx[i], t));
|
||||||
|
}
|
||||||
|
ASSERT_EQ(h_w_out[i], h_w[h_ridx[i]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto v : info.feature_names) {
|
||||||
|
ASSERT_EQ(v, "a");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST_P(TestMetaInfo, Slice) {
|
||||||
|
Context ctx;
|
||||||
|
auto [n_targets, is_cuda] = this->GetParam();
|
||||||
|
if (is_cuda) {
|
||||||
|
ctx = MakeCUDACtx(0);
|
||||||
|
}
|
||||||
|
this->Run(&ctx, n_targets);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Cpu, TestMetaInfo,
|
||||||
|
::testing::Values(std::tuple{1u, false}, std::tuple{3u, false}));
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Gpu, TestMetaInfo,
|
||||||
|
::testing::Values(std::tuple{1u, true}, std::tuple{3u, true}));
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
|
#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
|
||||||
#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
|
#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
|
||||||
@ -11,7 +11,6 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "../../../src/common/linalg_op.h"
|
#include "../../../src/common/linalg_op.h"
|
||||||
#include "../../../src/data/array_interface.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
inline void TestMetaInfoStridedData(DeviceOrd device) {
|
inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||||
|
|||||||
@ -39,11 +39,11 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
|
|||||||
|
|
||||||
if (fixed_size_sampling) {
|
if (fixed_size_sampling) {
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.sample_rows, kRows);
|
||||||
EXPECT_EQ(sample.page->n_rows, kRows);
|
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
||||||
EXPECT_EQ(sample.gpair.size(), kRows);
|
EXPECT_EQ(sample.gpair.size(), kRows);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
|
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
|
||||||
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f);
|
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
|
||||||
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,14 +88,16 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
|
|
||||||
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
auto sampled_page = sample.page;
|
auto p_fmat = sample.p_fmat;
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.sample_rows, kRows);
|
||||||
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
||||||
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
||||||
EXPECT_EQ(sampled_page->n_rows, kRows);
|
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
|
||||||
|
|
||||||
|
ASSERT_EQ(p_fmat->NumBatches(), 1);
|
||||||
|
for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||||
auto h_accessor = sampled_page->GetHostAccessor(&ctx, &h_gidx_buffer);
|
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
|
||||||
|
|
||||||
std::size_t offset = 0;
|
std::size_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
@ -109,6 +111,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(GradientBasedSampler, UniformSampling) {
|
TEST(GradientBasedSampler, UniformSampling) {
|
||||||
constexpr size_t kPageSize = 0;
|
constexpr size_t kPageSize = 0;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user