Thread safe, inplace prediction. (#5389)
Normal prediction with DMatrix is now thread safe with locks. Added inplace prediction is lock free thread safe. When data is on device (cupy, cudf), the returned data is also on device. * Implementation for numpy, csr, cudf and cupy. * Implementation for dask. * Remove sync in simple dmatrix.
This commit is contained in:
parent
7f980e9f83
commit
6601a641d7
@ -9,6 +9,7 @@
|
||||
#define XGBOOST_GBM_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/any.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
@ -92,6 +93,22 @@ class GradientBooster : public Model, public Configurable {
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin = 0,
|
||||
uint32_t layer_end = 0) const {
|
||||
LOG(FATAL) << "Inplace predict is not supported by current booster.";
|
||||
}
|
||||
/*!
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is usually
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#ifndef XGBOOST_LEARNER_H_
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <dmlc/any.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
@ -120,6 +121,21 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
|
||||
bool approx_contribs = false,
|
||||
bool pred_interactions = false) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param type Prediction type.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds Pointer to output prediction vector.
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const& x, std::string const& type,
|
||||
float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin = 0, uint32_t layer_end = 0) = 0;
|
||||
|
||||
void LoadModel(Json const& in) override = 0;
|
||||
void SaveModel(Json* out) const override = 0;
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
@ -54,6 +55,7 @@ struct PredictionCacheEntry {
|
||||
class PredictionContainer {
|
||||
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
|
||||
void ClearExpiredEntries();
|
||||
std::mutex cache_lock_;
|
||||
|
||||
public:
|
||||
PredictionContainer() = default;
|
||||
@ -133,6 +135,18 @@ class Predictor {
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
uint32_t const ntree_limit = 0) = 0;
|
||||
|
||||
/**
|
||||
* \brief Inplace prediction.
|
||||
* \param x Type erased data adapter.
|
||||
* \param model The model to predict from.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
|
||||
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
|
||||
/**
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||
|
||||
@ -86,7 +86,7 @@ class CMakeExtension(Extension): # pylint: disable=too-few-public-methods
|
||||
super().__init__(name=name, sources=[])
|
||||
|
||||
|
||||
class BuildExt(build_ext.build_ext):
|
||||
class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
|
||||
'''Custom build_ext command using CMake.'''
|
||||
|
||||
logger = logging.getLogger('XGBoost build_ext')
|
||||
|
||||
@ -207,6 +207,19 @@ def ctypes2numpy(cptr, length, dtype):
|
||||
return res
|
||||
|
||||
|
||||
def ctypes2cupy(cptr, length, dtype):
|
||||
"""Convert a ctypes pointer array to a cupy array."""
|
||||
import cupy # pylint: disable=import-error
|
||||
mem = cupy.zeros(length.value, dtype=dtype, order='C')
|
||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
cupy.cuda.runtime.memcpy(
|
||||
mem.__cuda_array_interface__['data'][0], addr,
|
||||
length.value * ctypes.sizeof(ctypes.c_float),
|
||||
cupy.cuda.runtime.memcpyDeviceToDevice)
|
||||
return mem
|
||||
|
||||
|
||||
def ctypes2buffer(cptr, length):
|
||||
"""Convert ctypes pointer to buffer type."""
|
||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||
@ -474,6 +487,7 @@ class DMatrix(object):
|
||||
data, feature_names, feature_types = _convert_dataframes(
|
||||
data, feature_names, feature_types
|
||||
)
|
||||
missing = np.nan if missing is None else missing
|
||||
|
||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
||||
handle = ctypes.c_void_p()
|
||||
@ -1428,12 +1442,17 @@ class Booster(object):
|
||||
training=False):
|
||||
"""Predict with data.
|
||||
|
||||
.. note:: This function is not thread safe.
|
||||
.. note:: This function is not thread safe except for ``gbtree``
|
||||
booster.
|
||||
|
||||
For each booster object, predict can only be called from one thread.
|
||||
If you want to run prediction using multiple thread, call
|
||||
``bst.copy()`` to make copies of model object and then call
|
||||
``predict()``.
|
||||
For ``gbtree`` booster, the thread safety is guaranteed by locks.
|
||||
For lock free prediction use ``inplace_predict`` instead. Also, the
|
||||
safety does not hold when used in conjunction with other methods.
|
||||
|
||||
When using booster other than ``gbtree``, predict can only be called
|
||||
from one thread. If you want to run prediction using multiple
|
||||
thread, call ``bst.copy()`` to make copies of model object and then
|
||||
call ``predict()``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1547,6 +1566,146 @@ class Booster(object):
|
||||
preds = preds.reshape(nrow, chunk_size)
|
||||
return preds
|
||||
|
||||
def inplace_predict(self, data, iteration_range=(0, 0),
|
||||
predict_type='value', missing=np.nan):
|
||||
'''Run prediction in-place, Unlike ``predict`` method, inplace prediction does
|
||||
not cache the prediction result.
|
||||
|
||||
Calling only ``inplace_predict`` in multiple threads is safe and lock
|
||||
free. But the safety does not hold when used in conjunction with other
|
||||
methods. E.g. you can't train the booster in one thread and perform
|
||||
prediction in the other.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
booster.set_param({'predictor': 'gpu_predictor'})
|
||||
booster.inplace_predict(cupy_array)
|
||||
|
||||
booster.set_param({'predictor': 'cpu_predictor})
|
||||
booster.inplace_predict(numpy_array)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : numpy.ndarray/scipy.sparse.csr_matrix/cupy.ndarray/
|
||||
cudf.DataFrame/pd.DataFrame
|
||||
The input data, must not be a view for numpy array. Set
|
||||
``predictor`` to ``gpu_predictor`` for running prediction on CuPy
|
||||
array or CuDF DataFrame.
|
||||
iteration_range : tuple
|
||||
Specifies which layer of trees are used in prediction. For
|
||||
example, if a random forest is trained with 100 rounds. Specifying
|
||||
`iteration_range=(10, 20)`, then only the forests built during [10,
|
||||
20) (open set) rounds are used in this prediction.
|
||||
predict_type : str
|
||||
* `value` Output model prediction values.
|
||||
* `margin` Output the raw untransformed margin value.
|
||||
missing : float
|
||||
Value in the input data which needs to be present as a missing
|
||||
value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy.ndarray/cupy.ndarray
|
||||
The prediction result. When input data is on GPU, prediction
|
||||
result is stored in a cupy array.
|
||||
|
||||
'''
|
||||
|
||||
def reshape_output(predt, rows):
|
||||
'''Reshape for multi-output prediction.'''
|
||||
if predt.size != rows and predt.size % rows == 0:
|
||||
cols = int(predt.size / rows)
|
||||
predt = predt.reshape(rows, cols)
|
||||
return predt
|
||||
return predt
|
||||
|
||||
length = c_bst_ulong()
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
iteration_range = (ctypes.c_uint(iteration_range[0]),
|
||||
ctypes.c_uint(iteration_range[1]))
|
||||
|
||||
# once caching is supported, we can pass id(data) as cache id.
|
||||
if isinstance(data, DataFrame):
|
||||
data = data.values
|
||||
if isinstance(data, np.ndarray):
|
||||
assert data.flags.c_contiguous
|
||||
arr = np.array(data.reshape(data.size), copy=False,
|
||||
dtype=np.float32)
|
||||
_check_call(_LIB.XGBoosterPredictFromDense(
|
||||
self.handle,
|
||||
arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)
|
||||
))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(preds, rows)
|
||||
if isinstance(data, scipy.sparse.csr_matrix):
|
||||
csr = data
|
||||
_check_call(_LIB.XGBoosterPredictFromCSR(
|
||||
self.handle,
|
||||
c_array(ctypes.c_size_t, csr.indptr),
|
||||
c_array(ctypes.c_uint, csr.indices),
|
||||
c_array(ctypes.c_float, csr.data),
|
||||
ctypes.c_size_t(len(csr.indptr)),
|
||||
ctypes.c_size_t(len(csr.data)),
|
||||
ctypes.c_size_t(csr.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(preds, rows)
|
||||
if lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
|
||||
assert data.flags.c_contiguous
|
||||
interface = data.__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
_check_call(_LIB.XGBoosterPredictFromArrayInterface(
|
||||
self.handle,
|
||||
interface_str,
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
mem = ctypes2cupy(preds, length, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(mem, rows)
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
||||
self.handle,
|
||||
interfaces_str,
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
mem = ctypes2cupy(preds, length, np.float32)
|
||||
rows = data.shape[0]
|
||||
predt = reshape_output(mem, rows)
|
||||
return predt
|
||||
|
||||
raise TypeError('Data type:' + str(type(data)) +
|
||||
' not supported by inplace prediction.')
|
||||
|
||||
def save_model(self, fname):
|
||||
"""Save the model to a file.
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from .compat import da, dd, delayed, get_client
|
||||
from .compat import sparse, scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
from .core import DMatrix, Booster, _expect
|
||||
from .training import train as worker_train
|
||||
@ -86,7 +87,7 @@ class RabitContext:
|
||||
LOGGER.debug('--------------- rabit say bye ------------------')
|
||||
|
||||
|
||||
def concat(value):
|
||||
def concat(value): # pylint: disable=too-many-return-statements
|
||||
'''To be replaced with dask builtin.'''
|
||||
if isinstance(value[0], numpy.ndarray):
|
||||
return numpy.concatenate(value, axis=0)
|
||||
@ -98,6 +99,9 @@ def concat(value):
|
||||
return pandas_concat(value, axis=0)
|
||||
if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)):
|
||||
return CUDF_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
|
||||
import cupy # pylint: disable=import-error
|
||||
return cupy.concatenate(value, axis=0)
|
||||
return dd.multi.concat(list(value), axis=0)
|
||||
|
||||
|
||||
@ -370,8 +374,9 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
Specify the dask client used for training. Use default client
|
||||
returned from dask if it's set to None.
|
||||
\\*\\*kwargs:
|
||||
Other parameters are the same as `xgboost.train` except for `evals_result`,
|
||||
which is returned as part of function return value instead of argument.
|
||||
Other parameters are the same as `xgboost.train` except for
|
||||
`evals_result`, which is returned as part of function return value
|
||||
instead of argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -500,11 +505,10 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
import dask
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
mapped_predict, data, True,
|
||||
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
|
||||
@ -572,6 +576,79 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
return predictions
|
||||
|
||||
|
||||
def inplace_predict(client, model, data,
|
||||
iteration_range=(0, 0),
|
||||
predict_type='value',
|
||||
missing=numpy.nan):
|
||||
'''Inplace prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client: dask.distributed.Client
|
||||
Specify the dask client used for training. Use default client
|
||||
returned from dask if it's set to None.
|
||||
model: Booster/dict
|
||||
The trained model.
|
||||
iteration_range: tuple
|
||||
Specify the range of trees used for prediction.
|
||||
predict_type: str
|
||||
* 'value': Normal prediction result.
|
||||
* 'margin': Output the raw untransformed margin value.
|
||||
missing: float
|
||||
Value in the input data which needs to be present as a missing
|
||||
value. If None, defaults to np.nan.
|
||||
Returns
|
||||
-------
|
||||
prediction: dask.array.Array
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(data, is_df):
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
if is_df:
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
# There's an error with cudf saying `concat_cudf` got an
|
||||
# expected argument `ignore_index`. So this is not yet working.
|
||||
prediction = cudf.DataFrame({'prediction': prediction},
|
||||
dtype=numpy.float32)
|
||||
else:
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(prediction, columns=['prediction'],
|
||||
dtype=numpy.float32)
|
||||
return prediction
|
||||
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
da.map_blocks,
|
||||
mapped_predict, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
mapped_predict, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
|
||||
|
||||
def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
||||
'''
|
||||
Parameters
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/logging.h"
|
||||
@ -450,6 +451,95 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
|
||||
xgboost::bst_ulong n_rows,
|
||||
xgboost::bst_ulong n_cols,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
|
||||
const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
size_t nindptr,
|
||||
size_t nelem,
|
||||
size_t num_col,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const *c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA.";
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA.";
|
||||
API_END();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
@ -52,3 +52,60 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_s
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = data::CudfAdapter(json_str);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt);
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
}
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = data::CupyAdapter(json_str);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt);
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -52,6 +52,13 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
: std::numeric_limits<float>::quiet_NaN();
|
||||
return COOTuple(row_idx, column_idx, value);
|
||||
}
|
||||
__device__ float GetValue(size_t ridx, bst_feature_t fidx) const {
|
||||
auto const& column = columns_[fidx];
|
||||
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
|
||||
? column.GetElement(ridx)
|
||||
: std::numeric_limits<float>::quiet_NaN();
|
||||
return value;
|
||||
}
|
||||
|
||||
private:
|
||||
common::Span<ArrayInterface> columns_;
|
||||
@ -129,6 +136,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
for (auto& json_col : json_columns) {
|
||||
auto column = ArrayInterface(get<Object const>(json_col));
|
||||
columns.push_back(column);
|
||||
CHECK_EQ(column.num_cols, 1);
|
||||
column_ptr.emplace_back(column_ptr.back() + column.num_rows);
|
||||
num_rows_ = std::max(num_rows_, size_t(column.num_rows));
|
||||
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
||||
|
||||
@ -122,8 +122,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
}
|
||||
// Sync
|
||||
sparse_page_.data.HostVector();
|
||||
|
||||
info.num_col_ = adapter->NumColumns();
|
||||
info.num_row_ = adapter->NumRows();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@ -16,6 +16,7 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/gbm.h"
|
||||
@ -203,6 +204,22 @@ class GBTree : public GradientBooster {
|
||||
bool training,
|
||||
unsigned ntree_limit) override;
|
||||
|
||||
void InplacePredict(dmlc::any const &x, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin = 0,
|
||||
unsigned layer_end = 0) const override {
|
||||
CHECK(configured_);
|
||||
// From here on, layer becomes concrete trees.
|
||||
bst_group_t groups = model_.learner_model_param_->num_output_group;
|
||||
uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree;
|
||||
uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree;
|
||||
if (tree_end == 0 || tree_end > model_.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model_.trees.size());
|
||||
}
|
||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
#include <dmlc/parameter.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
@ -18,6 +20,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dmlc/any.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/model.h"
|
||||
@ -205,7 +208,7 @@ class LearnerConfiguration : public Learner {
|
||||
PredictionContainer cache_;
|
||||
|
||||
protected:
|
||||
bool need_configuration_;
|
||||
std::atomic<bool> need_configuration_;
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// Stores information like best-iteration for early stopping.
|
||||
std::map<std::string, std::string> attributes_;
|
||||
@ -214,6 +217,7 @@ class LearnerConfiguration : public Learner {
|
||||
LearnerModelParam learner_model_param_;
|
||||
LearnerTrainParam tparam_;
|
||||
std::vector<std::string> metric_names_;
|
||||
std::mutex config_lock_;
|
||||
|
||||
public:
|
||||
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||
@ -226,6 +230,9 @@ class LearnerConfiguration : public Learner {
|
||||
// Configuration before data is known.
|
||||
|
||||
void Configure() override {
|
||||
// Varient of double checked lock
|
||||
if (!this->need_configuration_) { return; }
|
||||
std::lock_guard<std::mutex> gard(config_lock_);
|
||||
if (!this->need_configuration_) { return; }
|
||||
|
||||
monitor_.Start("Configure");
|
||||
@ -1003,6 +1010,23 @@ class LearnerImpl : public LearnerIO {
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||
return (*XGBAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::string const &type,
|
||||
float missing, HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin = 0, uint32_t layer_end = 0) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
|
||||
layer_end);
|
||||
if (type == "value") {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == "margin") {
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported prediction type:" << type;
|
||||
}
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
const std::map<std::string, std::string>& GetConfigurationArguments() const override {
|
||||
return cfg_;
|
||||
}
|
||||
|
||||
@ -2,13 +2,22 @@
|
||||
* Copyright by Contributors 2017-2020
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "../data/adapter.h"
|
||||
#include "../common/math.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -16,89 +25,156 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
bst_float PredValue(const SparsePage::Inst &inst,
|
||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||
const std::vector<int> &tree_info, int bst_group,
|
||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
||||
unsigned tree_end) {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
template <size_t kUnrollLen = 8>
|
||||
struct SparsePageView {
|
||||
SparsePage const* page;
|
||||
bst_row_t base_rowid;
|
||||
static size_t constexpr kUnroll = kUnrollLen;
|
||||
|
||||
explicit SparsePageView(SparsePage const *p)
|
||||
: page{p}, base_rowid{page->base_rowid} {
|
||||
// Pull to host before entering omp block, as this is not thread safe.
|
||||
page->data.HostVector();
|
||||
page->offset.HostVector();
|
||||
}
|
||||
SparsePage::Inst operator[](size_t i) { return (*page)[i]; }
|
||||
size_t Size() const { return page->Size(); }
|
||||
};
|
||||
|
||||
template <typename Adapter, size_t kUnrollLen = 8>
|
||||
class AdapterView {
|
||||
Adapter* adapter_;
|
||||
float missing_;
|
||||
common::Span<Entry> workspace_;
|
||||
std::vector<size_t> current_unroll_;
|
||||
|
||||
public:
|
||||
static size_t constexpr kUnroll = kUnrollLen;
|
||||
|
||||
public:
|
||||
explicit AdapterView(Adapter *adapter, float missing,
|
||||
common::Span<Entry> workplace)
|
||||
: adapter_{adapter}, missing_{missing}, workspace_{workplace},
|
||||
current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {}
|
||||
SparsePage::Inst operator[](size_t i) {
|
||||
bst_feature_t columns = adapter_->NumColumns();
|
||||
auto const &batch = adapter_->Value();
|
||||
auto row = batch.GetLine(i);
|
||||
auto t = omp_get_thread_num();
|
||||
auto const beg = (columns * kUnroll * t) + (current_unroll_[t] * columns);
|
||||
size_t non_missing {beg};
|
||||
for (size_t c = 0; c < row.Size(); ++c) {
|
||||
auto e = row.GetElement(c);
|
||||
if (missing_ != e.value && !common::CheckNAN(e.value)) {
|
||||
workspace_[non_missing] =
|
||||
Entry{static_cast<bst_feature_t>(e.column_idx), e.value};
|
||||
++non_missing;
|
||||
}
|
||||
}
|
||||
auto ret = workspace_.subspan(beg, non_missing - beg);
|
||||
current_unroll_[t]++;
|
||||
if (current_unroll_[t] == kUnroll) {
|
||||
current_unroll_[t] = 0;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t Size() const { return adapter_->NumRows(); }
|
||||
|
||||
bst_row_t const static base_rowid = 0; // NOLINT
|
||||
};
|
||||
|
||||
template <typename DataView>
|
||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
auto& thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param_->num_output_group;
|
||||
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
auto constexpr kUnroll = DataView::kUnroll;
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec &feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (size_t k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += PredValue(inst[k], model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||
RegTree::FVec &feats = thread_temp[0];
|
||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
auto inst = batch[i];
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] += PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
static bst_float PredValue(const SparsePage::Inst& inst,
|
||||
const std::vector<std::unique_ptr<RegTree>>& trees,
|
||||
const std::vector<int>& tree_info, int bst_group,
|
||||
RegTree::FVec* p_feats,
|
||||
unsigned tree_begin, unsigned tree_end) {
|
||||
bst_float psum = 0.0f;
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
// init thread buffers
|
||||
inline void InitThreadTemp(int nthread, int num_feature) {
|
||||
int prev_thread_temp_size = thread_temp.size();
|
||||
static void InitThreadTemp(int nthread, int num_feature, std::vector<RegTree::FVec>* out) {
|
||||
int prev_thread_temp_size = out->size();
|
||||
if (prev_thread_temp_size < nthread) {
|
||||
thread_temp.resize(nthread, RegTree::FVec());
|
||||
out->resize(nthread, RegTree::FVec());
|
||||
for (int i = prev_thread_temp_size; i < nthread; ++i) {
|
||||
thread_temp[i].Init(num_feature);
|
||||
(*out)[i].Init(num_feature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PredInternal(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) {
|
||||
int32_t const num_group = model.learner_model_param_->num_output_group;
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
std::vector<bst_float>& preds = *out_preds;
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
<< "size_leaf_vector is enforced to 0 so far";
|
||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
// Pull to host before entering omp block, as this is not thread safe.
|
||||
batch.data.HostVector();
|
||||
batch.offset.HostVector();
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += this->PredValue(
|
||||
inst[k], model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||
RegTree::FVec& feats = thread_temp[0];
|
||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
auto inst = batch[i];
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx * num_group + gid;
|
||||
preds[offset] +=
|
||||
this->PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
const int threads = omp_get_max_threads();
|
||||
InitThreadTemp(threads, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param_->num_output_group);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchKernel(SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
|
||||
tree_end, &thread_temp_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,9 +251,9 @@ class CPUPredictor : public Predictor {
|
||||
CHECK_LE(beg_version, end_version);
|
||||
|
||||
if (beg_version < end_version) {
|
||||
this->PredInternal(dmat, &out_preds->HostVector(), model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
}
|
||||
|
||||
// delta means {size of forest} * {number of newly accumulated layers}
|
||||
@ -189,12 +265,49 @@ class CPUPredictor : public Predictor {
|
||||
out_preds->Size() == dmat->Info().num_row_);
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto threads = omp_get_max_threads();
|
||||
auto m = dmlc::get<Adapter>(x);
|
||||
CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
MetaInfo info;
|
||||
info.num_col_ = m.NumColumns();
|
||||
info.num_row_ = m.NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads, model.learner_model_param_->num_feature, &thread_temp);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
|
||||
&m, missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
if (x.type() == typeid(data::DenseAdapter)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(data::CSRAdapter)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Data type is not supported by CPU Predictor.";
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
if (thread_temp.size() == 0) {
|
||||
thread_temp.resize(1, RegTree::FVec());
|
||||
thread_temp[0].Init(model.learner_model_param_->num_feature);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model.learner_model_param_->num_feature);
|
||||
}
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
@ -204,16 +317,16 @@ class CPUPredictor : public Predictor {
|
||||
(model.param.size_leaf_vector + 1));
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model.learner_model_param_->num_output_group; ++gid) {
|
||||
(*out_preds)[gid] =
|
||||
PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&thread_temp[0], 0, ntree_limit) +
|
||||
model.learner_model_param_->base_score;
|
||||
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&thread_temp_[0], 0, ntree_limit) +
|
||||
model.learner_model_param_->base_score;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
@ -230,7 +343,7 @@ class CPUPredictor : public Predictor {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
RegTree::FVec &feats = thread_temp_[tid];
|
||||
feats.Fill(batch[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||
@ -247,7 +360,7 @@ class CPUPredictor : public Predictor {
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature);
|
||||
InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param_->num_output_group;
|
||||
@ -277,7 +390,7 @@ class CPUPredictor : public Predictor {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec& feats = thread_temp[omp_get_thread_num()];
|
||||
RegTree::FVec &feats = thread_temp_[omp_get_thread_num()];
|
||||
std::vector<bst_float> this_tree_contribs(ncolumns);
|
||||
// loop over all classes
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
@ -359,7 +472,10 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::vector<RegTree::FVec> thread_temp_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../common/common.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
@ -116,6 +117,76 @@ struct EllpackLoader {
|
||||
}
|
||||
};
|
||||
|
||||
struct CuPyAdapterLoader {
|
||||
data::CupyAdapterBatch batch;
|
||||
bst_feature_t columns;
|
||||
float* smem;
|
||||
bool use_shared;
|
||||
|
||||
DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared,
|
||||
bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) :
|
||||
batch{batch},
|
||||
columns{num_features},
|
||||
use_shared{use_shared} {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
if (use_shared) {
|
||||
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
size_t shared_elements = blockDim.x * num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
auto beg = global_idx * columns;
|
||||
auto end = (global_idx + 1) * columns;
|
||||
for (size_t i = beg; i < end; ++i) {
|
||||
smem[threadIdx.x * num_features + (i - beg)] = batch.GetElement(i).value;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * columns + fidx];
|
||||
}
|
||||
return batch.GetElement(ridx * columns + fidx).value;
|
||||
}
|
||||
};
|
||||
|
||||
struct CuDFAdapterLoader {
|
||||
data::CudfAdapterBatch batch;
|
||||
bst_feature_t columns;
|
||||
float* smem;
|
||||
bool use_shared;
|
||||
|
||||
DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared,
|
||||
bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: batch{batch}, columns{num_features}, use_shared{use_shared} {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
if (use_shared) {
|
||||
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
size_t shared_elements = blockDim.x * num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
for (size_t i = 0; i < columns; ++i) {
|
||||
smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * columns + fidx];
|
||||
}
|
||||
return batch.GetValue(ridx, fidx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
Loader* loader) {
|
||||
@ -169,30 +240,61 @@ __global__ void PredictKernel(Data data,
|
||||
}
|
||||
}
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
class DeviceModel {
|
||||
public:
|
||||
dh::device_vector<RegTree::Node> nodes;
|
||||
dh::device_vector<size_t> tree_segments;
|
||||
dh::device_vector<int> tree_group;
|
||||
size_t tree_beg_; // NOLINT
|
||||
size_t tree_end_; // NOLINT
|
||||
int num_group;
|
||||
|
||||
void CopyModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<RegTree::Node>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
nodes.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
|
||||
sizeof(RegTree::Node) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
tree_group.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
this->tree_beg_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group_ = model.learner_model_param_->num_output_group;
|
||||
this->num_group = model.learner_model_param_->num_output_group;
|
||||
}
|
||||
|
||||
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) {
|
||||
dh::safe_cuda(cudaSetDevice(gpu_id));
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
thrust::host_vector<size_t> h_tree_segments{};
|
||||
h_tree_segments.reserve((tree_end - tree_begin) + 1);
|
||||
size_t sum = 0;
|
||||
h_tree_segments.push_back(sum);
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
sum += model.trees.at(tree_idx)->GetNodes().size();
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
CopyModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
}
|
||||
};
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void PredictInternal(const SparsePage& batch, size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
@ -214,10 +316,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||
this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
dh::ToSpan(model_.nodes), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group),
|
||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
}
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector<bst_float>* out_preds,
|
||||
size_t batch_offset) {
|
||||
@ -230,31 +332,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>,
|
||||
batch,
|
||||
dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_),
|
||||
this->tree_begin_, this->tree_end_, batch.NumFeatures(), num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
thrust::host_vector<size_t> h_tree_segments{};
|
||||
h_tree_segments.reserve((tree_end - tree_begin) + 1);
|
||||
size_t sum = 0;
|
||||
h_tree_segments.push_back(sum);
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
sum += model.trees.at(tree_idx)->GetNodes().size();
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
dh::ToSpan(model_.nodes), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group),
|
||||
model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||
@ -264,8 +345,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
monitor_.StartCuda("DevicePredictInternal");
|
||||
InitModel(model, tree_begin, tree_end);
|
||||
model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
|
||||
if (dmat->PageExists<EllpackPage>()) {
|
||||
@ -284,7 +364,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
batch_offset += batch.Size() * model.learner_model_param_->num_output_group;
|
||||
}
|
||||
}
|
||||
monitor_.StopCuda("DevicePredictInternal");
|
||||
}
|
||||
|
||||
public:
|
||||
@ -302,6 +381,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
unsigned ntree_limit = 0) override {
|
||||
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
|
||||
// FIXME(trivialfis): Remove the duplication.
|
||||
std::lock_guard<std::mutex> const guard(lock_);
|
||||
int device = generic_param_->gpu_id;
|
||||
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
||||
ConfigureDevice(device);
|
||||
@ -348,6 +428,63 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
out_preds->Size() == dmat->Info().num_row_);
|
||||
}
|
||||
|
||||
template <typename Adapter, typename Loader, typename Batch>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id);
|
||||
uint32_t const output_groups = model.learner_model_param_->num_output_group;
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
|
||||
|
||||
auto m = dmlc::get<Adapter>(x);
|
||||
CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx())
|
||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||
<< "but data is on: " << m.DeviceIdx();
|
||||
MetaInfo info;
|
||||
info.num_col_ = m.NumColumns();
|
||||
info.num_row_ = m.NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||
|
||||
auto shared_memory_bytes =
|
||||
static_cast<size_t>(sizeof(float) * m.NumColumns() * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<Loader, Batch>,
|
||||
m.Value(),
|
||||
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(),
|
||||
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group),
|
||||
tree_begin, tree_end, m.NumColumns(), info.num_row_,
|
||||
entry_start, use_shared, output_groups);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id);
|
||||
if (x.type() == typeid(data::CupyAdapter)) {
|
||||
this->DispatchedInplacePredict<data::CupyAdapter, CuPyAdapterLoader, data::CupyAdapterBatch>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(data::CudfAdapter)) {
|
||||
this->DispatchedInplacePredict<data::CudfAdapter, CuDFAdapterLoader, data::CudfAdapterBatch>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
@ -411,14 +548,9 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
common::Monitor monitor_;
|
||||
dh::device_vector<RegTree::Node> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
std::mutex lock_;
|
||||
DeviceModel model_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
* Copyright 2017-2020 by Contributors
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <mutex>
|
||||
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
|
||||
@ -25,6 +26,7 @@ void PredictionContainer::ClearExpiredEntries() {
|
||||
}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
std::lock_guard<std::mutex> guard { cache_lock_ };
|
||||
this->ClearExpiredEntries();
|
||||
container_[m.get()].ref = m;
|
||||
if (device != GenericParameter::kCpuId) {
|
||||
|
||||
@ -177,9 +177,8 @@ void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
||||
}
|
||||
}
|
||||
|
||||
void RandomDataGenerator::GenerateArrayInterface(
|
||||
HostDeviceVector<float> *storage, std::string *out) const {
|
||||
CHECK(out);
|
||||
Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
|
||||
size_t rows, size_t cols) const {
|
||||
this->GenerateDense(storage);
|
||||
Json array_interface {Object()};
|
||||
array_interface["data"] = std::vector<Json>(2);
|
||||
@ -187,13 +186,37 @@ void RandomDataGenerator::GenerateArrayInterface(
|
||||
array_interface["data"][1] = Boolean(false);
|
||||
|
||||
array_interface["shape"] = std::vector<Json>(2);
|
||||
array_interface["shape"][0] = rows_;
|
||||
array_interface["shape"][1] = cols_;
|
||||
array_interface["shape"][0] = rows;
|
||||
array_interface["shape"][1] = cols;
|
||||
|
||||
array_interface["typestr"] = String("<f4");
|
||||
array_interface["version"] = 1;
|
||||
return array_interface;
|
||||
}
|
||||
|
||||
Json::Dump(array_interface, out);
|
||||
std::string RandomDataGenerator::GenerateArrayInterface(
|
||||
HostDeviceVector<float> *storage) const {
|
||||
auto array_interface = this->ArrayInterfaceImpl(storage, rows_, cols_);
|
||||
std::string out;
|
||||
Json::Dump(array_interface, &out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string RandomDataGenerator::GenerateColumnarArrayInterface(
|
||||
std::vector<HostDeviceVector<float>> *data) const {
|
||||
CHECK(data);
|
||||
CHECK_EQ(data->size(), cols_);
|
||||
auto& storage = *data;
|
||||
Json arr { Array() };
|
||||
for (size_t i = 0; i < cols_; ++i) {
|
||||
auto column = this->ArrayInterfaceImpl(&storage[i], rows_, 1);
|
||||
get<Array>(arr).emplace_back(column);
|
||||
}
|
||||
std::string out;
|
||||
Json::Dump(arr, &out);
|
||||
return out;
|
||||
}
|
||||
|
||||
void RandomDataGenerator::GenerateCSR(
|
||||
|
||||
@ -181,6 +181,9 @@ class RandomDataGenerator {
|
||||
int32_t device_;
|
||||
int32_t seed_;
|
||||
|
||||
Json ArrayInterfaceImpl(HostDeviceVector<float> *storage, size_t rows,
|
||||
size_t cols) const;
|
||||
|
||||
public:
|
||||
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
|
||||
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lower_{0.0f}, upper_{1.0f},
|
||||
@ -204,7 +207,9 @@ class RandomDataGenerator {
|
||||
}
|
||||
|
||||
void GenerateDense(HostDeviceVector<float>* out) const;
|
||||
void GenerateArrayInterface(HostDeviceVector<float>* storage, std::string* out) const;
|
||||
std::string GenerateArrayInterface(HostDeviceVector<float>* storage) const;
|
||||
std::string GenerateColumnarArrayInterface(
|
||||
std::vector<HostDeviceVector<float>> *data) const;
|
||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||
HostDeviceVector<bst_feature_t>* columns) const;
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
#include <xgboost/predictor.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "test_predictor.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(CpuPredictor, Basic) {
|
||||
@ -138,4 +140,27 @@ TEST(CpuPredictor, ExternalMemory) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, InplacePredict) {
|
||||
bst_row_t constexpr kRows{128};
|
||||
bst_feature_t constexpr kCols{64};
|
||||
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(-1);
|
||||
{
|
||||
HostDeviceVector<float> data;
|
||||
gen.GenerateDense(&data);
|
||||
ASSERT_EQ(data.Size(), kRows * kCols);
|
||||
data::DenseAdapter x{data.HostPointer(), kRows, kCols};
|
||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
||||
}
|
||||
|
||||
{
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
gen.GenerateCSR(&data, &rptrs, &columns);
|
||||
data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(),
|
||||
data.HostPointer(), kRows, data.Size(), kCols);
|
||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/learner.h>
|
||||
|
||||
#include <string>
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "test_predictor.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -104,5 +105,43 @@ TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, InplacePredictCupy) {
|
||||
size_t constexpr kRows{128}, kCols{64};
|
||||
RandomDataGenerator gen(kRows, kCols, 0.5);
|
||||
gen.Device(0);
|
||||
HostDeviceVector<float> data;
|
||||
std::string interface_str = gen.GenerateArrayInterface(&data);
|
||||
data::CupyAdapter x{interface_str};
|
||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, InplacePredictCuDF) {
|
||||
size_t constexpr kRows{128}, kCols{64};
|
||||
RandomDataGenerator gen(kRows, kCols, 0.5);
|
||||
gen.Device(0);
|
||||
std::vector<HostDeviceVector<float>> storage(kCols);
|
||||
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
|
||||
data::CudfAdapter x {interface_str};
|
||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, MGPU_InplacePredict) {
|
||||
int32_t n_gpus = xgboost::common::AllVisibleGPUs();
|
||||
if (n_gpus <= 1) {
|
||||
LOG(WARNING) << "GPUPredictor.MGPU_InplacePredict is skipped.";
|
||||
return;
|
||||
}
|
||||
size_t constexpr kRows{128}, kCols{64};
|
||||
RandomDataGenerator gen(kRows, kCols, 0.5);
|
||||
gen.Device(1);
|
||||
HostDeviceVector<float> data;
|
||||
std::string interface_str = gen.GenerateArrayInterface(&data);
|
||||
data::CupyAdapter x{interface_str};
|
||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1);
|
||||
EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0),
|
||||
dmlc::Error);
|
||||
}
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -77,4 +77,59 @@ void TestTrainingPrediction(size_t rows, std::string tree_method) {
|
||||
predictions_0.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
bst_row_t rows, bst_feature_t cols,
|
||||
int32_t device) {
|
||||
size_t constexpr kClasses { 4 };
|
||||
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device);
|
||||
std::shared_ptr<DMatrix> m = gen.GenerateDMatix(true, false, kClasses);
|
||||
|
||||
std::unique_ptr<Learner> learner {
|
||||
Learner::Create({m})
|
||||
};
|
||||
|
||||
learner->SetParam("num_parallel_tree", "4");
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->SetParam("seed", "0");
|
||||
learner->SetParam("subsample", "0.5");
|
||||
learner->SetParam("gpu_id", std::to_string(device));
|
||||
learner->SetParam("predictor", predictor);
|
||||
for (int32_t it = 0; it < 4; ++it) {
|
||||
learner->UpdateOneIter(it, m);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> *p_out_predictions_0{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions_0, 0, 2);
|
||||
CHECK(p_out_predictions_0);
|
||||
HostDeviceVector<float> predict_0 (p_out_predictions_0->Size());
|
||||
predict_0.Copy(*p_out_predictions_0);
|
||||
|
||||
HostDeviceVector<float> *p_out_predictions_1{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions_1, 2, 4);
|
||||
CHECK(p_out_predictions_1);
|
||||
HostDeviceVector<float> predict_1 (p_out_predictions_1->Size());
|
||||
predict_1.Copy(*p_out_predictions_1);
|
||||
|
||||
HostDeviceVector<float>* p_out_predictions{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions, 0, 4);
|
||||
|
||||
auto& h_pred = p_out_predictions->HostVector();
|
||||
auto& h_pred_0 = predict_0.HostVector();
|
||||
auto& h_pred_1 = predict_1.HostVector();
|
||||
|
||||
ASSERT_EQ(h_pred.size(), rows * kClasses);
|
||||
ASSERT_EQ(h_pred.size(), h_pred_0.size());
|
||||
ASSERT_EQ(h_pred.size(), h_pred_1.size());
|
||||
for (size_t i = 0; i < h_pred.size(); ++i) {
|
||||
// Need to remove the global bias here.
|
||||
ASSERT_NEAR(h_pred[i], h_pred_0[i] + h_pred_1[i] - 0.5f, kRtEps);
|
||||
}
|
||||
|
||||
learner->SetParam("gpu_id", "-1");
|
||||
learner->Configure();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -58,6 +58,9 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins
|
||||
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method);
|
||||
|
||||
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
bst_row_t rows, bst_feature_t cols,
|
||||
int32_t device = -1);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import xgboost as xgb
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_predict import run_threaded_predict # noqa
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@ -111,3 +115,65 @@ class TestGPUPredict(unittest.TestCase):
|
||||
|
||||
assert np.allclose(cpu_train_score, gpu_train_score)
|
||||
assert np.allclose(cpu_test_score, gpu_test_score)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_inplace_predict_cupy(self):
|
||||
import cupy as cp
|
||||
rows = 1000
|
||||
cols = 10
|
||||
cp_rng = cp.random.RandomState(1994)
|
||||
cp.random.set_random_state(cp_rng)
|
||||
X = cp.random.randn(rows, cols)
|
||||
y = cp.random.randn(rows)
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train({'tree_method': 'gpu_hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
test = xgb.DMatrix(X[:10, ...])
|
||||
predt_from_array = booster.inplace_predict(X[:10, ...])
|
||||
predt_from_dmatrix = booster.predict(test)
|
||||
|
||||
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
|
||||
|
||||
def predict_dense(x):
|
||||
inplace_predt = booster.inplace_predict(x)
|
||||
d = xgb.DMatrix(x)
|
||||
copied_predt = cp.array(booster.predict(d))
|
||||
return cp.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_inplace_predict_cudf(self):
|
||||
import cupy as cp
|
||||
import cudf
|
||||
import pandas as pd
|
||||
rows = 1000
|
||||
cols = 10
|
||||
rng = np.random.RandomState(1994)
|
||||
X = rng.randn(rows, cols)
|
||||
X = pd.DataFrame(X)
|
||||
y = rng.randn(rows)
|
||||
|
||||
X = cudf.from_pandas(X)
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train({'tree_method': 'gpu_hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
test = xgb.DMatrix(X)
|
||||
predt_from_array = booster.inplace_predict(X)
|
||||
predt_from_dmatrix = booster.predict(test)
|
||||
|
||||
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
|
||||
|
||||
def predict_df(x):
|
||||
inplace_predt = booster.inplace_predict(x)
|
||||
d = xgb.DMatrix(x)
|
||||
copied_predt = cp.array(booster.predict(d))
|
||||
return cp.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_df)
|
||||
|
||||
@ -2,6 +2,7 @@ import sys
|
||||
import pytest
|
||||
import numpy as np
|
||||
import unittest
|
||||
import xgboost
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@ -29,6 +30,7 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
def test_dask_dataframe(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
import cupy
|
||||
X, y = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
@ -49,6 +51,42 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
|
||||
# There's an error with cudf saying `concat_cudf` got an
|
||||
# expected argument `ignore_index`. So the test here is just
|
||||
# place holder.
|
||||
|
||||
# series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
# assert isinstance(series_predictions, dd.Series)
|
||||
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
cupy.testing.assert_allclose(single_node, predictions)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dask_array(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
import cupy
|
||||
X, y = generate_array()
|
||||
|
||||
X = X.map_blocks(cupy.asarray)
|
||||
y = y.map_blocks(cupy.asarray)
|
||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=2)
|
||||
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
|
||||
inplace_predictions = dxgb.inplace_predict(
|
||||
client, out, X).compute()
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
np.testing.assert_allclose(single_node, from_dmatrix)
|
||||
cupy.testing.assert_allclose(
|
||||
cupy.array(single_node),
|
||||
inplace_predictions)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
|
||||
63
tests/python/test_predict.py
Normal file
63
tests/python/test_predict.py
Normal file
@ -0,0 +1,63 @@
|
||||
'''Tests for running inplace prediction.'''
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
from scipy import sparse
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
def run_threaded_predict(X, rows, predict_func):
|
||||
results = []
|
||||
per_thread = 20
|
||||
with ThreadPoolExecutor(max_workers=10) as e:
|
||||
for i in range(0, rows, int(rows / per_thread)):
|
||||
try:
|
||||
predictor = X[i:i+per_thread, ...]
|
||||
except TypeError:
|
||||
predictor = X.iloc[i:i+per_thread, ...]
|
||||
f = e.submit(predict_func, predictor)
|
||||
results.append(f)
|
||||
|
||||
for f in results:
|
||||
assert f.result()
|
||||
|
||||
|
||||
class TestInplacePredict(unittest.TestCase):
|
||||
'''Tests for running inplace prediction'''
|
||||
def test_predict(self):
|
||||
rows = 1000
|
||||
cols = 10
|
||||
|
||||
np.random.seed(1994)
|
||||
|
||||
X = np.random.randn(rows, cols)
|
||||
y = np.random.randn(rows)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train({'tree_method': 'hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
|
||||
test = xgb.DMatrix(X[:10, ...])
|
||||
predt_from_array = booster.inplace_predict(X[:10, ...])
|
||||
predt_from_dmatrix = booster.predict(test)
|
||||
|
||||
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
||||
|
||||
def predict_dense(x):
|
||||
inplace_predt = booster.inplace_predict(x)
|
||||
d = xgb.DMatrix(x)
|
||||
copied_predt = booster.predict(d)
|
||||
return np.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
|
||||
def predict_csr(x):
|
||||
inplace_predt = booster.inplace_predict(sparse.csr_matrix(x))
|
||||
d = xgb.DMatrix(x)
|
||||
copied_predt = booster.predict(d)
|
||||
return np.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_csr)
|
||||
@ -63,8 +63,14 @@ def test_from_dask_dataframe():
|
||||
from_df = prediction.compute()
|
||||
|
||||
assert isinstance(prediction, dd.Series)
|
||||
assert np.all(prediction.compute().values == from_dmatrix)
|
||||
assert np.all(from_dmatrix == from_df.to_numpy())
|
||||
|
||||
series_predictions = xgb.dask.inplace_predict(client, booster, X)
|
||||
assert isinstance(series_predictions, dd.Series)
|
||||
np.testing.assert_allclose(series_predictions.compute().values,
|
||||
from_dmatrix)
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user