Predict on Ellpack. (#5327)
* Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
parent
70a91ec3ba
commit
655cf17b60
@ -168,12 +168,19 @@ struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id;
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
int max_bin;
|
||||
int max_bin { 0 };
|
||||
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
||||
int gpu_batch_nrows;
|
||||
/*! \brief Page size for external memory mode. */
|
||||
size_t gpu_page_size;
|
||||
|
||||
BatchParam() = default;
|
||||
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
|
||||
size_t gpu_page_size = 0) :
|
||||
gpu_id{device},
|
||||
max_bin{max_bin},
|
||||
gpu_batch_nrows{gpu_batch_nrows},
|
||||
gpu_page_size{gpu_page_size}
|
||||
{}
|
||||
inline bool operator!=(const BatchParam& other) const {
|
||||
return gpu_id != other.gpu_id ||
|
||||
max_bin != other.max_bin ||
|
||||
@ -438,6 +445,9 @@ class DMatrix {
|
||||
*/
|
||||
template<typename T>
|
||||
BatchSet<T> GetBatches(const BatchParam& param = {});
|
||||
template <typename T>
|
||||
bool PageExists() const;
|
||||
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
@ -493,6 +503,9 @@ class DMatrix {
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool SparsePageExists() const = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
@ -500,6 +513,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool DMatrix::PageExists<EllpackPage>() const {
|
||||
return this->EllpackExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool DMatrix::PageExists<SparsePage>() const {
|
||||
return this->SparsePageExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
|
||||
return GetColumnBatches();
|
||||
|
||||
@ -105,7 +105,7 @@ class RegTree : public Model {
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
public:
|
||||
Node() {
|
||||
XGBOOST_DEVICE Node() {
|
||||
// assert compact alignment
|
||||
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
|
||||
"Node: 64 bit align");
|
||||
@ -422,7 +422,7 @@ class RegTree : public Model {
|
||||
* \param i feature index.
|
||||
* \return the i-th feature value
|
||||
*/
|
||||
bst_float Fvalue(size_t i) const;
|
||||
bst_float GetFvalue(size_t i) const;
|
||||
/*!
|
||||
* \brief check whether i-th entry is missing
|
||||
* \param i feature index.
|
||||
@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
|
||||
return data_.size();
|
||||
}
|
||||
|
||||
inline bst_float RegTree::FVec::Fvalue(size_t i) const {
|
||||
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
|
||||
return data_[i].fvalue;
|
||||
}
|
||||
|
||||
@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
|
||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||
const Entry* __restrict__ entries, // One batch of input data
|
||||
const float* __restrict__ cuts, // HistogramCuts::cut
|
||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
||||
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
|
||||
size_t base_row, // batch_row_begin
|
||||
size_t n_rows,
|
||||
size_t row_stride,
|
||||
|
||||
@ -76,6 +76,9 @@ struct EllpackInfo {
|
||||
size_t NumSymbols() const {
|
||||
return n_bins + 1;
|
||||
}
|
||||
size_t NumFeatures() const {
|
||||
return min_fvalue.size();
|
||||
}
|
||||
};
|
||||
|
||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||
@ -89,7 +92,7 @@ struct EllpackMatrix {
|
||||
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
ridx -= base_rowid;
|
||||
auto row_begin = info.row_stride * ridx;
|
||||
auto row_end = row_begin + info.row_stride;
|
||||
@ -103,6 +106,10 @@ struct EllpackMatrix {
|
||||
info.feature_segments[fidx],
|
||||
info.feature_segments[fidx + 1]);
|
||||
}
|
||||
return gidx;
|
||||
}
|
||||
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
|
||||
@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!ellpack_page_) {
|
||||
ellpack_page_.reset(new EllpackPage(this, param));
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
|
||||
@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
|
||||
std::unique_ptr<CSCPage> column_page_;
|
||||
std::unique_ptr<SortedCSCPage> sorted_column_page_;
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
BatchParam batch_param_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_page_);
|
||||
}
|
||||
bool SparsePageExists() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file sparse_page_dmatrix.cc
|
||||
* \brief The external memory version of Page Iterator.
|
||||
* \author Tianqi Chen
|
||||
@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// Lazily instantiate
|
||||
if (!ellpack_source_ || batch_param_ != param) {
|
||||
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
|
||||
batch_param_ = param;
|
||||
}
|
||||
|
||||
@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
|
||||
std::string cache_info_;
|
||||
// Store column densities to avoid recalculating
|
||||
std::vector<float> col_density_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_source_);
|
||||
}
|
||||
bool SparsePageExists() const override {
|
||||
return static_cast<bool>(row_source_);
|
||||
}
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../common/common.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
@ -22,78 +23,32 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
/**
|
||||
* \struct DevicePredictionNode
|
||||
*
|
||||
* \brief Packed 16 byte representation of a tree node for use in device
|
||||
* prediction
|
||||
*/
|
||||
struct DevicePredictionNode {
|
||||
XGBOOST_DEVICE DevicePredictionNode()
|
||||
: fidx{-1}, left_child_idx{-1}, right_child_idx{-1} {}
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
|
||||
union NodeValue {
|
||||
float leaf_weight;
|
||||
float fvalue;
|
||||
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
|
||||
common::Span<const bst_row_t> row_ptr) :
|
||||
d_data{data}, d_row_ptr{row_ptr} {}
|
||||
};
|
||||
|
||||
int fidx;
|
||||
int left_child_idx;
|
||||
int right_child_idx;
|
||||
NodeValue val{};
|
||||
|
||||
DevicePredictionNode(const RegTree::Node& n) { // NOLINT
|
||||
static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes");
|
||||
this->left_child_idx = n.LeftChild();
|
||||
this->right_child_idx = n.RightChild();
|
||||
this->fidx = n.SplitIndex();
|
||||
if (n.DefaultLeft()) {
|
||||
fidx |= (1U << 31);
|
||||
}
|
||||
|
||||
if (n.IsLeaf()) {
|
||||
this->val.leaf_weight = n.LeafValue();
|
||||
} else {
|
||||
this->val.fvalue = n.SplitCond();
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; }
|
||||
|
||||
XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
||||
|
||||
XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; }
|
||||
|
||||
XGBOOST_DEVICE int MissingIdx() const {
|
||||
if (MissingLeft()) {
|
||||
return this->left_child_idx;
|
||||
} else {
|
||||
return this->right_child_idx;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; }
|
||||
|
||||
XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; }
|
||||
};
|
||||
|
||||
struct ElementLoader {
|
||||
struct SparsePageLoader {
|
||||
bool use_shared;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
common::Span<const Entry> d_data;
|
||||
int num_features;
|
||||
bst_feature_t num_features;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ ElementLoader(bool use_shared, common::Span<const bst_row_t> row_ptr,
|
||||
common::Span<const Entry> entry, int num_features,
|
||||
float* smem, int num_rows, size_t entry_start)
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
d_row_ptr(row_ptr),
|
||||
d_data(entry),
|
||||
d_row_ptr(data.d_row_ptr),
|
||||
d_data(data.d_data),
|
||||
num_features(num_features),
|
||||
smem(smem),
|
||||
entry_start(entry_start) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@ -111,7 +66,7 @@ struct ElementLoader {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__device__ float GetFvalue(int ridx, int fidx) {
|
||||
__device__ float GetFvalue(int ridx, int fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * num_features + fidx];
|
||||
} else {
|
||||
@ -141,52 +96,69 @@ struct ElementLoader {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
|
||||
ElementLoader* loader) {
|
||||
DevicePredictionNode n = tree[0];
|
||||
struct EllpackLoader {
|
||||
EllpackMatrix const& matrix;
|
||||
XGBOOST_DEVICE EllpackLoader(EllpackMatrix const& m, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start) : matrix{m} {}
|
||||
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
|
||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
}
|
||||
// The gradient index needs to be shifted by one as min values are not included in the
|
||||
// cuts.
|
||||
if (gidx == matrix.info.feature_segments[fidx]) {
|
||||
return matrix.info.min_fvalue[fidx];
|
||||
}
|
||||
return matrix.info.gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
Loader* loader) {
|
||||
RegTree::Node n = tree[0];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetFvalue(ridx, n.GetFidx());
|
||||
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(fvalue)) {
|
||||
n = tree[n.MissingIdx()];
|
||||
n = tree[n.DefaultChild()];
|
||||
} else {
|
||||
if (fvalue < n.GetFvalue()) {
|
||||
n = tree[n.left_child_idx];
|
||||
if (fvalue < n.SplitCond()) {
|
||||
n = tree[n.LeftChild()];
|
||||
} else {
|
||||
n = tree[n.right_child_idx];
|
||||
n = tree[n.RightChild()];
|
||||
}
|
||||
}
|
||||
}
|
||||
return n.GetWeight();
|
||||
return n.LeafValue();
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void PredictKernel(Data data,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t> d_tree_segments,
|
||||
common::Span<int> d_tree_group,
|
||||
common::Span<const bst_row_t> d_row_ptr,
|
||||
common::Span<const Entry> d_data, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features,
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start,
|
||||
bool use_shared, int num_group) {
|
||||
extern __shared__ float smem[];
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
|
||||
num_rows, entry_start);
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
||||
if (global_idx >= num_rows) return;
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const DevicePredictionNode* d_tree =
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
sum += GetLeafWeight(global_idx, d_tree, &loader);
|
||||
float leaf = GetLeafWeight(global_idx, d_tree, &loader);
|
||||
sum += leaf;
|
||||
}
|
||||
d_out_predictions[global_idx] += sum;
|
||||
} else {
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
int tree_group = d_tree_group[tree_idx];
|
||||
const DevicePredictionNode* d_tree =
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||
d_out_predictions[out_prediction_idx] +=
|
||||
@ -199,12 +171,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
const thrust::host_vector<RegTree::Node>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
sizeof(RegTree::Node) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
@ -219,15 +191,11 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
this->num_group_ = model.learner_model_param_->num_output_group;
|
||||
}
|
||||
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
void PredictInternal(const SparsePage& batch, size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
predictions->SetDevice(generic_param_->gpu_id);
|
||||
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
@ -240,12 +208,29 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<BLOCK_THREADS>,
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||
this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
void PredictInternal(EllpackMatrix const& batch, HostDeviceVector<bst_float>* out_preds,
|
||||
size_t batch_offset) {
|
||||
const uint32_t BLOCK_THREADS = 256;
|
||||
size_t num_rows = batch.n_rows;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
bool use_shared = false;
|
||||
size_t entry_start = 0;
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||
PredictKernel<EllpackLoader, EllpackMatrix>,
|
||||
batch,
|
||||
dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||
this->tree_begin_, this->tree_end_, batch.info.NumFeatures(), num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
@ -261,7 +246,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
|
||||
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
@ -270,26 +255,31 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||
size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
monitor_.StartCuda("DevicePredictInternal");
|
||||
|
||||
InitModel(model, tree_begin, tree_end);
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
|
||||
if (dmat->PageExists<EllpackPage>()) {
|
||||
size_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||
this->PredictInternal(page.Impl()->matrix, out_preds, batch_offset);
|
||||
batch_offset += page.Impl()->matrix.n_rows;
|
||||
}
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
PredictInternal(batch, model.learner_model_param_->num_feature,
|
||||
this->PredictInternal(batch, model.learner_model_param_->num_feature,
|
||||
out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
|
||||
}
|
||||
|
||||
}
|
||||
monitor_.StopCuda("DevicePredictInternal");
|
||||
}
|
||||
|
||||
@ -418,7 +408,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
common::Monitor monitor_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<RegTree::Node> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
|
||||
@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
unsigned hot_index = 0;
|
||||
if (feat.IsMissing(split_index)) {
|
||||
hot_index = node.DefaultChild();
|
||||
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
|
||||
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
|
||||
hot_index = node.LeftChild();
|
||||
} else {
|
||||
hot_index = node.RightChild();
|
||||
|
||||
@ -688,7 +688,7 @@ struct GPUHistMakerDevice {
|
||||
[=] __device__(bst_uint ridx) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
if (isnan(cut_value)) {
|
||||
@ -737,7 +737,7 @@ struct GPUHistMakerDevice {
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
|
||||
@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
// tranverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
unsigned split_index = tree[pid].SplitIndex();
|
||||
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,7 +89,7 @@ struct ReadRowFunction {
|
||||
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
||||
|
||||
__device__ void operator()(size_t col) {
|
||||
auto value = matrix.GetElement(row, col);
|
||||
auto value = matrix.GetFvalue(row, col);
|
||||
if (isnan(value)) {
|
||||
value = -1;
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@ struct ReadRowFunction {
|
||||
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
|
||||
|
||||
__device__ void operator()(size_t col) {
|
||||
auto value = matrix.GetElement(row, col);
|
||||
auto value = matrix.GetFvalue(row, col);
|
||||
if (isnan(value)) {
|
||||
value = -1;
|
||||
}
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include "../../src/common/common.h"
|
||||
#include "../../src/common/hist_util.h"
|
||||
#include "../../src/gbm/gbtree_model.h"
|
||||
#if defined(__CUDACC__)
|
||||
#include "../../src/data/ellpack_page.cuh"
|
||||
|
||||
@ -11,11 +11,12 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "test_predictor.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
|
||||
TEST(GpuPredictor, Basic) {
|
||||
TEST(GPUPredictor, Basic) {
|
||||
auto cpu_lparam = CreateEmptyGenericParam(-1);
|
||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||
|
||||
@ -56,7 +57,20 @@ TEST(GpuPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(gpu_predictor, ExternalMemoryTest) {
|
||||
TEST(GPUPredictor, EllpackBasic) {
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
size_t rows = bins * 16;
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, bins);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, bins);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackTraining) {
|
||||
size_t constexpr kRows { 128 };
|
||||
TestTrainingPrediction(kRows, "gpu_hist");
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
auto lparam = CreateEmptyGenericParam(0);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
|
||||
|
||||
@ -2,13 +2,16 @@
|
||||
* Copyright 2020 by Contributors
|
||||
*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
|
||||
#include "test_predictor.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "../../../src/common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Predictor, PredictionCache) {
|
||||
@ -30,4 +33,52 @@ TEST(Predictor, PredictionCache) {
|
||||
add_cache();
|
||||
EXPECT_ANY_THROW(container.Entry(m));
|
||||
}
|
||||
|
||||
// Only run this test when CUDA is enabled.
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method) {
|
||||
size_t constexpr kCols = 16;
|
||||
size_t constexpr kClasses = 3;
|
||||
size_t constexpr kIters = 3;
|
||||
|
||||
std::unique_ptr<Learner> learner;
|
||||
auto train = [&](std::string predictor, HostDeviceVector<float>* out) {
|
||||
auto pp_m = CreateDMatrix(rows, kCols, 0);
|
||||
auto p_m = *pp_m;
|
||||
|
||||
auto &h_label = p_m->Info().labels_.HostVector();
|
||||
h_label.resize(rows);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
h_label[i] = i % kClasses;
|
||||
}
|
||||
|
||||
learner.reset(Learner::Create({}));
|
||||
learner->SetParam("tree_method", tree_method);
|
||||
learner->SetParam("objective", "multi:softprob");
|
||||
learner->SetParam("predictor", predictor);
|
||||
learner->SetParam("num_feature", std::to_string(kCols));
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
learner->UpdateOneIter(i, p_m);
|
||||
}
|
||||
learner->Predict(p_m, false, out);
|
||||
delete pp_m;
|
||||
};
|
||||
// Alternate the predictor, CPU predictor can not use ellpack while GPU predictor can
|
||||
// not use CPU histogram index. So it's guaranteed one of the following is not
|
||||
// predicting from histogram index. Note: As of writing only GPU supports predicting
|
||||
// from gradient index, the test is written for future portability.
|
||||
HostDeviceVector<float> predictions_0;
|
||||
train("cpu_predictor", &predictions_0);
|
||||
|
||||
HostDeviceVector<float> predictions_1;
|
||||
train("gpu_predictor", &predictions_1);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
EXPECT_NEAR(predictions_1.ConstHostVector()[i],
|
||||
predictions_0.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
70
tests/cpp/predictor/test_predictor.h
Normal file
70
tests/cpp/predictor/test_predictor.h
Normal file
@ -0,0 +1,70 @@
|
||||
#ifndef XGBOOST_TEST_PREDICTOR_H_
|
||||
#define XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
#include <xgboost/predictor.h>
|
||||
#include <string>
|
||||
#include <cstddef>
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
template <typename Page>
|
||||
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
|
||||
constexpr size_t kCols { 8 }, kClasses { 3 };
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_output_group = kClasses;
|
||||
param.base_score = 0.5;
|
||||
|
||||
auto lparam = CreateEmptyGenericParam(0);
|
||||
|
||||
std::unique_ptr<Predictor> predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create(name, &lparam));
|
||||
predictor->Configure({});
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, kClasses);
|
||||
|
||||
{
|
||||
auto pp_ellpack = CreateDMatrix(rows, kCols, 0);
|
||||
auto p_ellpack = *pp_ellpack;
|
||||
// Use same number of bins as rows.
|
||||
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
|
||||
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
|
||||
}
|
||||
|
||||
auto pp_precise = CreateDMatrix(rows, kCols, 0);
|
||||
auto p_precise = *pp_precise;
|
||||
|
||||
PredictionCacheEntry approx_out_predictions;
|
||||
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
|
||||
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
CHECK_EQ(approx_out_predictions.predictions.HostVector()[i],
|
||||
precise_out_predictions.predictions.HostVector()[i]);
|
||||
}
|
||||
|
||||
delete pp_precise;
|
||||
delete pp_ellpack;
|
||||
}
|
||||
|
||||
{
|
||||
// Predictor should never try to create the histogram index by itself. As only
|
||||
// histogram index from training data is valid and predictor doesn't known which
|
||||
// matrix is used for training.
|
||||
auto pp_dmat = CreateDMatrix(rows, kCols, 0);
|
||||
auto p_dmat = *pp_dmat;
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||
delete pp_dmat;
|
||||
}
|
||||
}
|
||||
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method);
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
@ -25,6 +25,19 @@ class Dataset:
|
||||
self.w = None
|
||||
self.use_external_memory = use_external_memory
|
||||
|
||||
def __str__(self):
|
||||
a = 'name: {name}\nobjective:{objective}, metric:{metric}, '.format(
|
||||
name=self.name,
|
||||
objective=self.objective,
|
||||
metric=self.metric)
|
||||
b = 'external memory:{use_external_memory}\n'.format(
|
||||
use_external_memory=self.use_external_memory
|
||||
)
|
||||
return a + b
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def get_boston():
|
||||
data = datasets.load_boston()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user