[SYCL] Add basic features for QuantileHistMaker (#10174)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
882f4136e0
commit
6e5c335cea
55
plugin/sycl/tree/updater_quantile_hist.cc
Normal file
55
plugin/sycl/tree/updater_quantile_hist.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2024 by Contributors
|
||||||
|
* \file updater_quantile_hist.cc
|
||||||
|
*/
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||||
|
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||||
|
#include "xgboost/tree_updater.h"
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
|
#include "updater_quantile_hist.h"
|
||||||
|
#include "../data.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
||||||
|
|
||||||
|
void QuantileHistMaker::Configure(const Args& args) {
|
||||||
|
const DeviceOrd device_spec = ctx_->Device();
|
||||||
|
qu_ = device_manager.GetQueue(device_spec);
|
||||||
|
|
||||||
|
param_.UpdateAllowUnknown(args);
|
||||||
|
hist_maker_param_.UpdateAllowUnknown(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||||
|
linalg::Matrix<GradientPair>* gpair,
|
||||||
|
DMatrix *dmat,
|
||||||
|
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
|
const std::vector<RegTree *> &trees) {
|
||||||
|
LOG(FATAL) << "Not Implemented yet";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||||
|
linalg::MatrixView<float> out_preds) {
|
||||||
|
LOG(FATAL) << "Not Implemented yet";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
|
||||||
|
.describe("Grow tree using quantized histogram with SYCL.")
|
||||||
|
.set_body(
|
||||||
|
[](Context const* ctx, ObjInfo const * task) {
|
||||||
|
return new QuantileHistMaker(ctx, task);
|
||||||
|
});
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
91
plugin/sycl/tree/updater_quantile_hist.h
Normal file
91
plugin/sycl/tree/updater_quantile_hist.h
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2024 by Contributors
|
||||||
|
* \file updater_quantile_hist.h
|
||||||
|
*/
|
||||||
|
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
|
||||||
|
#include <dmlc/timer.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../data/gradient_index.h"
|
||||||
|
#include "../common/hist_util.h"
|
||||||
|
#include "../common/row_set.h"
|
||||||
|
#include "../common/partition_builder.h"
|
||||||
|
#include "split_evaluator.h"
|
||||||
|
#include "../device_manager.h"
|
||||||
|
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "../../src/tree/constraints.h"
|
||||||
|
#include "../../src/common/random.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
// training parameters specific to this algorithm
|
||||||
|
struct HistMakerTrainParam
|
||||||
|
: public XGBoostParameter<HistMakerTrainParam> {
|
||||||
|
bool single_precision_histogram = false;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
|
||||||
|
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||||
|
"Use single precision to build histograms.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief construct a tree using quantized feature values with SYCL backend*/
|
||||||
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
|
public:
|
||||||
|
QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
|
||||||
|
TreeUpdater(ctx), task_{task} {
|
||||||
|
updater_monitor_.Init("SYCLQuantileHistMaker");
|
||||||
|
}
|
||||||
|
void Configure(const Args& args) override;
|
||||||
|
|
||||||
|
void Update(xgboost::tree::TrainParam const *param,
|
||||||
|
linalg::Matrix<GradientPair>* gpair,
|
||||||
|
DMatrix* dmat,
|
||||||
|
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
|
const std::vector<RegTree*>& trees) override;
|
||||||
|
|
||||||
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
|
linalg::MatrixView<float> out_preds) override;
|
||||||
|
|
||||||
|
void LoadConfig(Json const& in) override {
|
||||||
|
auto const& config = get<Object const>(in);
|
||||||
|
FromJson(config.at("train_param"), &this->param_);
|
||||||
|
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
auto& out = *p_out;
|
||||||
|
out["train_param"] = ToJson(param_);
|
||||||
|
out["sycl_hist_train_param"] = ToJson(hist_maker_param_);
|
||||||
|
}
|
||||||
|
|
||||||
|
char const* Name() const override {
|
||||||
|
return "grow_quantile_histmaker_sycl";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
HistMakerTrainParam hist_maker_param_;
|
||||||
|
// training parameter
|
||||||
|
xgboost::tree::TrainParam param_;
|
||||||
|
|
||||||
|
xgboost::common::Monitor updater_monitor_;
|
||||||
|
|
||||||
|
::sycl::queue qu_;
|
||||||
|
DeviceManager device_manager;
|
||||||
|
ObjInfo const *task_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
55
tests/cpp/plugin/test_sycl_quantile_hist_builder.cc
Normal file
55
tests/cpp/plugin/test_sycl_quantile_hist_builder.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2024 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||||
|
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
#include <xgboost/task.h>
|
||||||
|
#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
|
namespace xgboost::sycl::tree {
|
||||||
|
TEST(SyclQuantileHistMaker, Basic) {
|
||||||
|
Context ctx;
|
||||||
|
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||||
|
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
|
||||||
|
|
||||||
|
ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SyclQuantileHistMaker, JsonIO) {
|
||||||
|
Context ctx;
|
||||||
|
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||||
|
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
Json config {Object()};
|
||||||
|
{
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
|
||||||
|
updater->Configure({{"max_depth", std::to_string(42)}});
|
||||||
|
updater->Configure({{"single_precision_histogram", std::to_string(true)}});
|
||||||
|
updater->SaveConfig(&config);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
|
||||||
|
updater->LoadConfig(config);
|
||||||
|
|
||||||
|
Json new_config {Object()};
|
||||||
|
updater->SaveConfig(&new_config);
|
||||||
|
|
||||||
|
ASSERT_EQ(config, new_config);
|
||||||
|
|
||||||
|
auto max_depth = atoi(get<String const>(new_config["train_param"]["max_depth"]).c_str());
|
||||||
|
ASSERT_EQ(max_depth, 42);
|
||||||
|
|
||||||
|
auto single_precision_histogram = atoi(get<String const>(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str());
|
||||||
|
ASSERT_EQ(single_precision_histogram, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace xgboost::sycl::tree
|
||||||
Loading…
x
Reference in New Issue
Block a user