diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 843afeec1..4def6940d 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -216,6 +216,16 @@ class GloablApproxBuilder { bst_node_t num_leaves = 1; auto expand_set = driver.Pop(); + /** + * Note for update position + * Root: + * Not applied: No need to update position as initialization has got all the rows ordered. + * Applied: Update position is run on applied nodes so the rows are partitioned. + * Non-root: + * Not applied: That node is root of the subtree, same rule as root. + * Applied: Ditto + */ + while (!expand_set.empty()) { // candidates that can be further splited. std::vector valid_candidates; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9587c3b83..bb8ab78e4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -692,6 +692,9 @@ struct GPUHistMakerDevice { if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { monitor.Start("UpdatePosition"); + // Update position is only run when child is valid, instead of right after apply + // split (as in approx tree method). Hense we have the finalise position call + // in GPU Hist. this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 738d30d29..639768b5e 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -75,58 +75,5 @@ TEST(Approx, Partitioner) { } } } - -TEST(Approx, PredictionCache) { - size_t n_samples = 2048, n_features = 13; - auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - - { - omp_set_num_threads(1); - GenericParameter ctx; - ctx.InitAllowUnknown(Args{{"nthread", "8"}}); - std::unique_ptr approx{ - TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; - RegTree tree; - std::vector trees{&tree}; - auto gpair = GenerateRandomGradients(n_samples); - approx->Configure(Args{{"max_bin", "64"}}); - approx->Update(&gpair, Xy.get(), trees); - HostDeviceVector out_prediction_cached; - out_prediction_cached.Resize(n_samples); - auto cache = linalg::VectorView{ - out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId}; - ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache)); - } - - std::unique_ptr learner{Learner::Create({Xy})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("nthread", "0"); - learner->Configure(); - - for (size_t i = 0; i < 8; ++i) { - learner->UpdateOneIter(i, Xy); - } - - HostDeviceVector out_prediction_cached; - learner->Predict(Xy, false, &out_prediction_cached, 0, 0); - - Json model{Object()}; - learner->SaveModel(&model); - - HostDeviceVector out_prediction; - { - std::unique_ptr learner{Learner::Create({Xy})}; - learner->LoadModel(model); - learner->Predict(Xy, false, &out_prediction, 0, 0); - } - - auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); - auto const h_predt = out_prediction.ConstHostSpan(); - - ASSERT_EQ(h_predt.size(), h_predt_cached.size()); - for (size_t i = 0; i < h_predt.size(); ++i) { - ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); - } -} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc new file mode 100644 index 000000000..ebe66cf57 --- /dev/null +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -0,0 +1,108 @@ +/*! + * Copyright 2021-2022 by XGBoost contributors + */ +#include +#include +#include + +#include + +#include "../helpers.h" + +namespace xgboost { + +class TestPredictionCache : public ::testing::Test { + std::shared_ptr Xy_; + size_t n_samples_{2048}; + + protected: + void SetUp() override { + size_t n_features = 13; + Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.GenerateDMatrix(true); + } + + void RunLearnerTest(std::string updater_name, float subsample, std::string grow_policy) { + std::unique_ptr learner{Learner::Create({Xy_})}; + if (updater_name == "grow_gpu_hist") { + // gpu_id setup + learner->SetParam("tree_method", "gpu_hist"); + } else { + learner->SetParam("updater", updater_name); + } + learner->SetParam("grow_policy", grow_policy); + learner->SetParam("subsample", std::to_string(subsample)); + learner->SetParam("nthread", "0"); + learner->Configure(); + + for (size_t i = 0; i < 8; ++i) { + learner->UpdateOneIter(i, Xy_); + } + + HostDeviceVector out_prediction_cached; + learner->Predict(Xy_, false, &out_prediction_cached, 0, 0); + + Json model{Object()}; + learner->SaveModel(&model); + + HostDeviceVector out_prediction; + { + std::unique_ptr learner{Learner::Create({Xy_})}; + learner->LoadModel(model); + learner->Predict(Xy_, false, &out_prediction, 0, 0); + } + + auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); + auto const h_predt = out_prediction.ConstHostSpan(); + + ASSERT_EQ(h_predt.size(), h_predt_cached.size()); + for (size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); + } + } + + void RunTest(std::string updater_name) { + { + omp_set_num_threads(1); + GenericParameter ctx; + ctx.InitAllowUnknown(Args{{"nthread", "8"}}); + if (updater_name == "grow_gpu_hist") { + ctx.gpu_id = 0; + } else { + ctx.gpu_id = GenericParameter::kCpuId; + } + + std::unique_ptr updater{ + TreeUpdater::Create(updater_name, &ctx, ObjInfo{ObjInfo::kRegression})}; + RegTree tree; + std::vector trees{&tree}; + auto gpair = GenerateRandomGradients(n_samples_); + updater->Configure(Args{{"max_bin", "64"}}); + updater->Update(&gpair, Xy_.get(), trees); + HostDeviceVector out_prediction_cached; + out_prediction_cached.SetDevice(ctx.gpu_id); + out_prediction_cached.Resize(n_samples_); + auto cache = linalg::VectorView{ctx.gpu_id == GenericParameter::kCpuId + ? out_prediction_cached.HostSpan() + : out_prediction_cached.DeviceSpan(), + {out_prediction_cached.Size()}, + ctx.gpu_id}; + ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); + } + + for (auto policy : {"depthwise", "lossguide"}) { + for (auto subsample : {1.0f, 0.4f}) { + this->RunLearnerTest(updater_name, subsample, policy); + this->RunLearnerTest(updater_name, subsample, policy); + } + } + } +}; + +TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker"); } + +TEST_F(TestPredictionCache, Hist) { this->RunTest("grow_quantile_histmaker"); } + +#if defined(XGBOOST_USE_CUDA) +TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist"); } +#endif // defined(XGBOOST_USE_CUDA) +} // namespace xgboost diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 4aa3647e0..86200f335 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -26,10 +26,19 @@ parameter_strategy = strategies.fixed_dictionaries({ x['max_depth'] > 0 or x['grow_policy'] == 'lossguide')) -def train_result(param, dmat, num_rounds): - result = {} - xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False, - evals_result=result) +def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict: + result: xgb.callback.TrainingCallback.EvalsLog = {} + booster = xgb.train( + param, + dmat, + num_rounds, + [(dmat, "train")], + verbose_eval=False, + evals_result=result, + ) + assert booster.num_features() == dmat.num_col() + assert booster.num_boosted_rounds() == num_rounds + return result