Simplify inplace-predict. (#7910)

Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
Jiaming Yuan 2022-05-18 17:52:00 +08:00 committed by GitHub
parent 19775ffe15
commit 765097d514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 317 additions and 297 deletions

View File

@ -111,15 +111,14 @@ class GradientBooster : public Model, public Configurable {
/*!
* \brief Inplace prediction.
*
* \param x A type erased data adapter.
* \param p_fmat A proxy DMatrix that contains the data and related
* meta info.
* \param missing Missing value in the data.
* \param [in,out] out_preds The output preds.
* \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
PredictionCacheEntry*,
uint32_t,
virtual void InplacePredict(std::shared_ptr<DMatrix>, float, PredictionCacheEntry*, uint32_t,
uint32_t) const {
LOG(FATAL) << "Inplace predict is not supported by current booster.";
}

View File

@ -139,21 +139,16 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
/*!
* \brief Inplace prediction.
*
* \param x A type erased data adapter.
* \param p_m An optional Proxy DMatrix object storing meta info like
* base margin. Can be nullptr.
* \param p_fmat A proxy DMatrix that contains the data and related meta info.
* \param type Prediction type.
* \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &x,
std::shared_ptr<DMatrix> p_m,
PredictionType type,
float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;
virtual void InplacePredict(std::shared_ptr<DMatrix> p_m, PredictionType type, float missing,
HostDeviceVector<bst_float>** out_preds, uint32_t layer_begin,
uint32_t layer_end) = 0;
/*!
* \brief Calculate feature score. See doc in C API for outputs.

View File

@ -145,7 +145,9 @@ class Predictor {
/**
* \brief Inplace prediction.
* \param x Type erased data adapter.
*
* \param p_fmat A proxy DMatrix that contains the data and related
* meta info.
* \param model The model to predict from.
* \param missing Missing value in the data.
* \param [in,out] out_preds The output preds.
@ -154,11 +156,9 @@ class Predictor {
*
* \return True if the data can be handled by current predictor, false otherwise.
*/
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0,
uint32_t tree_end = 0) const = 0;
virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
float missing, PredictionCacheEntry* out_preds,
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
/**
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is

View File

@ -300,7 +300,7 @@ XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
CHECK(p_m);
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
CHECK(m) << "Current DMatrix type does not support set data.";
m->SetData(c_interface_str);
m->SetCUDAArray(c_interface_str);
API_END();
}
@ -312,7 +312,7 @@ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle,
CHECK(p_m);
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
CHECK(m) << "Current DMatrix type does not support set data.";
m->SetData(c_interface_str);
m->SetCUDAArray(c_interface_str);
API_END();
}
@ -825,74 +825,69 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
API_END();
}
template <typename T>
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
char const *c_json_config, Learner *learner,
size_t n_rows, size_t n_cols,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config, Learner *learner,
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) {
auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
HostDeviceVector<float>* p_predt { nullptr };
HostDeviceVector<float> *p_predt{nullptr};
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
float missing = GetMissing(config);
learner->InplacePredict(x, p_m, type, missing, &p_predt,
learner->InplacePredict(p_m, type, missing, &p_predt,
RequiredArg<Integer>(config, "iteration_begin", __func__),
RequiredArg<Integer>(config, "iteration_end", __func__));
CHECK(p_predt);
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
auto const &info = p_m->Info();
auto n_samples = info.num_row_;
auto n_features = info.num_col_;
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(),
learner->BoostedRounds(), &shape, out_dim);
*out_result = dmlc::BeginPtr(p_predt->HostVector());
*out_shape = dmlc::BeginPtr(shape);
}
// A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
char const *array_interface,
char const *c_json_config,
DMatrixHandle m,
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_interface,
char const *c_json_config, DMatrixHandle m,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim,
const float **out_result) {
xgboost::bst_ulong *out_dim, const float **out_result) {
API_BEGIN();
CHECK_HANDLE();
std::shared_ptr<xgboost::data::ArrayAdapter> x{
new xgboost::data::ArrayAdapter(StringView{array_interface})};
std::shared_ptr<DMatrix> p_m {nullptr};
if (m) {
std::shared_ptr<DMatrix> p_m{nullptr};
if (!m) {
p_m.reset(new data::DMatrixProxy);
} else {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy) << "Invalid input type for inplace predict.";
proxy->SetArrayData(array_interface);
auto *learner = static_cast<xgboost::Learner *>(handle);
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
x->NumColumns(), out_shape, out_dim, out_result);
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
API_END();
}
// A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
char const *indices, char const *data,
xgboost::bst_ulong cols,
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, char const *indices,
char const *data, xgboost::bst_ulong cols,
char const *c_json_config, DMatrixHandle m,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim,
const float **out_result) {
xgboost::bst_ulong *out_dim, const float **out_result) {
API_BEGIN();
CHECK_HANDLE();
std::shared_ptr<xgboost::data::CSRArrayAdapter> x{
new xgboost::data::CSRArrayAdapter{StringView{indptr},
StringView{indices}, StringView{data},
static_cast<size_t>(cols)}};
std::shared_ptr<DMatrix> p_m {nullptr};
if (m) {
std::shared_ptr<DMatrix> p_m{nullptr};
if (!m) {
p_m.reset(new data::DMatrixProxy);
} else {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy) << "Invalid input type for inplace predict.";
proxy->SetCSRData(indptr, indices, data, cols, true);
auto *learner = static_cast<xgboost::Learner *>(handle);
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
x->NumColumns(), out_shape, out_dim, out_result);
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
API_END();
}

View File

@ -1,10 +1,11 @@
// Copyright (c) 2019-2021 by Contributors
#include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "xgboost/learner.h"
// Copyright (c) 2019-2022 by Contributors
#include "../data/device_adapter.cuh"
#include "../data/proxy_dmatrix.h"
#include "c_api_error.h"
#include "c_api_utils.h"
#include "../data/device_adapter.cuh"
#include "xgboost/c_api.h"
#include "xgboost/data.h"
#include "xgboost/learner.h"
namespace xgboost {
@ -85,62 +86,65 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
API_END();
}
template <typename T>
int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
char const *c_json_config,
std::shared_ptr<DMatrix> p_m,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
char const *c_json_config, std::shared_ptr<DMatrix> p_m,
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) {
API_BEGIN();
CHECK_HANDLE();
if (!p_m) {
p_m.reset(new data::DMatrixProxy);
}
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy) << "Invalid input type for inplace predict.";
proxy->SetCUDAArray(c_array_interface);
auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0)
<< "Cache ID is not supported yet";
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
auto *learner = static_cast<Learner *>(handle);
std::string json_str{c_json_strs};
auto x = std::make_shared<T>(json_str);
HostDeviceVector<float> *p_predt{nullptr};
auto type = PredictionType(get<Integer const>(config["type"]));
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
float missing = GetMissing(config);
learner->InplacePredict(x, p_m, type, missing, &p_predt,
get<Integer const>(config["iteration_begin"]),
get<Integer const>(config["iteration_end"]));
learner->InplacePredict(p_m, type, missing, &p_predt,
RequiredArg<Integer>(config, "iteration_begin", __func__),
RequiredArg<Integer>(config, "iteration_end", __func__));
CHECK(p_predt);
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows();
bool strict_shape = get<Boolean const>(config["strict_shape"]);
CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize,
learner->Groups(), learner->BoostedRounds(), &shape,
out_dim);
size_t n_samples = p_m->Info().num_row_;
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
CalcPredictShape(strict_shape, type, n_samples, p_m->Info().num_col_, chunksize,
learner->Groups(), learner->BoostedRounds(), &shape, out_dim);
*out_shape = dmlc::BeginPtr(shape);
*out_result = p_predt->ConstDevicePointer();
API_END();
}
XGB_DLL int XGBoosterPredictFromCudaColumnar(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
std::shared_ptr<DMatrix> p_m {nullptr};
XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c_json_strs,
char const *c_json_config, DMatrixHandle m,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim,
const float **out_result) {
std::shared_ptr<DMatrix> p_m{nullptr};
if (m) {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
return InplacePreidctCuda<data::CudfAdapter>(
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
out_result);
}
XGB_DLL int XGBoosterPredictFromCudaArray(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
std::shared_ptr<DMatrix> p_m {nullptr};
XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_json_strs,
char const *c_json_config, DMatrixHandle m,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
std::shared_ptr<DMatrix> p_m{nullptr};
if (m) {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
return InplacePreidctCuda<data::CupyAdapter>(
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
out_result);
}

View File

@ -55,7 +55,7 @@ class DMatrixProxy : public DMatrix {
public:
int DeviceIdx() const { return ctx_.gpu_id; }
void SetData(char const* c_interface) {
void SetCUDAArray(char const* c_interface) {
common::AssertGPUSupport();
#if defined(XGBOOST_USE_CUDA)
std::string interface_str = c_interface;

View File

@ -795,88 +795,75 @@ class Dart : public GBTree {
this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end);
}
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds,
uint32_t layer_begin, unsigned layer_end) const override {
void InplacePredict(std::shared_ptr<DMatrix> p_fmat, float missing,
PredictionCacheEntry* p_out_preds, uint32_t layer_begin,
unsigned layer_end) const override {
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
std::vector<Predictor const *> predictors{
auto n_groups = model_.learner_model_param->num_output_group;
std::vector<Predictor const*> predictors {
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
};
Predictor const * predictor {nullptr};
MetaInfo info;
Predictor const* predictor{nullptr};
StringView msg{"Unsupported data type for inplace predict."};
int32_t device = GenericParameter::kCpuId;
PredictionCacheEntry predts;
// Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) {
if (ctx_->gpu_id != Context::kCpuId) {
predts.predictions.SetDevice(ctx_->gpu_id);
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
bool success = false;
for (auto const &p : predictors) {
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
i + 1)) {
for (auto const& p : predictors) {
if (p && p->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)) {
success = true;
predictor = p;
#if defined(XGBOOST_USE_CUDA)
device = predts.predictions.DeviceIdx();
#endif // defined(XGBOOST_USE_CUDA)
break;
}
}
CHECK(success) << msg;
} else {
// No base margin from meta info for each tree
predictor = this->GetPredictor().get();
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
&predts, i, i + 1);
device = predts.predictions.DeviceIdx();
bool success = predictor->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor
? "cpu_predictor"
: "gpu_predictor");
<< (tparam_.predictor == PredictorType::kCPUPredictor ? "cpu_predictor"
: "gpu_predictor");
}
};
auto w = this->weight_drop_.at(i);
size_t n_groups = model_.learner_model_param->num_output_group;
auto n_rows = predts.predictions.Size() / n_groups;
// Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
// base margin is added here.
if (p_m) {
p_m->Info().num_row_ = n_rows;
predictor->InitOutPredictions(p_m->Info(), &out_preds->predictions,
model_);
} else {
info.num_row_ = n_rows;
predictor->InitOutPredictions(info, &out_preds->predictions, model_);
}
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
}
// Multiple the tree weight
CHECK_EQ(predts.predictions.Size(), out_preds->predictions.Size());
auto w = this->weight_drop_.at(i);
auto group = model_.tree_info.at(i);
CHECK_EQ(predts.predictions.Size(), p_out_preds->predictions.Size());
if (device == GenericParameter::kCpuId) {
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows,
model_.learner_model_param->base_score, n_groups, group);
} else {
auto& h_predts = predts.predictions.HostVector();
auto& h_out_predts = p_out_preds->predictions.HostVector();
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {
const size_t offset = ridx * n_groups + group;
// Need to remove the base margin from individual tree.
h_out_predts[offset] += (h_predts[offset] - model_.learner_model_param->base_score) * w;
});
} else {
out_preds->predictions.SetDevice(device);
predts.predictions.SetDevice(device);
GPUDartInplacePredictInc(out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows,
model_.learner_model_param->base_score,
n_groups, group);
}
}
}

View File

@ -261,8 +261,7 @@ class GBTree : public GradientBooster {
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
bool training, unsigned layer_begin, unsigned layer_end) override;
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds,
void InplacePredict(std::shared_ptr<DMatrix> p_m, float missing, PredictionCacheEntry* out_preds,
uint32_t layer_begin, unsigned layer_end) const override {
CHECK(configured_);
uint32_t tree_begin, tree_end;
@ -278,15 +277,14 @@ class GBTree : public GradientBooster {
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
for (auto const &p : predictors) {
if (p && p->InplacePredict(x, p_m, model_, missing, out_preds,
tree_begin, tree_end)) {
if (p && p->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end)) {
return;
}
}
LOG(FATAL) << msg;
} else {
bool success = this->GetPredictor()->InplacePredict(
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
bool success = this->GetPredictor()->InplacePredict(p_m, model_, missing, out_preds,
tree_begin, tree_end);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor

View File

@ -1277,15 +1277,12 @@ class LearnerImpl : public LearnerIO {
return (*LearnerAPIThreadLocalStore::Get())[this];
}
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
PredictionType type, float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t iteration_begin,
void InplacePredict(std::shared_ptr<DMatrix> p_m, PredictionType type, float missing,
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
uint32_t iteration_end) override {
this->Configure();
auto& out_predictions = this->GetThreadLocal().prediction_entry;
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
iteration_begin, iteration_end);
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
if (type == PredictionType::kValue) {
obj_->PredTransform(&out_predictions.predictions);
} else if (type == PredictionType::kMargin) {

View File

@ -1,27 +1,27 @@
/*!
* Copyright by Contributors 2017-2021
*/
#include <dmlc/omp.h>
#include <dmlc/any.h>
#include <dmlc/omp.h>
#include <cstddef>
#include <limits>
#include <mutex>
#include "../common/categorical.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "../data/adapter.h"
#include "../data/proxy_dmatrix.h"
#include "../gbm/gbtree_model.h"
#include "predict_fn.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/logging.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"
#include "predict_fn.h"
#include "../data/adapter.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "../common/categorical.h"
#include "../gbm/gbtree_model.h"
namespace xgboost {
namespace predictor {
@ -327,22 +327,24 @@ class CPUPredictor : public Predictor {
&predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
}
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
this->DispatchedInplacePredict<data::CSRAdapter, 1>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
this->DispatchedInplacePredict<data::CSRAdapter, 1>(x, p_m, model, missing, out_preds,
tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize> (
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1> (
x, p_m, model, missing, out_preds, tree_begin, tree_end);
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1>(x, p_m, model, missing, out_preds,
tree_begin, tree_end);
} else {
return false;
}

View File

@ -1,28 +1,29 @@
/*!
* Copyright 2017-2021 by Contributors
*/
#include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <thrust/host_vector.h>
#include <GPUTreeShap/gpu_treeshap.h>
#include <memory>
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
#include "../common/device_helpers.cuh"
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
#include "../gbm/gbtree_model.h"
#include "predict_fn.h"
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"
#include "xgboost/host_device_vector.h"
#include "predict_fn.h"
#include "../gbm/gbtree_model.h"
#include "../data/ellpack_page.cuh"
#include "../data/device_adapter.cuh"
#include "../common/common.h"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace predictor {
@ -789,17 +790,19 @@ class GPUPredictor : public xgboost::Predictor {
m->NumRows(), entry_start, use_shared, output_groups, missing);
}
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, uint32_t tree_begin,
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel& model, float missing,
PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
this->DispatchedInplacePredict<data::CupyAdapter,
DeviceAdapterLoader<data::CupyAdapterBatch>>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
this->DispatchedInplacePredict<
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
this->DispatchedInplacePredict<data::CudfAdapter,
DeviceAdapterLoader<data::CudfAdapterBatch>>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else {
return false;

View File

@ -19,7 +19,7 @@ TEST(ProxyDMatrix, DeviceData) {
.GenerateColumnarArrayInterface(&label_storage);
DMatrixProxy proxy;
proxy.SetData(data.c_str());
proxy.SetCUDAArray(data.c_str());
proxy.SetInfo("label", labels.c_str());
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CupyAdapter>));
@ -34,7 +34,7 @@ TEST(ProxyDMatrix, DeviceData) {
data = RandomDataGenerator(kRows, kCols, 0)
.Device(0)
.GenerateColumnarArrayInterface(&columnar_storage);
proxy.SetData(data.c_str());
proxy.SetCUDAArray(data.c_str());
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CudfAdapter>));
ASSERT_EQ(dmlc::get<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumRows(),
kRows);

View File

@ -1,16 +1,17 @@
/*!
* Copyright 2019-2022 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <gtest/gtest.h>
#include <xgboost/generic_parameters.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/proxy_dmatrix.h"
#include "../../../src/gbm/gbtree.h"
#include "../helpers.h"
#include "xgboost/base.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "../helpers.h"
#include "../../../src/gbm/gbtree.h"
#include "../../../src/data/adapter.h"
#include "xgboost/predictor.h"
namespace xgboost {
@ -246,53 +247,78 @@ TEST(Dart, JsonIO) {
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0ul);
}
TEST(Dart, Prediction) {
size_t constexpr kRows = 16, kCols = 10;
namespace {
class Dart : public testing::TestWithParam<char const*> {
public:
void Run(std::string predictor) {
size_t constexpr kRows = 16, kCols = 10;
HostDeviceVector<float> data;
auto array_str = RandomDataGenerator(kRows, kCols, 0).GenerateArrayInterface(&data);
auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols);
HostDeviceVector<float> data;
auto rng = RandomDataGenerator(kRows, kCols, 0);
if (predictor == "gpu_predictor") {
rng.Device(0);
}
auto array_str = rng.GenerateArrayInterface(&data);
auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols);
std::vector<bst_float> labels (kRows);
for (size_t i = 0; i < kRows; ++i) {
labels[i] = i % 2;
std::vector<bst_float> labels(kRows);
for (size_t i = 0; i < kRows; ++i) {
labels[i] = i % 2;
}
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows);
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("booster", "dart");
learner->SetParam("rate_drop", "0.5");
learner->Configure();
for (size_t i = 0; i < 16; ++i) {
learner->UpdateOneIter(i, p_mat);
}
learner->SetParam("predictor", predictor);
HostDeviceVector<float> predts_training;
learner->Predict(p_mat, false, &predts_training, 0, 0, true);
HostDeviceVector<float>* inplace_predts;
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
if (predictor == "gpu_predictor") {
x->SetCUDAArray(array_str.c_str());
} else {
x->SetArrayData(array_str.c_str());
}
learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&inplace_predts, 0, 0);
CHECK(inplace_predts);
HostDeviceVector<float> predts_inference;
learner->Predict(p_mat, false, &predts_inference, 0, 0, false);
auto const& h_predts_training = predts_training.ConstHostVector();
auto const& h_predts_inference = predts_inference.ConstHostVector();
auto const& h_inplace_predts = inplace_predts->HostVector();
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size());
for (size_t i = 0; i < predts_inference.Size(); ++i) {
// Inference doesn't drop tree.
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10);
// Inplace prediction is inference.
ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10);
}
}
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows);
};
} // anonymous namespace
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("booster", "dart");
learner->SetParam("rate_drop", "0.5");
learner->Configure();
TEST_P(Dart, Prediction) { this->Run(GetParam()); }
for (size_t i = 0; i < 16; ++i) {
learner->UpdateOneIter(i, p_mat);
}
#if defined(XGBOOST_USE_CUDA)
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart,
testing::Values("auto", "cpu_predictor", "gpu_predictor"));
#else
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("auto", "cpu_predictor"));
#endif // defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> predts_training;
learner->Predict(p_mat, false, &predts_training, 0, 0, true);
HostDeviceVector<float>* inplace_predts;
auto adapter = std::shared_ptr<data::ArrayAdapter>(new data::ArrayAdapter{StringView{array_str}});
learner->InplacePredict(adapter, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(),
&inplace_predts, 0, 0);
CHECK(inplace_predts);
HostDeviceVector<float> predts_inference;
learner->Predict(p_mat, false, &predts_inference, 0, 0, false);
auto const& h_predts_training = predts_training.ConstHostVector();
auto const& h_predts_inference = predts_inference.ConstHostVector();
auto const& h_inplace_predts = inplace_predts->HostVector();
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size());
for (size_t i = 0; i < predts_inference.Size(); ++i) {
// Inference doesn't drop tree.
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10);
// Inplace prediction is inference.
ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10);
}
}
std::pair<Json, Json> TestModelSlice(std::string booster) {
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
@ -485,19 +511,20 @@ TEST(GBTree, PredictRange) {
// inplace predict
HostDeviceVector<float> raw_storage;
auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage);
std::shared_ptr<data::ArrayAdapter> x{new data::ArrayAdapter{StringView{raw}}};
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
x->SetArrayData(raw.data());
HostDeviceVector<float>* out_predt;
learner->InplacePredict(x, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 2);
learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&out_predt, 0, 2);
auto h_out_predt = out_predt->HostVector();
learner->InplacePredict(x, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&out_predt, 0, 0);
auto h_out_predt_full = out_predt->HostVector();
ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin()));
ASSERT_THROW(learner->InplacePredict(x, nullptr, PredictionType::kValue,
ASSERT_THROW(learner->InplacePredict(x, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 3),
dmlc::Error);
}

View File

@ -5,11 +5,12 @@
#include <gtest/gtest.h>
#include <xgboost/predictor.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/proxy_dmatrix.h"
#include "../../../src/gbm/gbtree.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../helpers.h"
#include "test_predictor.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/gbm/gbtree.h"
#include "../../../src/data/adapter.h"
namespace xgboost {
TEST(CpuPredictor, Basic) {
@ -172,8 +173,11 @@ TEST(CpuPredictor, InplacePredict) {
HostDeviceVector<float> data;
gen.GenerateDense(&data);
ASSERT_EQ(data.Size(), kRows * kCols);
std::shared_ptr<data::DenseAdapter> x{
new data::DenseAdapter(data.HostPointer(), kRows, kCols)};
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
auto array_interface = GetArrayInterface(&data, kRows, kCols);
std::string arr_str;
Json::Dump(array_interface, &arr_str);
x->SetArrayData(arr_str.data());
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
}
@ -182,9 +186,15 @@ TEST(CpuPredictor, InplacePredict) {
HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
gen.GenerateCSR(&data, &rptrs, &columns);
std::shared_ptr<data::CSRAdapter> x{new data::CSRAdapter(
rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows,
data.Size(), kCols)};
auto data_interface = GetArrayInterface(&data, kRows * kCols, 1);
auto rptr_interface = GetArrayInterface(&rptrs, kRows + 1, 1);
auto col_interface = GetArrayInterface(&columns, kRows * kCols, 1);
std::string data_str, rptr_str, col_str;
Json::Dump(data_interface, &data_str);
Json::Dump(rptr_interface, &rptr_str);
Json::Dump(col_interface, &col_str);
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy};
x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true);
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
}
}

View File

@ -1,17 +1,19 @@
/*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <xgboost/predictor.h>
#include <xgboost/logging.h>
#include <xgboost/learner.h>
#include <xgboost/logging.h>
#include <xgboost/predictor.h>
#include <string>
#include "../helpers.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/data/device_adapter.cuh"
#include "../../../src/data/proxy_dmatrix.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../helpers.h"
#include "test_predictor.h"
namespace xgboost {
@ -135,8 +137,9 @@ TEST(GPUPredictor, InplacePredictCupy) {
gen.Device(0);
HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data);
auto x = std::make_shared<data::CupyAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
}
TEST(GPUPredictor, InplacePredictCuDF) {
@ -145,8 +148,9 @@ TEST(GPUPredictor, InplacePredictCuDF) {
gen.Device(0);
std::vector<HostDeviceVector<float>> storage(kCols);
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
auto x = std::make_shared<data::CudfAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
}
TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
@ -160,10 +164,10 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
gen.Device(1);
HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data);
auto x = std::make_shared<data::CupyAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1);
EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0),
dmlc::Error);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 1);
EXPECT_THROW(TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0), dmlc::Error);
}
TEST(GpuPredictor, LesserFeatures) {

View File

@ -2,19 +2,20 @@
* Copyright 2020-2021 by Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/predictor.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/generic_parameters.h>
#include "test_predictor.h"
#include "../helpers.h"
#include "../../../src/data/adapter.h"
#include "../../../src/common/io.h"
#include "../../../src/common/categorical.h"
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/predictor.h>
#include "../../../src/common/bitfield.h"
#include "../../../src/common/categorical.h"
#include "../../../src/common/io.h"
#include "../../../src/data/adapter.h"
#include "../../../src/data/proxy_dmatrix.h"
#include "../helpers.h"
namespace xgboost {
TEST(Predictor, PredictionCache) {
@ -83,9 +84,8 @@ void TestTrainingPrediction(size_t rows, size_t bins,
train("gpu_predictor", &predictions_1);
}
void TestInplacePrediction(dmlc::any x, std::string predictor,
bst_row_t rows, bst_feature_t cols,
int32_t device) {
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
bst_feature_t cols, int32_t device) {
size_t constexpr kClasses { 4 };
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device);
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true, false, kClasses);
@ -105,24 +105,21 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
}
HostDeviceVector<float> *p_out_predictions_0{nullptr};
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions_0, 0, 2);
CHECK(p_out_predictions_0);
HostDeviceVector<float> predict_0 (p_out_predictions_0->Size());
predict_0.Copy(*p_out_predictions_0);
HostDeviceVector<float> *p_out_predictions_1{nullptr};
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions_1, 2, 4);
CHECK(p_out_predictions_1);
HostDeviceVector<float> predict_1 (p_out_predictions_1->Size());
predict_1.Copy(*p_out_predictions_1);
HostDeviceVector<float>* p_out_predictions{nullptr};
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions, 0, 4);
auto& h_pred = p_out_predictions->HostVector();
@ -378,25 +375,28 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
learner->SetParam("predictor", predictor);
learner->Predict(Xy, false, &sparse_predt, 0, 0);
std::vector<float> with_nan(kRows * kCols, std::numeric_limits<float>::quiet_NaN());
for (auto const& page : Xy->GetBatches<SparsePage>()) {
HostDeviceVector<float> with_nan(kRows * kCols, std::numeric_limits<float>::quiet_NaN());
auto& h_with_nan = with_nan.HostVector();
for (auto const &page : Xy->GetBatches<SparsePage>()) {
auto batch = page.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
auto row = batch[i];
for (auto e : row) {
with_nan[i * kCols + e.index] = e.fvalue;
h_with_nan[i * kCols + e.index] = e.fvalue;
}
}
}
learner->SetParam("predictor", "cpu_predictor");
// Xcode_12.4 doesn't compile with `std::make_shared`.
auto dense = std::shared_ptr<data::DenseAdapter>(
new data::DenseAdapter(with_nan.data(), kRows, kCols));
auto dense = std::shared_ptr<DMatrix>(new data::DMatrixProxy{});
auto array_interface = GetArrayInterface(&with_nan, kRows, kCols);
std::string arr_str;
Json::Dump(array_interface, &arr_str);
dynamic_cast<data::DMatrixProxy *>(dense.get())->SetArrayData(arr_str.data());
HostDeviceVector<float> *p_dense_predt;
learner->InplacePredict(dmlc::any(dense), nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &p_dense_predt,
0, 0);
learner->InplacePredict(dense, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&p_dense_predt, 0, 0);
auto const& dense_predt = *p_dense_predt;
if (predictor == "cpu_predictor") {

View File

@ -61,9 +61,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method,
std::shared_ptr<DMatrix> p_full,
std::shared_ptr<DMatrix> p_hist);
void TestInplacePrediction(dmlc::any x, std::string predictor,
bst_row_t rows, bst_feature_t cols,
int32_t device = -1);
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
bst_feature_t cols, int32_t device = -1);
void TestPredictionWithLesserFeatures(std::string preditor_name);