diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 867fd2dc8..5f57cd353 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -70,11 +70,14 @@ class TreeUpdater : public Configurable { * the prediction cache. If true, the prediction cache will have been * updated by the time this function returns. */ - virtual bool UpdatePredictionCache(const DMatrix* data, - HostDeviceVector* out_preds) { - // Remove unused parameter compiler warning. - (void) data; - (void) out_preds; + virtual bool UpdatePredictionCache(const DMatrix* /*data*/, + HostDeviceVector* /*out_preds*/) { + return false; + } + + virtual bool UpdatePredictionCacheMulticlass(const DMatrix* /*data*/, + HostDeviceVector* /*out_preds*/, + const int /*gid*/, const int /*ngroup*/) { return false; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index a4fb6e28e..1706842e2 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -196,11 +196,17 @@ void GBTree::DoBoost(DMatrix* p_fmat, const int ngroup = model_.learner_model_param->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); + auto* out = &predt->predictions; CHECK_NE(ngroup, 0); if (ngroup == 1) { std::vector > ret; BoostNewTrees(in_gpair, p_fmat, 0, &ret); + const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); + if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 && + updaters_.back()->UpdatePredictionCache(p_fmat, out)) { + predt->Update(1); + } } else { CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup * nrow gpairs"; @@ -210,6 +216,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, in_gpair->DeviceIdx()); const auto& gpair_h = in_gpair->ConstHostVector(); auto nsize = static_cast(tmp.Size()); + bool update_predict = true; for (int gid = 0; gid < ngroup; ++gid) { std::vector& tmp_h = tmp.HostVector(); #pragma omp parallel for schedule(static) @@ -218,7 +225,16 @@ void GBTree::DoBoost(DMatrix* p_fmat, } std::vector > ret; BoostNewTrees(&tmp, p_fmat, gid, &ret); + const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); + auto* out = &predt->predictions; + if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 && + updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) { + update_predict = false; + } + } + if (update_predict) { + predt->Update(1); } } monitor_.Stop("BoostNewTrees"); @@ -314,20 +330,9 @@ void GBTree::CommitModel(std::vector>>&& ne DMatrix* m, PredictionCacheEntry* predts) { monitor_.Start("CommitModel"); - int num_new_trees = 0; for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { - num_new_trees += new_trees[gid].size(); model_.CommitModel(std::move(new_trees[gid]), gid); } - auto* out = &predts->predictions; - if (model_.learner_model_param->num_output_group == 1 && - updaters_.size() > 0 && - num_new_trees == 1 && - out->Size() > 0 && - updaters_.back()->UpdatePredictionCache(m, out)) { - auto delta = num_new_trees / model_.learner_model_param->num_output_group; - predts->Update(delta); - } monitor_.Stop("CommitModel"); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index b0fa98c85..7248eae80 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -110,8 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, } bool QuantileHistMaker::UpdatePredictionCache( - const DMatrix* data, - HostDeviceVector* out_preds) { + const DMatrix* data, HostDeviceVector* out_preds) { if (param_.subsample < 1.0f) { return false; } else { @@ -125,6 +124,23 @@ bool QuantileHistMaker::UpdatePredictionCache( } } +bool QuantileHistMaker::UpdatePredictionCacheMulticlass( + const DMatrix* data, + HostDeviceVector* out_preds, const int gid, const int ngroup) { + if (param_.subsample < 1.0f) { + return false; + } else { + if (hist_maker_param_.single_precision_histogram && float_builder_) { + return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup); + } else if (double_builder_) { + return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup); + } else { + return false; + } + } +} + + template void BatchHistSynchronizer::SyncHistograms(BuilderT *builder, int, @@ -620,7 +636,7 @@ void QuantileHistMaker::Builder::Update( template bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, - HostDeviceVector* p_out_preds) { + HostDeviceVector* p_out_preds, const int gid, const int ngroup) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { @@ -659,7 +675,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( leaf_value = (*p_last_tree_)[nid].LeafValue(); for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { - out_preds[*it] += leaf_value; + out_preds[*it * ngroup + gid] += leaf_value; } } }); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index ed906cb90..d408b16b0 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -121,6 +121,9 @@ class QuantileHistMaker: public TreeUpdater { bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* out_preds) override; + bool UpdatePredictionCacheMulticlass(const DMatrix* data, + HostDeviceVector* out_preds, + const int gid, const int ngroup) override; void LoadConfig(Json const& in) override { auto const& config = get(in); @@ -243,7 +246,9 @@ class QuantileHistMaker: public TreeUpdater { } bool UpdatePredictionCache(const DMatrix* data, - HostDeviceVector* p_out_preds); + HostDeviceVector* p_out_preds, + const int gid = 0, const int ngroup = 1); + void SetHistSynchronizer(HistSynchronizer* sync); void SetHistRowsAdder(HistRowsAdder* adder); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 510ce073d..634747991 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -8,6 +8,7 @@ #include "../helpers.h" #include "test_predictor.h" #include "../../../src/gbm/gbtree_model.h" +#include "../../../src/gbm/gbtree.h" #include "../../../src/data/adapter.h" namespace xgboost { @@ -180,6 +181,49 @@ TEST(CpuPredictor, InplacePredict) { } } +TEST(CpuPredictor, UpdatePredictionCache) { + size_t constexpr kRows = 64, kCols = 16, kClasses = 4; + LearnerModelParam mparam; + mparam.num_feature = kCols; + mparam.num_output_group = kClasses; + mparam.base_score = 0; + + GenericParameter gparam; + gparam.Init(Args{}); + + std::unique_ptr gbm; + gbm.reset(static_cast(GradientBooster::Create("gbtree", &gparam, &mparam))); + std::map cfg; + cfg["tree_method"] = "hist"; + cfg["predictor"] = "cpu_predictor"; + Args args = {cfg.cbegin(), cfg.cend()}; + gbm->Configure(args); + + auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); + + HostDeviceVector gpair; + auto& h_gpair = gpair.HostVector(); + h_gpair.resize(kRows*kClasses); + for (size_t i = 0; i < kRows*kClasses; ++i) { + h_gpair[i] = {static_cast(i), 1}; + } + + PredictionCacheEntry predtion_cache; + predtion_cache.predictions.Resize(kRows*kClasses, 0); + // after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values + gbm->DoBoost(dmat.get(), &gpair, &predtion_cache); + + PredictionCacheEntry out_predictions; + // perform fair prediction on the same input data, should be equal to cached result + gbm->PredictBatch(dmat.get(), &out_predictions, false, 0); + + std::vector &out_predictions_h = out_predictions.predictions.HostVector(); + std::vector &predtion_cache_from_train = predtion_cache.predictions.HostVector(); + for (size_t i = 0; i < out_predictions_h.size(); ++i) { + ASSERT_NEAR(out_predictions_h[i], predtion_cache_from_train[i], kRtEps); + } +} + TEST(CpuPredictor, LesserFeatures) { TestPredictionWithLesserFeatures("cpu_predictor"); }