Copy data from Ellpack to GHist. (#8215)
This commit is contained in:
parent
7ee10e3dbd
commit
441ffc017a
27
src/common/algorithm.cuh
Normal file
27
src/common/algorithm.cuh
Normal file
@ -0,0 +1,27 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <thrust/binary_search.h> // thrust::upper_bound
|
||||
#include <thrust/execution_policy.h> // thrust::seq
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace cuda {
|
||||
template <typename It>
|
||||
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
|
||||
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first;
|
||||
return segment_id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t XGBOOST_DEVICE SegmentId(Span<T> segments_ptr, size_t idx) {
|
||||
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
16
src/common/algorithm.h
Normal file
16
src/common/algorithm.h
Normal file
@ -0,0 +1,16 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <algorithm> // std::upper_bound
|
||||
#include <cinttypes> // std::size_t
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
template <typename It, typename Idx>
|
||||
auto SegmentId(It first, It last, Idx idx) {
|
||||
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
|
||||
return segment_id;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/gradient_index.h"
|
||||
#include "algorithm.h"
|
||||
#include "hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -135,6 +136,22 @@ class DenseColumnIter : public Column<BinIdxT> {
|
||||
class ColumnMatrix {
|
||||
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);
|
||||
|
||||
template <typename ColumnBinT, typename BinT, typename RIdx>
|
||||
void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index) {
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
||||
begin[rid] = bin_id - index_base_[fid];
|
||||
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
|
||||
// kMissingId to the index to avoid missing flags.
|
||||
missing_flags_[feature_offsets_[fid] + rid] = false;
|
||||
} else {
|
||||
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
||||
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
|
||||
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
|
||||
++num_nonzeros_[fid];
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// get number of features
|
||||
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
|
||||
@ -144,27 +161,11 @@ class ColumnMatrix {
|
||||
this->InitStorage(gmat, sparse_threshold);
|
||||
}
|
||||
|
||||
template <typename Batch>
|
||||
void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat,
|
||||
size_t base_rowid) {
|
||||
// pre-fill index_ for dense columns
|
||||
auto n_features = gmat.Features();
|
||||
if (!any_missing_) {
|
||||
missing_flags_.resize(feature_offsets_[n_features], false);
|
||||
// row index is compressed, we need to dispatch it.
|
||||
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_features = n_features,
|
||||
n_threads = n_threads](auto t) {
|
||||
using RowBinIdxT = decltype(t);
|
||||
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
|
||||
});
|
||||
} else {
|
||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
||||
SetIndexMixedColumns(base_rowid, batch, gmat, n_features, missing);
|
||||
}
|
||||
}
|
||||
|
||||
// construct column matrix from GHistIndexMatrix
|
||||
void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
|
||||
/**
|
||||
* \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original
|
||||
* SparsePage.
|
||||
*/
|
||||
void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
|
||||
int32_t n_threads) {
|
||||
auto batch = data::SparsePageAdapterBatch{page.GetView()};
|
||||
this->InitStorage(gmat, sparse_threshold);
|
||||
@ -172,6 +173,54 @@ class ColumnMatrix {
|
||||
this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual
|
||||
* data.
|
||||
*
|
||||
* This function requires a binary search for each bin to get back the feature index
|
||||
* for those bins.
|
||||
*/
|
||||
void InitFromGHist(Context const* ctx, GHistIndexMatrix const& gmat) {
|
||||
auto n_threads = ctx->Threads();
|
||||
if (!any_missing_) {
|
||||
// row index is compressed, we need to dispatch it.
|
||||
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = gmat.Size(), n_threads = n_threads,
|
||||
n_features = gmat.Features()](auto t) {
|
||||
using RowBinIdxT = decltype(t);
|
||||
SetIndexNoMissing(gmat.base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features,
|
||||
n_threads);
|
||||
});
|
||||
} else {
|
||||
SetIndexMixedColumns(gmat);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Push batch of data for Quantile DMatrix support.
|
||||
*
|
||||
* \param batch Input data wrapped inside a adapter batch.
|
||||
* \param gmat The row-major histogram index that contains index for ALL data.
|
||||
* \param base_rowid The beginning row index for current batch.
|
||||
*/
|
||||
template <typename Batch>
|
||||
void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat,
|
||||
size_t base_rowid) {
|
||||
// pre-fill index_ for dense columns
|
||||
if (!any_missing_) {
|
||||
// row index is compressed, we need to dispatch it.
|
||||
|
||||
// use base_rowid from input parameter as gmat is a single matrix that contains all
|
||||
// the histogram index instead of being only a batch.
|
||||
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_threads = n_threads,
|
||||
n_features = gmat.Features()](auto t) {
|
||||
using RowBinIdxT = decltype(t);
|
||||
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
|
||||
});
|
||||
} else {
|
||||
SetIndexMixedColumns(base_rowid, batch, gmat, missing);
|
||||
}
|
||||
}
|
||||
|
||||
/* Set the number of bytes based on numeric limit of maximum number of bins provided by user */
|
||||
void SetTypeSize(size_t max_bin_per_feat) {
|
||||
if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) {
|
||||
@ -210,6 +259,7 @@ class ColumnMatrix {
|
||||
template <typename RowBinIdxT>
|
||||
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
|
||||
const size_t n_features, int32_t n_threads) {
|
||||
missing_flags_.resize(feature_offsets_[n_features], false);
|
||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||
using ColumnBinT = decltype(t);
|
||||
auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
|
||||
@ -232,29 +282,16 @@ class ColumnMatrix {
|
||||
*/
|
||||
template <typename Batch>
|
||||
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
|
||||
size_t n_features, float missing) {
|
||||
float missing) {
|
||||
auto n_features = gmat.Features();
|
||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
||||
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[base_rowid];
|
||||
auto is_valid = data::IsValidFunctor {missing};
|
||||
num_nonzeros_.resize(n_features, 0);
|
||||
auto is_valid = data::IsValidFunctor{missing};
|
||||
|
||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||
using ColumnBinT = decltype(t);
|
||||
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
|
||||
num_nonzeros_.resize(n_features, 0);
|
||||
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
|
||||
if (type_[fid] == kDenseColumn) {
|
||||
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
|
||||
begin[rid] = bin_id - index_base_[fid];
|
||||
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
|
||||
// kMissingId to the index to avoid missing flags.
|
||||
missing_flags_[feature_offsets_[fid] + rid] = false;
|
||||
} else {
|
||||
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
|
||||
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
|
||||
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
|
||||
++num_nonzeros_[fid];
|
||||
}
|
||||
};
|
||||
|
||||
size_t const batch_size = batch.Size();
|
||||
size_t k{0};
|
||||
for (size_t rid = 0; rid < batch_size; ++rid) {
|
||||
@ -264,7 +301,7 @@ class ColumnMatrix {
|
||||
if (is_valid(coo)) {
|
||||
auto fid = coo.column_idx;
|
||||
const uint32_t bin_id = row_index[k];
|
||||
get_bin_idx(bin_id, rid + base_rowid, fid);
|
||||
SetBinSparse(bin_id, rid + base_rowid, fid, local_index);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
@ -272,6 +309,40 @@ class ColumnMatrix {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Set column index for both dense and sparse columns, but with only GHistMatrix
|
||||
* available and requires a search for each bin.
|
||||
*/
|
||||
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
|
||||
auto n_features = gmat.Features();
|
||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
||||
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
|
||||
num_nonzeros_.resize(n_features, 0);
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
|
||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||
using ColumnBinT = decltype(t);
|
||||
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
|
||||
auto const batch_size = gmat.Size();
|
||||
size_t k{0};
|
||||
|
||||
for (size_t ridx = 0; ridx < batch_size; ++ridx) {
|
||||
auto r_beg = gmat.row_ptr[ridx];
|
||||
auto r_end = gmat.row_ptr[ridx + 1];
|
||||
bst_feature_t fidx{0};
|
||||
for (size_t j = r_beg; j < r_end; ++j) {
|
||||
const uint32_t bin_idx = row_index[k];
|
||||
// find the feature index for current bin.
|
||||
while (bin_idx >= ptrs[fidx + 1]) {
|
||||
fidx++;
|
||||
}
|
||||
SetBinSparse(bin_idx, ridx, fidx, local_index);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
BinTypeSize GetTypeSize() const { return bins_type_size_; }
|
||||
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "algorithm.cuh"
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
@ -1556,17 +1557,7 @@ XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIt
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
|
||||
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
|
||||
1 - first;
|
||||
return segment_id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
|
||||
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
|
||||
}
|
||||
using xgboost::common::cuda::SegmentId; // import it for compatibility
|
||||
|
||||
namespace detail {
|
||||
template <typename Key, typename KeyOutIt>
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "row_set.h"
|
||||
#include "threading_utils.h"
|
||||
#include "timer.h"
|
||||
#include "algorithm.h" // SegmentId
|
||||
|
||||
namespace xgboost {
|
||||
class GHistIndexMatrix;
|
||||
@ -130,9 +131,8 @@ class HistogramCuts {
|
||||
/**
|
||||
* \brief Search the bin index for categorical feature.
|
||||
*/
|
||||
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
|
||||
auto const &ptrs = this->Ptrs();
|
||||
auto const &vals = this->Values();
|
||||
bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector<uint32_t> const& ptrs,
|
||||
std::vector<float> const& vals) const {
|
||||
auto end = ptrs.at(fidx + 1) + vals.cbegin();
|
||||
auto beg = ptrs[fidx] + vals.cbegin();
|
||||
// Truncates the value in case it's not perfectly rounded.
|
||||
@ -143,6 +143,11 @@ class HistogramCuts {
|
||||
}
|
||||
return bin_idx;
|
||||
}
|
||||
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
|
||||
auto const& ptrs = this->Ptrs();
|
||||
auto const& vals = this->Values();
|
||||
return this->SearchCatBin(value, fidx, ptrs, vals);
|
||||
}
|
||||
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
|
||||
};
|
||||
|
||||
@ -189,6 +194,28 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) {
|
||||
* storage class.
|
||||
*/
|
||||
struct Index {
|
||||
// Inside the compressor, bin_idx is the index for cut value across all features. By
|
||||
// subtracting it with starting pointer of each feature, we can reduce it to smaller
|
||||
// value and store it with smaller types. Usable only with dense data.
|
||||
//
|
||||
// For sparse input we have to store an addition feature index (similar to sparse matrix
|
||||
// formats like CSR) for each bin in index field to choose the right offset.
|
||||
template <typename T>
|
||||
struct CompressBin {
|
||||
uint32_t const* offsets;
|
||||
|
||||
template <typename Bin, typename Feat>
|
||||
auto operator()(Bin bin_idx, Feat fidx) const {
|
||||
return static_cast<T>(bin_idx - offsets[fidx]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CompressBin<T> MakeCompressor() const {
|
||||
uint32_t const* offsets = this->Offset();
|
||||
return CompressBin<T>{offsets};
|
||||
}
|
||||
|
||||
Index() { SetBinTypeSize(binTypeSize_); }
|
||||
Index(const Index& i) = delete;
|
||||
Index& operator=(Index i) = delete;
|
||||
|
||||
@ -547,4 +547,15 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||
NumSymbols()),
|
||||
feature_types};
|
||||
}
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
|
||||
common::Span<FeatureType const> feature_types) const {
|
||||
return {Context::kCpuId,
|
||||
cuts_,
|
||||
is_dense,
|
||||
row_stride,
|
||||
base_rowid,
|
||||
n_rows,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.ConstHostPointer(), NumSymbols()),
|
||||
feature_types};
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -43,6 +43,11 @@ struct EllpackDeviceAccessor {
|
||||
base_rowid(base_rowid),
|
||||
n_rows(n_rows) ,gidx_iter(gidx_iter),
|
||||
feature_types{feature_types} {
|
||||
if (device == Context::kCpuId) {
|
||||
gidx_fvalue_map = cuts.cut_values_.ConstHostSpan();
|
||||
feature_segments = cuts.cut_ptrs_.ConstHostSpan();
|
||||
min_fvalue = cuts.min_vals_.ConstHostSpan();
|
||||
} else {
|
||||
cuts.cut_values_.SetDevice(device);
|
||||
cuts.cut_ptrs_.SetDevice(device);
|
||||
cuts.min_vals_.SetDevice(device);
|
||||
@ -50,6 +55,7 @@ struct EllpackDeviceAccessor {
|
||||
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
min_fvalue = cuts.min_vals_.ConstDeviceSpan();
|
||||
}
|
||||
}
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
@ -202,6 +208,7 @@ class EllpackPageImpl {
|
||||
EllpackDeviceAccessor
|
||||
GetDeviceAccessor(int device,
|
||||
common::Span<FeatureType const> feature_types = {}) const;
|
||||
EllpackDeviceAccessor GetHostAccessor(common::Span<FeatureType const> feature_types = {}) const;
|
||||
|
||||
private:
|
||||
/*!
|
||||
|
||||
@ -53,7 +53,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
// hist
|
||||
CHECK(!sorted_sketch);
|
||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->columns_->Init(page, *this, sparse_thresh, n_threads);
|
||||
this->columns_->InitFromSparse(page, *this, sparse_thresh, n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -66,6 +66,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
|
||||
max_num_bins(max_bin_per_feat),
|
||||
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
|
||||
BatchParam const &) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||
|
||||
@ -99,7 +105,7 @@ GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<Feature
|
||||
this->PushBatch(batch, ft, n_threads);
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
if (!std::isnan(sparse_thresh)) {
|
||||
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
|
||||
this->columns_->InitFromSparse(batch, *this, sparse_thresh, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
111
src/data/gradient_index.cu
Normal file
111
src/data/gradient_index.cu
Normal file
@ -0,0 +1,111 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <memory> // std::unique_ptr
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h" // Index
|
||||
#include "ellpack_page.cuh"
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
// Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin
|
||||
// searching. Is there a way to unify the code?
|
||||
template <typename BinT, typename CompressOffset>
|
||||
void SetIndexData(Context const* ctx, EllpackPageImpl const* page,
|
||||
std::vector<size_t>* p_hit_count_tloc, CompressOffset&& get_offset,
|
||||
GHistIndexMatrix* out) {
|
||||
auto accessor = page->GetHostAccessor();
|
||||
auto const kNull = static_cast<bst_bin_t>(accessor.NullValue());
|
||||
|
||||
common::Span<BinT> index_data_span = {out->index.data<BinT>(), out->index.Size()};
|
||||
auto n_bins_total = page->Cuts().TotalBins();
|
||||
|
||||
auto& hit_count_tloc = *p_hit_count_tloc;
|
||||
hit_count_tloc.clear();
|
||||
hit_count_tloc.resize(ctx->Threads() * n_bins_total, 0);
|
||||
|
||||
common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) {
|
||||
auto tid = omp_get_thread_num();
|
||||
size_t in_rbegin = page->row_stride * i;
|
||||
size_t out_rbegin = out->row_ptr[i];
|
||||
auto r_size = out->row_ptr[i + 1] - out->row_ptr[i];
|
||||
for (size_t j = 0; j < r_size; ++j) {
|
||||
auto bin_idx = accessor.gidx_iter[in_rbegin + j];
|
||||
assert(bin_idx != kNull);
|
||||
index_data_span[out_rbegin + j] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc[tid * n_bins_total + bin_idx];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page,
|
||||
std::vector<size_t>* p_out) {
|
||||
auto& row_ptr = *p_out;
|
||||
row_ptr.resize(page->Size() + 1, 0);
|
||||
if (page->is_dense) {
|
||||
std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride);
|
||||
} else {
|
||||
auto accessor = page->GetHostAccessor();
|
||||
auto const kNull = static_cast<bst_bin_t>(accessor.NullValue());
|
||||
|
||||
common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) {
|
||||
size_t ibegin = page->row_stride * i;
|
||||
for (size_t j = 0; j < page->row_stride; ++j) {
|
||||
bst_bin_t bin_idx = accessor.gidx_iter[ibegin + j];
|
||||
if (bin_idx != kNull) {
|
||||
row_ptr[i + 1]++;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
std::partial_sum(row_ptr.begin(), row_ptr.end(), row_ptr.begin());
|
||||
}
|
||||
|
||||
GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info,
|
||||
EllpackPage const& in_page, BatchParam const& p)
|
||||
: max_num_bins{p.max_bin} {
|
||||
auto page = in_page.Impl();
|
||||
isDense_ = page->is_dense;
|
||||
|
||||
CHECK_EQ(info.num_row_, in_page.Size());
|
||||
|
||||
this->cut = page->Cuts();
|
||||
// pull to host early, prevent race condition
|
||||
this->cut.Ptrs();
|
||||
this->cut.Values();
|
||||
this->cut.MinValues();
|
||||
|
||||
this->ResizeIndex(info.num_nonzero_, page->is_dense);
|
||||
if (page->is_dense) {
|
||||
this->index.SetBinOffset(page->Cuts().Ptrs());
|
||||
}
|
||||
|
||||
auto n_bins_total = page->Cuts().TotalBins();
|
||||
GetRowPtrFromEllpack(ctx, page, &this->row_ptr);
|
||||
if (page->is_dense) {
|
||||
common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
::xgboost::SetIndexData<T>(ctx, page, &hit_count_tloc_, index.MakeCompressor<T>(), this);
|
||||
});
|
||||
} else {
|
||||
// no compression
|
||||
::xgboost::SetIndexData<uint32_t>(
|
||||
ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this);
|
||||
}
|
||||
|
||||
this->hit_count.resize(n_bins_total, 0);
|
||||
this->GatherHitCount(ctx->Threads(), n_bins_total);
|
||||
|
||||
// sanity checks
|
||||
CHECK_EQ(this->Features(), info.num_col_);
|
||||
CHECK_EQ(this->Size(), info.num_row_);
|
||||
CHECK(this->cut.cut_ptrs_.HostCanRead());
|
||||
CHECK(this->cut.cut_values_.HostCanRead());
|
||||
CHECK(this->cut.min_vals_.HostCanRead());
|
||||
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, p.sparse_thresh);
|
||||
this->columns_->InitFromGHist(ctx, *this);
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -69,7 +69,7 @@ class GHistIndexMatrix {
|
||||
if (is_valid(elem)) {
|
||||
bst_bin_t bin_idx{-1};
|
||||
if (common::IsCat(ft, elem.column_idx)) {
|
||||
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx);
|
||||
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
|
||||
} else {
|
||||
bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values);
|
||||
}
|
||||
@ -81,6 +81,17 @@ class GHistIndexMatrix {
|
||||
});
|
||||
}
|
||||
|
||||
// Gather hit_count from all threads
|
||||
void GatherHitCount(int32_t n_threads, bst_bin_t n_bins_total) {
|
||||
CHECK_EQ(hit_count.size(), n_bins_total);
|
||||
common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
|
||||
hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid,
|
||||
common::Span<FeatureType const> ft) {
|
||||
@ -95,33 +106,20 @@ class GHistIndexMatrix {
|
||||
if (isDense_) {
|
||||
index.SetBinOffset(cut.Ptrs());
|
||||
}
|
||||
uint32_t const* offsets = index.Offset();
|
||||
if (isDense_) {
|
||||
// Inside the lambda functions, bin_idx is the index for cut value across all
|
||||
// features. By subtracting it with starting pointer of each feature, we can reduce
|
||||
// it to smaller value and compress it to smaller types.
|
||||
common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
common::Span<T> index_data_span = {index.data<T>(), index.Size()};
|
||||
SetIndexData(
|
||||
index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
|
||||
[offsets](auto bin_idx, auto fidx) { return static_cast<T>(bin_idx - offsets[fidx]); });
|
||||
SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
|
||||
index.MakeCompressor<T>());
|
||||
});
|
||||
} else {
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is
|
||||
not reduced */
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
// no compression
|
||||
SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
|
||||
hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
this->GatherHitCount(n_threads, n_bins_total);
|
||||
}
|
||||
|
||||
public:
|
||||
@ -129,12 +127,12 @@ class GHistIndexMatrix {
|
||||
std::vector<size_t> row_ptr;
|
||||
/*! \brief The index data */
|
||||
common::Index index;
|
||||
/*! \brief hit count of each index */
|
||||
/*! \brief hit count of each index, used for constructing the ColumnMatrix */
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief The corresponding cuts */
|
||||
common::HistogramCuts cut;
|
||||
/*! \brief max_bin for each feature. */
|
||||
size_t max_num_bins;
|
||||
bst_bin_t max_num_bins;
|
||||
/*! \brief base row index for current page (used by external memory) */
|
||||
size_t base_rowid{0};
|
||||
|
||||
@ -149,6 +147,13 @@ class GHistIndexMatrix {
|
||||
* for push batch.
|
||||
*/
|
||||
GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat);
|
||||
/**
|
||||
* \brief Constructor fro Iterative DMatrix where we might copy an existing ellpack page
|
||||
* to host gradient index.
|
||||
*/
|
||||
GHistIndexMatrix(Context const* ctx, MetaInfo const& info, EllpackPage const& page,
|
||||
BatchParam const& p);
|
||||
|
||||
/**
|
||||
* \brief Constructor for external memory.
|
||||
*/
|
||||
|
||||
@ -205,12 +205,11 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
|
||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
||||
CheckParam(param);
|
||||
CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training.
|
||||
Possible solutions:
|
||||
- Use `DMatrix` instead.
|
||||
- Use CPU input for `QuantileDMatrix`.
|
||||
- Run training on GPU.
|
||||
)";
|
||||
if (!ghist_) {
|
||||
CHECK(ellpack_);
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
|
||||
}
|
||||
|
||||
auto begin_iter =
|
||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
|
||||
@ -29,20 +29,17 @@ namespace data {
|
||||
* `QuantileDMatrix` is an intermediate storage for quantilization results including
|
||||
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
|
||||
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
|
||||
* with batches of data. During initializaion, it will walk through the data multiple
|
||||
* times iteratively in order to perform quantilization. This design can help us reduce
|
||||
* memory usage significantly by avoiding data concatenation along with removing the CSR
|
||||
* matrix `SparsePage`. However, it has its limitation (can be fixed if needed):
|
||||
* with batches of data. During initializaion, it walks through the data multiple times
|
||||
* iteratively in order to perform quantilization. This design helps us reduce memory
|
||||
* usage significantly by avoiding data concatenation along with removing the CSR matrix
|
||||
* `SparsePage`. However, it has its limitation (can be fixed if needed):
|
||||
*
|
||||
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
|
||||
* re-calculation of quantiles for each iteration. We can fix this by retaining a
|
||||
* reference to the callback if there are feature requests.
|
||||
*
|
||||
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
|
||||
* histogram index while the latter uses only Ellpack. This results into a design that
|
||||
* we can obtain the GPU format from CPU but the other way around is not yet
|
||||
* supported. We can search the bin value from ellpack to recover the feature index when
|
||||
* we support copying data from GPU to CPU.
|
||||
* histogram index while the latter uses only Ellpack.
|
||||
*/
|
||||
class IterativeDMatrix : public DMatrix {
|
||||
MetaInfo info_;
|
||||
|
||||
@ -23,7 +23,7 @@ TEST(DenseColumn, Test) {
|
||||
common::OmpGetNumThreads(0)};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
|
||||
}
|
||||
ASSERT_GE(column_matrix.GetTypeSize(), last);
|
||||
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
|
||||
@ -69,7 +69,7 @@ TEST(SparseColumn, Test) {
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, common::OmpGetNumThreads(0));
|
||||
}
|
||||
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
@ -97,7 +97,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0));
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, common::OmpGetNumThreads(0));
|
||||
}
|
||||
ASSERT_TRUE(column_matrix.AnyMissing());
|
||||
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "../../../src/common/column_matrix.h"
|
||||
#include "../../../src/common/io.h" // MemoryBufferStream
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
@ -107,5 +108,81 @@ TEST(GradientIndex, PushBatch) {
|
||||
test(0.5f);
|
||||
test(0.9f);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
|
||||
namespace {
|
||||
class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, float>> {
|
||||
protected:
|
||||
void Run(float density, double threshold) {
|
||||
// Only testing with small sample size as the cuts might be different between host and
|
||||
// device.
|
||||
size_t n_samples{128}, n_features{13};
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 1 - density}.GenerateDMatrix(true);
|
||||
std::unique_ptr<GHistIndexMatrix> from_ellpack;
|
||||
ASSERT_TRUE(Xy->SingleColBlock());
|
||||
bst_bin_t constexpr kBins{17};
|
||||
auto p = BatchParam{kBins, threshold};
|
||||
for (auto const &page : Xy->GetBatches<EllpackPage>(BatchParam{0, kBins})) {
|
||||
from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p});
|
||||
}
|
||||
|
||||
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(p)) {
|
||||
ASSERT_EQ(from_sparse_page.IsDense(), from_ellpack->IsDense());
|
||||
ASSERT_EQ(from_sparse_page.base_rowid, 0);
|
||||
ASSERT_EQ(from_sparse_page.base_rowid, from_ellpack->base_rowid);
|
||||
ASSERT_EQ(from_sparse_page.Size(), from_ellpack->Size());
|
||||
ASSERT_EQ(from_sparse_page.index.Size(), from_ellpack->index.Size());
|
||||
|
||||
auto const &gidx_from_sparse = from_sparse_page.index;
|
||||
auto const &gidx_from_ellpack = from_ellpack->index;
|
||||
|
||||
for (size_t i = 0; i < gidx_from_sparse.Size(); ++i) {
|
||||
ASSERT_EQ(gidx_from_sparse[i], gidx_from_ellpack[i]);
|
||||
}
|
||||
|
||||
auto const &columns_from_sparse = from_sparse_page.Transpose();
|
||||
auto const &columns_from_ellpack = from_ellpack->Transpose();
|
||||
ASSERT_EQ(columns_from_sparse.AnyMissing(), columns_from_ellpack.AnyMissing());
|
||||
ASSERT_EQ(columns_from_sparse.GetTypeSize(), columns_from_ellpack.GetTypeSize());
|
||||
ASSERT_EQ(columns_from_sparse.GetNumFeature(), columns_from_ellpack.GetNumFeature());
|
||||
for (size_t i = 0; i < n_features; ++i) {
|
||||
ASSERT_EQ(columns_from_sparse.GetColumnType(i), columns_from_ellpack.GetColumnType(i));
|
||||
}
|
||||
|
||||
std::string from_sparse_buf;
|
||||
{
|
||||
common::MemoryBufferStream fo{&from_sparse_buf};
|
||||
columns_from_sparse.Write(&fo);
|
||||
}
|
||||
std::string from_ellpack_buf;
|
||||
{
|
||||
common::MemoryBufferStream fo{&from_ellpack_buf};
|
||||
columns_from_sparse.Write(&fo);
|
||||
}
|
||||
ASSERT_EQ(from_sparse_buf, from_ellpack_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_P(GHistIndexMatrixTest, FromEllpack) {
|
||||
float sparsity;
|
||||
double thresh;
|
||||
std::tie(sparsity, thresh) = GetParam();
|
||||
this->Run(sparsity, thresh);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GHistIndexMatrix, GHistIndexMatrixTest,
|
||||
testing::Values(std::make_tuple(1.f, .0), // no missing
|
||||
std::make_tuple(.2f, .8), // sparse columns
|
||||
std::make_tuple(.8f, .2), // dense columns
|
||||
std::make_tuple(1.f, .2), // no missing
|
||||
std::make_tuple(.5f, .6), // sparse columns
|
||||
std::make_tuple(.6f, .4))); // dense columns
|
||||
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -37,7 +37,7 @@ TEST(QuantileHist, Partitioner) {
|
||||
GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads());
|
||||
bst_feature_t const split_ind = 0;
|
||||
common::ColumnMatrix column_indices;
|
||||
column_indices.Init(page, gmat, 0.5, ctx.Threads());
|
||||
column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads());
|
||||
{
|
||||
auto min_value = gmat.cut.MinValues()[split_ind];
|
||||
RegTree tree;
|
||||
|
||||
@ -32,32 +32,41 @@ class TestDeviceQuantileDMatrix:
|
||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_from_host(self) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"tree_method,max_bin", [
|
||||
("hist", 16), ("gpu_hist", 16), ("hist", 64), ("gpu_hist", 64)
|
||||
]
|
||||
)
|
||||
def test_interoperability(self, tree_method: str, max_bin: int) -> None:
|
||||
import cupy as cp
|
||||
n_samples = 64
|
||||
n_features = 3
|
||||
X, y, w = tm.make_batches(
|
||||
n_samples, n_features=n_features, n_batches=1, use_cupy=False
|
||||
)
|
||||
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
|
||||
booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
|
||||
# from CPU
|
||||
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
|
||||
booster_0 = xgb.train(
|
||||
{"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4
|
||||
)
|
||||
|
||||
X[0] = cp.array(X[0])
|
||||
y[0] = cp.array(y[0])
|
||||
w[0] = cp.array(w[0])
|
||||
|
||||
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
|
||||
booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
|
||||
# from GPU
|
||||
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
|
||||
booster_1 = xgb.train(
|
||||
{"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4
|
||||
)
|
||||
cp.testing.assert_allclose(
|
||||
booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0])
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not initialized with CPU"):
|
||||
# Training on CPU with GPU data is not supported.
|
||||
xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Only.*hist.*"):
|
||||
xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4)
|
||||
xgb.train(
|
||||
{"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_metainfo(self) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user