From 24d225c1abe1c97fa039c7718c4f294a59a531c3 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Wed, 21 Aug 2024 20:07:44 +0200 Subject: [PATCH] [SYCL] Implement UpdatePredictionCache and connect updater with leraner. (#10701) --------- Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/tree/hist_updater.cc | 95 ++++++++++++++++++ plugin/sycl/tree/hist_updater.h | 19 +++- plugin/sycl/tree/updater_quantile_hist.cc | 92 +++++++++++++++++- plugin/sycl/tree/updater_quantile_hist.h | 31 +++++- src/gbm/gbtree.cc | 3 +- tests/cpp/plugin/test_sycl_hist_updater.cc | 38 ++------ .../cpp/plugin/test_sycl_prediction_cache.cc | 23 +++++ tests/cpp/tree/test_prediction_cache.cc | 91 +---------------- tests/cpp/tree/test_prediction_cache.h | 97 +++++++++++++++++++ .../test_sycl_training_continuation.py | 59 +++++++++++ tests/python-sycl/test_sycl_updaters.py | 80 +++++++++++++++ 11 files changed, 502 insertions(+), 126 deletions(-) create mode 100644 tests/cpp/plugin/test_sycl_prediction_cache.cc create mode 100644 tests/cpp/tree/test_prediction_cache.h create mode 100644 tests/python-sycl/test_sycl_training_continuation.py create mode 100644 tests/python-sycl/test_sycl_updaters.py diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index efaddafdb..097e2da73 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -307,6 +307,99 @@ void HistUpdater::ExpandWithLossGuide( builder_monitor_.Stop("ExpandWithLossGuide"); } +template +void HistUpdater::Update( + xgboost::tree::TrainParam const *param, + const common::GHistIndexMatrix &gmat, + const USMVector& gpair, + DMatrix *p_fmat, + xgboost::common::Span> out_position, + RegTree *p_tree) { + builder_monitor_.Start("Update"); + + tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_); + interaction_constraints_.Reset(); + + this->InitData(gmat, gpair, *p_fmat, *p_tree); + if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) { + ExpandWithLossGuide(gmat, p_tree, gpair); + } else { + ExpandWithDepthWise(gmat, p_tree, gpair); + } + + for (int nid = 0; nid < p_tree->NumNodes(); ++nid) { + p_tree->Stat(nid).loss_chg = snode_host_[nid].best.loss_chg; + p_tree->Stat(nid).base_weight = snode_host_[nid].weight; + p_tree->Stat(nid).sum_hess = static_cast(snode_host_[nid].stats.GetHess()); + } + + builder_monitor_.Stop("Update"); +} + +template +bool HistUpdater::UpdatePredictionCache( + const DMatrix* data, + linalg::MatrixView out_preds) { + // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in + // conjunction with Update(). + if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { + return false; + } + builder_monitor_.Start("UpdatePredictionCache"); + CHECK_GT(out_preds.Size(), 0U); + + const size_t stride = out_preds.Stride(0); + const bool is_first_group = (out_pred_ptr == nullptr); + const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr; + const bool is_last_group = (gid + 1 == stride); + + const int buffer_size = out_preds.Size() *stride; + if (buffer_size == 0) return true; + + ::sycl::event event; + if (is_first_group) { + out_preds_buf_.ResizeNoCopy(&qu_, buffer_size); + out_pred_ptr = &out_preds(0); + event = qu_.memcpy(out_preds_buf_.Data(), out_pred_ptr, buffer_size * sizeof(bst_float), event); + } + auto* out_preds_buf_ptr = out_preds_buf_.Data(); + + size_t n_nodes = row_set_collection_.Size(); + std::vector<::sycl::event> events(n_nodes); + for (size_t node = 0; node < n_nodes; node++) { + const common::RowSetCollection::Elem& rowset = row_set_collection_[node]; + if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) { + int nid = rowset.node_id; + // if a node is marked as deleted by the pruner, traverse upward to locate + // a non-deleted leaf. + if ((*p_last_tree_)[nid].IsDeleted()) { + while ((*p_last_tree_)[nid].IsDeleted()) { + nid = (*p_last_tree_)[nid].Parent(); + } + CHECK((*p_last_tree_)[nid].IsLeaf()); + } + bst_float leaf_value = (*p_last_tree_)[nid].LeafValue(); + const size_t* rid = rowset.begin; + const size_t num_rows = rowset.Size(); + + events[node] = qu_.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) { + out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value; + }); + }); + } + } + if (is_last_group) { + qu_.memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events); + out_pred_ptr = nullptr; + } + qu_.wait(); + + builder_monitor_.Stop("UpdatePredictionCache"); + return true; +} + template void HistUpdater::InitSampling( const USMVector &gpair, @@ -479,6 +572,8 @@ void HistUpdater::InitData( } } + // store a pointer to the tree + p_last_tree_ = &tree; column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(), param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree); diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 5e0ca6645..fd5fdda94 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -11,10 +11,10 @@ #include #pragma GCC diagnostic pop -#include #include #include #include +#include #include "../common/partition_builder.h" #include "split_evaluator.h" @@ -54,12 +54,10 @@ class HistUpdater { explicit HistUpdater(const Context* ctx, ::sycl::queue qu, const xgboost::tree::TrainParam& param, - std::unique_ptr pruner, FeatureInteractionConstraintHost int_constraints_, DMatrix const* fmat) : ctx_(ctx), qu_(qu), param_(param), tree_evaluator_(qu, param, fmat->Info().num_col_), - pruner_(std::move(pruner)), interaction_constraints_{std::move(int_constraints_)}, p_last_tree_(nullptr), p_last_fmat_(fmat) { builder_monitor_.Init("SYCL::Quantile::HistUpdater"); @@ -73,6 +71,17 @@ class HistUpdater { sub_group_size_ = sub_group_sizes.back(); } + // update one tree, growing + void Update(xgboost::tree::TrainParam const *param, + const common::GHistIndexMatrix &gmat, + const USMVector& gpair, + DMatrix *p_fmat, + xgboost::common::Span> out_position, + RegTree *p_tree); + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView p_out_preds); + void SetHistSynchronizer(HistSynchronizer* sync); void SetHistRowsAdder(HistRowsAdder* adder); @@ -200,7 +209,6 @@ class HistUpdater { std::vector> best_splits_host_; TreeEvaluator tree_evaluator_; - std::unique_ptr pruner_; FeatureInteractionConstraintHost interaction_constraints_; // back pointers to tree and data matrix @@ -247,6 +255,9 @@ class HistUpdater { std::unique_ptr> hist_synchronizer_; std::unique_ptr> hist_rows_adder_; + USMVector out_preds_buf_; + bst_float* out_pred_ptr = nullptr; + ::sycl::queue qu_; }; diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index 98a42c3c8..ee7a7ad0f 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -3,6 +3,7 @@ * \file updater_quantile_hist.cc */ #include +#include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" @@ -29,6 +30,50 @@ void QuantileHistMaker::Configure(const Args& args) { param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); + + bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64); + if (hist_maker_param_.single_precision_histogram || !has_fp64_support) { + if (!hist_maker_param_.single_precision_histogram) { + LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; + } + hist_precision_ = HistPrecision::fp32; + } else { + hist_precision_ = HistPrecision::fp64; + } +} + +template +void QuantileHistMaker::SetPimpl(std::unique_ptr>* pimpl, + DMatrix *dmat) { + pimpl->reset(new HistUpdater( + ctx_, + qu_, + param_, + int_constraint_, dmat)); + if (collective::IsDistributed()) { + LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl"; + } else { + (*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer()); + (*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder()); + } +} + +template +void QuantileHistMaker::CallUpdate( + const std::unique_ptr>& pimpl, + xgboost::tree::TrainParam const *param, + linalg::Matrix *gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { + const auto* gpair_h = gpair->Data(); + gpair_device_.Resize(&qu_, gpair_h->Size()); + qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair)); + qu_.wait(); + + for (auto tree : trees) { + pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree); + } } void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, @@ -36,12 +81,55 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, DMatrix *dmat, xgboost::common::Span> out_position, const std::vector &trees) { - LOG(FATAL) << "Not Implemented yet"; + if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { + updater_monitor_.Start("DeviceMatrixInitialization"); + sycl::DeviceMatrix dmat_device; + dmat_device.Init(qu_, dmat); + updater_monitor_.Stop("DeviceMatrixInitialization"); + updater_monitor_.Start("GmatInitialization"); + gmat_.Init(qu_, ctx_, dmat_device, static_cast(param_.max_bin)); + updater_monitor_.Stop("GmatInitialization"); + is_gmat_initialized_ = true; + } + // rescale learning rate according to size of trees + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + int_constraint_.Configure(param_, dmat->Info().num_col_); + // build tree + if (hist_precision_ == HistPrecision::fp32) { + if (!pimpl_fp32) { + SetPimpl(&pimpl_fp32, dmat); + } + CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees); + } else { + if (!pimpl_fp64) { + SetPimpl(&pimpl_fp64, dmat); + } + CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees); + } + + param_.learning_rate = lr; + + p_last_dmat_ = dmat; } bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, linalg::MatrixView out_preds) { - LOG(FATAL) << "Not Implemented yet"; + if (param_.subsample < 1.0f) return false; + + if (hist_precision_ == HistPrecision::fp32) { + if (pimpl_fp32) { + return pimpl_fp32->UpdatePredictionCache(data, out_preds); + } else { + return false; + } + } else { + if (pimpl_fp64) { + return pimpl_fp64->UpdatePredictionCache(data, out_preds); + } else { + return false; + } + } } XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h index 93a50de3e..693255b26 100644 --- a/plugin/sycl/tree/updater_quantile_hist.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -9,6 +9,7 @@ #include #include +#include #include "../data/gradient_index.h" #include "../common/hist_util.h" @@ -16,8 +17,9 @@ #include "../common/partition_builder.h" #include "split_evaluator.h" #include "../device_manager.h" - +#include "hist_updater.h" #include "xgboost/data.h" + #include "xgboost/json.h" #include "../../src/tree/constraints.h" #include "../../src/common/random.h" @@ -75,12 +77,39 @@ class QuantileHistMaker: public TreeUpdater { HistMakerTrainParam hist_maker_param_; // training parameter xgboost::tree::TrainParam param_; + // quantized data matrix + common::GHistIndexMatrix gmat_; + // (optional) data matrix with feature grouping + // column accessor + DMatrix const* p_last_dmat_ {nullptr}; + bool is_gmat_initialized_ {false}; xgboost::common::Monitor updater_monitor_; + template + void SetPimpl(std::unique_ptr>*, DMatrix *dmat); + + template + void CallUpdate(const std::unique_ptr>& builder, + xgboost::tree::TrainParam const *param, + linalg::Matrix *gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees); + + enum class HistPrecision {fp32, fp64}; + HistPrecision hist_precision_; + + std::unique_ptr> pimpl_fp32; + std::unique_ptr> pimpl_fp64; + + FeatureInteractionConstraintHost int_constraint_; + ::sycl::queue qu_; DeviceManager device_manager; ObjInfo const *task_{nullptr}; + + USMVector gpair_device_; }; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 26c768faf..fe640ee00 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -52,7 +52,8 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method) case TreeMethod::kAuto: // Use hist as default in 2.0 case TreeMethod::kHist: { return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, - [] { return "grow_gpu_hist"; }); + [] { return "grow_gpu_hist"; }, + [] { return "grow_quantile_histmaker_sycl"; }); } case TreeMethod::kApprox: { return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; }); diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index 7789b4438..a341f4645 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -21,10 +21,8 @@ class TestHistUpdater : public HistUpdater { TestHistUpdater(const Context* ctx, ::sycl::queue qu, const xgboost::tree::TrainParam& param, - std::unique_ptr pruner, FeatureInteractionConstraintHost int_constraints_, DMatrix const* fmat) : HistUpdater(ctx, qu, param, - std::move(pruner), int_constraints_, fmat) {} void TestInitSampling(const USMVector &gpair, @@ -110,14 +108,12 @@ void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) { DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); USMVector row_indices_0(&qu, num_rows); USMVector row_indices_1(&qu, num_rows); @@ -165,14 +161,12 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); USMVector gpair(&qu, num_rows); GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess); @@ -221,14 +215,12 @@ void TestHistUpdaterBuildHistogramsLossGuide(const xgboost::tree::TrainParam& pa DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); updater.SetHistSynchronizer(new BatchHistSynchronizer()); updater.SetHistRowsAdder(new BatchHistRowsAdder()); @@ -285,14 +277,12 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); updater.SetHistSynchronizer(new BatchHistSynchronizer()); updater.SetHistRowsAdder(new BatchHistRowsAdder()); @@ -345,14 +335,12 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) { DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix(); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); updater.SetHistSynchronizer(new BatchHistSynchronizer()); updater.SetHistRowsAdder(new BatchHistRowsAdder()); @@ -423,8 +411,6 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa DeviceManager device_manager; auto qu = device_manager.GetQueue(ctx.Device()); - ObjInfo task{ObjInfo::kRegression}; - auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); sycl::DeviceMatrix dmat; dmat.Init(qu, p_fmat.get()); @@ -439,8 +425,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0))); FeatureInteractionConstraintHost int_constraints; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); USMVector gpair(&qu, num_rows); GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false); @@ -455,8 +440,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa std::vector row_indices_desired_host(num_rows); size_t n_left, n_right; { - std::unique_ptr pruner4verification{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater4verification(&ctx, qu, param, std::move(pruner4verification), int_constraints, p_fmat.get()); + TestHistUpdater updater4verification(&ctx, qu, param, int_constraints, p_fmat.get()); auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree); size_t n_nodes = nodes.size(); @@ -526,9 +510,7 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param) RegTree tree; FeatureInteractionConstraintHost int_constraints; - ObjInfo task{ObjInfo::kRegression}; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); updater.SetHistSynchronizer(new BatchHistSynchronizer()); updater.SetHistRowsAdder(new BatchHistRowsAdder()); auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); @@ -576,9 +558,7 @@ void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param) RegTree tree; FeatureInteractionConstraintHost int_constraints; - ObjInfo task{ObjInfo::kRegression}; - std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; - TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); updater.SetHistSynchronizer(new BatchHistSynchronizer()); updater.SetHistRowsAdder(new BatchHistRowsAdder()); auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); diff --git a/tests/cpp/plugin/test_sycl_prediction_cache.cc b/tests/cpp/plugin/test_sycl_prediction_cache.cc new file mode 100644 index 000000000..43f99dc63 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_prediction_cache.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "../tree/test_prediction_cache.h" +#pragma GCC diagnostic pop + +namespace xgboost::sycl::tree { + +class SyclPredictionCache : public xgboost::TestPredictionCache {}; + +TEST_F(SyclPredictionCache, Hist) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + this->RunTest(&ctx, "grow_quantile_histmaker_sycl", "one_output_per_tree"); +} + +} // namespace xgboost::sycl::tree diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index fc1d05087..5c22ace41 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -2,97 +2,10 @@ * Copyright 2021-2023 by XGBoost contributors */ #include -#include -#include -#include - -#include "../../../src/tree/param.h" // for TrainParam -#include "../helpers.h" -#include "xgboost/task.h" // for ObjInfo +#include "test_prediction_cache.h" namespace xgboost { - -class TestPredictionCache : public ::testing::Test { - std::shared_ptr Xy_; - std::size_t n_samples_{2048}; - - protected: - void SetUp() override { - std::size_t n_features = 13; - bst_target_t n_targets = 3; - Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true); - } - - void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample, - std::string const& grow_policy, std::string const& strategy) { - std::unique_ptr learner{Learner::Create({Xy_})}; - learner->SetParam("device", ctx->DeviceName()); - learner->SetParam("updater", updater_name); - learner->SetParam("multi_strategy", strategy); - learner->SetParam("grow_policy", grow_policy); - learner->SetParam("subsample", std::to_string(subsample)); - learner->SetParam("nthread", "0"); - learner->Configure(); - - for (size_t i = 0; i < 8; ++i) { - learner->UpdateOneIter(i, Xy_); - } - - HostDeviceVector out_prediction_cached; - learner->Predict(Xy_, false, &out_prediction_cached, 0, 0); - - Json model{Object()}; - learner->SaveModel(&model); - - HostDeviceVector out_prediction; - { - std::unique_ptr learner{Learner::Create({Xy_})}; - learner->LoadModel(model); - learner->Predict(Xy_, false, &out_prediction, 0, 0); - } - - auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); - auto const h_predt = out_prediction.ConstHostSpan(); - - ASSERT_EQ(h_predt.size(), h_predt_cached.size()); - for (size_t i = 0; i < h_predt.size(); ++i) { - ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); - } - } - - void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) { - { - ctx->InitAllowUnknown(Args{{"nthread", "8"}}); - - ObjInfo task{ObjInfo::kRegression}; - std::unique_ptr updater{TreeUpdater::Create(updater_name, ctx, &task)}; - RegTree tree; - std::vector trees{&tree}; - auto gpair = GenerateRandomGradients(ctx, n_samples_, 1); - tree::TrainParam param; - param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); - - updater->Configure(Args{}); - std::vector> position(1); - updater->Update(¶m, &gpair, Xy_.get(), position, trees); - HostDeviceVector out_prediction_cached; - out_prediction_cached.SetDevice(ctx->Device()); - out_prediction_cached.Resize(n_samples_); - auto cache = - linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); - ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); - } - - for (auto policy : {"depthwise", "lossguide"}) { - for (auto subsample : {1.0f, 0.4f}) { - this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); - this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); - } - } - } -}; - TEST_F(TestPredictionCache, Approx) { Context ctx; this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree"); @@ -119,4 +32,4 @@ TEST_F(TestPredictionCache, GpuApprox) { this->RunTest(&ctx, "grow_gpu_approx", "one_output_per_tree"); } #endif // defined(XGBOOST_USE_CUDA) -} // namespace xgboost +} // namespace xgboost \ No newline at end of file diff --git a/tests/cpp/tree/test_prediction_cache.h b/tests/cpp/tree/test_prediction_cache.h new file mode 100644 index 000000000..a92c30237 --- /dev/null +++ b/tests/cpp/tree/test_prediction_cache.h @@ -0,0 +1,97 @@ +/** + * Copyright 2021-2024 by XGBoost contributors. + */ +#pragma once + +#include + +#include +#include + +#include + +#include "../../../src/tree/param.h" // for TrainParam +#include "../helpers.h" +#include "xgboost/task.h" // for ObjInfo + +namespace xgboost { +class TestPredictionCache : public ::testing::Test { + std::shared_ptr Xy_; + std::size_t n_samples_{2048}; + + protected: + void SetUp() override { + std::size_t n_features = 13; + bst_target_t n_targets = 3; + Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true); + } + + void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample, + std::string const& grow_policy, std::string const& strategy) { + std::unique_ptr learner{Learner::Create({Xy_})}; + learner->SetParam("device", ctx->DeviceName()); + learner->SetParam("updater", updater_name); + learner->SetParam("multi_strategy", strategy); + learner->SetParam("grow_policy", grow_policy); + learner->SetParam("subsample", std::to_string(subsample)); + learner->SetParam("nthread", "0"); + learner->Configure(); + + for (size_t i = 0; i < 8; ++i) { + learner->UpdateOneIter(i, Xy_); + } + + HostDeviceVector out_prediction_cached; + learner->Predict(Xy_, false, &out_prediction_cached, 0, 0); + + Json model{Object()}; + learner->SaveModel(&model); + + HostDeviceVector out_prediction; + { + std::unique_ptr learner{Learner::Create({Xy_})}; + learner->LoadModel(model); + learner->Predict(Xy_, false, &out_prediction, 0, 0); + } + + auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); + auto const h_predt = out_prediction.ConstHostSpan(); + + ASSERT_EQ(h_predt.size(), h_predt_cached.size()); + for (size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); + } + } + + void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) { + { + ctx->InitAllowUnknown(Args{{"nthread", "8"}}); + + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create(updater_name, ctx, &task)}; + RegTree tree; + std::vector trees{&tree}; + auto gpair = GenerateRandomGradients(ctx, n_samples_, 1); + tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); + + updater->Configure(Args{}); + std::vector> position(1); + updater->Update(¶m, &gpair, Xy_.get(), position, trees); + HostDeviceVector out_prediction_cached; + out_prediction_cached.SetDevice(ctx->Device()); + out_prediction_cached.Resize(n_samples_); + auto cache = + linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); + ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); + } + + for (auto policy : {"depthwise", "lossguide"}) { + for (auto subsample : {1.0f, 0.4f}) { + this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); + this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); + } + } + } +}; +} // namespace xgboost diff --git a/tests/python-sycl/test_sycl_training_continuation.py b/tests/python-sycl/test_sycl_training_continuation.py new file mode 100644 index 000000000..e2a11c987 --- /dev/null +++ b/tests/python-sycl/test_sycl_training_continuation.py @@ -0,0 +1,59 @@ +import numpy as np +import xgboost as xgb +import json + +rng = np.random.RandomState(1994) + + +class TestSYCLTrainingContinuation: + def run_training_continuation(self, use_json): + kRows = 64 + kCols = 32 + X = np.random.randn(kRows, kCols) + y = np.random.randn(kRows) + dtrain = xgb.DMatrix(X, y) + params = { + "device": "sycl", + "max_depth": "2", + "gamma": "0.1", + "alpha": "0.01", + "enable_experimental_json_serialization": use_json, + } + bst_0 = xgb.train(params, dtrain, num_boost_round=64) + dump_0 = bst_0.get_dump(dump_format="json") + + bst_1 = xgb.train(params, dtrain, num_boost_round=32) + bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1) + dump_1 = bst_1.get_dump(dump_format="json") + + def recursive_compare(obj_0, obj_1): + if isinstance(obj_0, float): + assert np.isclose(obj_0, obj_1, atol=1e-6) + elif isinstance(obj_0, str): + assert obj_0 == obj_1 + elif isinstance(obj_0, int): + assert obj_0 == obj_1 + elif isinstance(obj_0, dict): + keys_0 = list(obj_0.keys()) + keys_1 = list(obj_1.keys()) + values_0 = list(obj_0.values()) + values_1 = list(obj_1.values()) + for i in range(len(obj_0.items())): + assert keys_0[i] == keys_1[i] + if list(obj_0.keys())[i] != "missing": + recursive_compare(values_0[i], values_1[i]) + else: + for i in range(len(obj_0)): + recursive_compare(obj_0[i], obj_1[i]) + + assert len(dump_0) == len(dump_1) + for i in range(len(dump_0)): + obj_0 = json.loads(dump_0[i]) + obj_1 = json.loads(dump_1[i]) + recursive_compare(obj_0, obj_1) + + def test_sycl_training_continuation_binary(self): + self.run_training_continuation(False) + + def test_sycl_training_continuation_json(self): + self.run_training_continuation(True) diff --git a/tests/python-sycl/test_sycl_updaters.py b/tests/python-sycl/test_sycl_updaters.py new file mode 100644 index 000000000..57ca8d783 --- /dev/null +++ b/tests/python-sycl/test_sycl_updaters.py @@ -0,0 +1,80 @@ +import numpy as np +import gc +import pytest +import xgboost as xgb +from hypothesis import given, strategies, assume, settings, note + +import sys +import os + +# sys.path.append("tests/python") +# import testing as tm +from xgboost import testing as tm + +parameter_strategy = strategies.fixed_dictionaries( + { + "max_depth": strategies.integers(0, 11), + "max_leaves": strategies.integers(0, 256), + "max_bin": strategies.integers(2, 1024), + "grow_policy": strategies.sampled_from(["lossguide", "depthwise"]), + "single_precision_histogram": strategies.booleans(), + "min_child_weight": strategies.floats(0.5, 2.0), + "seed": strategies.integers(0, 10), + # We cannot enable subsampling as the training loss can increase + # 'subsample': strategies.floats(0.5, 1.0), + "colsample_bytree": strategies.floats(0.5, 1.0), + "colsample_bylevel": strategies.floats(0.5, 1.0), + } +).filter( + lambda x: (x["max_depth"] > 0 or x["max_leaves"] > 0) + and (x["max_depth"] > 0 or x["grow_policy"] == "lossguide") +) + + +def train_result(param, dmat, num_rounds): + result = {} + xgb.train( + param, + dmat, + num_rounds, + [(dmat, "train")], + verbose_eval=False, + evals_result=result, + ) + return result + + +class TestSYCLUpdaters: + @given(parameter_strategy, strategies.integers(1, 5), tm.make_dataset_strategy()) + @settings(deadline=None) + def test_sycl_hist(self, param, num_rounds, dataset): + param["tree_method"] = "hist" + param["device"] = "sycl" + param["verbosity"] = 0 + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(result) + assert tm.non_increasing(result["train"][dataset.metric]) + + @given(tm.make_dataset_strategy(), strategies.integers(0, 1)) + @settings(deadline=None) + def test_specified_device_id_sycl_update(self, dataset, device_id): + # Read the list of sycl-devicese + sycl_ls = os.popen("sycl-ls").read() + devices = sycl_ls.split("\n") + + # Test should launch only on gpu + # Find gpus in the list of devices + # and use the id in the list insteard of device_id + target_device_type = "opencl:gpu" + found_devices = 0 + for idx in range(len(devices)): + if len(devices[idx]) >= len(target_device_type): + if devices[idx][1 : 1 + len(target_device_type)] == target_device_type: + if found_devices == device_id: + param = {"device": f"sycl:gpu:{idx}"} + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), 10) + assert tm.non_increasing(result["train"][dataset.metric]) + else: + found_devices += 1