diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 4664ada3e..5ccc05a12 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -119,6 +119,17 @@ class Predictor { */ virtual void Configure(const std::vector>&); + /** + * \brief Initialize output prediction + * + * \param info Meta info for the DMatrix object used for prediction. + * \param out_predt Prediction vector to be initialized. + * \param model Tree model used for prediction. + */ + virtual void InitOutPredictions(const MetaInfo &info, + HostDeviceVector *out_predt, + const gbm::GBTreeModel &model) const = 0; + /** * \brief Generate batch predictions for a given feature matrix. May use * cached predictions if available instead of calculating from scratch. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 4ce2a3245..18067c889 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -804,7 +804,7 @@ async def _train_async( workers = list(_get_workers_from_data(dtrain, evals)) _rabit_args = await _get_rabit_args(len(workers), client) - if params.get("booster", None) is not None and params["booster"] != "gbtree": + if params.get("booster", None) == "gblinear": raise NotImplementedError( f"booster `{params['booster']}` is not yet supported for dask." ) @@ -949,6 +949,15 @@ async def _direct_predict_impl( meta: Dict[int, str], ) -> _DaskCollection: columns = list(meta.keys()) + if len(output_shape) >= 3 and isinstance(data, dd.DataFrame): + # Without this check, dask will finish the prediction silently even if output + # dimension is greater than 3. But during map_partitions, dask passes a + # `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by + # `_convert_unknown_data` since dd.DataFrame is not known to xgboost native + # binding. + raise ValueError( + "Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions." + ) if _can_output_df(isinstance(data, dd.DataFrame), output_shape): if base_margin is not None and isinstance(base_margin, da.Array): # Easier for map_partitions @@ -1012,6 +1021,7 @@ def _infer_predict_output( if kwargs.pop("predict_type") == "margin": kwargs["output_margin"] = True m = DMatrix(test_sample) + # generated DMatrix doesn't have feature name, so no validation. test_predt = booster.predict(m, validate_features=False, **kwargs) n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 meta: Dict[int, str] = {} @@ -1098,6 +1108,7 @@ async def _predict_async( pred_contribs=pred_contribs, approx_contribs=approx_contribs, pred_interactions=pred_interactions, + strict_shape=strict_shape, ) ) return await _direct_predict_impl( @@ -1116,6 +1127,7 @@ async def _predict_async( pred_contribs=pred_contribs, approx_contribs=approx_contribs, pred_interactions=pred_interactions, + strict_shape=strict_shape, ) ) # Prediction on dask DMatrix. @@ -1206,10 +1218,9 @@ def predict( # pylint: disable=unused-argument .. note:: Using ``inplace_predict`` might be faster when some features are not needed. See - :py:meth:`xgboost.Booster.predict` for details on various parameters. When using - ``pred_interactions`` with mutli-class model, input should be ``da.Array`` or - ``DaskDMatrix`` due to limitation in ``da.map_blocks``. - + :py:meth:`xgboost.Booster.predict` for details on various parameters. When output + has more than 2 dimensions (shap value, leaf with strict_shape), input should be + ``da.Array`` or ``DaskDMatrix``. .. versionadded:: 1.0.0 @@ -1233,8 +1244,8 @@ def predict( # pylint: disable=unused-argument prediction: dask.array.Array/dask.dataframe.Series When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an array, when input data is ``dask.dataframe.DataFrame``, return value can be - ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, - depending on the output shape. + ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output + shape. ''' _assert_dask_support() @@ -1297,6 +1308,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches inplace=True, predict_type=predict_type, iteration_range=iteration_range, + strict_shape=strict_shape, ) ) return await _direct_predict_impl( @@ -1352,8 +1364,9 @@ def inplace_predict( # pylint: disable=unused-argument prediction : When input data is ``dask.array.Array``, the return value is an array, when input data is ``dask.dataframe.DataFrame``, return value can be - ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, - depending on the output shape. + ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output + shape. + """ _assert_dask_support() client = _xgb_get_client(client) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index fcb18319d..be97b9b4c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -754,10 +754,7 @@ class XGBModel(XGBModelBase): # Inplace predict doesn't handle as many data types as DMatrix, but it's # sufficient for dask interface where input is simpiler. params = self.get_params() - booster = self.booster - if params.get("predictor", None) is None and ( - booster is None or booster == "gbtree" - ): + if params.get("predictor", None) is None and self.booster != "gblinear": return True return False diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 30732dbd8..cbdbd2bb0 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -455,12 +455,22 @@ void GBTree::PredictBatch(DMatrix* p_fmat, // When begin layer is not 0, the cache is not useful. reset = true; } + if (out_preds->predictions.Size() == 0 && p_fmat->Info().num_row_ != 0) { + CHECK_EQ(out_preds->version, 0); + } + + auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat); + if (out_preds->version == 0) { + // out_preds->Size() can be non-zero as it's initialized here before any + // tree is built at the 0^th iterator. + predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, + model_); + } uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); - GetPredictor(&out_preds->predictions, p_fmat) - ->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end); + predictor->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end); if (reset) { out_preds->version = 0; } else { @@ -625,54 +635,124 @@ class Dart : public GBTree { out["dart_train_param"] = ToJson(dparam_); } + // An independent const function to make sure it's thread safe. + void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds, + bool training, unsigned layer_begin, + unsigned layer_end) const { + auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat); + CHECK(predictor); + predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, + model_); + p_out_preds->version = 0; + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + for (size_t i = tree_begin; i < tree_end; i += 1) { + if (training && + std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) { + continue; + } + + CHECK_GE(i, p_out_preds->version); + auto version = i / this->LayerTrees(); + p_out_preds->version = version; + + auto n_groups = model_.learner_model_param->num_output_group; + PredictionCacheEntry predts; + predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); + predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1); + + // Multiple the weight to output prediction. + auto w = this->weight_drop_.at(i); + auto &h_predts = predts.predictions.HostVector(); + auto group = model_.tree_info.at(i); + auto &h_out_predts = p_out_preds->predictions.HostVector(); + CHECK_EQ(h_out_predts.size(), h_predts.size()); + for (size_t ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) { + const size_t offset = ridx * n_groups + group; + h_out_predts[offset] += (h_predts[offset] * w); + } + } + } + void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training, unsigned layer_begin, unsigned layer_end) override { DropTrees(training); - int num_group = model_.learner_model_param->num_output_group; - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = - detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end); + } - size_t n = num_group * p_fmat->Info().num_row_; - const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); - auto& out_preds = p_out_preds->predictions.HostVector(); - out_preds.resize(n); - if (base_margin.size() != 0) { - CHECK_EQ(out_preds.size(), n); - std::copy(base_margin.begin(), base_margin.end(), out_preds.begin()); - } else { - std::fill(out_preds.begin(), out_preds.end(), - model_.learner_model_param->base_score); + void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + float missing, PredictionCacheEntry *out_preds, + uint32_t layer_begin, unsigned layer_end) const override { + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + std::vector predictors{ + cpu_predictor_.get(), +#if defined(XGBOOST_USE_CUDA) + gpu_predictor_.get() +#endif // defined(XGBOOST_USE_CUDA) + }; + + MetaInfo info; + StringView msg{"Unsupported data type for inplace predict."}; + // Inplace predict is not used for training, so no need to drop tree. + for (size_t i = tree_begin; i < tree_end; ++i) { + PredictionCacheEntry predts; + if (tparam_.predictor == PredictorType::kAuto) { + // Try both predictor implementations + bool success = false; + for (auto const &p : predictors) { + if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i, + i + 1)) { + success = true; + break; + } + } + CHECK(success) << msg; + } else { + // No base margin for each tree + bool success = this->GetPredictor()->InplacePredict( + x, nullptr, model_, missing, &predts, tree_begin, tree_end); + CHECK(success) << msg; + } + + auto w = this->weight_drop_.at(i); + auto &h_predts = predts.predictions.HostVector(); + auto &h_out_predts = out_preds->predictions.HostVector(); + if (h_out_predts.empty()) { + auto n_rows = + h_predts.size() / model_.learner_model_param->num_output_group; + if (p_m) { + p_m->Info().num_row_ = n_rows; + cpu_predictor_->InitOutPredictions(p_m->Info(), + &out_preds->predictions, model_); + } else { + info.num_row_ = n_rows; + cpu_predictor_->InitOutPredictions(info, &out_preds->predictions, + model_); + } + } + + // Multiple the tree weight + CHECK_EQ(h_predts.size(), h_out_predts.size()); + for (size_t i = 0; i < h_out_predts.size(); ++i) { + // Need to remove the base margin from indiviual tree. + h_out_predts[i] += + (h_predts[i] - model_.learner_model_param->base_score) * w; + } } - const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread); - PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end); } void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, unsigned layer_begin, unsigned layer_end) override { DropTrees(false); - if (thread_temp_.size() == 0) { - thread_temp_.resize(1, RegTree::FVec()); - thread_temp_[0].Init(model_.learner_model_param->num_feature); - } - out_preds->resize(model_.learner_model_param->num_output_group); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); - // loop over output groups - for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { - (*out_preds)[gid] = - PredValue(inst, gid, &thread_temp_[0], 0, tree_end) + - model_.learner_model_param->base_score; - } - } - - bool UseGPU() const override { - return GBTree::UseGPU(); + auto &predictor = this->GetPredictor(); + uint32_t _, tree_end; + std::tie(_, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + predictor->PredictInstance(inst, out_preds, model_, tree_end); } void PredictContribution(DMatrix* p_fmat, @@ -697,60 +777,6 @@ class Dart : public GBTree { } protected: - inline void PredLoopSpecalize( - DMatrix* p_fmat, - std::vector* out_preds, - int num_group, - unsigned tree_begin, - unsigned tree_end) { - CHECK_EQ(num_group, model_.learner_model_param->num_output_group); - std::vector& preds = *out_preds; - CHECK_EQ(model_.param.size_leaf_vector, 0) - << "size_leaf_vector is enforced to 0 so far"; - CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group); - // start collecting the prediction - for (const auto &batch : p_fmat->GetBatches()) { - auto page = batch.GetView(); - constexpr int kUnroll = 8; - const auto nsize = static_cast(batch.Size()); - const bst_omp_uint rest = nsize % kUnroll; - if (nsize >= kUnroll) { -#pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { - const int tid = omp_get_thread_num(); - RegTree::FVec& feats = thread_temp_[tid]; - int64_t ridx[kUnroll]; - SparsePage::Inst inst[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - ridx[k] = static_cast(batch.base_rowid + i + k); - } - for (int k = 0; k < kUnroll; ++k) { - inst[k] = page[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx[k] * num_group + gid; - preds[offset] += - this->PredValue(inst[k], gid, &feats, tree_begin, tree_end); - } - } - } - } - - for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { - RegTree::FVec& feats = thread_temp_[0]; - const auto ridx = static_cast(batch.base_rowid + i); - const SparsePage::Inst inst = page[i]; - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx * num_group + gid; - preds[offset] += - this->PredValue(inst, gid, - &feats, tree_begin, tree_end); - } - } - } - } - // commit new trees all at once void CommitModel(std::vector>>&& new_trees, @@ -765,32 +791,13 @@ class Dart : public GBTree { << "weight = " << weight_drop_.back(); } - // predict the leaf scores without dropped trees - bst_float PredValue(const SparsePage::Inst &inst, int bst_group, - RegTree::FVec *p_feats, unsigned tree_begin, - unsigned tree_end) const { - bst_float psum = 0.0f; - p_feats->Fill(inst); - for (size_t i = tree_begin; i < tree_end; ++i) { - if (model_.tree_info[i] == bst_group) { - bool drop = std::binary_search(idx_drop_.begin(), idx_drop_.end(), i); - if (!drop) { - int tid = model_.trees[i]->GetLeafIndex(*p_feats); - psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue(); - } - } - } - p_feats->Drop(inst); - return psum; - } - - // select which trees to drop - // passing clear=True will clear selection + // Select which trees to drop. inline void DropTrees(bool is_training) { - idx_drop_.clear(); if (!is_training) { + // This function should be thread safe when it's not training. return; } + idx_drop_.clear(); std::uniform_real_distribution<> runif(0.0, 1.0); auto& rnd = common::GlobalRandom(); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 338f24afc..2704521d7 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -201,7 +201,7 @@ class CPUPredictor : public Predictor { void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const { + const gbm::GBTreeModel& model) const override { CHECK_NE(model.learner_model_param->num_output_group, 0); size_t n = model.learner_model_param->num_output_group * info.num_row_; const auto& base_margin = info.base_margin_.HostVector(); @@ -234,26 +234,16 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(GenericParameter const* generic_param) : Predictor::Predictor{generic_param} {} + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, uint32_t tree_begin, uint32_t tree_end = 0) const override { auto* out_preds = &predts->predictions; - if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { - CHECK_EQ(predts->version, 0); - } // This is actually already handled in gbm, but large amount of tests rely on the // behaviour. if (tree_end == 0) { tree_end = model.trees.size(); } - if (predts->version == 0) { - // out_preds->Size() can be non-zero as it's initialized here before any tree is - // built at the 0^th iterator. - this->InitOutPredictions(dmat->Info(), out_preds, model); - } - if (tree_end - tree_begin == 0) { - return; - } this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, tree_end); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 03b9e1652..81961681d 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -599,21 +599,9 @@ class GPUPredictor : public xgboost::Predictor { int device = generic_param_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; auto* out_preds = &predts->predictions; - - if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { - CHECK_EQ(predts->version, 0); - } if (tree_end == 0) { tree_end = model.trees.size(); } - if (predts->version == 0) { - // out_preds->Size() can be non-zero as it's initialized here before any tree is - // built at the 0^th iterator. - this->InitOutPredictions(dmat->Info(), out_preds, model); - } - if (tree_end - tree_begin == 0) { - return; - } this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); } @@ -788,7 +776,7 @@ class GPUPredictor : public xgboost::Predictor { protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const { + const gbm::GBTreeModel& model) const override { size_t n_classes = model.learner_model_param->num_output_group; size_t n = n_classes * info.num_row_; const HostDeviceVector& base_margin = info.base_margin_; diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 2fbbab27f..8a3650bdf 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -10,6 +10,7 @@ #include "xgboost/learner.h" #include "../helpers.h" #include "../../../src/gbm/gbtree.h" +#include "../../../src/data/adapter.h" #include "xgboost/predictor.h" namespace xgboost { @@ -247,7 +248,9 @@ TEST(Dart, JsonIO) { TEST(Dart, Prediction) { size_t constexpr kRows = 16, kCols = 10; - auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + HostDeviceVector data; + auto array_str = RandomDataGenerator(kRows, kCols, 0).GenerateArrayInterface(&data); + auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols); std::vector labels (kRows); for (size_t i = 0; i < kRows; ++i) { @@ -265,16 +268,28 @@ TEST(Dart, Prediction) { } HostDeviceVector predts_training; - learner->Predict(p_mat, false, &predts_training, 0, true); - HostDeviceVector predts_inference; - learner->Predict(p_mat, false, &predts_inference, 0, false); + learner->Predict(p_mat, false, &predts_training, 0, 0, true); - auto& h_predts_training = predts_training.ConstHostVector(); - auto& h_predts_inference = predts_inference.ConstHostVector(); + HostDeviceVector* inplace_predts; + auto adapter = std::shared_ptr(new data::ArrayAdapter{StringView{array_str}}); + learner->InplacePredict(adapter, nullptr, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), + &inplace_predts, 0, 0); + CHECK(inplace_predts); + + HostDeviceVector predts_inference; + learner->Predict(p_mat, false, &predts_inference, 0, 0, false); + + auto const& h_predts_training = predts_training.ConstHostVector(); + auto const& h_predts_inference = predts_inference.ConstHostVector(); + auto const& h_inplace_predts = inplace_predts->HostVector(); ASSERT_EQ(h_predts_training.size(), h_predts_inference.size()); + ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size()); for (size_t i = 0; i < predts_inference.Size(); ++i) { // Inference doesn't drop tree. - ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps); + ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10); + // Inplace prediction is inference. + ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10); } } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index c5ee0b2e2..9a62d9ba3 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -31,6 +31,7 @@ TEST(CpuPredictor, Basic) { // Test predict batch PredictionCacheEntry out_predictions; + cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.predictions.HostVector(); @@ -107,6 +108,7 @@ TEST(CpuPredictor, ExternalMemory) { // Test predict batch PredictionCacheEntry out_predictions; + cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); std::vector &out_predictions_h = out_predictions.predictions.HostVector(); ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 6d38aec29..79ea0c8cf 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -44,7 +44,9 @@ TEST(GPUPredictor, Basic) { PredictionCacheEntry gpu_out_predictions; PredictionCacheEntry cpu_out_predictions; + gpu_predictor->InitOutPredictions(dmat->Info(), &gpu_out_predictions.predictions, model); gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0); + cpu_predictor->InitOutPredictions(dmat->Info(), &cpu_out_predictions.predictions, model); cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0); std::vector& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector(); @@ -111,6 +113,7 @@ TEST(GPUPredictor, ExternalMemoryTest) { for (const auto& dmat: dmats) { dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); PredictionCacheEntry out_predictions; + gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes); const std::vector &host_vector = out_predictions.predictions.ConstHostVector(); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 8df9d72d2..388a59cb8 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -218,6 +218,7 @@ void TestCategoricalPrediction(std::string name) { row[split_ind] = split_cat; auto m = GetDMatrixFromData(row, 1, kCols); + predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); ASSERT_EQ(out_predictions.predictions.Size(), 1ul); ASSERT_EQ(out_predictions.predictions.HostVector()[0], @@ -226,6 +227,7 @@ void TestCategoricalPrediction(std::string name) { row[split_ind] = split_cat + 1; m = GetDMatrixFromData(row, 1, kCols); out_predictions.version = 0; + predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + param.base_score); diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 68e034e0a..296d532d6 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -29,9 +29,11 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols, auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(); PredictionCacheEntry approx_out_predictions; + predictor->InitOutPredictions(p_hist->Info(), &approx_out_predictions.predictions, model); predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0); PredictionCacheEntry precise_out_predictions; + predictor->InitOutPredictions(p_precise->Info(), &precise_out_predictions.predictions, model); predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0); for (size_t i = 0; i < rows; ++i) { @@ -46,6 +48,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols, // matrix is used for training. auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(); PredictionCacheEntry precise_out_predictions; + predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model); predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0); ASSERT_FALSE(p_dmat->PageExists()); } diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 174a4a13e..7502619f2 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -22,6 +22,16 @@ def run_threaded_predict(X, rows, predict_func): assert f.result() +def verify_leaf_output(leaf: np.ndarray, num_parallel_tree: int): + for i in range(leaf.shape[0]): # n_samples + for j in range(leaf.shape[1]): # n_rounds + for k in range(leaf.shape[2]): # n_classes + tree_group = leaf[i, j, k, :] + assert tree_group.shape[0] == num_parallel_tree + # No sampling, all trees within forest are the same + assert np.all(tree_group == tree_group[0]) + + def run_predict_leaf(predictor): rows = 100 cols = 4 @@ -53,13 +63,7 @@ def run_predict_leaf(predictor): assert leaf.shape[2] == classes assert leaf.shape[3] == num_parallel_tree - for i in range(rows): - for j in range(num_boost_round): - for k in range(classes): - tree_group = leaf[i, j, k, :] - assert tree_group.shape[0] == num_parallel_tree - # No sampling, all trees within forest are the same - assert np.all(tree_group == tree_group[0]) + verify_leaf_output(leaf, num_parallel_tree) ntree_limit = 2 sliced = booster.predict( diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index f70b8ea36..1d531b085 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -18,6 +18,7 @@ import hypothesis from hypothesis import given, settings, note, HealthCheck from test_updaters import hist_parameter_strategy, exact_parameter_strategy from test_with_sklearn import run_feature_weights, run_data_initialization +from test_predict import verify_leaf_output if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -748,9 +749,9 @@ def test_dask_ranking(client: "Client") -> None: d = d.toarray() d[d == 0] = np.nan d[np.isinf(d)] = 0 - data.append(da.from_array(d)) + data.append(dd.from_array(d, chunksize=32)) else: - data.append(da.from_array(d)) + data.append(dd.from_array(d, chunksize=32)) ( x_train, @@ -782,6 +783,39 @@ def test_dask_ranking(client: "Client") -> None: assert rank.best_score > 0.98 +@pytest.mark.parametrize("booster", ["dart", "gbtree"]) +def test_dask_predict_leaf(booster: str, client: "Client") -> None: + from sklearn.datasets import load_digits + + X_, y_ = load_digits(return_X_y=True) + num_parallel_tree = 4 + X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32) + rounds = 4 + cls = xgb.dask.DaskXGBClassifier( + n_estimators=rounds, num_parallel_tree=num_parallel_tree, booster=booster + ) + cls.client = client + cls.fit(X, y) + leaf = xgb.dask.predict( + client, + cls.get_booster(), + X.to_dask_array(), # we can't map_blocks on dataframe when output is 4-dim. + pred_leaf=True, + strict_shape=True, + validate_features=False, + ).compute() + + assert leaf.shape[0] == X_.shape[0] + assert leaf.shape[1] == rounds + assert leaf.shape[2] == cls.n_classes_ + assert leaf.shape[3] == num_parallel_tree + + leaf_from_apply = cls.apply(X).reshape(leaf.shape).compute() + np.testing.assert_allclose(leaf_from_apply, leaf) + + verify_leaf_output(leaf, num_parallel_tree) + + class TestWithDask: def test_global_config(self, client: "Client") -> None: X, y, _ = generate_array() @@ -1101,15 +1135,16 @@ class TestWithDask: assert_shape(shap.shape) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) - X = dd.from_dask_array(X).repartition(npartitions=32) - y = dd.from_dask_array(y).repartition(npartitions=32) - shap_df = xgb.dask.predict( - client, booster, X, pred_contribs=True, validate_features=False - ).compute() - assert_shape(shap_df.shape) - assert np.allclose( - np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5 - ) + if "num_class" not in params.keys(): + X = dd.from_dask_array(X).repartition(npartitions=32) + y = dd.from_dask_array(y).repartition(npartitions=32) + shap_df = xgb.dask.predict( + client, booster, X, pred_contribs=True, validate_features=False + ).compute() + assert_shape(shap_df.shape) + assert np.allclose( + np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5 + ) def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) @@ -1218,17 +1253,13 @@ class TestWithDask: np.testing.assert_allclose(predt_0.compute(), predt_3) -def test_unsupported_features(client: "Client"): +def test_dask_unsupported_features(client: "Client") -> None: X, y, _ = generate_array() # gblinear doesn't support distributed training. with pytest.raises(NotImplementedError, match="gblinear"): xgb.dask.train( client, {"booster": "gblinear"}, xgb.dask.DaskDMatrix(client, X, y) ) - # dart prediction is not thread safe, running predict with each partition will have - # race. - with pytest.raises(NotImplementedError, match="dart"): - xgb.dask.train(client, {"booster": "dart"}, xgb.dask.DaskDMatrix(client, X, y)) class TestDaskCallbacks: