diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index cce92d367..a731bfac8 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -111,15 +111,14 @@ class GradientBooster : public Model, public Configurable { /*! * \brief Inplace prediction. * - * \param x A type erased data adapter. + * \param p_fmat A proxy DMatrix that contains the data and related + * meta info. * \param missing Missing value in the data. * \param [in,out] out_preds The output preds. * \param layer_begin (Optional) Beginning of boosted tree layer used for prediction. * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. */ - virtual void InplacePredict(dmlc::any const &, std::shared_ptr, float, - PredictionCacheEntry*, - uint32_t, + virtual void InplacePredict(std::shared_ptr, float, PredictionCacheEntry*, uint32_t, uint32_t) const { LOG(FATAL) << "Inplace predict is not supported by current booster."; } diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 80004e6a8..b16ea67ec 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -139,21 +139,16 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { /*! * \brief Inplace prediction. * - * \param x A type erased data adapter. - * \param p_m An optional Proxy DMatrix object storing meta info like - * base margin. Can be nullptr. + * \param p_fmat A proxy DMatrix that contains the data and related meta info. * \param type Prediction type. * \param missing Missing value in the data. * \param [in,out] out_preds Pointer to output prediction vector. * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. */ - virtual void InplacePredict(dmlc::any const &x, - std::shared_ptr p_m, - PredictionType type, - float missing, - HostDeviceVector **out_preds, - uint32_t layer_begin, uint32_t layer_end) = 0; + virtual void InplacePredict(std::shared_ptr p_m, PredictionType type, float missing, + HostDeviceVector** out_preds, uint32_t layer_begin, + uint32_t layer_end) = 0; /*! * \brief Calculate feature score. See doc in C API for outputs. diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 506392261..33c695bc1 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -145,7 +145,9 @@ class Predictor { /** * \brief Inplace prediction. - * \param x Type erased data adapter. + * + * \param p_fmat A proxy DMatrix that contains the data and related + * meta info. * \param model The model to predict from. * \param missing Missing value in the data. * \param [in,out] out_preds The output preds. @@ -154,11 +156,9 @@ class Predictor { * * \return True if the data can be handled by current predictor, false otherwise. */ - virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, - uint32_t tree_begin = 0, - uint32_t tree_end = 0) const = 0; + virtual bool InplacePredict(std::shared_ptr p_fmat, const gbm::GBTreeModel& model, + float missing, PredictionCacheEntry* out_preds, + uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0; /** * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3c7c53980..d72eb077b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -300,7 +300,7 @@ XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle, CHECK(p_m); auto m = static_cast(p_m->get()); CHECK(m) << "Current DMatrix type does not support set data."; - m->SetData(c_interface_str); + m->SetCUDAArray(c_interface_str); API_END(); } @@ -312,7 +312,7 @@ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, CHECK(p_m); auto m = static_cast(p_m->get()); CHECK(m) << "Current DMatrix type does not support set data."; - m->SetData(c_interface_str); + m->SetCUDAArray(c_interface_str); API_END(); } @@ -825,74 +825,69 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, API_END(); } -template -void InplacePredictImpl(std::shared_ptr x, std::shared_ptr p_m, - char const *c_json_config, Learner *learner, - size_t n_rows, size_t n_cols, - xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, const float **out_result) { +void InplacePredictImpl(std::shared_ptr p_m, char const *c_json_config, Learner *learner, + xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, + const float **out_result) { auto config = Json::Load(StringView{c_json_config}); CHECK_EQ(get(config["cache_id"]), 0) << "Cache ID is not supported yet"; - HostDeviceVector* p_predt { nullptr }; + HostDeviceVector *p_predt{nullptr}; auto type = PredictionType(RequiredArg(config, "type", __func__)); float missing = GetMissing(config); - learner->InplacePredict(x, p_m, type, missing, &p_predt, + learner->InplacePredict(p_m, type, missing, &p_predt, RequiredArg(config, "iteration_begin", __func__), RequiredArg(config, "iteration_end", __func__)); CHECK(p_predt); auto &shape = learner->GetThreadLocal().prediction_shape; - auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows; + auto const &info = p_m->Info(); + auto n_samples = info.num_row_; + auto n_features = info.num_col_; + auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples; bool strict_shape = RequiredArg(config, "strict_shape", __func__); - CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(), + CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(), learner->BoostedRounds(), &shape, out_dim); *out_result = dmlc::BeginPtr(p_predt->HostVector()); *out_shape = dmlc::BeginPtr(shape); } -// A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, - char const *array_interface, - char const *c_json_config, - DMatrixHandle m, +XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_interface, + char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, - const float **out_result) { + xgboost::bst_ulong *out_dim, const float **out_result) { API_BEGIN(); CHECK_HANDLE(); - std::shared_ptr x{ - new xgboost::data::ArrayAdapter(StringView{array_interface})}; - std::shared_ptr p_m {nullptr}; - if (m) { + std::shared_ptr p_m{nullptr}; + if (!m) { + p_m.reset(new data::DMatrixProxy); + } else { p_m = *static_cast *>(m); } + auto proxy = dynamic_cast(p_m.get()); + CHECK(proxy) << "Invalid input type for inplace predict."; + proxy->SetArrayData(array_interface); auto *learner = static_cast(handle); - InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(), - x->NumColumns(), out_shape, out_dim, out_result); + InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result); API_END(); } -// A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, - char const *indices, char const *data, - xgboost::bst_ulong cols, +XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, char const *indices, + char const *data, xgboost::bst_ulong cols, char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, - const float **out_result) { + xgboost::bst_ulong *out_dim, const float **out_result) { API_BEGIN(); CHECK_HANDLE(); - std::shared_ptr x{ - new xgboost::data::CSRArrayAdapter{StringView{indptr}, - StringView{indices}, StringView{data}, - static_cast(cols)}}; - std::shared_ptr p_m {nullptr}; - if (m) { + std::shared_ptr p_m{nullptr}; + if (!m) { + p_m.reset(new data::DMatrixProxy); + } else { p_m = *static_cast *>(m); } + auto proxy = dynamic_cast(p_m.get()); + CHECK(proxy) << "Invalid input type for inplace predict."; + proxy->SetCSRData(indptr, indices, data, cols, true); auto *learner = static_cast(handle); - InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(), - x->NumColumns(), out_shape, out_dim, out_result); + InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result); API_END(); } diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 80408ba46..c3b303fa4 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,10 +1,11 @@ -// Copyright (c) 2019-2021 by Contributors -#include "xgboost/data.h" -#include "xgboost/c_api.h" -#include "xgboost/learner.h" +// Copyright (c) 2019-2022 by Contributors +#include "../data/device_adapter.cuh" +#include "../data/proxy_dmatrix.h" #include "c_api_error.h" #include "c_api_utils.h" -#include "../data/device_adapter.cuh" +#include "xgboost/c_api.h" +#include "xgboost/data.h" +#include "xgboost/learner.h" namespace xgboost { @@ -85,62 +86,65 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, API_END(); } -template -int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs, - char const *c_json_config, - std::shared_ptr p_m, - xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, const float **out_result) { +int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface, + char const *c_json_config, std::shared_ptr p_m, + xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, + const float **out_result) { API_BEGIN(); CHECK_HANDLE(); + if (!p_m) { + p_m.reset(new data::DMatrixProxy); + } + auto proxy = dynamic_cast(p_m.get()); + CHECK(proxy) << "Invalid input type for inplace predict."; + proxy->SetCUDAArray(c_array_interface); + auto config = Json::Load(StringView{c_json_config}); - CHECK_EQ(get(config["cache_id"]), 0) - << "Cache ID is not supported yet"; + CHECK_EQ(get(config["cache_id"]), 0) << "Cache ID is not supported yet"; auto *learner = static_cast(handle); - std::string json_str{c_json_strs}; - auto x = std::make_shared(json_str); HostDeviceVector *p_predt{nullptr}; - auto type = PredictionType(get(config["type"])); + auto type = PredictionType(RequiredArg(config, "type", __func__)); float missing = GetMissing(config); - learner->InplacePredict(x, p_m, type, missing, &p_predt, - get(config["iteration_begin"]), - get(config["iteration_end"])); + learner->InplacePredict(p_m, type, missing, &p_predt, + RequiredArg(config, "iteration_begin", __func__), + RequiredArg(config, "iteration_end", __func__)); CHECK(p_predt); CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead()); auto &shape = learner->GetThreadLocal().prediction_shape; - auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows(); - bool strict_shape = get(config["strict_shape"]); - CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize, - learner->Groups(), learner->BoostedRounds(), &shape, - out_dim); + size_t n_samples = p_m->Info().num_row_; + auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples; + bool strict_shape = RequiredArg(config, "strict_shape", __func__); + CalcPredictShape(strict_shape, type, n_samples, p_m->Info().num_col_, chunksize, + learner->Groups(), learner->BoostedRounds(), &shape, out_dim); *out_shape = dmlc::BeginPtr(shape); *out_result = p_predt->ConstDevicePointer(); API_END(); } -XGB_DLL int XGBoosterPredictFromCudaColumnar( - BoosterHandle handle, char const *c_json_strs, char const *c_json_config, - DMatrixHandle m, xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, const float **out_result) { - std::shared_ptr p_m {nullptr}; +XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c_json_strs, + char const *c_json_config, DMatrixHandle m, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, + const float **out_result) { + std::shared_ptr p_m{nullptr}; if (m) { p_m = *static_cast *>(m); } - return InplacePreidctCuda( - handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); + return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, + out_result); } -XGB_DLL int XGBoosterPredictFromCudaArray( - BoosterHandle handle, char const *c_json_strs, char const *c_json_config, - DMatrixHandle m, xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, const float **out_result) { - std::shared_ptr p_m {nullptr}; +XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_json_strs, + char const *c_json_config, DMatrixHandle m, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { + std::shared_ptr p_m{nullptr}; if (m) { p_m = *static_cast *>(m); } - return InplacePreidctCuda( - handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); + return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, + out_result); } diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 8a6f67f14..8744bbf77 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -55,7 +55,7 @@ class DMatrixProxy : public DMatrix { public: int DeviceIdx() const { return ctx_.gpu_id; } - void SetData(char const* c_interface) { + void SetCUDAArray(char const* c_interface) { common::AssertGPUSupport(); #if defined(XGBOOST_USE_CUDA) std::string interface_str = c_interface; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index bb7c341f8..8f8facc53 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -795,88 +795,75 @@ class Dart : public GBTree { this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end); } - void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - float missing, PredictionCacheEntry *out_preds, - uint32_t layer_begin, unsigned layer_end) const override { + void InplacePredict(std::shared_ptr p_fmat, float missing, + PredictionCacheEntry* p_out_preds, uint32_t layer_begin, + unsigned layer_end) const override { uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); - std::vector predictors{ + auto n_groups = model_.learner_model_param->num_output_group; + + std::vector predictors { cpu_predictor_.get(), #if defined(XGBOOST_USE_CUDA) gpu_predictor_.get() #endif // defined(XGBOOST_USE_CUDA) }; - Predictor const * predictor {nullptr}; - - MetaInfo info; + Predictor const* predictor{nullptr}; StringView msg{"Unsupported data type for inplace predict."}; - int32_t device = GenericParameter::kCpuId; + PredictionCacheEntry predts; - // Inplace predict is not used for training, so no need to drop tree. - for (size_t i = tree_begin; i < tree_end; ++i) { + if (ctx_->gpu_id != Context::kCpuId) { + predts.predictions.SetDevice(ctx_->gpu_id); + } + predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); + + auto predict_impl = [&](size_t i) { + predts.predictions.Fill(0); if (tparam_.predictor == PredictorType::kAuto) { // Try both predictor implementations bool success = false; - for (auto const &p : predictors) { - if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i, - i + 1)) { + for (auto const& p : predictors) { + if (p && p->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)) { success = true; predictor = p; -#if defined(XGBOOST_USE_CUDA) - device = predts.predictions.DeviceIdx(); -#endif // defined(XGBOOST_USE_CUDA) break; } } CHECK(success) << msg; } else { - // No base margin from meta info for each tree predictor = this->GetPredictor().get(); - bool success = predictor->InplacePredict(x, nullptr, model_, missing, - &predts, i, i + 1); - device = predts.predictions.DeviceIdx(); + bool success = predictor->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1); CHECK(success) << msg << std::endl << "Current Predictor: " - << (tparam_.predictor == PredictorType::kCPUPredictor - ? "cpu_predictor" - : "gpu_predictor"); + << (tparam_.predictor == PredictorType::kCPUPredictor ? "cpu_predictor" + : "gpu_predictor"); } + }; - auto w = this->weight_drop_.at(i); - size_t n_groups = model_.learner_model_param->num_output_group; - auto n_rows = predts.predictions.Size() / n_groups; - + // Inplace predict is not used for training, so no need to drop tree. + for (size_t i = tree_begin; i < tree_end; ++i) { + predict_impl(i); if (i == tree_begin) { - // base margin is added here. - if (p_m) { - p_m->Info().num_row_ = n_rows; - predictor->InitOutPredictions(p_m->Info(), &out_preds->predictions, - model_); - } else { - info.num_row_ = n_rows; - predictor->InitOutPredictions(info, &out_preds->predictions, model_); - } + predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); } - // Multiple the tree weight - CHECK_EQ(predts.predictions.Size(), out_preds->predictions.Size()); + auto w = this->weight_drop_.at(i); auto group = model_.tree_info.at(i); + CHECK_EQ(predts.predictions.Size(), p_out_preds->predictions.Size()); - if (device == GenericParameter::kCpuId) { - auto &h_predts = predts.predictions.HostVector(); - auto &h_out_predts = out_preds->predictions.HostVector(); + size_t n_rows = p_fmat->Info().num_row_; + if (predts.predictions.DeviceIdx() != Context::kCpuId) { + p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx()); + GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(), + predts.predictions.DeviceSpan(), w, n_rows, + model_.learner_model_param->base_score, n_groups, group); + } else { + auto& h_predts = predts.predictions.HostVector(); + auto& h_out_predts = p_out_preds->predictions.HostVector(); common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) { const size_t offset = ridx * n_groups + group; - // Need to remove the base margin from individual tree. h_out_predts[offset] += (h_predts[offset] - model_.learner_model_param->base_score) * w; }); - } else { - out_preds->predictions.SetDevice(device); - predts.predictions.SetDevice(device); - GPUDartInplacePredictInc(out_preds->predictions.DeviceSpan(), - predts.predictions.DeviceSpan(), w, n_rows, - model_.learner_model_param->base_score, - n_groups, group); } } } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 020b7d0cb..0d2d025e5 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -261,8 +261,7 @@ class GBTree : public GradientBooster { void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds, bool training, unsigned layer_begin, unsigned layer_end) override; - void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - float missing, PredictionCacheEntry *out_preds, + void InplacePredict(std::shared_ptr p_m, float missing, PredictionCacheEntry* out_preds, uint32_t layer_begin, unsigned layer_end) const override { CHECK(configured_); uint32_t tree_begin, tree_end; @@ -278,15 +277,14 @@ class GBTree : public GradientBooster { if (tparam_.predictor == PredictorType::kAuto) { // Try both predictor implementations for (auto const &p : predictors) { - if (p && p->InplacePredict(x, p_m, model_, missing, out_preds, - tree_begin, tree_end)) { + if (p && p->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end)) { return; } } LOG(FATAL) << msg; } else { - bool success = this->GetPredictor()->InplacePredict( - x, p_m, model_, missing, out_preds, tree_begin, tree_end); + bool success = this->GetPredictor()->InplacePredict(p_m, model_, missing, out_preds, + tree_begin, tree_end); CHECK(success) << msg << std::endl << "Current Predictor: " << (tparam_.predictor == PredictorType::kCPUPredictor diff --git a/src/learner.cc b/src/learner.cc index 568cfc680..5d7d067e7 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1277,15 +1277,12 @@ class LearnerImpl : public LearnerIO { return (*LearnerAPIThreadLocalStore::Get())[this]; } - void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - PredictionType type, float missing, - HostDeviceVector **out_preds, - uint32_t iteration_begin, + void InplacePredict(std::shared_ptr p_m, PredictionType type, float missing, + HostDeviceVector** out_preds, uint32_t iteration_begin, uint32_t iteration_end) override { this->Configure(); auto& out_predictions = this->GetThreadLocal().prediction_entry; - this->gbm_->InplacePredict(x, p_m, missing, &out_predictions, - iteration_begin, iteration_end); + this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end); if (type == PredictionType::kValue) { obj_->PredTransform(&out_predictions.predictions); } else if (type == PredictionType::kMargin) { diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 892c95631..b5dd9b4af 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,27 +1,27 @@ /*! * Copyright by Contributors 2017-2021 */ -#include #include +#include #include #include #include +#include "../common/categorical.h" +#include "../common/math.h" +#include "../common/threading_utils.h" +#include "../data/adapter.h" +#include "../data/proxy_dmatrix.h" +#include "../gbm/gbtree_model.h" +#include "predict_fn.h" #include "xgboost/base.h" #include "xgboost/data.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/logging.h" #include "xgboost/predictor.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" -#include "xgboost/logging.h" -#include "xgboost/host_device_vector.h" - -#include "predict_fn.h" -#include "../data/adapter.h" -#include "../common/math.h" -#include "../common/threading_utils.h" -#include "../common/categorical.h" -#include "../gbm/gbtree_model.h" namespace xgboost { namespace predictor { @@ -327,22 +327,24 @@ class CPUPredictor : public Predictor { &predictions, model, tree_begin, tree_end, &thread_temp, n_threads); } - bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - const gbm::GBTreeModel &model, float missing, + bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { + auto proxy = dynamic_cast(p_m.get()); + CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input."; + auto x = proxy->Adapter(); if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict( x, p_m, model, missing, out_preds, tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict( - x, p_m, model, missing, out_preds, tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict ( + this->DispatchedInplacePredict( x, p_m, model, missing, out_preds, tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict ( - x, p_m, model, missing, out_preds, tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else { return false; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0a09dc255..d20918cf2 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,28 +1,29 @@ /*! * Copyright 2017-2021 by Contributors */ +#include #include #include #include #include #include -#include + #include +#include "../common/bitfield.h" +#include "../common/categorical.h" +#include "../common/common.h" +#include "../common/device_helpers.cuh" +#include "../data/device_adapter.cuh" +#include "../data/ellpack_page.cuh" +#include "../data/proxy_dmatrix.h" +#include "../gbm/gbtree_model.h" +#include "predict_fn.h" #include "xgboost/data.h" +#include "xgboost/host_device_vector.h" #include "xgboost/predictor.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" -#include "xgboost/host_device_vector.h" - -#include "predict_fn.h" -#include "../gbm/gbtree_model.h" -#include "../data/ellpack_page.cuh" -#include "../data/device_adapter.cuh" -#include "../common/common.h" -#include "../common/bitfield.h" -#include "../common/categorical.h" -#include "../common/device_helpers.cuh" namespace xgboost { namespace predictor { @@ -789,17 +790,19 @@ class GPUPredictor : public xgboost::Predictor { m->NumRows(), entry_start, use_shared, output_groups, missing); } - bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, - const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, uint32_t tree_begin, + bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel& model, float missing, + PredictionCacheEntry* out_preds, uint32_t tree_begin, unsigned tree_end) const override { + auto proxy = dynamic_cast(p_m.get()); + CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input."; + auto x = proxy->Adapter(); if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict< - data::CupyAdapter, DeviceAdapterLoader>( + this->DispatchedInplacePredict>( x, p_m, model, missing, out_preds, tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict< - data::CudfAdapter, DeviceAdapterLoader>( + this->DispatchedInplacePredict>( x, p_m, model, missing, out_preds, tree_begin, tree_end); } else { return false; diff --git a/tests/cpp/data/test_proxy_dmatrix.cu b/tests/cpp/data/test_proxy_dmatrix.cu index d9f315a8f..a599ada6d 100644 --- a/tests/cpp/data/test_proxy_dmatrix.cu +++ b/tests/cpp/data/test_proxy_dmatrix.cu @@ -19,7 +19,7 @@ TEST(ProxyDMatrix, DeviceData) { .GenerateColumnarArrayInterface(&label_storage); DMatrixProxy proxy; - proxy.SetData(data.c_str()); + proxy.SetCUDAArray(data.c_str()); proxy.SetInfo("label", labels.c_str()); ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr)); @@ -34,7 +34,7 @@ TEST(ProxyDMatrix, DeviceData) { data = RandomDataGenerator(kRows, kCols, 0) .Device(0) .GenerateColumnarArrayInterface(&columnar_storage); - proxy.SetData(data.c_str()); + proxy.SetCUDAArray(data.c_str()); ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr)); ASSERT_EQ(dmlc::get>(proxy.Adapter())->NumRows(), kRows); diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index f9fe7d386..00201769b 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -1,16 +1,17 @@ /*! * Copyright 2019-2022 XGBoost contributors */ -#include #include +#include #include +#include "../../../src/data/adapter.h" +#include "../../../src/data/proxy_dmatrix.h" +#include "../../../src/gbm/gbtree.h" +#include "../helpers.h" #include "xgboost/base.h" #include "xgboost/host_device_vector.h" #include "xgboost/learner.h" -#include "../helpers.h" -#include "../../../src/gbm/gbtree.h" -#include "../../../src/data/adapter.h" #include "xgboost/predictor.h" namespace xgboost { @@ -246,53 +247,78 @@ TEST(Dart, JsonIO) { ASSERT_NE(get(model["model"]["weight_drop"]).size(), 0ul); } -TEST(Dart, Prediction) { - size_t constexpr kRows = 16, kCols = 10; +namespace { +class Dart : public testing::TestWithParam { + public: + void Run(std::string predictor) { + size_t constexpr kRows = 16, kCols = 10; - HostDeviceVector data; - auto array_str = RandomDataGenerator(kRows, kCols, 0).GenerateArrayInterface(&data); - auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols); + HostDeviceVector data; + auto rng = RandomDataGenerator(kRows, kCols, 0); + if (predictor == "gpu_predictor") { + rng.Device(0); + } + auto array_str = rng.GenerateArrayInterface(&data); + auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols); - std::vector labels (kRows); - for (size_t i = 0; i < kRows; ++i) { - labels[i] = i % 2; + std::vector labels(kRows); + for (size_t i = 0; i < kRows; ++i) { + labels[i] = i % 2; + } + p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows); + + auto learner = std::unique_ptr(Learner::Create({p_mat})); + learner->SetParam("booster", "dart"); + learner->SetParam("rate_drop", "0.5"); + learner->Configure(); + + for (size_t i = 0; i < 16; ++i) { + learner->UpdateOneIter(i, p_mat); + } + + learner->SetParam("predictor", predictor); + + HostDeviceVector predts_training; + learner->Predict(p_mat, false, &predts_training, 0, 0, true); + + HostDeviceVector* inplace_predts; + std::shared_ptr x{new data::DMatrixProxy{}}; + if (predictor == "gpu_predictor") { + x->SetCUDAArray(array_str.c_str()); + } else { + x->SetArrayData(array_str.c_str()); + } + learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits::quiet_NaN(), + &inplace_predts, 0, 0); + CHECK(inplace_predts); + + HostDeviceVector predts_inference; + learner->Predict(p_mat, false, &predts_inference, 0, 0, false); + + auto const& h_predts_training = predts_training.ConstHostVector(); + auto const& h_predts_inference = predts_inference.ConstHostVector(); + auto const& h_inplace_predts = inplace_predts->HostVector(); + ASSERT_EQ(h_predts_training.size(), h_predts_inference.size()); + ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size()); + for (size_t i = 0; i < predts_inference.Size(); ++i) { + // Inference doesn't drop tree. + ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10); + // Inplace prediction is inference. + ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10); + } } - p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows); +}; +} // anonymous namespace - auto learner = std::unique_ptr(Learner::Create({p_mat})); - learner->SetParam("booster", "dart"); - learner->SetParam("rate_drop", "0.5"); - learner->Configure(); +TEST_P(Dart, Prediction) { this->Run(GetParam()); } - for (size_t i = 0; i < 16; ++i) { - learner->UpdateOneIter(i, p_mat); - } +#if defined(XGBOOST_USE_CUDA) +INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, + testing::Values("auto", "cpu_predictor", "gpu_predictor")); +#else +INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("auto", "cpu_predictor")); +#endif // defined(XGBOOST_USE_CUDA) - HostDeviceVector predts_training; - learner->Predict(p_mat, false, &predts_training, 0, 0, true); - - HostDeviceVector* inplace_predts; - auto adapter = std::shared_ptr(new data::ArrayAdapter{StringView{array_str}}); - learner->InplacePredict(adapter, nullptr, PredictionType::kValue, - std::numeric_limits::quiet_NaN(), - &inplace_predts, 0, 0); - CHECK(inplace_predts); - - HostDeviceVector predts_inference; - learner->Predict(p_mat, false, &predts_inference, 0, 0, false); - - auto const& h_predts_training = predts_training.ConstHostVector(); - auto const& h_predts_inference = predts_inference.ConstHostVector(); - auto const& h_inplace_predts = inplace_predts->HostVector(); - ASSERT_EQ(h_predts_training.size(), h_predts_inference.size()); - ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size()); - for (size_t i = 0; i < predts_inference.Size(); ++i) { - // Inference doesn't drop tree. - ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10); - // Inplace prediction is inference. - ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10); - } -} std::pair TestModelSlice(std::string booster) { size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3; @@ -485,19 +511,20 @@ TEST(GBTree, PredictRange) { // inplace predict HostDeviceVector raw_storage; auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage); - std::shared_ptr x{new data::ArrayAdapter{StringView{raw}}}; + std::shared_ptr x{new data::DMatrixProxy{}}; + x->SetArrayData(raw.data()); HostDeviceVector* out_predt; - learner->InplacePredict(x, nullptr, PredictionType::kValue, - std::numeric_limits::quiet_NaN(), &out_predt, 0, 2); + learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits::quiet_NaN(), + &out_predt, 0, 2); auto h_out_predt = out_predt->HostVector(); - learner->InplacePredict(x, nullptr, PredictionType::kValue, - std::numeric_limits::quiet_NaN(), &out_predt, 0, 0); + learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits::quiet_NaN(), + &out_predt, 0, 0); auto h_out_predt_full = out_predt->HostVector(); ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin())); - ASSERT_THROW(learner->InplacePredict(x, nullptr, PredictionType::kValue, + ASSERT_THROW(learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits::quiet_NaN(), &out_predt, 0, 3), dmlc::Error); } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index f43747abd..5b03f31d8 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -5,11 +5,12 @@ #include #include +#include "../../../src/data/adapter.h" +#include "../../../src/data/proxy_dmatrix.h" +#include "../../../src/gbm/gbtree.h" +#include "../../../src/gbm/gbtree_model.h" #include "../helpers.h" #include "test_predictor.h" -#include "../../../src/gbm/gbtree_model.h" -#include "../../../src/gbm/gbtree.h" -#include "../../../src/data/adapter.h" namespace xgboost { TEST(CpuPredictor, Basic) { @@ -172,8 +173,11 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector data; gen.GenerateDense(&data); ASSERT_EQ(data.Size(), kRows * kCols); - std::shared_ptr x{ - new data::DenseAdapter(data.HostPointer(), kRows, kCols)}; + std::shared_ptr x{new data::DMatrixProxy{}}; + auto array_interface = GetArrayInterface(&data, kRows, kCols); + std::string arr_str; + Json::Dump(array_interface, &arr_str); + x->SetArrayData(arr_str.data()); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } @@ -182,9 +186,15 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector rptrs; HostDeviceVector columns; gen.GenerateCSR(&data, &rptrs, &columns); - std::shared_ptr x{new data::CSRAdapter( - rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows, - data.Size(), kCols)}; + auto data_interface = GetArrayInterface(&data, kRows * kCols, 1); + auto rptr_interface = GetArrayInterface(&rptrs, kRows + 1, 1); + auto col_interface = GetArrayInterface(&columns, kRows * kCols, 1); + std::string data_str, rptr_str, col_str; + Json::Dump(data_interface, &data_str); + Json::Dump(rptr_interface, &rptr_str); + Json::Dump(col_interface, &col_str); + std::shared_ptr x{new data::DMatrixProxy}; + x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 3113bc62b..0dbbc8d45 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -1,17 +1,19 @@ /*! * Copyright 2017-2020 XGBoost contributors */ -#include #include +#include #include -#include -#include #include +#include +#include + #include -#include "../helpers.h" -#include "../../../src/gbm/gbtree_model.h" #include "../../../src/data/device_adapter.cuh" +#include "../../../src/data/proxy_dmatrix.h" +#include "../../../src/gbm/gbtree_model.h" +#include "../helpers.h" #include "test_predictor.h" namespace xgboost { @@ -135,8 +137,9 @@ TEST(GPUPredictor, InplacePredictCupy) { gen.Device(0); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - auto x = std::make_shared(interface_str); - TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); + std::shared_ptr p_fmat{new data::DMatrixProxy}; + dynamic_cast(p_fmat.get())->SetCUDAArray(interface_str.c_str()); + TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0); } TEST(GPUPredictor, InplacePredictCuDF) { @@ -145,8 +148,9 @@ TEST(GPUPredictor, InplacePredictCuDF) { gen.Device(0); std::vector> storage(kCols); auto interface_str = gen.GenerateColumnarArrayInterface(&storage); - auto x = std::make_shared(interface_str); - TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); + std::shared_ptr p_fmat{new data::DMatrixProxy}; + dynamic_cast(p_fmat.get())->SetCUDAArray(interface_str.c_str()); + TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0); } TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT @@ -160,10 +164,10 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT gen.Device(1); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - auto x = std::make_shared(interface_str); - TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1); - EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0), - dmlc::Error); + std::shared_ptr p_fmat{new data::DMatrixProxy}; + dynamic_cast(p_fmat.get())->SetCUDAArray(interface_str.c_str()); + TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 1); + EXPECT_THROW(TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0), dmlc::Error); } TEST(GpuPredictor, LesserFeatures) { diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index e1d8b096a..832d2cf4c 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -2,19 +2,20 @@ * Copyright 2020-2021 by Contributors */ -#include -#include -#include -#include -#include - #include "test_predictor.h" -#include "../helpers.h" -#include "../../../src/data/adapter.h" -#include "../../../src/common/io.h" -#include "../../../src/common/categorical.h" +#include +#include +#include +#include +#include + #include "../../../src/common/bitfield.h" +#include "../../../src/common/categorical.h" +#include "../../../src/common/io.h" +#include "../../../src/data/adapter.h" +#include "../../../src/data/proxy_dmatrix.h" +#include "../helpers.h" namespace xgboost { TEST(Predictor, PredictionCache) { @@ -83,9 +84,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, train("gpu_predictor", &predictions_1); } -void TestInplacePrediction(dmlc::any x, std::string predictor, - bst_row_t rows, bst_feature_t cols, - int32_t device) { +void TestInplacePrediction(std::shared_ptr x, std::string predictor, bst_row_t rows, + bst_feature_t cols, int32_t device) { size_t constexpr kClasses { 4 }; auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device); std::shared_ptr m = gen.GenerateDMatrix(true, false, kClasses); @@ -105,24 +105,21 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, } HostDeviceVector *p_out_predictions_0{nullptr}; - learner->InplacePredict(x, nullptr, PredictionType::kMargin, - std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions_0, 0, 2); CHECK(p_out_predictions_0); HostDeviceVector predict_0 (p_out_predictions_0->Size()); predict_0.Copy(*p_out_predictions_0); HostDeviceVector *p_out_predictions_1{nullptr}; - learner->InplacePredict(x, nullptr, PredictionType::kMargin, - std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions_1, 2, 4); CHECK(p_out_predictions_1); HostDeviceVector predict_1 (p_out_predictions_1->Size()); predict_1.Copy(*p_out_predictions_1); HostDeviceVector* p_out_predictions{nullptr}; - learner->InplacePredict(x, nullptr, PredictionType::kMargin, - std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions, 0, 4); auto& h_pred = p_out_predictions->HostVector(); @@ -378,25 +375,28 @@ void TestSparsePrediction(float sparsity, std::string predictor) { learner->SetParam("predictor", predictor); learner->Predict(Xy, false, &sparse_predt, 0, 0); - std::vector with_nan(kRows * kCols, std::numeric_limits::quiet_NaN()); - for (auto const& page : Xy->GetBatches()) { + HostDeviceVector with_nan(kRows * kCols, std::numeric_limits::quiet_NaN()); + auto& h_with_nan = with_nan.HostVector(); + for (auto const &page : Xy->GetBatches()) { auto batch = page.GetView(); for (size_t i = 0; i < batch.Size(); ++i) { auto row = batch[i]; for (auto e : row) { - with_nan[i * kCols + e.index] = e.fvalue; + h_with_nan[i * kCols + e.index] = e.fvalue; } } } learner->SetParam("predictor", "cpu_predictor"); // Xcode_12.4 doesn't compile with `std::make_shared`. - auto dense = std::shared_ptr( - new data::DenseAdapter(with_nan.data(), kRows, kCols)); + auto dense = std::shared_ptr(new data::DMatrixProxy{}); + auto array_interface = GetArrayInterface(&with_nan, kRows, kCols); + std::string arr_str; + Json::Dump(array_interface, &arr_str); + dynamic_cast(dense.get())->SetArrayData(arr_str.data()); HostDeviceVector *p_dense_predt; - learner->InplacePredict(dmlc::any(dense), nullptr, PredictionType::kValue, - std::numeric_limits::quiet_NaN(), &p_dense_predt, - 0, 0); + learner->InplacePredict(dense, PredictionType::kValue, std::numeric_limits::quiet_NaN(), + &p_dense_predt, 0, 0); auto const& dense_predt = *p_dense_predt; if (predictor == "cpu_predictor") { diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 9c5d99afe..1ff96096c 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -61,9 +61,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method, std::shared_ptr p_full, std::shared_ptr p_hist); -void TestInplacePrediction(dmlc::any x, std::string predictor, - bst_row_t rows, bst_feature_t cols, - int32_t device = -1); +void TestInplacePrediction(std::shared_ptr x, std::string predictor, bst_row_t rows, + bst_feature_t cols, int32_t device = -1); void TestPredictionWithLesserFeatures(std::string preditor_name);