Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
@@ -617,90 +618,92 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
char const *c_json_config, Learner *learner,
|
||||
size_t n_rows, size_t n_cols,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
|
||||
learner->BoostedRounds(), &shape, out_dim);
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
|
||||
xgboost::bst_ulong n_rows,
|
||||
xgboost::bst_ulong n_cols,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
|
||||
char const *array_interface,
|
||||
char const *c_json_config,
|
||||
DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::ArrayAdapter> x{
|
||||
new xgboost::data::ArrayAdapter(StringView{array_interface})};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::DenseAdapter> x{
|
||||
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
|
||||
const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
size_t nindptr,
|
||||
size_t nelem,
|
||||
size_t num_col,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const *c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong cols,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::CSRArrayAdapter> x{
|
||||
new xgboost::data::CSRArrayAdapter{
|
||||
StringView{indptr}, StringView{indices}, StringView{data}, cols}};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::CSRAdapter> x{
|
||||
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const float **out_result) {
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::AssertGPUSupport();
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// Copyright (c) 2019-2020 by Contributors
|
||||
// Copyright (c) 2019-2021 by Contributors
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
@@ -30,59 +31,63 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
template <typename T>
|
||||
int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
||||
char const *c_json_config,
|
||||
std::shared_ptr<DMatrix> p_m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0)
|
||||
<< "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = std::make_shared<data::CudfAdapter>(json_str);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
auto x = std::make_shared<T>(json_str);
|
||||
HostDeviceVector<float> *p_predt{nullptr};
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
|
||||
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows();
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize,
|
||||
learner->Groups(), learner->BoostedRounds(), &shape,
|
||||
out_dim);
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = std::make_shared<data::CupyAdapter>(json_str);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
return InplacePreidctCuda<data::CudfAdapter>(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
return InplacePreidctCuda<data::CupyAdapter>(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
114
src/c_api/c_api_utils.h
Normal file
114
src/c_api/c_api_utils.h
Normal file
@@ -0,0 +1,114 @@
|
||||
/*!
|
||||
* Copyright (c) 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
*
|
||||
* \param strict_shape Whether should we reshape the output with consideration of groups
|
||||
* and forest.
|
||||
* \param type Prediction type
|
||||
* \param rows Input samples
|
||||
* \param cols Input features
|
||||
* \param chunksize Total elements of output / rows
|
||||
* \param groups Number of output groups from Learner
|
||||
* \param rounds end_iteration - beg_iteration
|
||||
* \param out_shape Output shape
|
||||
* \param out_dim Output dimension
|
||||
*/
|
||||
inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows, size_t cols,
|
||||
size_t chunksize, size_t groups, size_t rounds,
|
||||
std::vector<bst_ulong> *out_shape,
|
||||
xgboost::bst_ulong *out_dim) {
|
||||
auto &shape = *out_shape;
|
||||
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
|
||||
rows != 0) {
|
||||
CHECK_EQ(chunksize, groups);
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case PredictionType::kValue:
|
||||
case PredictionType::kMargin: {
|
||||
if (chunksize == 1 && !strict_shape) {
|
||||
*out_dim = 1;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
} else {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = groups;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kApproxContribution:
|
||||
case PredictionType::kContribution: {
|
||||
auto groups = chunksize / (cols + 1);
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = cols + 1;
|
||||
} else {
|
||||
*out_dim = 3;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = groups;
|
||||
shape[2] = cols + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kInteraction: {
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 3;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = cols + 1;
|
||||
shape[2] = cols + 1;
|
||||
} else {
|
||||
*out_dim = 4;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = groups;
|
||||
shape[2] = cols + 1;
|
||||
shape[3] = cols + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kLeaf: {
|
||||
if (strict_shape) {
|
||||
shape.resize(4);
|
||||
shape[0] = rows;
|
||||
shape[1] = rounds;
|
||||
shape[2] = groups;
|
||||
auto forest = chunksize / (shape[1] * shape[2]);
|
||||
forest = std::max(static_cast<decltype(forest)>(1), forest);
|
||||
shape[3] = forest;
|
||||
*out_dim = shape.size();
|
||||
} else {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = chunksize;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown prediction type:" << static_cast<int>(type);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(
|
||||
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
|
||||
chunksize * rows);
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2019~2020 by Contributors
|
||||
* Copyright (c) 2019~2021 by Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
@@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
|
||||
size_t num_columns_;
|
||||
};
|
||||
|
||||
class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface array_interface_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
|
||||
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
public:
|
||||
explicit ArrayAdapter(StringView array_interface) {
|
||||
auto j = Json::Load(array_interface);
|
||||
array_interface_ = ArrayInterface(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
}
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
|
||||
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
|
||||
public:
|
||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||
size_t num_cols)
|
||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
||||
}
|
||||
|
||||
CSRArrayAdapterBatch const& Value() const override {
|
||||
return batch_;
|
||||
}
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_cols * indptr_.num_rows;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
size_t NumColumns() const { return num_cols_; }
|
||||
|
||||
private:
|
||||
CSRArrayAdapterBatch batch_;
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t num_cols_;
|
||||
};
|
||||
|
||||
class CSCAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* Copyright 2019-2021 by Contributors
|
||||
* \file array_interface.h
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
@@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string UnSupportedType(const char (&typestr)[3]) {
|
||||
static std::string UnSupportedType(StringView typestr) {
|
||||
return TypeStr(typestr[1]) + " is not supported.";
|
||||
}
|
||||
};
|
||||
@@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
Validate(column);
|
||||
@@ -257,16 +258,24 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
type[0] = typestr.at(0);
|
||||
type[1] = typestr.at(1);
|
||||
type[2] = typestr.at(2);
|
||||
this->CheckType();
|
||||
this->AssignType(StringView{typestr});
|
||||
}
|
||||
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
|
||||
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
explicit ArrayInterface(StringView str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load(str);
|
||||
if (IsA<Object>(jinterface)) {
|
||||
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||
return;
|
||||
@@ -279,71 +288,114 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
void CheckType() const {
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
return;
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
type = kF8;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '1') {
|
||||
type = kI1;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '2') {
|
||||
type = kI2;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '4') {
|
||||
type = kI4;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '8') {
|
||||
type = kI8;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '1') {
|
||||
type = kU1;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '2') {
|
||||
type = kU2;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '4') {
|
||||
type = kU4;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '8') {
|
||||
type = kU8;
|
||||
} else {
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values;
|
||||
switch (type) {
|
||||
case kF4:
|
||||
p_values = reinterpret_cast<float *>(data) + offset;
|
||||
break;
|
||||
case kF8:
|
||||
p_values = reinterpret_cast<double *>(data) + offset;
|
||||
break;
|
||||
case kI1:
|
||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
||||
break;
|
||||
case kI2:
|
||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
||||
break;
|
||||
case kI4:
|
||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
||||
break;
|
||||
case kI8:
|
||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
||||
break;
|
||||
case kU1:
|
||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
||||
break;
|
||||
case kU2:
|
||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
||||
break;
|
||||
case kU4:
|
||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
||||
break;
|
||||
case kU8:
|
||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
||||
break;
|
||||
}
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
} else {
|
||||
SPAN_CHECK(false);
|
||||
return 0;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
|
||||
Type type;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*
|
||||
* \file data.cu
|
||||
* \brief Handles setting metainfo from array interface.
|
||||
@@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) {
|
||||
} // anonymous namespace
|
||||
|
||||
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
|
||||
<< "Expected integer metainfo";
|
||||
CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8)
|
||||
<< "Expected integer for group info.";
|
||||
|
||||
auto ptr_device = SetDeviceToPtr(column.data);
|
||||
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
|
||||
auto d_tmp = temp.data();
|
||||
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
d_tmp[idx] = column.GetElement(idx);
|
||||
d_tmp[idx] = column.GetElement<size_t>(idx);
|
||||
});
|
||||
auto length = column.num_rows;
|
||||
out->resize(length + 1);
|
||||
@@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return static_cast<uint32_t>(array_interface.GetElement(i));
|
||||
return array_interface.GetElement<uint32_t>(i);
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
|
||||
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
|
||||
if (array_interface.GetElement<uint32_t>(i) >
|
||||
array_interface.GetElement<uint32_t>(i + 1)) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@@ -265,15 +265,34 @@ class GBTree : public GradientBooster {
|
||||
bool training,
|
||||
unsigned ntree_limit) override;
|
||||
|
||||
void InplacePredict(dmlc::any const &x, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin,
|
||||
unsigned layer_end) const override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin, unsigned layer_end) const override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
std::vector<Predictor const *> predictors{
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
gpu_predictor_.get()
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
if (tparam_.predictor == PredictorType::kAuto) {
|
||||
// Try both predictor implementations
|
||||
for (auto const &p : predictors) {
|
||||
if (p && p->InplacePredict(x, p_m, model_, missing, out_preds,
|
||||
tree_begin, tree_end)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << msg;
|
||||
} else {
|
||||
bool success = this->GetPredictor()->InplacePredict(
|
||||
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
CHECK(success) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->gbm_->BoostedRounds();
|
||||
}
|
||||
uint32_t Groups() const override {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->learner_model_param_.num_output_group;
|
||||
}
|
||||
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||
return (*LearnerAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::string const &type,
|
||||
float missing, HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t iteration_begin,
|
||||
uint32_t iteration_end) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
|
||||
layer_end);
|
||||
if (type == "value") {
|
||||
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
|
||||
iteration_begin, iteration_end);
|
||||
if (type == PredictionType::kValue) {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == "margin") {
|
||||
} else if (type == PredictionType::kMargin) {
|
||||
// do nothing
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported prediction type:" << type;
|
||||
LOG(FATAL) << "Unsupported prediction type:" << static_cast<int>(type);
|
||||
}
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2017-2020
|
||||
* Copyright by Contributors 2017-2021
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
@@ -287,7 +287,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
@@ -295,33 +295,44 @@ class CPUPredictor : public Predictor {
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
MetaInfo info;
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
|
||||
} else {
|
||||
MetaInfo info;
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>,
|
||||
kBlockOfRowsSize>(AdapterView<Adapter>(
|
||||
m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Data type is not supported by CPU Predictor.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
@@ -644,7 +644,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter, typename Loader>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
@@ -659,16 +659,20 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
|
||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||
<< "but data is on: " << m->DeviceIdx();
|
||||
MetaInfo info;
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
|
||||
} else {
|
||||
MetaInfo info;
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
|
||||
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes);
|
||||
SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
size_t entry_start = 0;
|
||||
|
||||
@@ -680,23 +684,25 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
||||
info.num_row_, entry_start, use_shared, output_groups);
|
||||
m->NumRows(), entry_start, use_shared, output_groups);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
|
||||
this->DispatchedInplacePredict<
|
||||
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
|
||||
this->DispatchedInplacePredict<
|
||||
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
|
||||
Reference in New Issue
Block a user