/** * Copyright 2020-2023 by XGBoost Contributors */ #include "test_predictor.h" #include #include // for Context #include // for DMatrix, BatchIterator, BatchSet, MetaInfo #include // for HostDeviceVector #include // for PredictionCacheEntry, Predictor, Predic... #include // for StringView #include // for max #include // for numeric_limits #include // for shared_ptr #include // for unordered_map #include "../../../src/common/bitfield.h" // for LBitField32 #include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix #include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy #include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator #include "xgboost/json.h" // for Json, Object, get, String #include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector #include "xgboost/logging.h" // for CHECK #include "xgboost/span.h" // for operator!=, SpanIterator, Span #include "xgboost/tree_model.h" // for RegTree namespace xgboost { TEST(Predictor, PredictionCache) { size_t constexpr kRows = 16, kCols = 4; PredictionContainer container; DMatrix *m; // Add a cache that is immediately expired. auto add_cache = [&]() { auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); container.Cache(p_dmat, Context::kCpuId); m = p_dmat.get(); }; add_cache(); ASSERT_EQ(container.Container().size(), 0ul); add_cache(); EXPECT_ANY_THROW(container.Entry(m)); } void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, std::shared_ptr p_full, std::shared_ptr p_hist) { size_t constexpr kCols = 16; size_t constexpr kClasses = 3; size_t constexpr kIters = 3; std::unique_ptr learner; p_hist->Info().labels.Reshape(rows, 1); auto &h_label = p_hist->Info().labels.Data()->HostVector(); for (size_t i = 0; i < rows; ++i) { h_label[i] = i % kClasses; } learner.reset(Learner::Create({})); learner->SetParams(Args{{"objective", "multi:softprob"}, {"num_feature", std::to_string(kCols)}, {"num_class", std::to_string(kClasses)}, {"max_bin", std::to_string(bins)}, {"device", ctx->DeviceName()}}); learner->Configure(); for (size_t i = 0; i < kIters; ++i) { learner->UpdateOneIter(i, p_hist); } Json model{Object{}}; learner->SaveModel(&model); learner.reset(Learner::Create({})); learner->LoadModel(model); learner->SetParam("device", ctx->DeviceName()); learner->Configure(); HostDeviceVector from_full; learner->Predict(p_full, false, &from_full, 0, 0); HostDeviceVector from_hist; learner->Predict(p_hist, false, &from_hist, 0, 0); for (size_t i = 0; i < rows; ++i) { EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps); } } void TestInplacePrediction(Context const *ctx, std::shared_ptr x, bst_row_t rows, bst_feature_t cols) { std::size_t constexpr kClasses { 4 }; auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->gpu_id); std::shared_ptr m = gen.GenerateDMatrix(true, false, kClasses); std::unique_ptr learner { Learner::Create({m}) }; learner->SetParam("num_parallel_tree", "4"); learner->SetParam("num_class", std::to_string(kClasses)); learner->SetParam("seed", "0"); learner->SetParam("subsample", "0.5"); learner->SetParam("tree_method", "hist"); for (int32_t it = 0; it < 4; ++it) { learner->UpdateOneIter(it, m); } learner->SetParam("device", ctx->DeviceName()); learner->Configure(); HostDeviceVector *p_out_predictions_0{nullptr}; learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions_0, 0, 2); CHECK(p_out_predictions_0); HostDeviceVector predict_0 (p_out_predictions_0->Size()); predict_0.Copy(*p_out_predictions_0); HostDeviceVector *p_out_predictions_1{nullptr}; learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions_1, 2, 4); CHECK(p_out_predictions_1); HostDeviceVector predict_1 (p_out_predictions_1->Size()); predict_1.Copy(*p_out_predictions_1); HostDeviceVector* p_out_predictions{nullptr}; learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits::quiet_NaN(), &p_out_predictions, 0, 4); auto& h_pred = p_out_predictions->HostVector(); auto& h_pred_0 = predict_0.HostVector(); auto& h_pred_1 = predict_1.HostVector(); ASSERT_EQ(h_pred.size(), rows * kClasses); ASSERT_EQ(h_pred.size(), h_pred_0.size()); ASSERT_EQ(h_pred.size(), h_pred_1.size()); for (size_t i = 0; i < h_pred.size(); ++i) { // Need to remove the global bias here. ASSERT_NEAR(h_pred[i], h_pred_0[i] + h_pred_1[i] - 0.5f, kRtEps); } learner->SetParam("device", "cpu"); learner->Configure(); } namespace { std::unique_ptr LearnerForTest(Context const *ctx, std::shared_ptr dmat, size_t iters, size_t forest = 1) { std::unique_ptr learner{Learner::Create({dmat})}; learner->SetParams( Args{{"num_parallel_tree", std::to_string(forest)}, {"device", ctx->DeviceName()}}); for (size_t i = 0; i < iters; ++i) { learner->UpdateOneIter(i, dmat); } return learner; } void VerifyPredictionWithLesserFeatures(Learner *learner, bst_row_t kRows, std::shared_ptr m_test, std::shared_ptr m_invalid) { HostDeviceVector prediction; Json config{Object()}; learner->SaveConfig(&config); learner->Predict(m_test, false, &prediction, 0, 0); ASSERT_EQ(prediction.Size(), kRows); ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error); } void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner, size_t rows, std::shared_ptr m_test, std::shared_ptr m_invalid) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); std::shared_ptr sliced_test{m_test->SliceCol(world_size, rank)}; std::shared_ptr sliced_invalid{m_invalid->SliceCol(world_size, rank)}; VerifyPredictionWithLesserFeatures(learner, rows, sliced_test, sliced_invalid); } } // anonymous namespace void TestPredictionWithLesserFeatures(Context const *ctx) { size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4; auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true); auto learner = LearnerForTest(ctx, m_train, kIters); auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false); auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false); VerifyPredictionWithLesserFeatures(learner.get(), kRows, m_test, m_invalid); } void TestPredictionDeviceAccess() { Context ctx; size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4; auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true); auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false); auto learner = LearnerForTest(&ctx, m_train, kIters); HostDeviceVector from_cpu; { ASSERT_EQ(from_cpu.DeviceIdx(), Context::kCpuId); Context cpu_ctx; learner->SetParam("device", cpu_ctx.DeviceName()); learner->Predict(m_test, false, &from_cpu, 0, 0); ASSERT_TRUE(from_cpu.HostCanWrite()); ASSERT_FALSE(from_cpu.DeviceCanRead()); } #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) HostDeviceVector from_cuda; { Context cuda_ctx = MakeCUDACtx(0); learner->SetParam("device", cuda_ctx.DeviceName()); learner->Predict(m_test, false, &from_cuda, 0, 0); ASSERT_EQ(from_cuda.DeviceIdx(), 0); ASSERT_TRUE(from_cuda.DeviceCanWrite()); ASSERT_FALSE(from_cuda.HostCanRead()); } auto const &h_cpu = from_cpu.ConstHostVector(); auto const &h_gpu = from_cuda.ConstHostVector(); for (size_t i = 0; i < h_cpu.size(); ++i) { ASSERT_NEAR(h_cpu[i], h_gpu[i], kRtEps); } #endif // defined(XGBOOST_USE_CUDA) } void TestPredictionWithLesserFeaturesColumnSplit(Context const *ctx) { size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4; auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true); auto learner = LearnerForTest(ctx, m_train, kIters); auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false); auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false); auto constexpr kWorldSize = 2; RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit, learner.get(), kRows, m_test, m_invalid); } void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind, bst_cat_t split_cat, float left_weight, float right_weight) { PredictionCacheEntry out_predictions; std::vector> trees; trees.push_back(std::unique_ptr(new RegTree)); auto& p_tree = trees.front(); std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); LBitField32 cats_bits(split_cats); cats_bits.Set(split_cat); p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, left_weight, right_weight, 3.0f, 2.2f, 7.0f, 9.0f); model->CommitModelGroup(std::move(trees), 0); } void TestCategoricalPrediction(Context const* ctx, bool is_column_split) { size_t constexpr kCols = 10; PredictionCacheEntry out_predictions; LearnerModelParam mparam{MakeMP(kCols, .5, 1)}; uint32_t split_ind = 3; bst_cat_t split_cat = 4; float left_weight = 1.3f; float right_weight = 1.7f; gbm::GBTreeModel model(&mparam, ctx); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); std::unique_ptr predictor{CreatePredictorForTest(ctx)}; std::vector row(kCols); row[split_ind] = split_cat; auto m = GetDMatrixFromData(row, 1, kCols); std::vector types(10, FeatureType::kCategorical); m->Info().feature_types.HostVector() = types; if (is_column_split) { m = std::shared_ptr{m->SliceCol(collective::GetWorldSize(), collective::GetRank())}; } predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); auto score = mparam.BaseScore(DeviceOrd::CPU())(0); ASSERT_EQ(out_predictions.predictions.Size(), 1ul); ASSERT_EQ(out_predictions.predictions.HostVector()[0], right_weight + score); // go to right for matching cat row[split_ind] = split_cat + 1; m = GetDMatrixFromData(row, 1, kCols); if (is_column_split) { m = std::shared_ptr{m->SliceCol(collective::GetWorldSize(), collective::GetRank())}; } out_predictions.version = 0; predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score); } void TestCategoricalPredictionColumnSplit(Context const *ctx) { auto constexpr kWorldSize = 2; RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, ctx, true); } void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) { size_t constexpr kCols = 10; PredictionCacheEntry out_predictions; LearnerModelParam mparam{MakeMP(kCols, .5, 1)}; uint32_t split_ind = 3; bst_cat_t split_cat = 4; float left_weight = 1.3f; float right_weight = 1.7f; gbm::GBTreeModel model(&mparam, ctx); GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); std::unique_ptr predictor{CreatePredictorForTest(ctx)}; std::vector row(kCols); row[split_ind] = split_cat; auto m = GetDMatrixFromData(row, 1, kCols); if (is_column_split) { m = std::shared_ptr{m->SliceCol(collective::GetWorldSize(), collective::GetRank())}; } predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); CHECK_EQ(out_predictions.predictions.Size(), 1); // go to left if it doesn't match the category, otherwise right. ASSERT_EQ(out_predictions.predictions.HostVector()[0], 2); row[split_ind] = split_cat + 1; m = GetDMatrixFromData(row, 1, kCols); if (is_column_split) { m = std::shared_ptr{m->SliceCol(collective::GetWorldSize(), collective::GetRank())}; } out_predictions.version = 0; predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1); } void TestCategoricalPredictLeafColumnSplit(Context const *ctx) { auto constexpr kWorldSize = 2; RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, ctx, true); } void TestIterationRange(Context const* ctx) { size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10; auto dmat = RandomDataGenerator(kRows, kCols, 0) .Device(ctx->gpu_id) .GenerateDMatrix(true, true, kClasses); auto learner = LearnerForTest(ctx, dmat, kIters, kForest); bool bound = false; bst_layer_t lend{3}; std::unique_ptr sliced{learner->Slice(0, lend, 1, &bound)}; ASSERT_FALSE(bound); HostDeviceVector out_predt_sliced; HostDeviceVector out_predt_ranged; // margin { sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false); learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size()); ASSERT_EQ(h_sliced, h_range); } // SHAP { sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false); learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size()); ASSERT_EQ(h_sliced, h_range); } // SHAP interaction { sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true); learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size()); ASSERT_EQ(h_sliced, h_range); } // Leaf { sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false); learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size()); ASSERT_EQ(h_sliced, h_range); } } namespace { void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *sliced, std::vector const &expected_margin_ranged, std::vector const &expected_margin_sliced, std::vector const &expected_leaf_ranged, std::vector const &expected_leaf_sliced) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); std::shared_ptr Xy{dmat->SliceCol(world_size, rank)}; HostDeviceVector out_predt_sliced; HostDeviceVector out_predt_ranged; // margin { sliced->Predict(Xy, true, &out_predt_sliced, 0, 0, false, false, false, false, false); learner->Predict(Xy, true, &out_predt_ranged, 0, 3, false, false, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), expected_margin_sliced.size()); ASSERT_EQ(h_sliced, expected_margin_sliced); ASSERT_EQ(h_range.size(), expected_margin_ranged.size()); ASSERT_EQ(h_range, expected_margin_ranged); } // Leaf { sliced->Predict(Xy, false, &out_predt_sliced, 0, 0, false, true, false, false, false); learner->Predict(Xy, false, &out_predt_ranged, 0, 3, false, true, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), expected_leaf_sliced.size()); ASSERT_EQ(h_sliced, expected_leaf_sliced); ASSERT_EQ(h_range.size(), expected_leaf_ranged.size()); ASSERT_EQ(h_range, expected_leaf_ranged); } } } // anonymous namespace void TestIterationRangeColumnSplit(Context const* ctx) { size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10; auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); auto learner = LearnerForTest(ctx, dmat, kIters, kForest); learner->SetParam("device", ctx->DeviceName()); bool bound = false; std::unique_ptr sliced{learner->Slice(0, 3, 1, &bound)}; ASSERT_FALSE(bound); // margin HostDeviceVector margin_predt_sliced; HostDeviceVector margin_predt_ranged; sliced->Predict(dmat, true, &margin_predt_sliced, 0, 0, false, false, false, false, false); learner->Predict(dmat, true, &margin_predt_ranged, 0, 3, false, false, false, false, false); auto const &margin_sliced = margin_predt_sliced.HostVector(); auto const &margin_ranged = margin_predt_ranged.HostVector(); // Leaf HostDeviceVector leaf_predt_sliced; HostDeviceVector leaf_predt_ranged; sliced->Predict(dmat, false, &leaf_predt_sliced, 0, 0, false, true, false, false, false); learner->Predict(dmat, false, &leaf_predt_ranged, 0, 3, false, true, false, false, false); auto const &leaf_sliced = leaf_predt_sliced.HostVector(); auto const &leaf_ranged = leaf_predt_ranged.HostVector(); auto constexpr kWorldSize = 2; RunWithInMemoryCommunicator(kWorldSize, VerifyIterationRangeColumnSplit, dmat.get(), learner.get(), sliced.get(), margin_ranged, margin_sliced, leaf_ranged, leaf_sliced); } void TestSparsePrediction(Context const *ctx, float sparsity) { size_t constexpr kRows = 512, kCols = 128, kIters = 4; auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true); auto learner = LearnerForTest(ctx, Xy, kIters); HostDeviceVector sparse_predt; Json model{Object{}}; learner->SaveModel(&model); learner.reset(Learner::Create({Xy})); learner->LoadModel(model); if (ctx->IsCUDA()) { learner->SetParam("tree_method", "gpu_hist"); learner->SetParam("gpu_id", std::to_string(ctx->gpu_id)); } learner->Predict(Xy, false, &sparse_predt, 0, 0); HostDeviceVector with_nan(kRows * kCols, std::numeric_limits::quiet_NaN()); auto &h_with_nan = with_nan.HostVector(); for (auto const &page : Xy->GetBatches()) { auto batch = page.GetView(); for (size_t i = 0; i < batch.Size(); ++i) { auto row = batch[i]; for (auto e : row) { h_with_nan[i * kCols + e.index] = e.fvalue; } } } learner->SetParam("tree_method", "hist"); learner->SetParam("gpu_id", "-1"); // Xcode_12.4 doesn't compile with `std::make_shared`. auto dense = std::shared_ptr(new data::DMatrixProxy{}); auto array_interface = GetArrayInterface(&with_nan, kRows, kCols); std::string arr_str; Json::Dump(array_interface, &arr_str); dynamic_cast(dense.get())->SetArrayData(arr_str.data()); HostDeviceVector *p_dense_predt; learner->InplacePredict(dense, PredictionType::kValue, std::numeric_limits::quiet_NaN(), &p_dense_predt, 0, 0); auto const &dense_predt = *p_dense_predt; if (ctx->IsCPU()) { ASSERT_EQ(dense_predt.HostVector(), sparse_predt.HostVector()); } else { auto const &h_dense = dense_predt.HostVector(); auto const &h_sparse = sparse_predt.HostVector(); ASSERT_EQ(h_dense.size(), h_sparse.size()); for (size_t i = 0; i < h_dense.size(); ++i) { ASSERT_FLOAT_EQ(h_dense[i], h_sparse[i]); } } } namespace { void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner, std::vector const &expected_predt) { std::shared_ptr sliced{ dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())}; HostDeviceVector sparse_predt; learner->Predict(sliced, false, &sparse_predt, 0, 0); auto const &predt = sparse_predt.HostVector(); ASSERT_EQ(predt.size(), expected_predt.size()); for (size_t i = 0; i < predt.size(); ++i) { ASSERT_FLOAT_EQ(predt[i], expected_predt[i]); } } } // anonymous namespace void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity) { size_t constexpr kRows = 512, kCols = 128, kIters = 4; auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true); auto learner = LearnerForTest(ctx, Xy, kIters); HostDeviceVector sparse_predt; Json model{Object{}}; learner->SaveModel(&model); learner.reset(Learner::Create({Xy})); learner->LoadModel(model); learner->SetParam("device", ctx->DeviceName()); learner->Predict(Xy, false, &sparse_predt, 0, 0); auto constexpr kWorldSize = 2; RunWithInMemoryCommunicator(kWorldSize, VerifySparsePredictionColumnSplit, Xy.get(), learner.get(), sparse_predt.HostVector()); } void TestVectorLeafPrediction(Context const *ctx) { std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor", ctx)); size_t constexpr kRows = 5; size_t constexpr kCols = 5; LearnerModelParam mparam{static_cast(kCols), linalg::Vector{{0.5}, {1}, Context::kCpuId}, 1, 3, MultiStrategy::kMultiOutputTree}; std::vector> trees; trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature}); std::vector p_w(mparam.LeafLength(), 0.0f); std::vector l_w(mparam.LeafLength(), 1.0f); std::vector r_w(mparam.LeafLength(), 2.0f); auto &tree = trees.front(); tree->ExpandNode(0, static_cast(1), 2.0, true, linalg::MakeVec(p_w.data(), p_w.size()), linalg::MakeVec(l_w.data(), l_w.size()), linalg::MakeVec(r_w.data(), r_w.size())); ASSERT_TRUE(tree->IsMultiTarget()); ASSERT_TRUE(mparam.IsVectorLeaf()); gbm::GBTreeModel model{&mparam, ctx}; model.CommitModelGroup(std::move(trees), 0); auto run_test = [&](float expected, HostDeviceVector *p_data) { { auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); PredictionCacheEntry predt_cache; cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); ASSERT_EQ(predt_cache.predictions.Size(), kRows * mparam.LeafLength()); cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); auto const &h_predt = predt_cache.predictions.HostVector(); for (auto v : h_predt) { ASSERT_EQ(v, expected); } } { // inplace PredictionCacheEntry predt_cache; auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); auto arr = GetArrayInterface(p_data, kRows, kCols); std::string str; Json::Dump(arr, &str); auto proxy = std::shared_ptr(new data::DMatrixProxy{}); dynamic_cast(proxy.get())->SetArrayData(str.data()); cpu_predictor->InplacePredict(proxy, model, std::numeric_limits::quiet_NaN(), &predt_cache, 0, 1); auto const &h_predt = predt_cache.predictions.HostVector(); for (auto v : h_predt) { ASSERT_EQ(v, expected); } } { // ghist PredictionCacheEntry predt_cache; auto &h_data = p_data->HostVector(); // give it at least two bins, otherwise the histogram cuts only have min and max values. for (std::size_t i = 0; i < 5; ++i) { h_data[i] = 1.0; } auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); auto iter = NumpyArrayIterForTest{ctx, *p_data, kRows, static_cast(kCols), static_cast(1)}; p_fmat = std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits::quiet_NaN(), 0, 256); cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); auto const &h_predt = predt_cache.predictions.HostVector(); // the smallest v uses the min_value from histogram cuts, which leads to a left leaf // during prediction. for (std::size_t i = 5; i < h_predt.size(); ++i) { ASSERT_EQ(h_predt[i], expected) << i; } } }; // go to right HostDeviceVector data(kRows * kCols, model.trees.front()->SplitCond(RegTree::kRoot) + 1.0); run_test(2.5, &data); // go to left data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0); run_test(1.5, &data); } } // namespace xgboost