Predict on Ellpack. (#5327)

* Unify GPU prediction node.
* Add `PageExists`.
* Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
Jiaming Yuan 2020-02-23 06:27:03 +08:00 committed by GitHub
parent 70a91ec3ba
commit 655cf17b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 320 additions and 134 deletions

View File

@ -168,12 +168,19 @@ struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin;
int max_bin { 0 };
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;
BatchParam() = default;
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
size_t gpu_page_size = 0) :
gpu_id{device},
max_bin{max_bin},
gpu_batch_nrows{gpu_batch_nrows},
gpu_page_size{gpu_page_size}
{}
inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
@ -438,6 +445,9 @@ class DMatrix {
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
bool PageExists() const;
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
@ -493,6 +503,9 @@ class DMatrix {
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool SparsePageExists() const = 0;
};
template<>
@ -500,6 +513,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
return GetRowBatches();
}
template<>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}
template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetColumnBatches();

View File

@ -105,7 +105,7 @@ class RegTree : public Model {
/*! \brief tree node */
class Node {
public:
Node() {
XGBOOST_DEVICE Node() {
// assert compact alignment
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
@ -422,7 +422,7 @@ class RegTree : public Model {
* \param i feature index.
* \return the i-th feature value
*/
bst_float Fvalue(size_t i) const;
bst_float GetFvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
return data_.size();
}
inline bst_float RegTree::FVec::Fvalue(size_t i) const {
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
}
@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
}
return nid;
}

View File

@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,

View File

@ -76,6 +76,9 @@ struct EllpackInfo {
size_t NumSymbols() const {
return n_bins + 1;
}
size_t NumFeatures() const {
return min_fvalue.size();
}
};
/** \brief Struct for accessing and manipulating an ellpack matrix on the
@ -89,7 +92,7 @@ struct EllpackMatrix {
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride;
@ -103,6 +106,10 @@ struct EllpackMatrix {
info.feature_segments[fidx],
info.feature_segments[fidx + 1]);
}
return gidx;
}
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
auto gidx = GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
}

View File

@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
}
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
ellpack_page_.reset(new EllpackPage(this, param));
batch_param_ = param;
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));

View File

@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
BatchParam batch_param_;
bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_);
}
bool SparsePageExists() const override {
return true;
}
};
} // namespace data
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator.
* \author Tianqi Chen
@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// Lazily instantiate
if (!ellpack_source_ || batch_param_ != param) {
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param;
}

View File

@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
std::string cache_info_;
// Store column densities to avoid recalculating
std::vector<float> col_density_;
bool EllpackExists() const override {
return static_cast<bool>(ellpack_source_);
}
bool SparsePageExists() const override {
return static_cast<bool>(row_source_);
}
};
} // namespace data
} // namespace xgboost

View File

@ -14,6 +14,7 @@
#include "xgboost/host_device_vector.h"
#include "../gbm/gbtree_model.h"
#include "../data/ellpack_page.cuh"
#include "../common/common.h"
#include "../common/device_helpers.cuh"
@ -22,78 +23,32 @@ namespace predictor {
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
/**
* \struct DevicePredictionNode
*
* \brief Packed 16 byte representation of a tree node for use in device
* prediction
*/
struct DevicePredictionNode {
XGBOOST_DEVICE DevicePredictionNode()
: fidx{-1}, left_child_idx{-1}, right_child_idx{-1} {}
struct SparsePageView {
common::Span<const Entry> d_data;
common::Span<const bst_row_t> d_row_ptr;
union NodeValue {
float leaf_weight;
float fvalue;
};
int fidx;
int left_child_idx;
int right_child_idx;
NodeValue val{};
DevicePredictionNode(const RegTree::Node& n) { // NOLINT
static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes");
this->left_child_idx = n.LeftChild();
this->right_child_idx = n.RightChild();
this->fidx = n.SplitIndex();
if (n.DefaultLeft()) {
fidx |= (1U << 31);
}
if (n.IsLeaf()) {
this->val.leaf_weight = n.LeafValue();
} else {
this->val.fvalue = n.SplitCond();
}
}
XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; }
XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); }
XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; }
XGBOOST_DEVICE int MissingIdx() const {
if (MissingLeft()) {
return this->left_child_idx;
} else {
return this->right_child_idx;
}
}
XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; }
XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; }
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
common::Span<const bst_row_t> row_ptr) :
d_data{data}, d_row_ptr{row_ptr} {}
};
struct ElementLoader {
struct SparsePageLoader {
bool use_shared;
common::Span<const bst_row_t> d_row_ptr;
common::Span<const Entry> d_data;
int num_features;
bst_feature_t num_features;
float* smem;
size_t entry_start;
__device__ ElementLoader(bool use_shared, common::Span<const bst_row_t> row_ptr,
common::Span<const Entry> entry, int num_features,
float* smem, int num_rows, size_t entry_start)
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start)
: use_shared(use_shared),
d_row_ptr(row_ptr),
d_data(entry),
d_row_ptr(data.d_row_ptr),
d_data(data.d_data),
num_features(num_features),
smem(smem),
entry_start(entry_start) {
extern __shared__ float _smem[];
smem = _smem;
// Copy instances
if (use_shared) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
@ -111,7 +66,7 @@ struct ElementLoader {
__syncthreads();
}
}
__device__ float GetFvalue(int ridx, int fidx) {
__device__ float GetFvalue(int ridx, int fidx) const {
if (use_shared) {
return smem[threadIdx.x * num_features + fidx];
} else {
@ -141,52 +96,69 @@ struct ElementLoader {
}
};
__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
ElementLoader* loader) {
DevicePredictionNode n = tree[0];
struct EllpackLoader {
EllpackMatrix const& matrix;
XGBOOST_DEVICE EllpackLoader(EllpackMatrix const& m, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start) : matrix{m} {}
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
auto gidx = matrix.GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
}
// The gradient index needs to be shifted by one as min values are not included in the
// cuts.
if (gidx == matrix.info.feature_segments[fidx]) {
return matrix.info.min_fvalue[fidx];
}
return matrix.info.gidx_fvalue_map[gidx - 1];
}
};
template <typename Loader>
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
Loader* loader) {
RegTree::Node n = tree[0];
while (!n.IsLeaf()) {
float fvalue = loader->GetFvalue(ridx, n.GetFidx());
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
n = tree[n.MissingIdx()];
n = tree[n.DefaultChild()];
} else {
if (fvalue < n.GetFvalue()) {
n = tree[n.left_child_idx];
if (fvalue < n.SplitCond()) {
n = tree[n.LeftChild()];
} else {
n = tree[n.right_child_idx];
n = tree[n.RightChild()];
}
}
}
return n.GetWeight();
return n.LeafValue();
}
template <int BLOCK_THREADS>
__global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
template <typename Loader, typename Data>
__global__ void PredictKernel(Data data,
common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t> d_tree_segments,
common::Span<int> d_tree_group,
common::Span<const bst_row_t> d_row_ptr,
common::Span<const Entry> d_data, size_t tree_begin,
size_t tree_end, size_t num_features,
size_t tree_begin, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start,
bool use_shared, int num_group) {
extern __shared__ float smem[];
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
num_rows, entry_start);
Loader loader(data, use_shared, num_features, num_rows, entry_start);
if (global_idx >= num_rows) return;
if (num_group == 1) {
float sum = 0;
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
const DevicePredictionNode* d_tree =
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
sum += GetLeafWeight(global_idx, d_tree, &loader);
float leaf = GetLeafWeight(global_idx, d_tree, &loader);
sum += leaf;
}
d_out_predictions[global_idx] += sum;
} else {
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
int tree_group = d_tree_group[tree_idx];
const DevicePredictionNode* d_tree =
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
d_out_predictions[out_prediction_idx] +=
@ -198,13 +170,13 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
class GPUPredictor : public xgboost::Predictor {
private:
void InitModel(const gbm::GBTreeModel& model,
const thrust::host_vector<size_t>& h_tree_segments,
const thrust::host_vector<DevicePredictionNode>& h_nodes,
size_t tree_begin, size_t tree_end) {
const thrust::host_vector<size_t>& h_tree_segments,
const thrust::host_vector<RegTree::Node>& h_nodes,
size_t tree_begin, size_t tree_end) {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
nodes_.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
sizeof(DevicePredictionNode) * h_nodes.size(),
sizeof(RegTree::Node) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments_.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
@ -219,15 +191,11 @@ class GPUPredictor : public xgboost::Predictor {
this->num_group_ = model.learner_model_param_->num_output_group;
}
void PredictInternal(const SparsePage& batch,
size_t num_features,
void PredictInternal(const SparsePage& batch, size_t num_features,
HostDeviceVector<bst_float>* predictions,
size_t batch_offset) {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
predictions->SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(generic_param_->gpu_id);
const uint32_t BLOCK_THREADS = 128;
size_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
@ -240,12 +208,29 @@ class GPUPredictor : public xgboost::Predictor {
use_shared = false;
}
size_t entry_start = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<BLOCK_THREADS>,
PredictKernel<SparsePageLoader, SparsePageView>,
data,
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
this->tree_begin_, this->tree_end_, num_features, num_rows,
entry_start, use_shared, this->num_group_);
}
void PredictInternal(EllpackMatrix const& batch, HostDeviceVector<bst_float>* out_preds,
size_t batch_offset) {
const uint32_t BLOCK_THREADS = 256;
size_t num_rows = batch.n_rows;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
bool use_shared = false;
size_t entry_start = 0;
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
PredictKernel<EllpackLoader, EllpackMatrix>,
batch,
dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset),
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
this->tree_begin_, this->tree_end_, batch.info.NumFeatures(), num_rows,
entry_start, use_shared, this->num_group_);
}
@ -261,7 +246,7 @@ class GPUPredictor : public xgboost::Predictor {
h_tree_segments.push_back(sum);
}
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
std::copy(src_nodes.begin(), src_nodes.end(),
@ -270,26 +255,31 @@ class GPUPredictor : public xgboost::Predictor {
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
}
void DevicePredictInternal(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
const gbm::GBTreeModel& model, size_t tree_begin,
size_t tree_end) {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
if (tree_end - tree_begin == 0) {
return;
}
monitor_.StartCuda("DevicePredictInternal");
InitModel(model, tree_begin, tree_end);
out_preds->SetDevice(generic_param_->gpu_id);
size_t batch_offset = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(generic_param_->gpu_id);
PredictInternal(batch, model.learner_model_param_->num_feature,
out_preds, batch_offset);
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
if (dmat->PageExists<EllpackPage>()) {
size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
this->PredictInternal(page.Impl()->matrix, out_preds, batch_offset);
batch_offset += page.Impl()->matrix.n_rows;
}
} else {
size_t batch_offset = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
this->PredictInternal(batch, model.learner_model_param_->num_feature,
out_preds, batch_offset);
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
}
}
monitor_.StopCuda("DevicePredictInternal");
}
@ -418,7 +408,7 @@ class GPUPredictor : public xgboost::Predictor {
}
common::Monitor monitor_;
dh::device_vector<DevicePredictionNode> nodes_;
dh::device_vector<RegTree::Node> nodes_;
dh::device_vector<size_t> tree_segments_;
dh::device_vector<int> tree_group_;
size_t max_shared_memory_bytes_;

View File

@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();

View File

@ -688,7 +688,7 @@ struct GPUHistMakerDevice {
[=] __device__(bst_uint ridx) {
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex());
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
// Missing value
int new_position = 0;
if (isnan(cut_value)) {
@ -737,7 +737,7 @@ struct GPUHistMakerDevice {
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();

View File

@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater {
// tranverse tree
while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex();
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
gstats[pid].Add(gpair[ridx]);
}
}

View File

@ -89,7 +89,7 @@ struct ReadRowFunction {
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
__device__ void operator()(size_t col) {
auto value = matrix.GetElement(row, col);
auto value = matrix.GetFvalue(row, col);
if (isnan(value)) {
value = -1;
}

View File

@ -86,7 +86,7 @@ struct ReadRowFunction {
: matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {}
__device__ void operator()(size_t col) {
auto value = matrix.GetElement(row, col);
auto value = matrix.GetFvalue(row, col);
if (isnan(value)) {
value = -1;
}

View File

@ -21,7 +21,6 @@
#include <xgboost/c_api.h>
#include "../../src/common/common.h"
#include "../../src/common/hist_util.h"
#include "../../src/gbm/gbtree_model.h"
#if defined(__CUDACC__)
#include "../../src/data/ellpack_page.cuh"

View File

@ -11,11 +11,12 @@
#include "gtest/gtest.h"
#include "../helpers.h"
#include "../../../src/gbm/gbtree_model.h"
#include "test_predictor.h"
namespace xgboost {
namespace predictor {
TEST(GpuPredictor, Basic) {
TEST(GPUPredictor, Basic) {
auto cpu_lparam = CreateEmptyGenericParam(-1);
auto gpu_lparam = CreateEmptyGenericParam(0);
@ -56,7 +57,20 @@ TEST(GpuPredictor, Basic) {
}
}
TEST(gpu_predictor, ExternalMemoryTest) {
TEST(GPUPredictor, EllpackBasic) {
for (size_t bins = 2; bins < 258; bins += 16) {
size_t rows = bins * 16;
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, bins);
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, bins);
}
}
TEST(GPUPredictor, EllpackTraining) {
size_t constexpr kRows { 128 };
TestTrainingPrediction(kRows, "gpu_hist");
}
TEST(GPUPredictor, ExternalMemoryTest) {
auto lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));

View File

@ -2,13 +2,16 @@
* Copyright 2020 by Contributors
*/
#include <cstddef>
#include <gtest/gtest.h>
#include <xgboost/predictor.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/generic_parameters.h>
#include "test_predictor.h"
#include "../helpers.h"
#include "xgboost/generic_parameters.h"
#include "../../../src/common/io.h"
namespace xgboost {
TEST(Predictor, PredictionCache) {
@ -30,4 +33,52 @@ TEST(Predictor, PredictionCache) {
add_cache();
EXPECT_ANY_THROW(container.Entry(m));
}
// Only run this test when CUDA is enabled.
void TestTrainingPrediction(size_t rows, std::string tree_method) {
size_t constexpr kCols = 16;
size_t constexpr kClasses = 3;
size_t constexpr kIters = 3;
std::unique_ptr<Learner> learner;
auto train = [&](std::string predictor, HostDeviceVector<float>* out) {
auto pp_m = CreateDMatrix(rows, kCols, 0);
auto p_m = *pp_m;
auto &h_label = p_m->Info().labels_.HostVector();
h_label.resize(rows);
for (size_t i = 0; i < rows; ++i) {
h_label[i] = i % kClasses;
}
learner.reset(Learner::Create({}));
learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", "multi:softprob");
learner->SetParam("predictor", predictor);
learner->SetParam("num_feature", std::to_string(kCols));
learner->SetParam("num_class", std::to_string(kClasses));
learner->Configure();
for (size_t i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, p_m);
}
learner->Predict(p_m, false, out);
delete pp_m;
};
// Alternate the predictor, CPU predictor can not use ellpack while GPU predictor can
// not use CPU histogram index. So it's guaranteed one of the following is not
// predicting from histogram index. Note: As of writing only GPU supports predicting
// from gradient index, the test is written for future portability.
HostDeviceVector<float> predictions_0;
train("cpu_predictor", &predictions_0);
HostDeviceVector<float> predictions_1;
train("gpu_predictor", &predictions_1);
for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(predictions_1.ConstHostVector()[i],
predictions_0.ConstHostVector()[i], kRtEps);
}
}
} // namespace xgboost

View File

@ -0,0 +1,70 @@
#ifndef XGBOOST_TEST_PREDICTOR_H_
#define XGBOOST_TEST_PREDICTOR_H_
#include <xgboost/predictor.h>
#include <string>
#include <cstddef>
#include "../helpers.h"
namespace xgboost {
template <typename Page>
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
constexpr size_t kCols { 8 }, kClasses { 3 };
LearnerModelParam param;
param.num_feature = kCols;
param.num_output_group = kClasses;
param.base_score = 0.5;
auto lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> predictor =
std::unique_ptr<Predictor>(Predictor::Create(name, &lparam));
predictor->Configure({});
gbm::GBTreeModel model = CreateTestModel(&param, kClasses);
{
auto pp_ellpack = CreateDMatrix(rows, kCols, 0);
auto p_ellpack = *pp_ellpack;
// Use same number of bins as rows.
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
}
auto pp_precise = CreateDMatrix(rows, kCols, 0);
auto p_precise = *pp_precise;
PredictionCacheEntry approx_out_predictions;
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
PredictionCacheEntry precise_out_predictions;
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
for (size_t i = 0; i < rows; ++i) {
CHECK_EQ(approx_out_predictions.predictions.HostVector()[i],
precise_out_predictions.predictions.HostVector()[i]);
}
delete pp_precise;
delete pp_ellpack;
}
{
// Predictor should never try to create the histogram index by itself. As only
// histogram index from training data is valid and predictor doesn't known which
// matrix is used for training.
auto pp_dmat = CreateDMatrix(rows, kCols, 0);
auto p_dmat = *pp_dmat;
PredictionCacheEntry precise_out_predictions;
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
ASSERT_FALSE(p_dmat->PageExists<Page>());
delete pp_dmat;
}
}
void TestTrainingPrediction(size_t rows, std::string tree_method);
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_

View File

@ -25,6 +25,19 @@ class Dataset:
self.w = None
self.use_external_memory = use_external_memory
def __str__(self):
a = 'name: {name}\nobjective:{objective}, metric:{metric}, '.format(
name=self.name,
objective=self.objective,
metric=self.metric)
b = 'external memory:{use_external_memory}\n'.format(
use_external_memory=self.use_external_memory
)
return a + b
def __repr__(self):
return self.__str__()
def get_boston():
data = datasets.load_boston()