Predict on Ellpack. (#5327)
* Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
parent
70a91ec3ba
commit
655cf17b60
@ -168,12 +168,19 @@ struct BatchParam {
|
|||||||
/*! \brief The GPU device to use. */
|
/*! \brief The GPU device to use. */
|
||||||
int gpu_id;
|
int gpu_id;
|
||||||
/*! \brief Maximum number of bins per feature for histograms. */
|
/*! \brief Maximum number of bins per feature for histograms. */
|
||||||
int max_bin;
|
int max_bin { 0 };
|
||||||
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
||||||
int gpu_batch_nrows;
|
int gpu_batch_nrows;
|
||||||
/*! \brief Page size for external memory mode. */
|
/*! \brief Page size for external memory mode. */
|
||||||
size_t gpu_page_size;
|
size_t gpu_page_size;
|
||||||
|
BatchParam() = default;
|
||||||
|
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
|
||||||
|
size_t gpu_page_size = 0) :
|
||||||
|
gpu_id{device},
|
||||||
|
max_bin{max_bin},
|
||||||
|
gpu_batch_nrows{gpu_batch_nrows},
|
||||||
|
gpu_page_size{gpu_page_size}
|
||||||
|
{}
|
||||||
inline bool operator!=(const BatchParam& other) const {
|
inline bool operator!=(const BatchParam& other) const {
|
||||||
return gpu_id != other.gpu_id ||
|
return gpu_id != other.gpu_id ||
|
||||||
max_bin != other.max_bin ||
|
max_bin != other.max_bin ||
|
||||||
@ -438,6 +445,9 @@ class DMatrix {
|
|||||||
*/
|
*/
|
||||||
template<typename T>
|
template<typename T>
|
||||||
BatchSet<T> GetBatches(const BatchParam& param = {});
|
BatchSet<T> GetBatches(const BatchParam& param = {});
|
||||||
|
template <typename T>
|
||||||
|
bool PageExists() const;
|
||||||
|
|
||||||
// the following are column meta data, should be able to answer them fast.
|
// the following are column meta data, should be able to answer them fast.
|
||||||
/*! \return Whether the data columns single column block. */
|
/*! \return Whether the data columns single column block. */
|
||||||
virtual bool SingleColBlock() const = 0;
|
virtual bool SingleColBlock() const = 0;
|
||||||
@ -493,6 +503,9 @@ class DMatrix {
|
|||||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||||
|
|
||||||
|
virtual bool EllpackExists() const = 0;
|
||||||
|
virtual bool SparsePageExists() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
@ -500,6 +513,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
|
|||||||
return GetRowBatches();
|
return GetRowBatches();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline bool DMatrix::PageExists<EllpackPage>() const {
|
||||||
|
return this->EllpackExists();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline bool DMatrix::PageExists<SparsePage>() const {
|
||||||
|
return this->SparsePageExists();
|
||||||
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
|
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
|
||||||
return GetColumnBatches();
|
return GetColumnBatches();
|
||||||
|
|||||||
@ -105,7 +105,7 @@ class RegTree : public Model {
|
|||||||
/*! \brief tree node */
|
/*! \brief tree node */
|
||||||
class Node {
|
class Node {
|
||||||
public:
|
public:
|
||||||
Node() {
|
XGBOOST_DEVICE Node() {
|
||||||
// assert compact alignment
|
// assert compact alignment
|
||||||
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
|
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
|
||||||
"Node: 64 bit align");
|
"Node: 64 bit align");
|
||||||
@ -422,7 +422,7 @@ class RegTree : public Model {
|
|||||||
* \param i feature index.
|
* \param i feature index.
|
||||||
* \return the i-th feature value
|
* \return the i-th feature value
|
||||||
*/
|
*/
|
||||||
bst_float Fvalue(size_t i) const;
|
bst_float GetFvalue(size_t i) const;
|
||||||
/*!
|
/*!
|
||||||
* \brief check whether i-th entry is missing
|
* \brief check whether i-th entry is missing
|
||||||
* \param i feature index.
|
* \param i feature index.
|
||||||
@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
|
|||||||
return data_.size();
|
return data_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bst_float RegTree::FVec::Fvalue(size_t i) const {
|
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
|
||||||
return data_[i].fvalue;
|
return data_[i].fvalue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
|||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
while (!(*this)[nid].IsLeaf()) {
|
while (!(*this)[nid].IsLeaf()) {
|
||||||
unsigned split_index = (*this)[nid].SplitIndex();
|
unsigned split_index = (*this)[nid].SplitIndex();
|
||||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||||
}
|
}
|
||||||
return nid;
|
return nid;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||||
const Entry* __restrict__ entries, // One batch of input data
|
const Entry* __restrict__ entries, // One batch of input data
|
||||||
const float* __restrict__ cuts, // HistogramCuts::cut
|
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
||||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
|
||||||
size_t base_row, // batch_row_begin
|
size_t base_row, // batch_row_begin
|
||||||
size_t n_rows,
|
size_t n_rows,
|
||||||
size_t row_stride,
|
size_t row_stride,
|
||||||
|
|||||||
@ -76,6 +76,9 @@ struct EllpackInfo {
|
|||||||
size_t NumSymbols() const {
|
size_t NumSymbols() const {
|
||||||
return n_bins + 1;
|
return n_bins + 1;
|
||||||
}
|
}
|
||||||
|
size_t NumFeatures() const {
|
||||||
|
return min_fvalue.size();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||||
@ -89,7 +92,7 @@ struct EllpackMatrix {
|
|||||||
|
|
||||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||||
// Given a row index and a feature index, returns the corresponding cut value
|
// Given a row index and a feature index, returns the corresponding cut value
|
||||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||||
ridx -= base_rowid;
|
ridx -= base_rowid;
|
||||||
auto row_begin = info.row_stride * ridx;
|
auto row_begin = info.row_stride * ridx;
|
||||||
auto row_end = row_begin + info.row_stride;
|
auto row_end = row_begin + info.row_stride;
|
||||||
@ -103,6 +106,10 @@ struct EllpackMatrix {
|
|||||||
info.feature_segments[fidx],
|
info.feature_segments[fidx],
|
||||||
info.feature_segments[fidx + 1]);
|
info.feature_segments[fidx + 1]);
|
||||||
}
|
}
|
||||||
|
return gidx;
|
||||||
|
}
|
||||||
|
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||||
|
auto gidx = GetBinIndex(ridx, fidx);
|
||||||
if (gidx == -1) {
|
if (gidx == -1) {
|
||||||
return nan("");
|
return nan("");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||||
CHECK_GE(param.gpu_id, 0);
|
|
||||||
CHECK_GE(param.max_bin, 2);
|
|
||||||
// ELLPACK page doesn't exist, generate it
|
// ELLPACK page doesn't exist, generate it
|
||||||
if (!ellpack_page_) {
|
if (!(batch_param_ != BatchParam{})) {
|
||||||
|
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||||
|
}
|
||||||
|
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
|
||||||
|
CHECK_GE(param.gpu_id, 0);
|
||||||
|
CHECK_GE(param.max_bin, 2);
|
||||||
ellpack_page_.reset(new EllpackPage(this, param));
|
ellpack_page_.reset(new EllpackPage(this, param));
|
||||||
|
batch_param_ = param;
|
||||||
}
|
}
|
||||||
auto begin_iter =
|
auto begin_iter =
|
||||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||||
|
|||||||
@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
std::unique_ptr<CSCPage> column_page_;
|
std::unique_ptr<CSCPage> column_page_;
|
||||||
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
||||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||||
|
BatchParam batch_param_;
|
||||||
|
|
||||||
|
bool EllpackExists() const override {
|
||||||
|
return static_cast<bool>(ellpack_page_);
|
||||||
|
}
|
||||||
|
bool SparsePageExists() const override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014 by Contributors
|
* Copyright 2014-2020 by Contributors
|
||||||
* \file sparse_page_dmatrix.cc
|
* \file sparse_page_dmatrix.cc
|
||||||
* \brief The external memory version of Page Iterator.
|
* \brief The external memory version of Page Iterator.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
|||||||
CHECK_GE(param.gpu_id, 0);
|
CHECK_GE(param.gpu_id, 0);
|
||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
// Lazily instantiate
|
// Lazily instantiate
|
||||||
if (!ellpack_source_ || batch_param_ != param) {
|
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
|
||||||
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
|
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
|
||||||
batch_param_ = param;
|
batch_param_ = param;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
std::string cache_info_;
|
std::string cache_info_;
|
||||||
// Store column densities to avoid recalculating
|
// Store column densities to avoid recalculating
|
||||||
std::vector<float> col_density_;
|
std::vector<float> col_density_;
|
||||||
|
|
||||||
|
bool EllpackExists() const override {
|
||||||
|
return static_cast<bool>(ellpack_source_);
|
||||||
|
}
|
||||||
|
bool SparsePageExists() const override {
|
||||||
|
return static_cast<bool>(row_source_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
#include "../gbm/gbtree_model.h"
|
#include "../gbm/gbtree_model.h"
|
||||||
|
#include "../data/ellpack_page.cuh"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
|
|
||||||
@ -22,78 +23,32 @@ namespace predictor {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||||
|
|
||||||
/**
|
struct SparsePageView {
|
||||||
* \struct DevicePredictionNode
|
common::Span<const Entry> d_data;
|
||||||
*
|
common::Span<const bst_row_t> d_row_ptr;
|
||||||
* \brief Packed 16 byte representation of a tree node for use in device
|
|
||||||
* prediction
|
|
||||||
*/
|
|
||||||
struct DevicePredictionNode {
|
|
||||||
XGBOOST_DEVICE DevicePredictionNode()
|
|
||||||
: fidx{-1}, left_child_idx{-1}, right_child_idx{-1} {}
|
|
||||||
|
|
||||||
union NodeValue {
|
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
|
||||||
float leaf_weight;
|
common::Span<const bst_row_t> row_ptr) :
|
||||||
float fvalue;
|
d_data{data}, d_row_ptr{row_ptr} {}
|
||||||
};
|
|
||||||
|
|
||||||
int fidx;
|
|
||||||
int left_child_idx;
|
|
||||||
int right_child_idx;
|
|
||||||
NodeValue val{};
|
|
||||||
|
|
||||||
DevicePredictionNode(const RegTree::Node& n) { // NOLINT
|
|
||||||
static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes");
|
|
||||||
this->left_child_idx = n.LeftChild();
|
|
||||||
this->right_child_idx = n.RightChild();
|
|
||||||
this->fidx = n.SplitIndex();
|
|
||||||
if (n.DefaultLeft()) {
|
|
||||||
fidx |= (1U << 31);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n.IsLeaf()) {
|
|
||||||
this->val.leaf_weight = n.LeafValue();
|
|
||||||
} else {
|
|
||||||
this->val.fvalue = n.SplitCond();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; }
|
|
||||||
|
|
||||||
XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
|
||||||
|
|
||||||
XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; }
|
|
||||||
|
|
||||||
XGBOOST_DEVICE int MissingIdx() const {
|
|
||||||
if (MissingLeft()) {
|
|
||||||
return this->left_child_idx;
|
|
||||||
} else {
|
|
||||||
return this->right_child_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; }
|
|
||||||
|
|
||||||
XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ElementLoader {
|
struct SparsePageLoader {
|
||||||
bool use_shared;
|
bool use_shared;
|
||||||
common::Span<const bst_row_t> d_row_ptr;
|
common::Span<const bst_row_t> d_row_ptr;
|
||||||
common::Span<const Entry> d_data;
|
common::Span<const Entry> d_data;
|
||||||
int num_features;
|
bst_feature_t num_features;
|
||||||
float* smem;
|
float* smem;
|
||||||
size_t entry_start;
|
size_t entry_start;
|
||||||
|
|
||||||
__device__ ElementLoader(bool use_shared, common::Span<const bst_row_t> row_ptr,
|
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||||
common::Span<const Entry> entry, int num_features,
|
bst_row_t num_rows, size_t entry_start)
|
||||||
float* smem, int num_rows, size_t entry_start)
|
|
||||||
: use_shared(use_shared),
|
: use_shared(use_shared),
|
||||||
d_row_ptr(row_ptr),
|
d_row_ptr(data.d_row_ptr),
|
||||||
d_data(entry),
|
d_data(data.d_data),
|
||||||
num_features(num_features),
|
num_features(num_features),
|
||||||
smem(smem),
|
|
||||||
entry_start(entry_start) {
|
entry_start(entry_start) {
|
||||||
|
extern __shared__ float _smem[];
|
||||||
|
smem = _smem;
|
||||||
// Copy instances
|
// Copy instances
|
||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
@ -111,7 +66,7 @@ struct ElementLoader {
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__device__ float GetFvalue(int ridx, int fidx) {
|
__device__ float GetFvalue(int ridx, int fidx) const {
|
||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
return smem[threadIdx.x * num_features + fidx];
|
return smem[threadIdx.x * num_features + fidx];
|
||||||
} else {
|
} else {
|
||||||
@ -141,52 +96,69 @@ struct ElementLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
|
struct EllpackLoader {
|
||||||
ElementLoader* loader) {
|
EllpackMatrix const& matrix;
|
||||||
DevicePredictionNode n = tree[0];
|
XGBOOST_DEVICE EllpackLoader(EllpackMatrix const& m, bool use_shared, bst_feature_t num_features,
|
||||||
|
bst_row_t num_rows, size_t entry_start) : matrix{m} {}
|
||||||
|
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
|
||||||
|
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||||
|
if (gidx == -1) {
|
||||||
|
return nan("");
|
||||||
|
}
|
||||||
|
// The gradient index needs to be shifted by one as min values are not included in the
|
||||||
|
// cuts.
|
||||||
|
if (gidx == matrix.info.feature_segments[fidx]) {
|
||||||
|
return matrix.info.min_fvalue[fidx];
|
||||||
|
}
|
||||||
|
return matrix.info.gidx_fvalue_map[gidx - 1];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Loader>
|
||||||
|
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||||
|
Loader* loader) {
|
||||||
|
RegTree::Node n = tree[0];
|
||||||
while (!n.IsLeaf()) {
|
while (!n.IsLeaf()) {
|
||||||
float fvalue = loader->GetFvalue(ridx, n.GetFidx());
|
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
|
||||||
// Missing value
|
// Missing value
|
||||||
if (isnan(fvalue)) {
|
if (isnan(fvalue)) {
|
||||||
n = tree[n.MissingIdx()];
|
n = tree[n.DefaultChild()];
|
||||||
} else {
|
} else {
|
||||||
if (fvalue < n.GetFvalue()) {
|
if (fvalue < n.SplitCond()) {
|
||||||
n = tree[n.left_child_idx];
|
n = tree[n.LeftChild()];
|
||||||
} else {
|
} else {
|
||||||
n = tree[n.right_child_idx];
|
n = tree[n.RightChild()];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return n.GetWeight();
|
return n.LeafValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int BLOCK_THREADS>
|
template <typename Loader, typename Data>
|
||||||
__global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
__global__ void PredictKernel(Data data,
|
||||||
|
common::Span<const RegTree::Node> d_nodes,
|
||||||
common::Span<float> d_out_predictions,
|
common::Span<float> d_out_predictions,
|
||||||
common::Span<size_t> d_tree_segments,
|
common::Span<size_t> d_tree_segments,
|
||||||
common::Span<int> d_tree_group,
|
common::Span<int> d_tree_group,
|
||||||
common::Span<const bst_row_t> d_row_ptr,
|
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||||
common::Span<const Entry> d_data, size_t tree_begin,
|
|
||||||
size_t tree_end, size_t num_features,
|
|
||||||
size_t num_rows, size_t entry_start,
|
size_t num_rows, size_t entry_start,
|
||||||
bool use_shared, int num_group) {
|
bool use_shared, int num_group) {
|
||||||
extern __shared__ float smem[];
|
|
||||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
|
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
||||||
num_rows, entry_start);
|
|
||||||
if (global_idx >= num_rows) return;
|
if (global_idx >= num_rows) return;
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const DevicePredictionNode* d_tree =
|
const RegTree::Node* d_tree =
|
||||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||||
sum += GetLeafWeight(global_idx, d_tree, &loader);
|
float leaf = GetLeafWeight(global_idx, d_tree, &loader);
|
||||||
|
sum += leaf;
|
||||||
}
|
}
|
||||||
d_out_predictions[global_idx] += sum;
|
d_out_predictions[global_idx] += sum;
|
||||||
} else {
|
} else {
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
int tree_group = d_tree_group[tree_idx];
|
int tree_group = d_tree_group[tree_idx];
|
||||||
const DevicePredictionNode* d_tree =
|
const RegTree::Node* d_tree =
|
||||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||||
d_out_predictions[out_prediction_idx] +=
|
d_out_predictions[out_prediction_idx] +=
|
||||||
@ -198,13 +170,13 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
|||||||
class GPUPredictor : public xgboost::Predictor {
|
class GPUPredictor : public xgboost::Predictor {
|
||||||
private:
|
private:
|
||||||
void InitModel(const gbm::GBTreeModel& model,
|
void InitModel(const gbm::GBTreeModel& model,
|
||||||
const thrust::host_vector<size_t>& h_tree_segments,
|
const thrust::host_vector<size_t>& h_tree_segments,
|
||||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
const thrust::host_vector<RegTree::Node>& h_nodes,
|
||||||
size_t tree_begin, size_t tree_end) {
|
size_t tree_begin, size_t tree_end) {
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
nodes_.resize(h_nodes.size());
|
nodes_.resize(h_nodes.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
sizeof(RegTree::Node) * h_nodes.size(),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
tree_segments_.resize(h_tree_segments.size());
|
tree_segments_.resize(h_tree_segments.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||||
@ -219,15 +191,11 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
this->num_group_ = model.learner_model_param_->num_output_group;
|
this->num_group_ = model.learner_model_param_->num_output_group;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInternal(const SparsePage& batch,
|
void PredictInternal(const SparsePage& batch, size_t num_features,
|
||||||
size_t num_features,
|
|
||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
size_t batch_offset) {
|
size_t batch_offset) {
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
predictions->SetDevice(generic_param_->gpu_id);
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
|
|
||||||
const uint32_t BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
size_t num_rows = batch.Size();
|
size_t num_rows = batch.Size();
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
@ -240,12 +208,29 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
use_shared = false;
|
use_shared = false;
|
||||||
}
|
}
|
||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
|
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
PredictKernel<BLOCK_THREADS>,
|
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||||
|
data,
|
||||||
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||||
|
entry_start, use_shared, this->num_group_);
|
||||||
|
}
|
||||||
|
void PredictInternal(EllpackMatrix const& batch, HostDeviceVector<bst_float>* out_preds,
|
||||||
|
size_t batch_offset) {
|
||||||
|
const uint32_t BLOCK_THREADS = 256;
|
||||||
|
size_t num_rows = batch.n_rows;
|
||||||
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
|
||||||
|
bool use_shared = false;
|
||||||
|
size_t entry_start = 0;
|
||||||
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||||
|
PredictKernel<EllpackLoader, EllpackMatrix>,
|
||||||
|
batch,
|
||||||
|
dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset),
|
||||||
|
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||||
|
this->tree_begin_, this->tree_end_, batch.info.NumFeatures(), num_rows,
|
||||||
entry_start, use_shared, this->num_group_);
|
entry_start, use_shared, this->num_group_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +246,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
h_tree_segments.push_back(sum);
|
h_tree_segments.push_back(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
|
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
|
||||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||||
@ -270,26 +255,31 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DevicePredictInternal(DMatrix* dmat,
|
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
|
||||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||||
size_t tree_end) {
|
size_t tree_end) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
if (tree_end - tree_begin == 0) {
|
if (tree_end - tree_begin == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
monitor_.StartCuda("DevicePredictInternal");
|
monitor_.StartCuda("DevicePredictInternal");
|
||||||
|
|
||||||
InitModel(model, tree_begin, tree_end);
|
InitModel(model, tree_begin, tree_end);
|
||||||
|
out_preds->SetDevice(generic_param_->gpu_id);
|
||||||
|
|
||||||
size_t batch_offset = 0;
|
if (dmat->PageExists<EllpackPage>()) {
|
||||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
size_t batch_offset = 0;
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
this->PredictInternal(page.Impl()->matrix, out_preds, batch_offset);
|
||||||
PredictInternal(batch, model.learner_model_param_->num_feature,
|
batch_offset += page.Impl()->matrix.n_rows;
|
||||||
out_preds, batch_offset);
|
}
|
||||||
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
|
} else {
|
||||||
|
size_t batch_offset = 0;
|
||||||
|
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
|
this->PredictInternal(batch, model.learner_model_param_->num_feature,
|
||||||
|
out_preds, batch_offset);
|
||||||
|
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
monitor_.StopCuda("DevicePredictInternal");
|
monitor_.StopCuda("DevicePredictInternal");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -418,7 +408,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
dh::device_vector<DevicePredictionNode> nodes_;
|
dh::device_vector<RegTree::Node> nodes_;
|
||||||
dh::device_vector<size_t> tree_segments_;
|
dh::device_vector<size_t> tree_segments_;
|
||||||
dh::device_vector<int> tree_group_;
|
dh::device_vector<int> tree_group_;
|
||||||
size_t max_shared_memory_bytes_;
|
size_t max_shared_memory_bytes_;
|
||||||
|
|||||||
@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
|||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
while (!(*this)[nid].IsLeaf()) {
|
while (!(*this)[nid].IsLeaf()) {
|
||||||
split_index = (*this)[nid].SplitIndex();
|
split_index = (*this)[nid].SplitIndex();
|
||||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||||
bst_float new_value = this->node_mean_values_[nid];
|
bst_float new_value = this->node_mean_values_[nid];
|
||||||
// update feature weight
|
// update feature weight
|
||||||
out_contribs[split_index] += new_value - node_value;
|
out_contribs[split_index] += new_value - node_value;
|
||||||
@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
|||||||
unsigned hot_index = 0;
|
unsigned hot_index = 0;
|
||||||
if (feat.IsMissing(split_index)) {
|
if (feat.IsMissing(split_index)) {
|
||||||
hot_index = node.DefaultChild();
|
hot_index = node.DefaultChild();
|
||||||
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
|
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
|
||||||
hot_index = node.LeftChild();
|
hot_index = node.LeftChild();
|
||||||
} else {
|
} else {
|
||||||
hot_index = node.RightChild();
|
hot_index = node.RightChild();
|
||||||
|
|||||||
@ -688,7 +688,7 @@ struct GPUHistMakerDevice {
|
|||||||
[=] __device__(bst_uint ridx) {
|
[=] __device__(bst_uint ridx) {
|
||||||
// given a row index, returns the node id it belongs to
|
// given a row index, returns the node id it belongs to
|
||||||
bst_float cut_value =
|
bst_float cut_value =
|
||||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
|
||||||
// Missing value
|
// Missing value
|
||||||
int new_position = 0;
|
int new_position = 0;
|
||||||
if (isnan(cut_value)) {
|
if (isnan(cut_value)) {
|
||||||
@ -737,7 +737,7 @@ struct GPUHistMakerDevice {
|
|||||||
auto node = d_nodes[position];
|
auto node = d_nodes[position];
|
||||||
|
|
||||||
while (!node.IsLeaf()) {
|
while (!node.IsLeaf()) {
|
||||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||||
// Missing value
|
// Missing value
|
||||||
if (isnan(element)) {
|
if (isnan(element)) {
|
||||||
position = node.DefaultChild();
|
position = node.DefaultChild();
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
// tranverse tree
|
// tranverse tree
|
||||||
while (!tree[pid].IsLeaf()) {
|
while (!tree[pid].IsLeaf()) {
|
||||||
unsigned split_index = tree[pid].SplitIndex();
|
unsigned split_index = tree[pid].SplitIndex();
|
||||||
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||||
gstats[pid].Add(gpair[ridx]);
|
gstats[pid].Add(gpair[ridx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -89,7 +89,7 @@ struct ReadRowFunction {
|
|||||||
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
||||||
|
|
||||||
__device__ void operator()(size_t col) {
|
__device__ void operator()(size_t col) {
|
||||||
auto value = matrix.GetElement(row, col);
|
auto value = matrix.GetFvalue(row, col);
|
||||||
if (isnan(value)) {
|
if (isnan(value)) {
|
||||||
value = -1;
|
value = -1;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,7 +86,7 @@ struct ReadRowFunction {
|
|||||||
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
||||||
|
|
||||||
__device__ void operator()(size_t col) {
|
__device__ void operator()(size_t col) {
|
||||||
auto value = matrix.GetElement(row, col);
|
auto value = matrix.GetFvalue(row, col);
|
||||||
if (isnan(value)) {
|
if (isnan(value)) {
|
||||||
value = -1;
|
value = -1;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,7 +21,6 @@
|
|||||||
#include <xgboost/c_api.h>
|
#include <xgboost/c_api.h>
|
||||||
|
|
||||||
#include "../../src/common/common.h"
|
#include "../../src/common/common.h"
|
||||||
#include "../../src/common/hist_util.h"
|
|
||||||
#include "../../src/gbm/gbtree_model.h"
|
#include "../../src/gbm/gbtree_model.h"
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include "../../src/data/ellpack_page.cuh"
|
#include "../../src/data/ellpack_page.cuh"
|
||||||
|
|||||||
@ -11,11 +11,12 @@
|
|||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../../../src/gbm/gbtree_model.h"
|
#include "../../../src/gbm/gbtree_model.h"
|
||||||
|
#include "test_predictor.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
|
|
||||||
TEST(GpuPredictor, Basic) {
|
TEST(GPUPredictor, Basic) {
|
||||||
auto cpu_lparam = CreateEmptyGenericParam(-1);
|
auto cpu_lparam = CreateEmptyGenericParam(-1);
|
||||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||||
|
|
||||||
@ -56,7 +57,20 @@ TEST(GpuPredictor, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(gpu_predictor, ExternalMemoryTest) {
|
TEST(GPUPredictor, EllpackBasic) {
|
||||||
|
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||||
|
size_t rows = bins * 16;
|
||||||
|
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, bins);
|
||||||
|
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, bins);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GPUPredictor, EllpackTraining) {
|
||||||
|
size_t constexpr kRows { 128 };
|
||||||
|
TestTrainingPrediction(kRows, "gpu_hist");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GPUPredictor, ExternalMemoryTest) {
|
||||||
auto lparam = CreateEmptyGenericParam(0);
|
auto lparam = CreateEmptyGenericParam(0);
|
||||||
std::unique_ptr<Predictor> gpu_predictor =
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
|
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
|
||||||
|
|||||||
@ -2,13 +2,16 @@
|
|||||||
* Copyright 2020 by Contributors
|
* Copyright 2020 by Contributors
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/predictor.h>
|
#include <xgboost/predictor.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
#include <xgboost/host_device_vector.h>
|
||||||
|
#include <xgboost/generic_parameters.h>
|
||||||
|
|
||||||
|
#include "test_predictor.h"
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "xgboost/generic_parameters.h"
|
#include "../../../src/common/io.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(Predictor, PredictionCache) {
|
TEST(Predictor, PredictionCache) {
|
||||||
@ -30,4 +33,52 @@ TEST(Predictor, PredictionCache) {
|
|||||||
add_cache();
|
add_cache();
|
||||||
EXPECT_ANY_THROW(container.Entry(m));
|
EXPECT_ANY_THROW(container.Entry(m));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only run this test when CUDA is enabled.
|
||||||
|
void TestTrainingPrediction(size_t rows, std::string tree_method) {
|
||||||
|
size_t constexpr kCols = 16;
|
||||||
|
size_t constexpr kClasses = 3;
|
||||||
|
size_t constexpr kIters = 3;
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner;
|
||||||
|
auto train = [&](std::string predictor, HostDeviceVector<float>* out) {
|
||||||
|
auto pp_m = CreateDMatrix(rows, kCols, 0);
|
||||||
|
auto p_m = *pp_m;
|
||||||
|
|
||||||
|
auto &h_label = p_m->Info().labels_.HostVector();
|
||||||
|
h_label.resize(rows);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
|
h_label[i] = i % kClasses;
|
||||||
|
}
|
||||||
|
|
||||||
|
learner.reset(Learner::Create({}));
|
||||||
|
learner->SetParam("tree_method", tree_method);
|
||||||
|
learner->SetParam("objective", "multi:softprob");
|
||||||
|
learner->SetParam("predictor", predictor);
|
||||||
|
learner->SetParam("num_feature", std::to_string(kCols));
|
||||||
|
learner->SetParam("num_class", std::to_string(kClasses));
|
||||||
|
learner->Configure();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kIters; ++i) {
|
||||||
|
learner->UpdateOneIter(i, p_m);
|
||||||
|
}
|
||||||
|
learner->Predict(p_m, false, out);
|
||||||
|
delete pp_m;
|
||||||
|
};
|
||||||
|
// Alternate the predictor, CPU predictor can not use ellpack while GPU predictor can
|
||||||
|
// not use CPU histogram index. So it's guaranteed one of the following is not
|
||||||
|
// predicting from histogram index. Note: As of writing only GPU supports predicting
|
||||||
|
// from gradient index, the test is written for future portability.
|
||||||
|
HostDeviceVector<float> predictions_0;
|
||||||
|
train("cpu_predictor", &predictions_0);
|
||||||
|
|
||||||
|
HostDeviceVector<float> predictions_1;
|
||||||
|
train("gpu_predictor", &predictions_1);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
|
EXPECT_NEAR(predictions_1.ConstHostVector()[i],
|
||||||
|
predictions_0.ConstHostVector()[i], kRtEps);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
70
tests/cpp/predictor/test_predictor.h
Normal file
70
tests/cpp/predictor/test_predictor.h
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#ifndef XGBOOST_TEST_PREDICTOR_H_
|
||||||
|
#define XGBOOST_TEST_PREDICTOR_H_
|
||||||
|
|
||||||
|
#include <xgboost/predictor.h>
|
||||||
|
#include <string>
|
||||||
|
#include <cstddef>
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
template <typename Page>
|
||||||
|
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
|
||||||
|
constexpr size_t kCols { 8 }, kClasses { 3 };
|
||||||
|
|
||||||
|
LearnerModelParam param;
|
||||||
|
param.num_feature = kCols;
|
||||||
|
param.num_output_group = kClasses;
|
||||||
|
param.base_score = 0.5;
|
||||||
|
|
||||||
|
auto lparam = CreateEmptyGenericParam(0);
|
||||||
|
|
||||||
|
std::unique_ptr<Predictor> predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create(name, &lparam));
|
||||||
|
predictor->Configure({});
|
||||||
|
|
||||||
|
gbm::GBTreeModel model = CreateTestModel(¶m, kClasses);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto pp_ellpack = CreateDMatrix(rows, kCols, 0);
|
||||||
|
auto p_ellpack = *pp_ellpack;
|
||||||
|
// Use same number of bins as rows.
|
||||||
|
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
|
||||||
|
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pp_precise = CreateDMatrix(rows, kCols, 0);
|
||||||
|
auto p_precise = *pp_precise;
|
||||||
|
|
||||||
|
PredictionCacheEntry approx_out_predictions;
|
||||||
|
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
|
||||||
|
|
||||||
|
PredictionCacheEntry precise_out_predictions;
|
||||||
|
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
|
CHECK_EQ(approx_out_predictions.predictions.HostVector()[i],
|
||||||
|
precise_out_predictions.predictions.HostVector()[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete pp_precise;
|
||||||
|
delete pp_ellpack;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Predictor should never try to create the histogram index by itself. As only
|
||||||
|
// histogram index from training data is valid and predictor doesn't known which
|
||||||
|
// matrix is used for training.
|
||||||
|
auto pp_dmat = CreateDMatrix(rows, kCols, 0);
|
||||||
|
auto p_dmat = *pp_dmat;
|
||||||
|
PredictionCacheEntry precise_out_predictions;
|
||||||
|
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||||
|
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||||
|
delete pp_dmat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestTrainingPrediction(size_t rows, std::string tree_method);
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||||
@ -25,6 +25,19 @@ class Dataset:
|
|||||||
self.w = None
|
self.w = None
|
||||||
self.use_external_memory = use_external_memory
|
self.use_external_memory = use_external_memory
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
a = 'name: {name}\nobjective:{objective}, metric:{metric}, '.format(
|
||||||
|
name=self.name,
|
||||||
|
objective=self.objective,
|
||||||
|
metric=self.metric)
|
||||||
|
b = 'external memory:{use_external_memory}\n'.format(
|
||||||
|
use_external_memory=self.use_external_memory
|
||||||
|
)
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
def get_boston():
|
def get_boston():
|
||||||
data = datasets.load_boston()
|
data = datasets.load_boston()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user