Predict on Ellpack. (#5327)
* Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
@@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
|
||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||
const Entry* __restrict__ entries, // One batch of input data
|
||||
const float* __restrict__ cuts, // HistogramCuts::cut
|
||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
||||
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
|
||||
size_t base_row, // batch_row_begin
|
||||
size_t n_rows,
|
||||
size_t row_stride,
|
||||
|
||||
@@ -76,6 +76,9 @@ struct EllpackInfo {
|
||||
size_t NumSymbols() const {
|
||||
return n_bins + 1;
|
||||
}
|
||||
size_t NumFeatures() const {
|
||||
return min_fvalue.size();
|
||||
}
|
||||
};
|
||||
|
||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||
@@ -89,7 +92,7 @@ struct EllpackMatrix {
|
||||
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
ridx -= base_rowid;
|
||||
auto row_begin = info.row_stride * ridx;
|
||||
auto row_end = row_begin + info.row_stride;
|
||||
@@ -103,6 +106,10 @@ struct EllpackMatrix {
|
||||
info.feature_segments[fidx],
|
||||
info.feature_segments[fidx + 1]);
|
||||
}
|
||||
return gidx;
|
||||
}
|
||||
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
|
||||
@@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!ellpack_page_) {
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
ellpack_page_.reset(new EllpackPage(this, param));
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
|
||||
@@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
|
||||
std::unique_ptr<CSCPage> column_page_;
|
||||
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
BatchParam batch_param_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_page_);
|
||||
}
|
||||
bool SparsePageExists() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file sparse_page_dmatrix.cc
|
||||
* \brief The external memory version of Page Iterator.
|
||||
* \author Tianqi Chen
|
||||
@@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// Lazily instantiate
|
||||
if (!ellpack_source_ || batch_param_ != param) {
|
||||
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
|
||||
batch_param_ = param;
|
||||
}
|
||||
|
||||
@@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
|
||||
std::string cache_info_;
|
||||
// Store column densities to avoid recalculating
|
||||
std::vector<float> col_density_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_source_);
|
||||
}
|
||||
bool SparsePageExists() const override {
|
||||
return static_cast<bool>(row_source_);
|
||||
}
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user