[SYCL] Implement UpdatePredictionCache and connect updater with leraner. (#10701)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
9b88495840
commit
24d225c1ab
@ -307,6 +307,99 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
|
||||
builder_monitor_.Stop("ExpandWithLossGuide");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::Update(
|
||||
xgboost::tree::TrainParam const *param,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device>& gpair,
|
||||
DMatrix *p_fmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
RegTree *p_tree) {
|
||||
builder_monitor_.Start("Update");
|
||||
|
||||
tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_);
|
||||
interaction_constraints_.Reset();
|
||||
|
||||
this->InitData(gmat, gpair, *p_fmat, *p_tree);
|
||||
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) {
|
||||
ExpandWithLossGuide(gmat, p_tree, gpair);
|
||||
} else {
|
||||
ExpandWithDepthWise(gmat, p_tree, gpair);
|
||||
}
|
||||
|
||||
for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
|
||||
p_tree->Stat(nid).loss_chg = snode_host_[nid].best.loss_chg;
|
||||
p_tree->Stat(nid).base_weight = snode_host_[nid].weight;
|
||||
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_host_[nid].stats.GetHess());
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("Update");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
linalg::MatrixView<float> out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
const size_t stride = out_preds.Stride(0);
|
||||
const bool is_first_group = (out_pred_ptr == nullptr);
|
||||
const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr;
|
||||
const bool is_last_group = (gid + 1 == stride);
|
||||
|
||||
const int buffer_size = out_preds.Size() *stride;
|
||||
if (buffer_size == 0) return true;
|
||||
|
||||
::sycl::event event;
|
||||
if (is_first_group) {
|
||||
out_preds_buf_.ResizeNoCopy(&qu_, buffer_size);
|
||||
out_pred_ptr = &out_preds(0);
|
||||
event = qu_.memcpy(out_preds_buf_.Data(), out_pred_ptr, buffer_size * sizeof(bst_float), event);
|
||||
}
|
||||
auto* out_preds_buf_ptr = out_preds_buf_.Data();
|
||||
|
||||
size_t n_nodes = row_set_collection_.Size();
|
||||
std::vector<::sycl::event> events(n_nodes);
|
||||
for (size_t node = 0; node < n_nodes; node++) {
|
||||
const common::RowSetCollection::Elem& rowset = row_set_collection_[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) {
|
||||
int nid = rowset.node_id;
|
||||
// if a node is marked as deleted by the pruner, traverse upward to locate
|
||||
// a non-deleted leaf.
|
||||
if ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
while ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
nid = (*p_last_tree_)[nid].Parent();
|
||||
}
|
||||
CHECK((*p_last_tree_)[nid].IsLeaf());
|
||||
}
|
||||
bst_float leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
const size_t* rid = rowset.begin;
|
||||
const size_t num_rows = rowset.Size();
|
||||
|
||||
events[node] = qu_.submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
|
||||
out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
if (is_last_group) {
|
||||
qu_.memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events);
|
||||
out_pred_ptr = nullptr;
|
||||
}
|
||||
qu_.wait();
|
||||
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitSampling(
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
@ -479,6 +572,8 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
}
|
||||
}
|
||||
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
|
||||
@ -11,10 +11,10 @@
|
||||
#include <xgboost/tree_updater.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
@ -54,12 +54,10 @@ class HistUpdater {
|
||||
explicit HistUpdater(const Context* ctx,
|
||||
::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: ctx_(ctx), qu_(qu), param_(param),
|
||||
tree_evaluator_(qu, param, fmat->Info().num_col_),
|
||||
pruner_(std::move(pruner)),
|
||||
interaction_constraints_{std::move(int_constraints_)},
|
||||
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
||||
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||
@ -73,6 +71,17 @@ class HistUpdater {
|
||||
sub_group_size_ = sub_group_sizes.back();
|
||||
}
|
||||
|
||||
// update one tree, growing
|
||||
void Update(xgboost::tree::TrainParam const *param,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device>& gpair,
|
||||
DMatrix *p_fmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
RegTree *p_tree);
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<float> p_out_preds);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
@ -200,7 +209,6 @@ class HistUpdater {
|
||||
std::vector<SplitEntry<GradientSumT>> best_splits_host_;
|
||||
|
||||
TreeEvaluator<GradientSumT> tree_evaluator_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
|
||||
// back pointers to tree and data matrix
|
||||
@ -247,6 +255,9 @@ class HistUpdater {
|
||||
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
|
||||
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
|
||||
|
||||
USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
|
||||
bst_float* out_pred_ptr = nullptr;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
* \file updater_quantile_hist.cc
|
||||
*/
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
@ -29,6 +30,50 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
|
||||
param_.UpdateAllowUnknown(args);
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
|
||||
bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64);
|
||||
if (hist_maker_param_.single_precision_histogram || !has_fp64_support) {
|
||||
if (!hist_maker_param_.single_precision_histogram) {
|
||||
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
|
||||
}
|
||||
hist_precision_ = HistPrecision::fp32;
|
||||
} else {
|
||||
hist_precision_ = HistPrecision::fp64;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pimpl,
|
||||
DMatrix *dmat) {
|
||||
pimpl->reset(new HistUpdater<GradientSumT>(
|
||||
ctx_,
|
||||
qu_,
|
||||
param_,
|
||||
int_constraint_, dmat));
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl";
|
||||
} else {
|
||||
(*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
(*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::CallUpdate(
|
||||
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
|
||||
xgboost::tree::TrainParam const *param,
|
||||
linalg::Matrix<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
const auto* gpair_h = gpair->Data();
|
||||
gpair_device_.Resize(&qu_, gpair_h->Size());
|
||||
qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair));
|
||||
qu_.wait();
|
||||
|
||||
for (auto tree : trees) {
|
||||
pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree);
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||
@ -36,12 +81,55 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
LOG(FATAL) << "Not Implemented yet";
|
||||
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
||||
updater_monitor_.Start("DeviceMatrixInitialization");
|
||||
sycl::DeviceMatrix dmat_device;
|
||||
dmat_device.Init(qu_, dmat);
|
||||
updater_monitor_.Stop("DeviceMatrixInitialization");
|
||||
updater_monitor_.Start("GmatInitialization");
|
||||
gmat_.Init(qu_, ctx_, dmat_device, static_cast<uint32_t>(param_.max_bin));
|
||||
updater_monitor_.Stop("GmatInitialization");
|
||||
is_gmat_initialized_ = true;
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
int_constraint_.Configure(param_, dmat->Info().num_col_);
|
||||
// build tree
|
||||
if (hist_precision_ == HistPrecision::fp32) {
|
||||
if (!pimpl_fp32) {
|
||||
SetPimpl(&pimpl_fp32, dmat);
|
||||
}
|
||||
CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees);
|
||||
} else {
|
||||
if (!pimpl_fp64) {
|
||||
SetPimpl(&pimpl_fp64, dmat);
|
||||
}
|
||||
CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees);
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
|
||||
p_last_dmat_ = dmat;
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<float> out_preds) {
|
||||
LOG(FATAL) << "Not Implemented yet";
|
||||
if (param_.subsample < 1.0f) return false;
|
||||
|
||||
if (hist_precision_ == HistPrecision::fp32) {
|
||||
if (pimpl_fp32) {
|
||||
return pimpl_fp32->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (pimpl_fp64) {
|
||||
return pimpl_fp64->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "../data/gradient_index.h"
|
||||
#include "../common/hist_util.h"
|
||||
@ -16,8 +17,9 @@
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
#include "../device_manager.h"
|
||||
|
||||
#include "hist_updater.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../../src/tree/constraints.h"
|
||||
#include "../../src/common/random.h"
|
||||
@ -75,12 +77,39 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
HistMakerTrainParam hist_maker_param_;
|
||||
// training parameter
|
||||
xgboost::tree::TrainParam param_;
|
||||
// quantized data matrix
|
||||
common::GHistIndexMatrix gmat_;
|
||||
// (optional) data matrix with feature grouping
|
||||
// column accessor
|
||||
DMatrix const* p_last_dmat_ {nullptr};
|
||||
bool is_gmat_initialized_ {false};
|
||||
|
||||
xgboost::common::Monitor updater_monitor_;
|
||||
|
||||
template<typename GradientSumT>
|
||||
void SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>*, DMatrix *dmat);
|
||||
|
||||
template<typename GradientSumT>
|
||||
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
|
||||
xgboost::tree::TrainParam const *param,
|
||||
linalg::Matrix<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees);
|
||||
|
||||
enum class HistPrecision {fp32, fp64};
|
||||
HistPrecision hist_precision_;
|
||||
|
||||
std::unique_ptr<HistUpdater<float>> pimpl_fp32;
|
||||
std::unique_ptr<HistUpdater<double>> pimpl_fp64;
|
||||
|
||||
FeatureInteractionConstraintHost int_constraint_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
DeviceManager device_manager;
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair_device_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -52,7 +52,8 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
|
||||
case TreeMethod::kAuto: // Use hist as default in 2.0
|
||||
case TreeMethod::kHist: {
|
||||
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
|
||||
[] { return "grow_gpu_hist"; });
|
||||
[] { return "grow_gpu_hist"; },
|
||||
[] { return "grow_quantile_histmaker_sycl"; });
|
||||
}
|
||||
case TreeMethod::kApprox: {
|
||||
return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });
|
||||
|
||||
@ -21,10 +21,8 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
|
||||
TestHistUpdater(const Context* ctx,
|
||||
::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat) : HistUpdater<GradientSumT>(ctx, qu, param,
|
||||
std::move(pruner),
|
||||
int_constraints_, fmat) {}
|
||||
|
||||
void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
@ -110,14 +108,12 @@ void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
|
||||
USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows);
|
||||
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows);
|
||||
@ -165,14 +161,12 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess);
|
||||
@ -221,14 +215,12 @@ void TestHistUpdaterBuildHistogramsLossGuide(const xgboost::tree::TrainParam& pa
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
|
||||
@ -285,14 +277,12 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
|
||||
@ -345,14 +335,12 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
|
||||
@ -423,8 +411,6 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
sycl::DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
@ -439,8 +425,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
|
||||
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false);
|
||||
|
||||
@ -455,8 +440,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
|
||||
std::vector<size_t> row_indices_desired_host(num_rows);
|
||||
size_t n_left, n_right;
|
||||
{
|
||||
std::unique_ptr<TreeUpdater> pruner4verification{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, std::move(pruner4verification), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree);
|
||||
|
||||
size_t n_nodes = nodes.size();
|
||||
@ -526,9 +510,7 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
|
||||
|
||||
RegTree tree;
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
|
||||
@ -576,9 +558,7 @@ void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param)
|
||||
|
||||
RegTree tree;
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
|
||||
|
||||
23
tests/cpp/plugin/test_sycl_prediction_cache.cc
Normal file
23
tests/cpp/plugin/test_sycl_prediction_cache.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Copyright 2020-2024 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include "../tree/test_prediction_cache.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
namespace xgboost::sycl::tree {
|
||||
|
||||
class SyclPredictionCache : public xgboost::TestPredictionCache {};
|
||||
|
||||
TEST_F(SyclPredictionCache, Hist) {
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
this->RunTest(&ctx, "grow_quantile_histmaker_sycl", "one_output_per_tree");
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl::tree
|
||||
@ -2,97 +2,10 @@
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "test_prediction_cache.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class TestPredictionCache : public ::testing::Test {
|
||||
std::shared_ptr<DMatrix> Xy_;
|
||||
std::size_t n_samples_{2048};
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::size_t n_features = 13;
|
||||
bst_target_t n_targets = 3;
|
||||
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
|
||||
}
|
||||
|
||||
void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
|
||||
std::string const& grow_policy, std::string const& strategy) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("device", ctx->DeviceName());
|
||||
learner->SetParam("updater", updater_name);
|
||||
learner->SetParam("multi_strategy", strategy);
|
||||
learner->SetParam("grow_policy", grow_policy);
|
||||
learner->SetParam("subsample", std::to_string(subsample));
|
||||
learner->SetParam("nthread", "0");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
learner->UpdateOneIter(i, Xy_);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
learner->Predict(Xy_, false, &out_prediction_cached, 0, 0);
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
HostDeviceVector<float> out_prediction;
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->LoadModel(model);
|
||||
learner->Predict(Xy_, false, &out_prediction, 0, 0);
|
||||
}
|
||||
|
||||
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
|
||||
auto const h_predt = out_prediction.ConstHostSpan();
|
||||
|
||||
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
|
||||
for (size_t i = 0; i < h_predt.size(); ++i) {
|
||||
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
|
||||
{
|
||||
ctx->InitAllowUnknown(Args{{"nthread", "8"}});
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
|
||||
RegTree tree;
|
||||
std::vector<RegTree*> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(ctx, n_samples_, 1);
|
||||
tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
|
||||
|
||||
updater->Configure(Args{});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
updater->Update(¶m, &gpair, Xy_.get(), position, trees);
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
out_prediction_cached.SetDevice(ctx->Device());
|
||||
out_prediction_cached.Resize(n_samples_);
|
||||
auto cache =
|
||||
linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
|
||||
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
|
||||
}
|
||||
|
||||
for (auto policy : {"depthwise", "lossguide"}) {
|
||||
for (auto subsample : {1.0f, 0.4f}) {
|
||||
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
|
||||
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestPredictionCache, Approx) {
|
||||
Context ctx;
|
||||
this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree");
|
||||
@ -119,4 +32,4 @@ TEST_F(TestPredictionCache, GpuApprox) {
|
||||
this->RunTest(&ctx, "grow_gpu_approx", "one_output_per_tree");
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
97
tests/cpp/tree/test_prediction_cache.h
Normal file
97
tests/cpp/tree/test_prediction_cache.h
Normal file
@ -0,0 +1,97 @@
|
||||
/**
|
||||
* Copyright 2021-2024 by XGBoost contributors.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
|
||||
namespace xgboost {
|
||||
class TestPredictionCache : public ::testing::Test {
|
||||
std::shared_ptr<DMatrix> Xy_;
|
||||
std::size_t n_samples_{2048};
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::size_t n_features = 13;
|
||||
bst_target_t n_targets = 3;
|
||||
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
|
||||
}
|
||||
|
||||
void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
|
||||
std::string const& grow_policy, std::string const& strategy) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("device", ctx->DeviceName());
|
||||
learner->SetParam("updater", updater_name);
|
||||
learner->SetParam("multi_strategy", strategy);
|
||||
learner->SetParam("grow_policy", grow_policy);
|
||||
learner->SetParam("subsample", std::to_string(subsample));
|
||||
learner->SetParam("nthread", "0");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
learner->UpdateOneIter(i, Xy_);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
learner->Predict(Xy_, false, &out_prediction_cached, 0, 0);
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
HostDeviceVector<float> out_prediction;
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->LoadModel(model);
|
||||
learner->Predict(Xy_, false, &out_prediction, 0, 0);
|
||||
}
|
||||
|
||||
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
|
||||
auto const h_predt = out_prediction.ConstHostSpan();
|
||||
|
||||
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
|
||||
for (size_t i = 0; i < h_predt.size(); ++i) {
|
||||
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
|
||||
{
|
||||
ctx->InitAllowUnknown(Args{{"nthread", "8"}});
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
|
||||
RegTree tree;
|
||||
std::vector<RegTree*> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(ctx, n_samples_, 1);
|
||||
tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
|
||||
|
||||
updater->Configure(Args{});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
updater->Update(¶m, &gpair, Xy_.get(), position, trees);
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
out_prediction_cached.SetDevice(ctx->Device());
|
||||
out_prediction_cached.Resize(n_samples_);
|
||||
auto cache =
|
||||
linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
|
||||
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
|
||||
}
|
||||
|
||||
for (auto policy : {"depthwise", "lossguide"}) {
|
||||
for (auto subsample : {1.0f, 0.4f}) {
|
||||
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
|
||||
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
59
tests/python-sycl/test_sycl_training_continuation.py
Normal file
59
tests/python-sycl/test_sycl_training_continuation.py
Normal file
@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import json
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestSYCLTrainingContinuation:
|
||||
def run_training_continuation(self, use_json):
|
||||
kRows = 64
|
||||
kCols = 32
|
||||
X = np.random.randn(kRows, kCols)
|
||||
y = np.random.randn(kRows)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
params = {
|
||||
"device": "sycl",
|
||||
"max_depth": "2",
|
||||
"gamma": "0.1",
|
||||
"alpha": "0.01",
|
||||
"enable_experimental_json_serialization": use_json,
|
||||
}
|
||||
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
||||
dump_0 = bst_0.get_dump(dump_format="json")
|
||||
|
||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
||||
dump_1 = bst_1.get_dump(dump_format="json")
|
||||
|
||||
def recursive_compare(obj_0, obj_1):
|
||||
if isinstance(obj_0, float):
|
||||
assert np.isclose(obj_0, obj_1, atol=1e-6)
|
||||
elif isinstance(obj_0, str):
|
||||
assert obj_0 == obj_1
|
||||
elif isinstance(obj_0, int):
|
||||
assert obj_0 == obj_1
|
||||
elif isinstance(obj_0, dict):
|
||||
keys_0 = list(obj_0.keys())
|
||||
keys_1 = list(obj_1.keys())
|
||||
values_0 = list(obj_0.values())
|
||||
values_1 = list(obj_1.values())
|
||||
for i in range(len(obj_0.items())):
|
||||
assert keys_0[i] == keys_1[i]
|
||||
if list(obj_0.keys())[i] != "missing":
|
||||
recursive_compare(values_0[i], values_1[i])
|
||||
else:
|
||||
for i in range(len(obj_0)):
|
||||
recursive_compare(obj_0[i], obj_1[i])
|
||||
|
||||
assert len(dump_0) == len(dump_1)
|
||||
for i in range(len(dump_0)):
|
||||
obj_0 = json.loads(dump_0[i])
|
||||
obj_1 = json.loads(dump_1[i])
|
||||
recursive_compare(obj_0, obj_1)
|
||||
|
||||
def test_sycl_training_continuation_binary(self):
|
||||
self.run_training_continuation(False)
|
||||
|
||||
def test_sycl_training_continuation_json(self):
|
||||
self.run_training_continuation(True)
|
||||
80
tests/python-sycl/test_sycl_updaters.py
Normal file
80
tests/python-sycl/test_sycl_updaters.py
Normal file
@ -0,0 +1,80 @@
|
||||
import numpy as np
|
||||
import gc
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
from hypothesis import given, strategies, assume, settings, note
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# sys.path.append("tests/python")
|
||||
# import testing as tm
|
||||
from xgboost import testing as tm
|
||||
|
||||
parameter_strategy = strategies.fixed_dictionaries(
|
||||
{
|
||||
"max_depth": strategies.integers(0, 11),
|
||||
"max_leaves": strategies.integers(0, 256),
|
||||
"max_bin": strategies.integers(2, 1024),
|
||||
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]),
|
||||
"single_precision_histogram": strategies.booleans(),
|
||||
"min_child_weight": strategies.floats(0.5, 2.0),
|
||||
"seed": strategies.integers(0, 10),
|
||||
# We cannot enable subsampling as the training loss can increase
|
||||
# 'subsample': strategies.floats(0.5, 1.0),
|
||||
"colsample_bytree": strategies.floats(0.5, 1.0),
|
||||
"colsample_bylevel": strategies.floats(0.5, 1.0),
|
||||
}
|
||||
).filter(
|
||||
lambda x: (x["max_depth"] > 0 or x["max_leaves"] > 0)
|
||||
and (x["max_depth"] > 0 or x["grow_policy"] == "lossguide")
|
||||
)
|
||||
|
||||
|
||||
def train_result(param, dmat, num_rounds):
|
||||
result = {}
|
||||
xgb.train(
|
||||
param,
|
||||
dmat,
|
||||
num_rounds,
|
||||
[(dmat, "train")],
|
||||
verbose_eval=False,
|
||||
evals_result=result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TestSYCLUpdaters:
|
||||
@given(parameter_strategy, strategies.integers(1, 5), tm.make_dataset_strategy())
|
||||
@settings(deadline=None)
|
||||
def test_sycl_hist(self, param, num_rounds, dataset):
|
||||
param["tree_method"] = "hist"
|
||||
param["device"] = "sycl"
|
||||
param["verbosity"] = 0
|
||||
param = dataset.set_params(param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(result)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
@given(tm.make_dataset_strategy(), strategies.integers(0, 1))
|
||||
@settings(deadline=None)
|
||||
def test_specified_device_id_sycl_update(self, dataset, device_id):
|
||||
# Read the list of sycl-devicese
|
||||
sycl_ls = os.popen("sycl-ls").read()
|
||||
devices = sycl_ls.split("\n")
|
||||
|
||||
# Test should launch only on gpu
|
||||
# Find gpus in the list of devices
|
||||
# and use the id in the list insteard of device_id
|
||||
target_device_type = "opencl:gpu"
|
||||
found_devices = 0
|
||||
for idx in range(len(devices)):
|
||||
if len(devices[idx]) >= len(target_device_type):
|
||||
if devices[idx][1 : 1 + len(target_device_type)] == target_device_type:
|
||||
if found_devices == device_id:
|
||||
param = {"device": f"sycl:gpu:{idx}"}
|
||||
param = dataset.set_params(param)
|
||||
result = train_result(param, dataset.get_dmat(), 10)
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
else:
|
||||
found_devices += 1
|
||||
Loading…
x
Reference in New Issue
Block a user