Simplify inplace-predict. (#7910)

Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
Jiaming Yuan
2022-05-18 17:52:00 +08:00
committed by GitHub
parent 19775ffe15
commit 765097d514
17 changed files with 317 additions and 297 deletions

View File

@@ -1,28 +1,29 @@
/*!
* Copyright 2017-2021 by Contributors
*/
#include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <thrust/host_vector.h>
#include <GPUTreeShap/gpu_treeshap.h>
#include <memory>
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
#include "../common/device_helpers.cuh"
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
#include "../gbm/gbtree_model.h"
#include "predict_fn.h"
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"
#include "xgboost/host_device_vector.h"
#include "predict_fn.h"
#include "../gbm/gbtree_model.h"
#include "../data/ellpack_page.cuh"
#include "../data/device_adapter.cuh"
#include "../common/common.h"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace predictor {
@@ -789,17 +790,19 @@ class GPUPredictor : public xgboost::Predictor {
m->NumRows(), entry_start, use_shared, output_groups, missing);
}
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, uint32_t tree_begin,
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel& model, float missing,
PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
this->DispatchedInplacePredict<data::CupyAdapter,
DeviceAdapterLoader<data::CupyAdapterBatch>>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
this->DispatchedInplacePredict<
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
this->DispatchedInplacePredict<data::CudfAdapter,
DeviceAdapterLoader<data::CudfAdapterBatch>>(
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else {
return false;