Predict on Ellpack. (#5327)

* Unify GPU prediction node.
* Add `PageExists`.
* Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
Jiaming Yuan
2020-02-23 06:27:03 +08:00
committed by GitHub
parent 70a91ec3ba
commit 655cf17b60
19 changed files with 320 additions and 134 deletions

View File

@@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,

View File

@@ -76,6 +76,9 @@ struct EllpackInfo {
size_t NumSymbols() const {
return n_bins + 1;
}
size_t NumFeatures() const {
return min_fvalue.size();
}
};
/** \brief Struct for accessing and manipulating an ellpack matrix on the
@@ -89,7 +92,7 @@ struct EllpackMatrix {
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride;
@@ -103,6 +106,10 @@ struct EllpackMatrix {
info.feature_segments[fidx],
info.feature_segments[fidx + 1]);
}
return gidx;
}
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
auto gidx = GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
}

View File

@@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
}
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
ellpack_page_.reset(new EllpackPage(this, param));
batch_param_ = param;
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));

View File

@@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
BatchParam batch_param_;
bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_);
}
bool SparsePageExists() const override {
return true;
}
};
} // namespace data
} // namespace xgboost

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator.
* \author Tianqi Chen
@@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// Lazily instantiate
if (!ellpack_source_ || batch_param_ != param) {
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param;
}

View File

@@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
std::string cache_info_;
// Store column densities to avoid recalculating
std::vector<float> col_density_;
bool EllpackExists() const override {
return static_cast<bool>(ellpack_source_);
}
bool SparsePageExists() const override {
return static_cast<bool>(row_source_);
}
};
} // namespace data
} // namespace xgboost