Use Predictor for dart. (#6693)

* Use normal predictor for dart booster.
* Implement `inplace_predict` for dart.
* Enable `dart` for dask interface now that it's thread-safe.
* categorical data should be working out of box for dart now.

The implementation is not very efficient as it has to pull back the data and
apply weight for each tree, but still a significant improvement over previous
implementation as now we no longer binary search for each sample.

* Fix output prediction shape on dataframe.
This commit is contained in:
Jiaming Yuan
2021-02-09 23:30:19 +08:00
committed by GitHub
parent dbf7e9d3cb
commit e8c5c53e2f
13 changed files with 246 additions and 180 deletions

View File

@@ -31,6 +31,7 @@ TEST(CpuPredictor, Basic) {
// Test predict batch
PredictionCacheEntry out_predictions;
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
@@ -107,6 +108,7 @@ TEST(CpuPredictor, ExternalMemory) {
// Test predict batch
PredictionCacheEntry out_predictions;
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_);