Predict on Ellpack. (#5327)
* Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
@@ -168,12 +168,19 @@ struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id;
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
int max_bin;
|
||||
int max_bin { 0 };
|
||||
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
||||
int gpu_batch_nrows;
|
||||
/*! \brief Page size for external memory mode. */
|
||||
size_t gpu_page_size;
|
||||
|
||||
BatchParam() = default;
|
||||
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
|
||||
size_t gpu_page_size = 0) :
|
||||
gpu_id{device},
|
||||
max_bin{max_bin},
|
||||
gpu_batch_nrows{gpu_batch_nrows},
|
||||
gpu_page_size{gpu_page_size}
|
||||
{}
|
||||
inline bool operator!=(const BatchParam& other) const {
|
||||
return gpu_id != other.gpu_id ||
|
||||
max_bin != other.max_bin ||
|
||||
@@ -438,6 +445,9 @@ class DMatrix {
|
||||
*/
|
||||
template<typename T>
|
||||
BatchSet<T> GetBatches(const BatchParam& param = {});
|
||||
template <typename T>
|
||||
bool PageExists() const;
|
||||
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
@@ -493,6 +503,9 @@ class DMatrix {
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool SparsePageExists() const = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
@@ -500,6 +513,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool DMatrix::PageExists<EllpackPage>() const {
|
||||
return this->EllpackExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool DMatrix::PageExists<SparsePage>() const {
|
||||
return this->SparsePageExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetColumnBatches();
|
||||
|
||||
@@ -105,7 +105,7 @@ class RegTree : public Model {
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
public:
|
||||
Node() {
|
||||
XGBOOST_DEVICE Node() {
|
||||
// assert compact alignment
|
||||
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
|
||||
"Node: 64 bit align");
|
||||
@@ -422,7 +422,7 @@ class RegTree : public Model {
|
||||
* \param i feature index.
|
||||
* \return the i-th feature value
|
||||
*/
|
||||
bst_float Fvalue(size_t i) const;
|
||||
bst_float GetFvalue(size_t i) const;
|
||||
/*!
|
||||
* \brief check whether i-th entry is missing
|
||||
* \param i feature index.
|
||||
@@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
|
||||
return data_.size();
|
||||
}
|
||||
|
||||
inline bst_float RegTree::FVec::Fvalue(size_t i) const {
|
||||
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
|
||||
return data_[i].fvalue;
|
||||
}
|
||||
|
||||
@@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user