[SYCL] Implement UpdatePredictionCache and connect updater with leraner. (#10701)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin 2024-08-21 20:07:44 +02:00 committed by GitHub
parent 9b88495840
commit 24d225c1ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 502 additions and 126 deletions

View File

@ -307,6 +307,99 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
builder_monitor_.Stop("ExpandWithLossGuide"); builder_monitor_.Stop("ExpandWithLossGuide");
} }
template <typename GradientSumT>
void HistUpdater<GradientSumT>::Update(
xgboost::tree::TrainParam const *param,
const common::GHistIndexMatrix &gmat,
const USMVector<GradientPair, MemoryType::on_device>& gpair,
DMatrix *p_fmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
RegTree *p_tree) {
builder_monitor_.Start("Update");
tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_);
interaction_constraints_.Reset();
this->InitData(gmat, gpair, *p_fmat, *p_tree);
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) {
ExpandWithLossGuide(gmat, p_tree, gpair);
} else {
ExpandWithDepthWise(gmat, p_tree, gpair);
}
for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
p_tree->Stat(nid).loss_chg = snode_host_[nid].best.loss_chg;
p_tree->Stat(nid).base_weight = snode_host_[nid].weight;
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_host_[nid].stats.GetHess());
}
builder_monitor_.Stop("Update");
}
template<typename GradientSumT>
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
linalg::MatrixView<float> out_preds) {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
return false;
}
builder_monitor_.Start("UpdatePredictionCache");
CHECK_GT(out_preds.Size(), 0U);
const size_t stride = out_preds.Stride(0);
const bool is_first_group = (out_pred_ptr == nullptr);
const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr;
const bool is_last_group = (gid + 1 == stride);
const int buffer_size = out_preds.Size() *stride;
if (buffer_size == 0) return true;
::sycl::event event;
if (is_first_group) {
out_preds_buf_.ResizeNoCopy(&qu_, buffer_size);
out_pred_ptr = &out_preds(0);
event = qu_.memcpy(out_preds_buf_.Data(), out_pred_ptr, buffer_size * sizeof(bst_float), event);
}
auto* out_preds_buf_ptr = out_preds_buf_.Data();
size_t n_nodes = row_set_collection_.Size();
std::vector<::sycl::event> events(n_nodes);
for (size_t node = 0; node < n_nodes; node++) {
const common::RowSetCollection::Elem& rowset = row_set_collection_[node];
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) {
int nid = rowset.node_id;
// if a node is marked as deleted by the pruner, traverse upward to locate
// a non-deleted leaf.
if ((*p_last_tree_)[nid].IsDeleted()) {
while ((*p_last_tree_)[nid].IsDeleted()) {
nid = (*p_last_tree_)[nid].Parent();
}
CHECK((*p_last_tree_)[nid].IsLeaf());
}
bst_float leaf_value = (*p_last_tree_)[nid].LeafValue();
const size_t* rid = rowset.begin;
const size_t num_rows = rowset.Size();
events[node] = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value;
});
});
}
}
if (is_last_group) {
qu_.memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events);
out_pred_ptr = nullptr;
}
qu_.wait();
builder_monitor_.Stop("UpdatePredictionCache");
return true;
}
template<typename GradientSumT> template<typename GradientSumT>
void HistUpdater<GradientSumT>::InitSampling( void HistUpdater<GradientSumT>::InitSampling(
const USMVector<GradientPair, MemoryType::on_device> &gpair, const USMVector<GradientPair, MemoryType::on_device> &gpair,
@ -479,6 +572,8 @@ void HistUpdater<GradientSumT>::InitData(
} }
} }
// store a pointer to the tree
p_last_tree_ = &tree;
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(), column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree); param_.colsample_bytree);

View File

@ -11,10 +11,10 @@
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#include <utility>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <utility>
#include "../common/partition_builder.h" #include "../common/partition_builder.h"
#include "split_evaluator.h" #include "split_evaluator.h"
@ -54,12 +54,10 @@ class HistUpdater {
explicit HistUpdater(const Context* ctx, explicit HistUpdater(const Context* ctx,
::sycl::queue qu, ::sycl::queue qu,
const xgboost::tree::TrainParam& param, const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_, FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat) DMatrix const* fmat)
: ctx_(ctx), qu_(qu), param_(param), : ctx_(ctx), qu_(qu), param_(param),
tree_evaluator_(qu, param, fmat->Info().num_col_), tree_evaluator_(qu, param, fmat->Info().num_col_),
pruner_(std::move(pruner)),
interaction_constraints_{std::move(int_constraints_)}, interaction_constraints_{std::move(int_constraints_)},
p_last_tree_(nullptr), p_last_fmat_(fmat) { p_last_tree_(nullptr), p_last_fmat_(fmat) {
builder_monitor_.Init("SYCL::Quantile::HistUpdater"); builder_monitor_.Init("SYCL::Quantile::HistUpdater");
@ -73,6 +71,17 @@ class HistUpdater {
sub_group_size_ = sub_group_sizes.back(); sub_group_size_ = sub_group_sizes.back();
} }
// update one tree, growing
void Update(xgboost::tree::TrainParam const *param,
const common::GHistIndexMatrix &gmat,
const USMVector<GradientPair, MemoryType::on_device>& gpair,
DMatrix *p_fmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
RegTree *p_tree);
bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> p_out_preds);
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync); void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder); void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
@ -200,7 +209,6 @@ class HistUpdater {
std::vector<SplitEntry<GradientSumT>> best_splits_host_; std::vector<SplitEntry<GradientSumT>> best_splits_host_;
TreeEvaluator<GradientSumT> tree_evaluator_; TreeEvaluator<GradientSumT> tree_evaluator_;
std::unique_ptr<TreeUpdater> pruner_;
FeatureInteractionConstraintHost interaction_constraints_; FeatureInteractionConstraintHost interaction_constraints_;
// back pointers to tree and data matrix // back pointers to tree and data matrix
@ -247,6 +255,9 @@ class HistUpdater {
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_; std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_; std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
bst_float* out_pred_ptr = nullptr;
::sycl::queue qu_; ::sycl::queue qu_;
}; };

View File

@ -3,6 +3,7 @@
* \file updater_quantile_hist.cc * \file updater_quantile_hist.cc
*/ */
#include <vector> #include <vector>
#include <memory>
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-Wtautological-constant-compare"
@ -29,6 +30,50 @@ void QuantileHistMaker::Configure(const Args& args) {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64);
if (hist_maker_param_.single_precision_histogram || !has_fp64_support) {
if (!hist_maker_param_.single_precision_histogram) {
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
}
hist_precision_ = HistPrecision::fp32;
} else {
hist_precision_ = HistPrecision::fp64;
}
}
template<typename GradientSumT>
void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pimpl,
DMatrix *dmat) {
pimpl->reset(new HistUpdater<GradientSumT>(
ctx_,
qu_,
param_,
int_constraint_, dmat));
if (collective::IsDistributed()) {
LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl";
} else {
(*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
(*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
}
}
template<typename GradientSumT>
void QuantileHistMaker::CallUpdate(
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair> *gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
const auto* gpair_h = gpair->Data();
gpair_device_.Resize(&qu_, gpair_h->Size());
qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair));
qu_.wait();
for (auto tree : trees) {
pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree);
}
} }
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
@ -36,12 +81,55 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
DMatrix *dmat, DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position, xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) { const std::vector<RegTree *> &trees) {
LOG(FATAL) << "Not Implemented yet"; if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("DeviceMatrixInitialization");
sycl::DeviceMatrix dmat_device;
dmat_device.Init(qu_, dmat);
updater_monitor_.Stop("DeviceMatrixInitialization");
updater_monitor_.Start("GmatInitialization");
gmat_.Init(qu_, ctx_, dmat_device, static_cast<uint32_t>(param_.max_bin));
updater_monitor_.Stop("GmatInitialization");
is_gmat_initialized_ = true;
}
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
int_constraint_.Configure(param_, dmat->Info().num_col_);
// build tree
if (hist_precision_ == HistPrecision::fp32) {
if (!pimpl_fp32) {
SetPimpl(&pimpl_fp32, dmat);
}
CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees);
} else {
if (!pimpl_fp64) {
SetPimpl(&pimpl_fp64, dmat);
}
CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees);
}
param_.learning_rate = lr;
p_last_dmat_ = dmat;
} }
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) { linalg::MatrixView<float> out_preds) {
LOG(FATAL) << "Not Implemented yet"; if (param_.subsample < 1.0f) return false;
if (hist_precision_ == HistPrecision::fp32) {
if (pimpl_fp32) {
return pimpl_fp32->UpdatePredictionCache(data, out_preds);
} else {
return false;
}
} else {
if (pimpl_fp64) {
return pimpl_fp64->UpdatePredictionCache(data, out_preds);
} else {
return false;
}
}
} }
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")

View File

@ -9,6 +9,7 @@
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <memory>
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
@ -16,8 +17,9 @@
#include "../common/partition_builder.h" #include "../common/partition_builder.h"
#include "split_evaluator.h" #include "split_evaluator.h"
#include "../device_manager.h" #include "../device_manager.h"
#include "hist_updater.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "../../src/tree/constraints.h" #include "../../src/tree/constraints.h"
#include "../../src/common/random.h" #include "../../src/common/random.h"
@ -75,12 +77,39 @@ class QuantileHistMaker: public TreeUpdater {
HistMakerTrainParam hist_maker_param_; HistMakerTrainParam hist_maker_param_;
// training parameter // training parameter
xgboost::tree::TrainParam param_; xgboost::tree::TrainParam param_;
// quantized data matrix
common::GHistIndexMatrix gmat_;
// (optional) data matrix with feature grouping
// column accessor
DMatrix const* p_last_dmat_ {nullptr};
bool is_gmat_initialized_ {false};
xgboost::common::Monitor updater_monitor_; xgboost::common::Monitor updater_monitor_;
template<typename GradientSumT>
void SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>*, DMatrix *dmat);
template<typename GradientSumT>
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair> *gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees);
enum class HistPrecision {fp32, fp64};
HistPrecision hist_precision_;
std::unique_ptr<HistUpdater<float>> pimpl_fp32;
std::unique_ptr<HistUpdater<double>> pimpl_fp64;
FeatureInteractionConstraintHost int_constraint_;
::sycl::queue qu_; ::sycl::queue qu_;
DeviceManager device_manager; DeviceManager device_manager;
ObjInfo const *task_{nullptr}; ObjInfo const *task_{nullptr};
USMVector<GradientPair, MemoryType::on_device> gpair_device_;
}; };

View File

@ -52,7 +52,8 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
case TreeMethod::kAuto: // Use hist as default in 2.0 case TreeMethod::kAuto: // Use hist as default in 2.0
case TreeMethod::kHist: { case TreeMethod::kHist: {
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
[] { return "grow_gpu_hist"; }); [] { return "grow_gpu_hist"; },
[] { return "grow_quantile_histmaker_sycl"; });
} }
case TreeMethod::kApprox: { case TreeMethod::kApprox: {
return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; }); return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });

View File

@ -21,10 +21,8 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
TestHistUpdater(const Context* ctx, TestHistUpdater(const Context* ctx,
::sycl::queue qu, ::sycl::queue qu,
const xgboost::tree::TrainParam& param, const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_, FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat) : HistUpdater<GradientSumT>(ctx, qu, param, DMatrix const* fmat) : HistUpdater<GradientSumT>(ctx, qu, param,
std::move(pruner),
int_constraints_, fmat) {} int_constraints_, fmat) {}
void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair, void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
@ -110,14 +108,12 @@ void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows); USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows);
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows); USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows);
@ -165,14 +161,12 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows); USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess); GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess);
@ -221,14 +215,12 @@ void TestHistUpdaterBuildHistogramsLossGuide(const xgboost::tree::TrainParam& pa
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>()); updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>()); updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
@ -285,14 +277,12 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>()); updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>()); updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
@ -345,14 +335,12 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>()); updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>()); updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
@ -423,8 +411,6 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
DeviceManager device_manager; DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device()); auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
sycl::DeviceMatrix dmat; sycl::DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get()); dmat.Init(qu, p_fmat.get());
@ -439,8 +425,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0))); nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)}; TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows); USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false); GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false);
@ -455,8 +440,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
std::vector<size_t> row_indices_desired_host(num_rows); std::vector<size_t> row_indices_desired_host(num_rows);
size_t n_left, n_right; size_t n_left, n_right;
{ {
std::unique_ptr<TreeUpdater> pruner4verification{TreeUpdater::Create("prune", &ctx, &task)}; TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, int_constraints, p_fmat.get());
TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, std::move(pruner4verification), int_constraints, p_fmat.get());
auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree); auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree);
size_t n_nodes = nodes.size(); size_t n_nodes = nodes.size();
@ -526,9 +510,7 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
RegTree tree; RegTree tree;
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
ObjInfo task{ObjInfo::kRegression}; TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>()); updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>()); updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
@ -576,9 +558,7 @@ void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param)
RegTree tree; RegTree tree;
FeatureInteractionConstraintHost int_constraints; FeatureInteractionConstraintHost int_constraints;
ObjInfo task{ObjInfo::kRegression}; TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>()); updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>()); updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020-2024 by XGBoost contributors
*/
#include <gtest/gtest.h>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "../tree/test_prediction_cache.h"
#pragma GCC diagnostic pop
namespace xgboost::sycl::tree {
class SyclPredictionCache : public xgboost::TestPredictionCache {};
TEST_F(SyclPredictionCache, Hist) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
this->RunTest(&ctx, "grow_quantile_histmaker_sycl", "one_output_per_tree");
}
} // namespace xgboost::sycl::tree

View File

@ -2,97 +2,10 @@
* Copyright 2021-2023 by XGBoost contributors * Copyright 2021-2023 by XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h>
#include <memory> #include "test_prediction_cache.h"
#include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h"
#include "xgboost/task.h" // for ObjInfo
namespace xgboost { namespace xgboost {
class TestPredictionCache : public ::testing::Test {
std::shared_ptr<DMatrix> Xy_;
std::size_t n_samples_{2048};
protected:
void SetUp() override {
std::size_t n_features = 13;
bst_target_t n_targets = 3;
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
}
void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
std::string const& grow_policy, std::string const& strategy) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("updater", updater_name);
learner->SetParam("multi_strategy", strategy);
learner->SetParam("grow_policy", grow_policy);
learner->SetParam("subsample", std::to_string(subsample));
learner->SetParam("nthread", "0");
learner->Configure();
for (size_t i = 0; i < 8; ++i) {
learner->UpdateOneIter(i, Xy_);
}
HostDeviceVector<float> out_prediction_cached;
learner->Predict(Xy_, false, &out_prediction_cached, 0, 0);
Json model{Object()};
learner->SaveModel(&model);
HostDeviceVector<float> out_prediction;
{
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->LoadModel(model);
learner->Predict(Xy_, false, &out_prediction, 0, 0);
}
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
auto const h_predt = out_prediction.ConstHostSpan();
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
for (size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
}
}
void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
{
ctx->InitAllowUnknown(Args{{"nthread", "8"}});
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
RegTree tree;
std::vector<RegTree*> trees{&tree};
auto gpair = GenerateRandomGradients(ctx, n_samples_, 1);
tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
updater->Configure(Args{});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&param, &gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.SetDevice(ctx->Device());
out_prediction_cached.Resize(n_samples_);
auto cache =
linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
}
for (auto policy : {"depthwise", "lossguide"}) {
for (auto subsample : {1.0f, 0.4f}) {
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
}
}
}
};
TEST_F(TestPredictionCache, Approx) { TEST_F(TestPredictionCache, Approx) {
Context ctx; Context ctx;
this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree"); this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree");

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021-2024 by XGBoost contributors.
*/
#pragma once
#include <gtest/gtest.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h>
#include <memory>
#include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h"
#include "xgboost/task.h" // for ObjInfo
namespace xgboost {
class TestPredictionCache : public ::testing::Test {
std::shared_ptr<DMatrix> Xy_;
std::size_t n_samples_{2048};
protected:
void SetUp() override {
std::size_t n_features = 13;
bst_target_t n_targets = 3;
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
}
void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
std::string const& grow_policy, std::string const& strategy) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("updater", updater_name);
learner->SetParam("multi_strategy", strategy);
learner->SetParam("grow_policy", grow_policy);
learner->SetParam("subsample", std::to_string(subsample));
learner->SetParam("nthread", "0");
learner->Configure();
for (size_t i = 0; i < 8; ++i) {
learner->UpdateOneIter(i, Xy_);
}
HostDeviceVector<float> out_prediction_cached;
learner->Predict(Xy_, false, &out_prediction_cached, 0, 0);
Json model{Object()};
learner->SaveModel(&model);
HostDeviceVector<float> out_prediction;
{
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->LoadModel(model);
learner->Predict(Xy_, false, &out_prediction, 0, 0);
}
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
auto const h_predt = out_prediction.ConstHostSpan();
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
for (size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
}
}
void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
{
ctx->InitAllowUnknown(Args{{"nthread", "8"}});
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
RegTree tree;
std::vector<RegTree*> trees{&tree};
auto gpair = GenerateRandomGradients(ctx, n_samples_, 1);
tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
updater->Configure(Args{});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&param, &gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.SetDevice(ctx->Device());
out_prediction_cached.Resize(n_samples_);
auto cache =
linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
}
for (auto policy : {"depthwise", "lossguide"}) {
for (auto subsample : {1.0f, 0.4f}) {
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
}
}
}
};
} // namespace xgboost

View File

@ -0,0 +1,59 @@
import numpy as np
import xgboost as xgb
import json
rng = np.random.RandomState(1994)
class TestSYCLTrainingContinuation:
def run_training_continuation(self, use_json):
kRows = 64
kCols = 32
X = np.random.randn(kRows, kCols)
y = np.random.randn(kRows)
dtrain = xgb.DMatrix(X, y)
params = {
"device": "sycl",
"max_depth": "2",
"gamma": "0.1",
"alpha": "0.01",
"enable_experimental_json_serialization": use_json,
}
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
dump_0 = bst_0.get_dump(dump_format="json")
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
dump_1 = bst_1.get_dump(dump_format="json")
def recursive_compare(obj_0, obj_1):
if isinstance(obj_0, float):
assert np.isclose(obj_0, obj_1, atol=1e-6)
elif isinstance(obj_0, str):
assert obj_0 == obj_1
elif isinstance(obj_0, int):
assert obj_0 == obj_1
elif isinstance(obj_0, dict):
keys_0 = list(obj_0.keys())
keys_1 = list(obj_1.keys())
values_0 = list(obj_0.values())
values_1 = list(obj_1.values())
for i in range(len(obj_0.items())):
assert keys_0[i] == keys_1[i]
if list(obj_0.keys())[i] != "missing":
recursive_compare(values_0[i], values_1[i])
else:
for i in range(len(obj_0)):
recursive_compare(obj_0[i], obj_1[i])
assert len(dump_0) == len(dump_1)
for i in range(len(dump_0)):
obj_0 = json.loads(dump_0[i])
obj_1 = json.loads(dump_1[i])
recursive_compare(obj_0, obj_1)
def test_sycl_training_continuation_binary(self):
self.run_training_continuation(False)
def test_sycl_training_continuation_json(self):
self.run_training_continuation(True)

View File

@ -0,0 +1,80 @@
import numpy as np
import gc
import pytest
import xgboost as xgb
from hypothesis import given, strategies, assume, settings, note
import sys
import os
# sys.path.append("tests/python")
# import testing as tm
from xgboost import testing as tm
parameter_strategy = strategies.fixed_dictionaries(
{
"max_depth": strategies.integers(0, 11),
"max_leaves": strategies.integers(0, 256),
"max_bin": strategies.integers(2, 1024),
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]),
"single_precision_histogram": strategies.booleans(),
"min_child_weight": strategies.floats(0.5, 2.0),
"seed": strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
"colsample_bytree": strategies.floats(0.5, 1.0),
"colsample_bylevel": strategies.floats(0.5, 1.0),
}
).filter(
lambda x: (x["max_depth"] > 0 or x["max_leaves"] > 0)
and (x["max_depth"] > 0 or x["grow_policy"] == "lossguide")
)
def train_result(param, dmat, num_rounds):
result = {}
xgb.train(
param,
dmat,
num_rounds,
[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
return result
class TestSYCLUpdaters:
@given(parameter_strategy, strategies.integers(1, 5), tm.make_dataset_strategy())
@settings(deadline=None)
def test_sycl_hist(self, param, num_rounds, dataset):
param["tree_method"] = "hist"
param["device"] = "sycl"
param["verbosity"] = 0
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
@given(tm.make_dataset_strategy(), strategies.integers(0, 1))
@settings(deadline=None)
def test_specified_device_id_sycl_update(self, dataset, device_id):
# Read the list of sycl-devicese
sycl_ls = os.popen("sycl-ls").read()
devices = sycl_ls.split("\n")
# Test should launch only on gpu
# Find gpus in the list of devices
# and use the id in the list insteard of device_id
target_device_type = "opencl:gpu"
found_devices = 0
for idx in range(len(devices)):
if len(devices[idx]) >= len(target_device_type):
if devices[idx][1 : 1 + len(target_device_type)] == target_device_type:
if found_devices == device_id:
param = {"device": f"sycl:gpu:{idx}"}
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result["train"][dataset.metric])
else:
found_devices += 1