Predictor for vector leaf. (#8898)
This commit is contained in:
parent
8be6095ece
commit
c400fa1e8d
@ -1,52 +1,64 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <algorithm> // for max, fill, min
|
||||
#include <any> // for any, any_cast
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <cassert> // for assert
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t, int32_t, uint64_t
|
||||
#include <memory> // for unique_ptr, shared_ptr
|
||||
#include <ostream> // for char_traits, operator<<, basic_ostream
|
||||
#include <typeinfo> // for type_info
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/gradient_index.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "cpu_treeshap.h" // CalculateContributions
|
||||
#include "predict_fn.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/categorical.h" // for IsCat, Decision
|
||||
#include "../common/common.h" // for DivRoundUp
|
||||
#include "../common/math.h" // for CheckNAN
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter
|
||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam
|
||||
#include "cpu_treeshap.h" // for CalculateContributions
|
||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||
#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti
|
||||
#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch...
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/learner.h" // for LearnerModelParam
|
||||
#include "xgboost/linalg.h" // for TensorView, All, VectorView, Tensor
|
||||
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_NE
|
||||
#include "xgboost/multi_target_tree_model.h" // for MultiTargetTree
|
||||
#include "xgboost/predictor.h" // for PredictionCacheEntry, Predictor, PredictorReg
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/tree_model.h" // for RegTree, MTNotImplemented, RTreeNodeStat
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
namespace xgboost::predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
namespace scalar {
|
||||
template <bool has_missing, bool has_categorical>
|
||||
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
bst_node_t nid = 0;
|
||||
while (!tree[nid].IsLeaf()) {
|
||||
unsigned split_index = tree[nid].SplitIndex();
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
bst_node_t nidx{0};
|
||||
while (!tree[nidx].IsLeaf()) {
|
||||
bst_feature_t split_index = tree[nidx].SplitIndex();
|
||||
auto fvalue = feat.GetFvalue(split_index);
|
||||
nid = GetNextNode<has_missing, has_categorical>(
|
||||
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||
nidx = GetNextNode<has_missing, has_categorical>(
|
||||
tree[nidx], nidx, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||
}
|
||||
return nid;
|
||||
return nidx;
|
||||
}
|
||||
|
||||
bst_float PredValue(const SparsePage::Inst &inst,
|
||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||
const std::vector<int> &tree_info, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
const std::vector<int> &tree_info, std::int32_t bst_group,
|
||||
RegTree::FVec *p_feats, std::uint32_t tree_begin, std::uint32_t tree_end) {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
@ -68,40 +80,92 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
||||
}
|
||||
|
||||
template <bool has_categorical>
|
||||
bst_float
|
||||
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
const bst_node_t leaf = p_feats.HasMissing() ?
|
||||
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
|
||||
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
const bst_node_t leaf = p_feats.HasMissing()
|
||||
? GetLeafIndex<true, has_categorical>(tree, p_feats, cats)
|
||||
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
return tree[leaf].LeafValue();
|
||||
}
|
||||
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, std::vector<bst_float> *out_preds,
|
||||
const size_t predict_offset, const size_t num_group,
|
||||
const std::vector<RegTree::FVec> &thread_temp,
|
||||
const size_t offset, const size_t block_size) {
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
const size_t tree_end, const size_t predict_offset,
|
||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
const size_t gid = model.tree_info[tree_id];
|
||||
auto const &tree = *model.trees[tree_id];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
auto has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
out_predt(predict_offset + i, gid) +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace scalar
|
||||
|
||||
namespace multi {
|
||||
template <bool has_missing, bool has_categorical>
|
||||
bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat,
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
bst_node_t nidx{0};
|
||||
while (!tree.IsLeaf(nidx)) {
|
||||
unsigned split_index = tree.SplitIndex(nidx);
|
||||
auto fvalue = feat.GetFvalue(split_index);
|
||||
nidx = GetNextNodeMulti<has_missing, has_categorical>(
|
||||
tree, nidx, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||
}
|
||||
return nidx;
|
||||
}
|
||||
|
||||
template <bool has_categorical>
|
||||
void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const &cats,
|
||||
linalg::VectorView<float> out_predt) {
|
||||
bst_node_t const leaf = p_feats.HasMissing()
|
||||
? GetLeafIndex<true, has_categorical>(tree, p_feats, cats)
|
||||
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
auto leaf_value = tree.LeafValue(leaf);
|
||||
assert(out_predt.Shape(0) == leaf_value.Shape(0) && "shape mismatch.");
|
||||
for (size_t i = 0; i < leaf_value.Size(); ++i) {
|
||||
out_predt(i) += leaf_value(i);
|
||||
}
|
||||
}
|
||||
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, const size_t predict_offset,
|
||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
auto const &tree = *model.trees.at(tree_id);
|
||||
auto cats = tree.GetCategoriesMatrix();
|
||||
bool has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||
t_predts);
|
||||
}
|
||||
} else {
|
||||
for (std::size_t i = 0; i < block_size; ++i) {
|
||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||
PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||
t_predts);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace multi
|
||||
|
||||
template <typename DataView>
|
||||
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
|
||||
@ -127,7 +191,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc
|
||||
}
|
||||
|
||||
namespace {
|
||||
static size_t constexpr kUnroll = 8;
|
||||
static std::size_t constexpr kUnroll = 8;
|
||||
} // anonymous namespace
|
||||
|
||||
struct SparsePageView {
|
||||
@ -227,15 +291,13 @@ class AdapterView {
|
||||
};
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
void PredictBatchByBlockOfRowsKernel(
|
||||
DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads) {
|
||||
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
|
||||
int32_t tree_begin, int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
|
||||
linalg::TensorView<float, 2> out_predt) {
|
||||
auto &thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param->num_output_group;
|
||||
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far";
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
@ -243,16 +305,19 @@ void PredictBatchByBlockOfRowsKernel(
|
||||
|
||||
common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) {
|
||||
const size_t batch_offset = block_id * block_of_rows_size;
|
||||
const size_t block_size =
|
||||
std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
||||
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
||||
p_thread_temp);
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
||||
// process block of rows through all trees to keep cache locality
|
||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
|
||||
batch_offset + batch.base_rowid, num_group, thread_temp,
|
||||
fvec_offset, block_size);
|
||||
if (model.learner_model_param->IsVectorLeaf()) {
|
||||
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
||||
thread_temp, fvec_offset, block_size, out_predt);
|
||||
} else {
|
||||
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
||||
thread_temp, fvec_offset, block_size, out_predt);
|
||||
}
|
||||
|
||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||
});
|
||||
}
|
||||
@ -557,33 +622,6 @@ class ColumnSplitHelper {
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end, std::vector<bst_float> *out_preds) const {
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
|
||||
constexpr double kDensityThresh = .5;
|
||||
size_t total =
|
||||
std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, static_cast<uint64_t>(1));
|
||||
double density = static_cast<double>(p_fmat->Info().num_nonzero_) / static_cast<double>(total);
|
||||
bool blocked = density > kDensityThresh;
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs);
|
||||
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostVector();
|
||||
for (auto const &batch : p_fmat->GetBatches<GHistIndexMatrix>({})) {
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, kBlockOfRowsSize>(
|
||||
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads},
|
||||
out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads);
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, 1>(
|
||||
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads},
|
||||
out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||
if (p_fmat->IsColumnSplit()) {
|
||||
@ -592,11 +630,6 @@ class CPUPredictor : public Predictor {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!p_fmat->PageExists<SparsePage>()) {
|
||||
this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
constexpr double kDensityThresh = .5;
|
||||
size_t total =
|
||||
@ -606,16 +639,38 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs);
|
||||
|
||||
std::size_t n_samples = p_fmat->Info().num_row_;
|
||||
std::size_t n_groups = model.learner_model_param->OutputLength();
|
||||
CHECK_EQ(out_preds->size(), n_samples * n_groups);
|
||||
linalg::TensorView<float, 2> out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id};
|
||||
|
||||
if (!p_fmat->PageExists<SparsePage>()) {
|
||||
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostVector();
|
||||
for (auto const &batch : p_fmat->GetBatches<GHistIndexMatrix>({})) {
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, kBlockOfRowsSize>(
|
||||
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model,
|
||||
tree_begin, tree_end, &feat_vecs, n_threads, out_predt);
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, 1>(
|
||||
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model,
|
||||
tree_begin, tree_end, &feat_vecs, n_threads, out_predt);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
||||
if (blocked) {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView, kBlockOfRowsSize>(
|
||||
SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads);
|
||||
SparsePageView{&batch}, model, tree_begin, tree_end, &feat_vecs, n_threads,
|
||||
out_predt);
|
||||
|
||||
} else {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView, 1>(
|
||||
SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads);
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView, 1>(SparsePageView{&batch}, model,
|
||||
tree_begin, tree_end, &feat_vecs,
|
||||
n_threads, out_predt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -623,17 +678,15 @@ class CPUPredictor : public Predictor {
|
||||
public:
|
||||
explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {}
|
||||
|
||||
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
|
||||
const gbm::GBTreeModel &model, uint32_t tree_begin,
|
||||
uint32_t tree_end = 0) const override {
|
||||
auto* out_preds = &predts->predictions;
|
||||
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model,
|
||||
uint32_t tree_begin, uint32_t tree_end = 0) const override {
|
||||
auto *out_preds = &predts->predictions;
|
||||
// This is actually already handled in gbm, but large amount of tests rely on the
|
||||
// behaviour.
|
||||
if (tree_end == 0) {
|
||||
tree_end = model.trees.size();
|
||||
}
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
|
||||
tree_end);
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
template <typename Adapter, size_t kBlockSize>
|
||||
@ -653,13 +706,16 @@ class CPUPredictor : public Predictor {
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
|
||||
std::vector<Entry> workspace(m->NumColumns() * kUnroll * n_threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(n_threads * kBlockSize, &thread_temp);
|
||||
std::size_t n_groups = model.learner_model_param->OutputLength();
|
||||
linalg::TensorView<float, 2> out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId};
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads), model,
|
||||
tree_begin, tree_end, &thread_temp, n_threads, out_predt);
|
||||
}
|
||||
|
||||
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel &model, float missing,
|
||||
@ -689,6 +745,7 @@ class CPUPredictor : public Predictor {
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented();
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
feat_vecs.resize(1, RegTree::FVec());
|
||||
feat_vecs[0].Init(model.learner_model_param->num_feature);
|
||||
@ -701,31 +758,30 @@ class CPUPredictor : public Predictor {
|
||||
auto base_score = model.learner_model_param->BaseScore(ctx_)(0);
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
|
||||
(*out_preds)[gid] =
|
||||
PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], 0, ntree_limit) +
|
||||
(*out_preds)[gid] = scalar::PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0],
|
||||
0, ntree_limit) +
|
||||
base_score;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *out_preds,
|
||||
const gbm::GBTreeModel &model, unsigned ntree_limit) const override {
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
InitThreadTemp(n_threads, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
std::vector<bst_float>& preds = out_preds->HostVector();
|
||||
std::vector<bst_float> &preds = out_preds->HostVector();
|
||||
preds.resize(info.num_row_ * ntree_limit);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
auto page = batch.GetView();
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
||||
common::ParallelFor(page.Size(), n_threads, [&](auto i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = feat_vecs[tid];
|
||||
@ -733,23 +789,28 @@ class CPUPredictor : public Predictor {
|
||||
feats.Init(num_feature);
|
||||
}
|
||||
feats.Fill(page[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
auto const& tree = *model.trees[j];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||
for (std::uint32_t j = 0; j < ntree_limit; ++j) {
|
||||
auto const &tree = *model.trees[j];
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t nidx;
|
||||
if (tree.IsMultiTarget()) {
|
||||
nidx = multi::GetLeafIndex<true, true>(*tree.GetMultiTargetTree(), feats, cats);
|
||||
} else {
|
||||
nidx = scalar::GetLeafIndex<true, true>(tree, feats, cats);
|
||||
}
|
||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(nidx);
|
||||
}
|
||||
feats.Drop(page[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix *p_fmat,
|
||||
HostDeviceVector<float> *out_contribs,
|
||||
void PredictContribution(DMatrix *p_fmat, HostDeviceVector<float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, uint32_t ntree_limit,
|
||||
std::vector<bst_float> const *tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) const override {
|
||||
std::vector<bst_float> const *tree_weights, bool approximate,
|
||||
int condition, unsigned condition_feature) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||
<< "Predict contribution" << MTNotImplemented();
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
@ -825,11 +886,12 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, unsigned ntree_limit,
|
||||
std::vector<bst_float> const *tree_weights,
|
||||
bool approximate) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||
<< "Predict interaction contribution" << MTNotImplemented();
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
size_t const ncolumns = model.learner_model_param->num_feature;
|
||||
@ -884,5 +946,4 @@ class CPUPredictor : public Predictor {
|
||||
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
|
||||
.describe("Make predictions using CPU.")
|
||||
.set_body([](Context const *ctx) { return new CPUPredictor(ctx); });
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::predictor
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#define XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#include "../common/categorical.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
namespace xgboost::predictor {
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
|
||||
float fvalue, bool is_missing,
|
||||
@ -24,6 +23,25 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree,
|
||||
bst_node_t const nidx, float fvalue,
|
||||
bool is_missing,
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
if (has_missing && is_missing) {
|
||||
return tree.DefaultChild(nidx);
|
||||
} else {
|
||||
if (has_categorical && common::IsCat(cats.split_type, nidx)) {
|
||||
auto node_categories =
|
||||
cats.categories.subspan(cats.node_ptr[nidx].beg, cats.node_ptr[nidx].size);
|
||||
return common::Decision(node_categories, fvalue) ? tree.LeftChild(nidx)
|
||||
: tree.RightChild(nidx);
|
||||
} else {
|
||||
return tree.LeftChild(nidx) + !(fvalue < tree.SplitCond(nidx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::predictor
|
||||
#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
|
||||
@ -224,19 +224,18 @@ std::string RandomDataGenerator::GenerateArrayInterface(
|
||||
return out;
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::string>, std::string>
|
||||
RandomDataGenerator::GenerateArrayInterfaceBatch(
|
||||
HostDeviceVector<float> *storage, size_t batches) const {
|
||||
this->GenerateDense(storage);
|
||||
std::pair<std::vector<std::string>, std::string> MakeArrayInterfaceBatch(
|
||||
HostDeviceVector<float> const* storage, std::size_t n_samples, bst_feature_t n_features,
|
||||
std::size_t batches, std::int32_t device) {
|
||||
std::vector<std::string> result(batches);
|
||||
std::vector<Json> objects;
|
||||
|
||||
size_t const rows_per_batch = rows_ / batches;
|
||||
size_t const rows_per_batch = n_samples / batches;
|
||||
|
||||
auto make_interface = [storage, this](size_t offset, size_t rows) {
|
||||
auto make_interface = [storage, device, n_features](std::size_t offset, std::size_t rows) {
|
||||
Json array_interface{Object()};
|
||||
array_interface["data"] = std::vector<Json>(2);
|
||||
if (device_ >= 0) {
|
||||
if (device >= 0) {
|
||||
array_interface["data"][0] =
|
||||
Integer(reinterpret_cast<int64_t>(storage->DevicePointer() + offset));
|
||||
array_interface["stream"] = Null{};
|
||||
@ -249,22 +248,22 @@ RandomDataGenerator::GenerateArrayInterfaceBatch(
|
||||
|
||||
array_interface["shape"] = std::vector<Json>(2);
|
||||
array_interface["shape"][0] = rows;
|
||||
array_interface["shape"][1] = cols_;
|
||||
array_interface["shape"][1] = n_features;
|
||||
|
||||
array_interface["typestr"] = String("<f4");
|
||||
array_interface["version"] = 3;
|
||||
return array_interface;
|
||||
};
|
||||
|
||||
auto j_interface = make_interface(0, rows_);
|
||||
auto j_interface = make_interface(0, n_samples);
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < batches - 1; ++i) {
|
||||
objects.emplace_back(make_interface(offset, rows_per_batch));
|
||||
offset += rows_per_batch * cols_;
|
||||
offset += rows_per_batch * n_features;
|
||||
}
|
||||
|
||||
size_t const remaining = rows_ - offset / cols_;
|
||||
CHECK_LE(offset, rows_ * cols_);
|
||||
size_t const remaining = n_samples - offset / n_features;
|
||||
CHECK_LE(offset, n_samples * n_features);
|
||||
objects.emplace_back(make_interface(offset, remaining));
|
||||
|
||||
for (size_t i = 0; i < batches; ++i) {
|
||||
@ -276,6 +275,12 @@ RandomDataGenerator::GenerateArrayInterfaceBatch(
|
||||
return {result, interface_str};
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::string>, std::string> RandomDataGenerator::GenerateArrayInterfaceBatch(
|
||||
HostDeviceVector<float>* storage, size_t batches) const {
|
||||
this->GenerateDense(storage);
|
||||
return MakeArrayInterfaceBatch(storage, rows_, cols_, batches, device_);
|
||||
}
|
||||
|
||||
std::string RandomDataGenerator::GenerateColumnarArrayInterface(
|
||||
std::vector<HostDeviceVector<float>> *data) const {
|
||||
CHECK(data);
|
||||
@ -400,11 +405,14 @@ int NumpyArrayIterForTest::Next() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix>
|
||||
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns){
|
||||
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
|
||||
bst_feature_t num_columns) {
|
||||
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
|
||||
return std::shared_ptr<DMatrix>(new data::SimpleDMatrix(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||
auto p_fmat = std::shared_ptr<DMatrix>(
|
||||
new data::SimpleDMatrix(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||
CHECK_EQ(p_fmat->Info().num_row_, num_rows);
|
||||
CHECK_EQ(p_fmat->Info().num_col_, num_columns);
|
||||
return p_fmat;
|
||||
}
|
||||
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features,
|
||||
@ -572,12 +580,23 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
|
||||
return gbm;
|
||||
}
|
||||
|
||||
ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols,
|
||||
size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} {
|
||||
ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
|
||||
: rows_{rows}, cols_{cols}, n_batches_{batches} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity});
|
||||
std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
|
||||
}
|
||||
|
||||
ArrayIterForTest::ArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
|
||||
std::size_t n_samples, bst_feature_t n_features,
|
||||
std::size_t n_batches)
|
||||
: rows_{n_samples}, cols_{n_features}, n_batches_{n_batches} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
this->data_.Resize(data.Size());
|
||||
CHECK_EQ(this->data_.Size(), rows_ * cols_ * n_batches);
|
||||
this->data_.Copy(data);
|
||||
std::tie(batches_, interface_) =
|
||||
rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
|
||||
MakeArrayInterfaceBatch(&data_, rows_, cols_, n_batches_, ctx->gpu_id);
|
||||
}
|
||||
|
||||
ArrayIterForTest::~ArrayIterForTest() { XGDMatrixFree(proxy_); }
|
||||
|
||||
@ -188,7 +188,7 @@ class SimpleRealUniformDistribution {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Json GetArrayInterface(HostDeviceVector<T> *storage, size_t rows, size_t cols) {
|
||||
Json GetArrayInterface(HostDeviceVector<T> const* storage, size_t rows, size_t cols) {
|
||||
Json array_interface{Object()};
|
||||
array_interface["data"] = std::vector<Json>(2);
|
||||
if (storage->DeviceCanRead()) {
|
||||
@ -318,8 +318,8 @@ GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
|
||||
return x;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float> &x,
|
||||
int num_rows, int num_columns);
|
||||
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
|
||||
bst_feature_t num_columns);
|
||||
|
||||
/**
|
||||
* \brief Create Sparse Page using data iterator.
|
||||
@ -394,7 +394,7 @@ typedef void *DMatrixHandle; // NOLINT(*);
|
||||
class ArrayIterForTest {
|
||||
protected:
|
||||
HostDeviceVector<float> data_;
|
||||
size_t iter_ {0};
|
||||
size_t iter_{0};
|
||||
DMatrixHandle proxy_;
|
||||
std::unique_ptr<RandomDataGenerator> rng_;
|
||||
|
||||
@ -418,6 +418,11 @@ class ArrayIterForTest {
|
||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||
|
||||
explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches);
|
||||
/**
|
||||
* \brief Create iterator with user provided data.
|
||||
*/
|
||||
explicit ArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
|
||||
std::size_t n_samples, bst_feature_t n_features, std::size_t n_batches);
|
||||
virtual ~ArrayIterForTest();
|
||||
};
|
||||
|
||||
@ -433,6 +438,10 @@ class NumpyArrayIterForTest : public ArrayIterForTest {
|
||||
public:
|
||||
explicit NumpyArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(),
|
||||
size_t batches = Batches());
|
||||
explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
|
||||
std::size_t n_samples, bst_feature_t n_features,
|
||||
std::size_t n_batches)
|
||||
: ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {}
|
||||
int Next() override;
|
||||
~NumpyArrayIterForTest() override = default;
|
||||
};
|
||||
|
||||
@ -305,4 +305,10 @@ TEST(CpuPredictor, Sparse) {
|
||||
TestSparsePrediction(0.2, "cpu_predictor");
|
||||
TestSparsePrediction(0.8, "cpu_predictor");
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, Multi) {
|
||||
Context ctx;
|
||||
ctx.nthread = 1;
|
||||
TestVectorLeafPrediction(&ctx);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,28 +1,34 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
|
||||
#include "test_predictor.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...
|
||||
|
||||
#include "../../../src/common/bitfield.h"
|
||||
#include "../../../src/common/categorical.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/proxy_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
#include <algorithm> // for max
|
||||
#include <limits> // for numeric_limits
|
||||
#include <unordered_map> // for unordered_map
|
||||
|
||||
#include "../../../src/common/bitfield.h" // for LBitField32
|
||||
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
|
||||
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator
|
||||
#include "xgboost/json.h" // for Json, Object, get, String
|
||||
#include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
#include "xgboost/span.h" // for operator!=, SpanIterator, Span
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Predictor, PredictionCache) {
|
||||
size_t constexpr kRows = 16, kCols = 4;
|
||||
|
||||
PredictionContainer container;
|
||||
DMatrix* m;
|
||||
DMatrix *m;
|
||||
// Add a cache that is immediately expired.
|
||||
auto add_cache = [&]() {
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
@ -412,4 +418,101 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestVectorLeafPrediction(Context const *ctx) {
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx));
|
||||
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
|
||||
LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
|
||||
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
|
||||
MultiStrategy::kMonolithic};
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});
|
||||
|
||||
std::vector<float> p_w(mparam.LeafLength(), 0.0f);
|
||||
std::vector<float> l_w(mparam.LeafLength(), 1.0f);
|
||||
std::vector<float> r_w(mparam.LeafLength(), 2.0f);
|
||||
|
||||
auto &tree = trees.front();
|
||||
tree->ExpandNode(0, static_cast<bst_feature_t>(1), 2.0, true,
|
||||
linalg::MakeVec(p_w.data(), p_w.size()), linalg::MakeVec(l_w.data(), l_w.size()),
|
||||
linalg::MakeVec(r_w.data(), r_w.size()));
|
||||
ASSERT_TRUE(tree->IsMultiTarget());
|
||||
ASSERT_TRUE(mparam.IsVectorLeaf());
|
||||
|
||||
gbm::GBTreeModel model{&mparam, ctx};
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
|
||||
auto run_test = [&](float expected, HostDeviceVector<float> *p_data) {
|
||||
{
|
||||
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);
|
||||
PredictionCacheEntry predt_cache;
|
||||
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
|
||||
ASSERT_EQ(predt_cache.predictions.Size(), kRows * mparam.LeafLength());
|
||||
cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1);
|
||||
auto const &h_predt = predt_cache.predictions.HostVector();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_EQ(v, expected);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inplace
|
||||
PredictionCacheEntry predt_cache;
|
||||
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);
|
||||
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
|
||||
auto arr = GetArrayInterface(p_data, kRows, kCols);
|
||||
std::string str;
|
||||
Json::Dump(arr, &str);
|
||||
auto proxy = std::shared_ptr<DMatrix>(new data::DMatrixProxy{});
|
||||
dynamic_cast<data::DMatrixProxy *>(proxy.get())->SetArrayData(str.data());
|
||||
cpu_predictor->InplacePredict(proxy, model, std::numeric_limits<float>::quiet_NaN(),
|
||||
&predt_cache, 0, 1);
|
||||
auto const &h_predt = predt_cache.predictions.HostVector();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_EQ(v, expected);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// ghist
|
||||
PredictionCacheEntry predt_cache;
|
||||
auto &h_data = p_data->HostVector();
|
||||
// give it at least two bins, otherwise the histogram cuts only have min and max values.
|
||||
for (std::size_t i = 0; i < 5; ++i) {
|
||||
h_data[i] = 1.0;
|
||||
}
|
||||
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);
|
||||
|
||||
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
|
||||
|
||||
auto iter = NumpyArrayIterForTest{ctx, *p_data, kRows, static_cast<bst_feature_t>(kCols),
|
||||
static_cast<std::size_t>(1)};
|
||||
p_fmat =
|
||||
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
||||
|
||||
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
|
||||
cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1);
|
||||
auto const &h_predt = predt_cache.predictions.HostVector();
|
||||
// the smallest v uses the min_value from histogram cuts, which leads to a left leaf
|
||||
// during prediction.
|
||||
for (std::size_t i = 5; i < h_predt.size(); ++i) {
|
||||
ASSERT_EQ(h_predt[i], expected) << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// go to right
|
||||
HostDeviceVector<float> data(kRows * kCols, model.trees.front()->SplitCond(RegTree::kRoot) + 1.0);
|
||||
run_test(2.5, &data);
|
||||
|
||||
// go to left
|
||||
data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0);
|
||||
run_test(1.5, &data);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TEST_PREDICTOR_H_
|
||||
#define XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/predictor.h>
|
||||
#include <string>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "../../../src/gbm/gbtree_model.h" // for GBTreeModel
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -48,7 +55,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model);
|
||||
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||
CHECK(!p_dmat->PageExists<Page>());
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,6 +76,8 @@ void TestCategoricalPredictLeaf(StringView name);
|
||||
void TestIterationRange(std::string name);
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor);
|
||||
|
||||
void TestVectorLeafPrediction(Context const* ctx);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user