Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
@@ -617,90 +618,92 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
char const *c_json_config, Learner *learner,
|
||||
size_t n_rows, size_t n_cols,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
|
||||
learner->BoostedRounds(), &shape, out_dim);
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
|
||||
xgboost::bst_ulong n_rows,
|
||||
xgboost::bst_ulong n_cols,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
|
||||
char const *array_interface,
|
||||
char const *c_json_config,
|
||||
DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::ArrayAdapter> x{
|
||||
new xgboost::data::ArrayAdapter(StringView{array_interface})};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::DenseAdapter> x{
|
||||
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
|
||||
const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
size_t nindptr,
|
||||
size_t nelem,
|
||||
size_t num_col,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const *c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong cols,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::CSRArrayAdapter> x{
|
||||
new xgboost::data::CSRArrayAdapter{
|
||||
StringView{indptr}, StringView{indices}, StringView{data}, cols}};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::CSRAdapter> x{
|
||||
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const float **out_result) {
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::AssertGPUSupport();
|
||||
|
||||
Reference in New Issue
Block a user