Copy data from Ellpack to GHist. (#8215)

This commit is contained in:
Jiaming Yuan 2022-09-06 23:05:49 +08:00 committed by GitHub
parent 7ee10e3dbd
commit 441ffc017a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 466 additions and 112 deletions

27
src/common/algorithm.cuh Normal file
View File

@ -0,0 +1,27 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#include <thrust/binary_search.h> // thrust::upper_bound
#include <thrust/execution_policy.h> // thrust::seq
#include "xgboost/base.h"
#include "xgboost/span.h"
namespace xgboost {
namespace common {
namespace cuda {
template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
} // namespace cuda
} // namespace common
} // namespace xgboost

16
src/common/algorithm.h Normal file
View File

@ -0,0 +1,16 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#include <algorithm> // std::upper_bound
#include <cinttypes> // std::size_t
namespace xgboost {
namespace common {
template <typename It, typename Idx>
auto SegmentId(It first, It last, Idx idx) {
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
return segment_id;
}
} // namespace common
} // namespace xgboost

View File

@ -18,6 +18,7 @@
#include "../data/adapter.h" #include "../data/adapter.h"
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
#include "algorithm.h"
#include "hist_util.h" #include "hist_util.h"
namespace xgboost { namespace xgboost {
@ -135,6 +136,22 @@ class DenseColumnIter : public Column<BinIdxT> {
class ColumnMatrix { class ColumnMatrix {
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold); void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);
template <typename ColumnBinT, typename BinT, typename RIdx>
void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index) {
if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[rid] = bin_id - index_base_[fid];
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
// kMissingId to the index to avoid missing flags.
missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
++num_nonzeros_[fid];
}
}
public: public:
// get number of features // get number of features
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); } bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
@ -144,27 +161,11 @@ class ColumnMatrix {
this->InitStorage(gmat, sparse_threshold); this->InitStorage(gmat, sparse_threshold);
} }
template <typename Batch> /**
void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat, * \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original
size_t base_rowid) { * SparsePage.
// pre-fill index_ for dense columns */
auto n_features = gmat.Features(); void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
if (!any_missing_) {
missing_flags_.resize(feature_offsets_[n_features], false);
// row index is compressed, we need to dispatch it.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_features = n_features,
n_threads = n_threads](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
});
} else {
missing_flags_.resize(feature_offsets_[n_features], true);
SetIndexMixedColumns(base_rowid, batch, gmat, n_features, missing);
}
}
// construct column matrix from GHistIndexMatrix
void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) { int32_t n_threads) {
auto batch = data::SparsePageAdapterBatch{page.GetView()}; auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->InitStorage(gmat, sparse_threshold); this->InitStorage(gmat, sparse_threshold);
@ -172,6 +173,54 @@ class ColumnMatrix {
this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0); this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0);
} }
/**
* \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual
* data.
*
* This function requires a binary search for each bin to get back the feature index
* for those bins.
*/
void InitFromGHist(Context const* ctx, GHistIndexMatrix const& gmat) {
auto n_threads = ctx->Threads();
if (!any_missing_) {
// row index is compressed, we need to dispatch it.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = gmat.Size(), n_threads = n_threads,
n_features = gmat.Features()](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(gmat.base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features,
n_threads);
});
} else {
SetIndexMixedColumns(gmat);
}
}
/**
* \brief Push batch of data for Quantile DMatrix support.
*
* \param batch Input data wrapped inside a adapter batch.
* \param gmat The row-major histogram index that contains index for ALL data.
* \param base_rowid The beginning row index for current batch.
*/
template <typename Batch>
void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat,
size_t base_rowid) {
// pre-fill index_ for dense columns
if (!any_missing_) {
// row index is compressed, we need to dispatch it.
// use base_rowid from input parameter as gmat is a single matrix that contains all
// the histogram index instead of being only a batch.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_threads = n_threads,
n_features = gmat.Features()](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
});
} else {
SetIndexMixedColumns(base_rowid, batch, gmat, missing);
}
}
/* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */
void SetTypeSize(size_t max_bin_per_feat) { void SetTypeSize(size_t max_bin_per_feat) {
if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) { if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) {
@ -210,6 +259,7 @@ class ColumnMatrix {
template <typename RowBinIdxT> template <typename RowBinIdxT>
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples, void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
const size_t n_features, int32_t n_threads) { const size_t n_features, int32_t n_threads) {
missing_flags_.resize(feature_offsets_[n_features], false);
DispatchBinType(bins_type_size_, [&](auto t) { DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t); using ColumnBinT = decltype(t);
auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()), auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
@ -232,29 +282,16 @@ class ColumnMatrix {
*/ */
template <typename Batch> template <typename Batch>
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat, void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
size_t n_features, float missing) { float missing) {
auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[base_rowid]; auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[base_rowid];
num_nonzeros_.resize(n_features, 0);
auto is_valid = data::IsValidFunctor{missing}; auto is_valid = data::IsValidFunctor{missing};
DispatchBinType(bins_type_size_, [&](auto t) { DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t); using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data()); ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
num_nonzeros_.resize(n_features, 0);
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
begin[rid] = bin_id - index_base_[fid];
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
// kMissingId to the index to avoid missing flags.
missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
++num_nonzeros_[fid];
}
};
size_t const batch_size = batch.Size(); size_t const batch_size = batch.Size();
size_t k{0}; size_t k{0};
for (size_t rid = 0; rid < batch_size; ++rid) { for (size_t rid = 0; rid < batch_size; ++rid) {
@ -264,7 +301,7 @@ class ColumnMatrix {
if (is_valid(coo)) { if (is_valid(coo)) {
auto fid = coo.column_idx; auto fid = coo.column_idx;
const uint32_t bin_id = row_index[k]; const uint32_t bin_id = row_index[k];
get_bin_idx(bin_id, rid + base_rowid, fid); SetBinSparse(bin_id, rid + base_rowid, fid, local_index);
++k; ++k;
} }
} }
@ -272,6 +309,40 @@ class ColumnMatrix {
}); });
} }
/**
* \brief Set column index for both dense and sparse columns, but with only GHistMatrix
* available and requires a search for each bin.
*/
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
num_nonzeros_.resize(n_features, 0);
auto const& ptrs = gmat.cut.Ptrs();
DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
auto const batch_size = gmat.Size();
size_t k{0};
for (size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = gmat.row_ptr[ridx];
auto r_end = gmat.row_ptr[ridx + 1];
bst_feature_t fidx{0};
for (size_t j = r_beg; j < r_end; ++j) {
const uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
SetBinSparse(bin_idx, ridx, fidx, local_index);
++k;
}
}
});
}
BinTypeSize GetTypeSize() const { return bins_type_size_; } BinTypeSize GetTypeSize() const { return bins_type_size_; }
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }

View File

@ -35,6 +35,7 @@
#include "xgboost/global_config.h" #include "xgboost/global_config.h"
#include "common.h" #include "common.h"
#include "algorithm.cuh"
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
@ -1556,17 +1557,7 @@ XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIt
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func); return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
} }
template <typename It> using xgboost::common::cuda::SegmentId; // import it for compatibility
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
namespace detail { namespace detail {
template <typename Key, typename KeyOutIt> template <typename Key, typename KeyOutIt>

View File

@ -22,6 +22,7 @@
#include "row_set.h" #include "row_set.h"
#include "threading_utils.h" #include "threading_utils.h"
#include "timer.h" #include "timer.h"
#include "algorithm.h" // SegmentId
namespace xgboost { namespace xgboost {
class GHistIndexMatrix; class GHistIndexMatrix;
@ -130,9 +131,8 @@ class HistogramCuts {
/** /**
* \brief Search the bin index for categorical feature. * \brief Search the bin index for categorical feature.
*/ */
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector<uint32_t> const& ptrs,
auto const &ptrs = this->Ptrs(); std::vector<float> const& vals) const {
auto const &vals = this->Values();
auto end = ptrs.at(fidx + 1) + vals.cbegin(); auto end = ptrs.at(fidx + 1) + vals.cbegin();
auto beg = ptrs[fidx] + vals.cbegin(); auto beg = ptrs[fidx] + vals.cbegin();
// Truncates the value in case it's not perfectly rounded. // Truncates the value in case it's not perfectly rounded.
@ -143,6 +143,11 @@ class HistogramCuts {
} }
return bin_idx; return bin_idx;
} }
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
auto const& ptrs = this->Ptrs();
auto const& vals = this->Values();
return this->SearchCatBin(value, fidx, ptrs, vals);
}
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); } bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
}; };
@ -189,6 +194,28 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) {
* storage class. * storage class.
*/ */
struct Index { struct Index {
// Inside the compressor, bin_idx is the index for cut value across all features. By
// subtracting it with starting pointer of each feature, we can reduce it to smaller
// value and store it with smaller types. Usable only with dense data.
//
// For sparse input we have to store an addition feature index (similar to sparse matrix
// formats like CSR) for each bin in index field to choose the right offset.
template <typename T>
struct CompressBin {
uint32_t const* offsets;
template <typename Bin, typename Feat>
auto operator()(Bin bin_idx, Feat fidx) const {
return static_cast<T>(bin_idx - offsets[fidx]);
}
};
template <typename T>
CompressBin<T> MakeCompressor() const {
uint32_t const* offsets = this->Offset();
return CompressBin<T>{offsets};
}
Index() { SetBinTypeSize(binTypeSize_); } Index() { SetBinTypeSize(binTypeSize_); }
Index(const Index& i) = delete; Index(const Index& i) = delete;
Index& operator=(Index i) = delete; Index& operator=(Index i) = delete;

View File

@ -547,4 +547,15 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
NumSymbols()), NumSymbols()),
feature_types}; feature_types};
} }
EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
common::Span<FeatureType const> feature_types) const {
return {Context::kCpuId,
cuts_,
is_dense,
row_stride,
base_rowid,
n_rows,
common::CompressedIterator<uint32_t>(gidx_buffer.ConstHostPointer(), NumSymbols()),
feature_types};
}
} // namespace xgboost } // namespace xgboost

View File

@ -43,6 +43,11 @@ struct EllpackDeviceAccessor {
base_rowid(base_rowid), base_rowid(base_rowid),
n_rows(n_rows) ,gidx_iter(gidx_iter), n_rows(n_rows) ,gidx_iter(gidx_iter),
feature_types{feature_types} { feature_types{feature_types} {
if (device == Context::kCpuId) {
gidx_fvalue_map = cuts.cut_values_.ConstHostSpan();
feature_segments = cuts.cut_ptrs_.ConstHostSpan();
min_fvalue = cuts.min_vals_.ConstHostSpan();
} else {
cuts.cut_values_.SetDevice(device); cuts.cut_values_.SetDevice(device);
cuts.cut_ptrs_.SetDevice(device); cuts.cut_ptrs_.SetDevice(device);
cuts.min_vals_.SetDevice(device); cuts.min_vals_.SetDevice(device);
@ -50,6 +55,7 @@ struct EllpackDeviceAccessor {
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan(); feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts.min_vals_.ConstDeviceSpan(); min_fvalue = cuts.min_vals_.ConstDeviceSpan();
} }
}
// Get a matrix element, uses binary search for look up Return NaN if missing // Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value // Given a row index and a feature index, returns the corresponding cut value
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const { __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
@ -202,6 +208,7 @@ class EllpackPageImpl {
EllpackDeviceAccessor EllpackDeviceAccessor
GetDeviceAccessor(int device, GetDeviceAccessor(int device,
common::Span<FeatureType const> feature_types = {}) const; common::Span<FeatureType const> feature_types = {}) const;
EllpackDeviceAccessor GetHostAccessor(common::Span<FeatureType const> feature_types = {}) const;
private: private:
/*! /*!

View File

@ -53,7 +53,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
// hist // hist
CHECK(!sorted_sketch); CHECK(!sorted_sketch);
for (auto const &page : p_fmat->GetBatches<SparsePage>()) { for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
this->columns_->Init(page, *this, sparse_thresh, n_threads); this->columns_->InitFromSparse(page, *this, sparse_thresh, n_threads);
} }
} }
} }
@ -66,6 +66,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
max_num_bins(max_bin_per_feat), max_num_bins(max_bin_per_feat),
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
#if !defined(XGBOOST_USE_CUDA)
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
BatchParam const &) {
common::AssertGPUSupport();
}
#endif // defined(XGBOOST_USE_CUDA)
GHistIndexMatrix::~GHistIndexMatrix() = default; GHistIndexMatrix::~GHistIndexMatrix() = default;
@ -99,7 +105,7 @@ GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<Feature
this->PushBatch(batch, ft, n_threads); this->PushBatch(batch, ft, n_threads);
this->columns_ = std::make_unique<common::ColumnMatrix>(); this->columns_ = std::make_unique<common::ColumnMatrix>();
if (!std::isnan(sparse_thresh)) { if (!std::isnan(sparse_thresh)) {
this->columns_->Init(batch, *this, sparse_thresh, n_threads); this->columns_->InitFromSparse(batch, *this, sparse_thresh, n_threads);
} }
} }

111
src/data/gradient_index.cu Normal file
View File

@ -0,0 +1,111 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <memory> // std::unique_ptr
#include "../common/column_matrix.h"
#include "../common/hist_util.h" // Index
#include "ellpack_page.cuh"
#include "gradient_index.h"
#include "xgboost/data.h"
namespace xgboost {
// Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin
// searching. Is there a way to unify the code?
template <typename BinT, typename CompressOffset>
void SetIndexData(Context const* ctx, EllpackPageImpl const* page,
std::vector<size_t>* p_hit_count_tloc, CompressOffset&& get_offset,
GHistIndexMatrix* out) {
auto accessor = page->GetHostAccessor();
auto const kNull = static_cast<bst_bin_t>(accessor.NullValue());
common::Span<BinT> index_data_span = {out->index.data<BinT>(), out->index.Size()};
auto n_bins_total = page->Cuts().TotalBins();
auto& hit_count_tloc = *p_hit_count_tloc;
hit_count_tloc.clear();
hit_count_tloc.resize(ctx->Threads() * n_bins_total, 0);
common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) {
auto tid = omp_get_thread_num();
size_t in_rbegin = page->row_stride * i;
size_t out_rbegin = out->row_ptr[i];
auto r_size = out->row_ptr[i + 1] - out->row_ptr[i];
for (size_t j = 0; j < r_size; ++j) {
auto bin_idx = accessor.gidx_iter[in_rbegin + j];
assert(bin_idx != kNull);
index_data_span[out_rbegin + j] = get_offset(bin_idx, j);
++hit_count_tloc[tid * n_bins_total + bin_idx];
}
});
}
void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page,
std::vector<size_t>* p_out) {
auto& row_ptr = *p_out;
row_ptr.resize(page->Size() + 1, 0);
if (page->is_dense) {
std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride);
} else {
auto accessor = page->GetHostAccessor();
auto const kNull = static_cast<bst_bin_t>(accessor.NullValue());
common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) {
size_t ibegin = page->row_stride * i;
for (size_t j = 0; j < page->row_stride; ++j) {
bst_bin_t bin_idx = accessor.gidx_iter[ibegin + j];
if (bin_idx != kNull) {
row_ptr[i + 1]++;
}
}
});
}
std::partial_sum(row_ptr.begin(), row_ptr.end(), row_ptr.begin());
}
GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info,
EllpackPage const& in_page, BatchParam const& p)
: max_num_bins{p.max_bin} {
auto page = in_page.Impl();
isDense_ = page->is_dense;
CHECK_EQ(info.num_row_, in_page.Size());
this->cut = page->Cuts();
// pull to host early, prevent race condition
this->cut.Ptrs();
this->cut.Values();
this->cut.MinValues();
this->ResizeIndex(info.num_nonzero_, page->is_dense);
if (page->is_dense) {
this->index.SetBinOffset(page->Cuts().Ptrs());
}
auto n_bins_total = page->Cuts().TotalBins();
GetRowPtrFromEllpack(ctx, page, &this->row_ptr);
if (page->is_dense) {
common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
::xgboost::SetIndexData<T>(ctx, page, &hit_count_tloc_, index.MakeCompressor<T>(), this);
});
} else {
// no compression
::xgboost::SetIndexData<uint32_t>(
ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this);
}
this->hit_count.resize(n_bins_total, 0);
this->GatherHitCount(ctx->Threads(), n_bins_total);
// sanity checks
CHECK_EQ(this->Features(), info.num_col_);
CHECK_EQ(this->Size(), info.num_row_);
CHECK(this->cut.cut_ptrs_.HostCanRead());
CHECK(this->cut.cut_values_.HostCanRead());
CHECK(this->cut.min_vals_.HostCanRead());
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, p.sparse_thresh);
this->columns_->InitFromGHist(ctx, *this);
}
} // namespace xgboost

View File

@ -69,7 +69,7 @@ class GHistIndexMatrix {
if (is_valid(elem)) { if (is_valid(elem)) {
bst_bin_t bin_idx{-1}; bst_bin_t bin_idx{-1};
if (common::IsCat(ft, elem.column_idx)) { if (common::IsCat(ft, elem.column_idx)) {
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx); bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
} else { } else {
bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values); bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values);
} }
@ -81,6 +81,17 @@ class GHistIndexMatrix {
}); });
} }
// Gather hit_count from all threads
void GatherHitCount(int32_t n_threads, bst_bin_t n_bins_total) {
CHECK_EQ(hit_count.size(), n_bins_total);
common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < n_threads; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch
}
});
}
template <typename Batch, typename IsValid> template <typename Batch, typename IsValid>
void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid, void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid,
common::Span<FeatureType const> ft) { common::Span<FeatureType const> ft) {
@ -95,33 +106,20 @@ class GHistIndexMatrix {
if (isDense_) { if (isDense_) {
index.SetBinOffset(cut.Ptrs()); index.SetBinOffset(cut.Ptrs());
} }
uint32_t const* offsets = index.Offset();
if (isDense_) { if (isDense_) {
// Inside the lambda functions, bin_idx is the index for cut value across all
// features. By subtracting it with starting pointer of each feature, we can reduce
// it to smaller value and compress it to smaller types.
common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) { common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) {
using T = decltype(dtype); using T = decltype(dtype);
common::Span<T> index_data_span = {index.data<T>(), index.Size()}; common::Span<T> index_data_span = {index.data<T>(), index.Size()};
SetIndexData( SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, index.MakeCompressor<T>());
[offsets](auto bin_idx, auto fidx) { return static_cast<T>(bin_idx - offsets[fidx]); });
}); });
} else { } else {
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is
not reduced */
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index}; common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
// no compression
SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
[](auto idx, auto) { return idx; }); [](auto idx, auto) { return idx; });
} }
this->GatherHitCount(n_threads, n_bins_total);
common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < n_threads; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch
}
});
} }
public: public:
@ -129,12 +127,12 @@ class GHistIndexMatrix {
std::vector<size_t> row_ptr; std::vector<size_t> row_ptr;
/*! \brief The index data */ /*! \brief The index data */
common::Index index; common::Index index;
/*! \brief hit count of each index */ /*! \brief hit count of each index, used for constructing the ColumnMatrix */
std::vector<size_t> hit_count; std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */ /*! \brief The corresponding cuts */
common::HistogramCuts cut; common::HistogramCuts cut;
/*! \brief max_bin for each feature. */ /*! \brief max_bin for each feature. */
size_t max_num_bins; bst_bin_t max_num_bins;
/*! \brief base row index for current page (used by external memory) */ /*! \brief base row index for current page (used by external memory) */
size_t base_rowid{0}; size_t base_rowid{0};
@ -149,6 +147,13 @@ class GHistIndexMatrix {
* for push batch. * for push batch.
*/ */
GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat); GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat);
/**
* \brief Constructor fro Iterative DMatrix where we might copy an existing ellpack page
* to host gradient index.
*/
GHistIndexMatrix(Context const* ctx, MetaInfo const& info, EllpackPage const& page,
BatchParam const& p);
/** /**
* \brief Constructor for external memory. * \brief Constructor for external memory.
*/ */

View File

@ -205,12 +205,11 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) { BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
CheckParam(param); CheckParam(param);
CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training. if (!ghist_) {
Possible solutions: CHECK(ellpack_);
- Use `DMatrix` instead. ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
- Use CPU input for `QuantileDMatrix`. }
- Run training on GPU.
)";
auto begin_iter = auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_)); BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter); return BatchSet<GHistIndexMatrix>(begin_iter);

View File

@ -29,20 +29,17 @@ namespace data {
* `QuantileDMatrix` is an intermediate storage for quantilization results including * `QuantileDMatrix` is an intermediate storage for quantilization results including
* quantile cuts and histogram index. Quantilization is designed to be performed on stream * quantile cuts and histogram index. Quantilization is designed to be performed on stream
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work * of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
* with batches of data. During initializaion, it will walk through the data multiple * with batches of data. During initializaion, it walks through the data multiple times
* times iteratively in order to perform quantilization. This design can help us reduce * iteratively in order to perform quantilization. This design helps us reduce memory
* memory usage significantly by avoiding data concatenation along with removing the CSR * usage significantly by avoiding data concatenation along with removing the CSR matrix
* matrix `SparsePage`. However, it has its limitation (can be fixed if needed): * `SparsePage`. However, it has its limitation (can be fixed if needed):
* *
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a * - It's only supported by hist tree method (both CPU and GPU) since approx requires a
* re-calculation of quantiles for each iteration. We can fix this by retaining a * re-calculation of quantiles for each iteration. We can fix this by retaining a
* reference to the callback if there are feature requests. * reference to the callback if there are feature requests.
* *
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for * - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack. This results into a design that * histogram index while the latter uses only Ellpack.
* we can obtain the GPU format from CPU but the other way around is not yet
* supported. We can search the bin value from ellpack to recover the feature index when
* we support copying data from GPU to CPU.
*/ */
class IterativeDMatrix : public DMatrix { class IterativeDMatrix : public DMatrix {
MetaInfo info_; MetaInfo info_;

View File

@ -23,7 +23,7 @@ TEST(DenseColumn, Test) {
common::OmpGetNumThreads(0)}; common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); column_matrix.InitFromSparse(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
} }
ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_GE(column_matrix.GetTypeSize(), last);
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
@ -69,7 +69,7 @@ TEST(SparseColumn, Test) {
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); column_matrix.InitFromSparse(page, gmat, 1.0, common::OmpGetNumThreads(0));
} }
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype); using T = decltype(dtype);
@ -97,7 +97,7 @@ TEST(DenseColumnWithMissing, Test) {
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); column_matrix.InitFromSparse(page, gmat, 0.2, common::OmpGetNumThreads(0));
} }
ASSERT_TRUE(column_matrix.AnyMissing()); ASSERT_TRUE(column_matrix.AnyMissing());
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {

View File

@ -5,6 +5,7 @@
#include <xgboost/data.h> #include <xgboost/data.h>
#include "../../../src/common/column_matrix.h" #include "../../../src/common/column_matrix.h"
#include "../../../src/common/io.h" // MemoryBufferStream
#include "../../../src/data/gradient_index.h" #include "../../../src/data/gradient_index.h"
#include "../helpers.h" #include "../helpers.h"
@ -107,5 +108,81 @@ TEST(GradientIndex, PushBatch) {
test(0.5f); test(0.5f);
test(0.9f); test(0.9f);
} }
#if defined(XGBOOST_USE_CUDA)
namespace {
class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, float>> {
protected:
void Run(float density, double threshold) {
// Only testing with small sample size as the cuts might be different between host and
// device.
size_t n_samples{128}, n_features{13};
Context ctx;
ctx.gpu_id = 0;
auto Xy = RandomDataGenerator{n_samples, n_features, 1 - density}.GenerateDMatrix(true);
std::unique_ptr<GHistIndexMatrix> from_ellpack;
ASSERT_TRUE(Xy->SingleColBlock());
bst_bin_t constexpr kBins{17};
auto p = BatchParam{kBins, threshold};
for (auto const &page : Xy->GetBatches<EllpackPage>(BatchParam{0, kBins})) {
from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p});
}
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(p)) {
ASSERT_EQ(from_sparse_page.IsDense(), from_ellpack->IsDense());
ASSERT_EQ(from_sparse_page.base_rowid, 0);
ASSERT_EQ(from_sparse_page.base_rowid, from_ellpack->base_rowid);
ASSERT_EQ(from_sparse_page.Size(), from_ellpack->Size());
ASSERT_EQ(from_sparse_page.index.Size(), from_ellpack->index.Size());
auto const &gidx_from_sparse = from_sparse_page.index;
auto const &gidx_from_ellpack = from_ellpack->index;
for (size_t i = 0; i < gidx_from_sparse.Size(); ++i) {
ASSERT_EQ(gidx_from_sparse[i], gidx_from_ellpack[i]);
}
auto const &columns_from_sparse = from_sparse_page.Transpose();
auto const &columns_from_ellpack = from_ellpack->Transpose();
ASSERT_EQ(columns_from_sparse.AnyMissing(), columns_from_ellpack.AnyMissing());
ASSERT_EQ(columns_from_sparse.GetTypeSize(), columns_from_ellpack.GetTypeSize());
ASSERT_EQ(columns_from_sparse.GetNumFeature(), columns_from_ellpack.GetNumFeature());
for (size_t i = 0; i < n_features; ++i) {
ASSERT_EQ(columns_from_sparse.GetColumnType(i), columns_from_ellpack.GetColumnType(i));
}
std::string from_sparse_buf;
{
common::MemoryBufferStream fo{&from_sparse_buf};
columns_from_sparse.Write(&fo);
}
std::string from_ellpack_buf;
{
common::MemoryBufferStream fo{&from_ellpack_buf};
columns_from_sparse.Write(&fo);
}
ASSERT_EQ(from_sparse_buf, from_ellpack_buf);
}
}
};
} // anonymous namespace
TEST_P(GHistIndexMatrixTest, FromEllpack) {
float sparsity;
double thresh;
std::tie(sparsity, thresh) = GetParam();
this->Run(sparsity, thresh);
}
INSTANTIATE_TEST_SUITE_P(GHistIndexMatrix, GHistIndexMatrixTest,
testing::Values(std::make_tuple(1.f, .0), // no missing
std::make_tuple(.2f, .8), // sparse columns
std::make_tuple(.8f, .2), // dense columns
std::make_tuple(1.f, .2), // no missing
std::make_tuple(.5f, .6), // sparse columns
std::make_tuple(.6f, .4))); // dense columns
#endif // defined(XGBOOST_USE_CUDA)
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -37,7 +37,7 @@ TEST(QuantileHist, Partitioner) {
GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads()); GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads());
bst_feature_t const split_ind = 0; bst_feature_t const split_ind = 0;
common::ColumnMatrix column_indices; common::ColumnMatrix column_indices;
column_indices.Init(page, gmat, 0.5, ctx.Threads()); column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads());
{ {
auto min_value = gmat.cut.MinValues()[split_ind]; auto min_value = gmat.cut.MinValues()[split_ind];
RegTree tree; RegTree tree;

View File

@ -32,32 +32,41 @@ class TestDeviceQuantileDMatrix:
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_from_host(self) -> None: @pytest.mark.parametrize(
"tree_method,max_bin", [
("hist", 16), ("gpu_hist", 16), ("hist", 64), ("gpu_hist", 64)
]
)
def test_interoperability(self, tree_method: str, max_bin: int) -> None:
import cupy as cp import cupy as cp
n_samples = 64 n_samples = 64
n_features = 3 n_features = 3
X, y, w = tm.make_batches( X, y, w = tm.make_batches(
n_samples, n_features=n_features, n_batches=1, use_cupy=False n_samples, n_features=n_features, n_batches=1, use_cupy=False
) )
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) # from CPU
booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
booster_0 = xgb.train(
{"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4
)
X[0] = cp.array(X[0]) X[0] = cp.array(X[0])
y[0] = cp.array(y[0]) y[0] = cp.array(y[0])
w[0] = cp.array(w[0]) w[0] = cp.array(w[0])
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) # from GPU
booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
booster_1 = xgb.train(
{"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4
)
cp.testing.assert_allclose( cp.testing.assert_allclose(
booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0]) booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0])
) )
with pytest.raises(ValueError, match="not initialized with CPU"):
# Training on CPU with GPU data is not supported.
xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4)
with pytest.raises(ValueError, match=r"Only.*hist.*"): with pytest.raises(ValueError, match=r"Only.*hist.*"):
xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4) xgb.train(
{"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4
)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_metainfo(self) -> None: def test_metainfo(self) -> None: