diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc new file mode 100644 index 000000000..98a42c3c8 --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -0,0 +1,55 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.cc + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "xgboost/tree_updater.h" +#pragma GCC diagnostic pop + +#include "xgboost/logging.h" + +#include "updater_quantile_hist.h" +#include "../data.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); + +DMLC_REGISTER_PARAMETER(HistMakerTrainParam); + +void QuantileHistMaker::Configure(const Args& args) { + const DeviceOrd device_spec = ctx_->Device(); + qu_ = device_manager.GetQueue(device_spec); + + param_.UpdateAllowUnknown(args); + hist_maker_param_.UpdateAllowUnknown(args); +} + +void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { + LOG(FATAL) << "Not Implemented yet"; +} + +bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) { + LOG(FATAL) << "Not Implemented yet"; +} + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") +.describe("Grow tree using quantized histogram with SYCL.") +.set_body( + [](Context const* ctx, ObjInfo const * task) { + return new QuantileHistMaker(ctx, task); + }); +} // namespace tree +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h new file mode 100644 index 000000000..93a50de3e --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -0,0 +1,91 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.h + */ +#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ +#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ + +#include +#include + +#include + +#include "../data/gradient_index.h" +#include "../common/hist_util.h" +#include "../common/row_set.h" +#include "../common/partition_builder.h" +#include "split_evaluator.h" +#include "../device_manager.h" + +#include "xgboost/data.h" +#include "xgboost/json.h" +#include "../../src/tree/constraints.h" +#include "../../src/common/random.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +// training parameters specific to this algorithm +struct HistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram = false; + // declare parameters + DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; + +/*! \brief construct a tree using quantized feature values with SYCL backend*/ +class QuantileHistMaker: public TreeUpdater { + public: + QuantileHistMaker(Context const* ctx, ObjInfo const * task) : + TreeUpdater(ctx), task_{task} { + updater_monitor_.Init("SYCLQuantileHistMaker"); + } + void Configure(const Args& args) override; + + void Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix* dmat, + xgboost::common::Span> out_position, + const std::vector& trees) override; + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) override; + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["sycl_hist_train_param"] = ToJson(hist_maker_param_); + } + + char const* Name() const override { + return "grow_quantile_histmaker_sycl"; + } + + protected: + HistMakerTrainParam hist_maker_param_; + // training parameter + xgboost::tree::TrainParam param_; + + xgboost::common::Monitor updater_monitor_; + + ::sycl::queue qu_; + DeviceManager device_manager; + ObjInfo const *task_{nullptr}; +}; + + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ diff --git a/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc new file mode 100644 index 000000000..4bf7bd962 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#include +#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker +#pragma GCC diagnostic pop + +namespace xgboost::sycl::tree { +TEST(SyclQuantileHistMaker, Basic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + + ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl"); +} + +TEST(SyclQuantileHistMaker, JsonIO) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + Json config {Object()}; + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->Configure({{"max_depth", std::to_string(42)}}); + updater->Configure({{"single_precision_histogram", std::to_string(true)}}); + updater->SaveConfig(&config); + } + + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->LoadConfig(config); + + Json new_config {Object()}; + updater->SaveConfig(&new_config); + + ASSERT_EQ(config, new_config); + + auto max_depth = atoi(get(new_config["train_param"]["max_depth"]).c_str()); + ASSERT_EQ(max_depth, 42); + + auto single_precision_histogram = atoi(get(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str()); + ASSERT_EQ(single_precision_histogram, 1); + } + +} +} // namespace xgboost::sycl::tree