Use Predictor for dart. (#6693)

* Use normal predictor for dart booster.
* Implement `inplace_predict` for dart.
* Enable `dart` for dask interface now that it's thread-safe.
* categorical data should be working out of box for dart now.

The implementation is not very efficient as it has to pull back the data and
apply weight for each tree, but still a significant improvement over previous
implementation as now we no longer binary search for each sample.

* Fix output prediction shape on dataframe.
This commit is contained in:
Jiaming Yuan
2021-02-09 23:30:19 +08:00
committed by GitHub
parent dbf7e9d3cb
commit e8c5c53e2f
13 changed files with 246 additions and 180 deletions

View File

@@ -201,7 +201,7 @@ class CPUPredictor : public Predictor {
void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model) const {
const gbm::GBTreeModel& model) const override {
CHECK_NE(model.learner_model_param->num_output_group, 0);
size_t n = model.learner_model_param->num_output_group * info.num_row_;
const auto& base_margin = info.base_margin_.HostVector();
@@ -234,26 +234,16 @@ class CPUPredictor : public Predictor {
public:
explicit CPUPredictor(GenericParameter const* generic_param) :
Predictor::Predictor{generic_param} {}
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
const gbm::GBTreeModel &model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
auto* out_preds = &predts->predictions;
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0);
}
// This is actually already handled in gbm, but large amount of tests rely on the
// behaviour.
if (tree_end == 0) {
tree_end = model.trees.size();
}
if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model);
}
if (tree_end - tree_begin == 0) {
return;
}
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
tree_end);
}