Use Predictor for dart. (#6693)

* Use normal predictor for dart booster.
* Implement `inplace_predict` for dart.
* Enable `dart` for dask interface now that it's thread-safe.
* categorical data should be working out of box for dart now.

The implementation is not very efficient as it has to pull back the data and
apply weight for each tree, but still a significant improvement over previous
implementation as now we no longer binary search for each sample.

* Fix output prediction shape on dataframe.
This commit is contained in:
Jiaming Yuan
2021-02-09 23:30:19 +08:00
committed by GitHub
parent dbf7e9d3cb
commit e8c5c53e2f
13 changed files with 246 additions and 180 deletions

View File

@@ -119,6 +119,17 @@ class Predictor {
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
/**
* \brief Initialize output prediction
*
* \param info Meta info for the DMatrix object used for prediction.
* \param out_predt Prediction vector to be initialized.
* \param model Tree model used for prediction.
*/
virtual void InitOutPredictions(const MetaInfo &info,
HostDeviceVector<bst_float> *out_predt,
const gbm::GBTreeModel &model) const = 0;
/**
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.