Predict on Ellpack. (#5327)

* Unify GPU prediction node.
* Add `PageExists`.
* Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
Jiaming Yuan
2020-02-23 06:27:03 +08:00
committed by GitHub
parent 70a91ec3ba
commit 655cf17b60
19 changed files with 320 additions and 134 deletions

View File

@@ -168,12 +168,19 @@ struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin;
int max_bin { 0 };
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;
BatchParam() = default;
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
size_t gpu_page_size = 0) :
gpu_id{device},
max_bin{max_bin},
gpu_batch_nrows{gpu_batch_nrows},
gpu_page_size{gpu_page_size}
{}
inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
@@ -438,6 +445,9 @@ class DMatrix {
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
bool PageExists() const;
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
@@ -493,6 +503,9 @@ class DMatrix {
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool SparsePageExists() const = 0;
};
template<>
@@ -500,6 +513,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
return GetRowBatches();
}
template<>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}
template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetColumnBatches();

View File

@@ -105,7 +105,7 @@ class RegTree : public Model {
/*! \brief tree node */
class Node {
public:
Node() {
XGBOOST_DEVICE Node() {
// assert compact alignment
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
@@ -422,7 +422,7 @@ class RegTree : public Model {
* \param i feature index.
* \return the i-th feature value
*/
bst_float Fvalue(size_t i) const;
bst_float GetFvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
@@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
return data_.size();
}
inline bst_float RegTree::FVec::Fvalue(size_t i) const {
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
}
@@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
}
return nid;
}