[SYCL] Add splits evaluation (#10605)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
6d9fcb771e
commit
f6cae4da85
@ -139,6 +139,17 @@ class USMVector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Resize without keeping the data*/
|
||||||
|
void ResizeNoCopy(::sycl::queue* qu, size_t size_new) {
|
||||||
|
if (size_new <= capacity_) {
|
||||||
|
size_ = size_new;
|
||||||
|
} else {
|
||||||
|
size_ = size_new;
|
||||||
|
capacity_ = size_new;
|
||||||
|
data_ = allocate_memory_(qu, size_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Resize(::sycl::queue* qu, size_t size_new, T v) {
|
void Resize(::sycl::queue* qu, size_t size_new, T v) {
|
||||||
if (size_new <= size_) {
|
if (size_new <= size_) {
|
||||||
size_ = size_new;
|
size_ = size_new;
|
||||||
|
|||||||
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
#include <oneapi/dpl/random>
|
#include <oneapi/dpl/random>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../../src/collective/allreduce.h"
|
#include "../../src/collective/allreduce.h"
|
||||||
|
|
||||||
@ -14,6 +16,10 @@ namespace xgboost {
|
|||||||
namespace sycl {
|
namespace sycl {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
|
using ::sycl::ext::oneapi::plus;
|
||||||
|
using ::sycl::ext::oneapi::minimum;
|
||||||
|
using ::sycl::ext::oneapi::maximum;
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void HistUpdater<GradientSumT>::SetHistSynchronizer(
|
void HistUpdater<GradientSumT>::SetHistSynchronizer(
|
||||||
HistSynchronizer<GradientSumT> *sync) {
|
HistSynchronizer<GradientSumT> *sync) {
|
||||||
@ -126,6 +132,10 @@ void HistUpdater<GradientSumT>::InitData(
|
|||||||
builder_monitor_.Start("InitData");
|
builder_monitor_.Start("InitData");
|
||||||
const auto& info = fmat.Info();
|
const auto& info = fmat.Info();
|
||||||
|
|
||||||
|
if (!column_sampler_) {
|
||||||
|
column_sampler_ = xgboost::common::MakeColumnSampler(ctx_);
|
||||||
|
}
|
||||||
|
|
||||||
// initialize the row set
|
// initialize the row set
|
||||||
{
|
{
|
||||||
row_set_collection_.Clear();
|
row_set_collection_.Clear();
|
||||||
@ -213,6 +223,9 @@ void HistUpdater<GradientSumT>::InitData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
|
||||||
|
param_.colsample_bynode, param_.colsample_bylevel,
|
||||||
|
param_.colsample_bytree);
|
||||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||||
/* specialized code for dense data:
|
/* specialized code for dense data:
|
||||||
choose the column that has a least positive number of discrete bins.
|
choose the column that has a least positive number of discrete bins.
|
||||||
@ -309,6 +322,148 @@ void HistUpdater<GradientSumT>::InitNewNode(int nid,
|
|||||||
builder_monitor_.Stop("InitNewNode");
|
builder_monitor_.Stop("InitNewNode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nodes_set - set of nodes to be processed in parallel
|
||||||
|
template<typename GradientSumT>
|
||||||
|
void HistUpdater<GradientSumT>::EvaluateSplits(
|
||||||
|
const std::vector<ExpandEntry>& nodes_set,
|
||||||
|
const common::GHistIndexMatrix& gmat,
|
||||||
|
const RegTree& tree) {
|
||||||
|
builder_monitor_.Start("EvaluateSplits");
|
||||||
|
|
||||||
|
const size_t n_nodes_in_set = nodes_set.size();
|
||||||
|
|
||||||
|
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
|
||||||
|
|
||||||
|
// Generate feature set for each tree node
|
||||||
|
size_t pos = 0;
|
||||||
|
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
|
||||||
|
const bst_node_t nid = nodes_set[nid_in_set].nid;
|
||||||
|
FeatureSetType features_set = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
|
||||||
|
for (size_t idx = 0; idx < features_set->Size(); idx++) {
|
||||||
|
const size_t fid = features_set->ConstHostVector()[idx];
|
||||||
|
if (interaction_constraints_.Query(nid, fid)) {
|
||||||
|
auto this_hist = hist_[nid].DataConst();
|
||||||
|
if (pos < split_queries_host_.size()) {
|
||||||
|
split_queries_host_[pos] = SplitQuery{nid, fid, this_hist};
|
||||||
|
} else {
|
||||||
|
split_queries_host_.push_back({nid, fid, this_hist});
|
||||||
|
}
|
||||||
|
++pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const size_t total_features = pos;
|
||||||
|
|
||||||
|
split_queries_device_.Resize(&qu_, total_features);
|
||||||
|
auto event = qu_.memcpy(split_queries_device_.Data(), split_queries_host_.data(),
|
||||||
|
total_features * sizeof(SplitQuery));
|
||||||
|
|
||||||
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
SplitQuery* split_queries_device = split_queries_device_.Data();
|
||||||
|
const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst();
|
||||||
|
const bst_float* cut_val = gmat.cut_device.Values().DataConst();
|
||||||
|
const bst_float* cut_minval = gmat.cut_device.MinValues().DataConst();
|
||||||
|
|
||||||
|
snode_device_.ResizeNoCopy(&qu_, snode_host_.size());
|
||||||
|
event = qu_.memcpy(snode_device_.Data(), snode_host_.data(),
|
||||||
|
snode_host_.size() * sizeof(NodeEntry<GradientSumT>), event);
|
||||||
|
const NodeEntry<GradientSumT>* snode = snode_device_.Data();
|
||||||
|
|
||||||
|
const float min_child_weight = param_.min_child_weight;
|
||||||
|
|
||||||
|
best_splits_device_.ResizeNoCopy(&qu_, total_features);
|
||||||
|
if (best_splits_host_.size() < total_features) best_splits_host_.resize(total_features);
|
||||||
|
SplitEntry<GradientSumT>* best_splits = best_splits_device_.Data();
|
||||||
|
|
||||||
|
event = qu_.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(event);
|
||||||
|
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_),
|
||||||
|
::sycl::range<2>(1, sub_group_size_)),
|
||||||
|
[=](::sycl::nd_item<2> pid) {
|
||||||
|
int i = pid.get_global_id(0);
|
||||||
|
auto sg = pid.get_sub_group();
|
||||||
|
int nid = split_queries_device[i].nid;
|
||||||
|
int fid = split_queries_device[i].fid;
|
||||||
|
const GradientPairT* hist_data = split_queries_device[i].hist;
|
||||||
|
|
||||||
|
best_splits[i] = snode[nid].best;
|
||||||
|
EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid],
|
||||||
|
&(best_splits[i]), fid, nid, evaluator, min_child_weight);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
event = qu_.memcpy(best_splits_host_.data(), best_splits,
|
||||||
|
total_features * sizeof(SplitEntry<GradientSumT>), event);
|
||||||
|
|
||||||
|
qu_.wait();
|
||||||
|
for (size_t i = 0; i < total_features; i++) {
|
||||||
|
int nid = split_queries_host_[i].nid;
|
||||||
|
snode_host_[nid].best.Update(best_splits_host_[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
builder_monitor_.Stop("EvaluateSplits");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enumerate the split values of specific feature.
|
||||||
|
// Returns the sum of gradients corresponding to the data points that contains a non-missing value
|
||||||
|
// for the particular feature fid.
|
||||||
|
template <typename GradientSumT>
|
||||||
|
void HistUpdater<GradientSumT>::EnumerateSplit(
|
||||||
|
const ::sycl::sub_group& sg,
|
||||||
|
const uint32_t* cut_ptr,
|
||||||
|
const bst_float* cut_val,
|
||||||
|
const GradientPairT* hist_data,
|
||||||
|
const NodeEntry<GradientSumT>& snode,
|
||||||
|
SplitEntry<GradientSumT>* p_best,
|
||||||
|
bst_uint fid,
|
||||||
|
bst_uint nodeID,
|
||||||
|
typename TreeEvaluator<GradientSumT>::SplitEvaluator const &evaluator,
|
||||||
|
float min_child_weight) {
|
||||||
|
SplitEntry<GradientSumT> best;
|
||||||
|
|
||||||
|
int32_t ibegin = static_cast<int32_t>(cut_ptr[fid]);
|
||||||
|
int32_t iend = static_cast<int32_t>(cut_ptr[fid + 1]);
|
||||||
|
|
||||||
|
GradStats<GradientSumT> sum(0, 0);
|
||||||
|
|
||||||
|
int32_t sub_group_size = sg.get_local_range().size();
|
||||||
|
const size_t local_id = sg.get_local_id()[0];
|
||||||
|
|
||||||
|
/* TODO(razdoburdin)
|
||||||
|
* Currently the first additions are fast and the last are slow.
|
||||||
|
* Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster
|
||||||
|
*/
|
||||||
|
for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) {
|
||||||
|
sum.Add(::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()),
|
||||||
|
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()));
|
||||||
|
|
||||||
|
if (sum.GetHess() >= min_child_weight) {
|
||||||
|
GradStats<GradientSumT> c = snode.stats - sum;
|
||||||
|
if (c.GetHess() >= min_child_weight) {
|
||||||
|
bst_float loss_chg = evaluator.CalcSplitGain(nodeID, fid, sum, c) - snode.root_gain;
|
||||||
|
bst_float split_pt = cut_val[i];
|
||||||
|
best.Update(loss_chg, fid, split_pt, false, sum, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool last_iter = i + sub_group_size >= iend;
|
||||||
|
if (!last_iter) {
|
||||||
|
size_t end = i - local_id + sub_group_size;
|
||||||
|
if (end > iend) end = iend;
|
||||||
|
for (size_t j = i + 1; j < end; ++j) {
|
||||||
|
sum.Add(hist_data[j].GetGrad(), hist_data[j].GetHess());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float total_loss_chg = ::sycl::reduce_over_group(sg, best.loss_chg, maximum<>());
|
||||||
|
bst_feature_t total_split_index = ::sycl::reduce_over_group(sg,
|
||||||
|
best.loss_chg == total_loss_chg ?
|
||||||
|
best.SplitIndex() :
|
||||||
|
(1U << 31) - 1U, minimum<>());
|
||||||
|
if (best.loss_chg == total_loss_chg &&
|
||||||
|
best.SplitIndex() == total_split_index) p_best->Update(best);
|
||||||
|
}
|
||||||
|
|
||||||
template class HistUpdater<float>;
|
template class HistUpdater<float>;
|
||||||
template class HistUpdater<double>;
|
template class HistUpdater<double>;
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
#include "hist_synchronizer.h"
|
#include "hist_synchronizer.h"
|
||||||
#include "hist_row_adder.h"
|
#include "hist_row_adder.h"
|
||||||
|
|
||||||
|
#include "../../src/common/random.h"
|
||||||
#include "../data.h"
|
#include "../data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -62,6 +63,9 @@ class HistUpdater {
|
|||||||
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
||||||
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
|
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||||
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
|
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||||
|
if (param.max_depth > 0) {
|
||||||
|
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
|
||||||
|
}
|
||||||
const auto sub_group_sizes =
|
const auto sub_group_sizes =
|
||||||
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
|
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
|
||||||
sub_group_size_ = sub_group_sizes.back();
|
sub_group_size_ = sub_group_sizes.back();
|
||||||
@ -74,9 +78,28 @@ class HistUpdater {
|
|||||||
friend class BatchHistSynchronizer<GradientSumT>;
|
friend class BatchHistSynchronizer<GradientSumT>;
|
||||||
friend class BatchHistRowsAdder<GradientSumT>;
|
friend class BatchHistRowsAdder<GradientSumT>;
|
||||||
|
|
||||||
|
struct SplitQuery {
|
||||||
|
bst_node_t nid;
|
||||||
|
size_t fid;
|
||||||
|
const GradientPairT* hist;
|
||||||
|
};
|
||||||
|
|
||||||
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||||
USMVector<size_t, MemoryType::on_device>* row_indices);
|
USMVector<size_t, MemoryType::on_device>* row_indices);
|
||||||
|
|
||||||
|
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||||
|
const common::GHistIndexMatrix& gmat,
|
||||||
|
const RegTree& tree);
|
||||||
|
|
||||||
|
// Enumerate the split values of specific feature
|
||||||
|
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
||||||
|
// value for the particular feature fid.
|
||||||
|
static void EnumerateSplit(const ::sycl::sub_group& sg,
|
||||||
|
const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data,
|
||||||
|
const NodeEntry<GradientSumT> &snode, SplitEntry<GradientSumT>* p_best, bst_uint fid,
|
||||||
|
bst_uint nodeID,
|
||||||
|
typename TreeEvaluator<GradientSumT>::SplitEvaluator const &evaluator,
|
||||||
|
float min_child_weight);
|
||||||
|
|
||||||
void InitData(const common::GHistIndexMatrix& gmat,
|
void InitData(const common::GHistIndexMatrix& gmat,
|
||||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||||
@ -118,6 +141,14 @@ class HistUpdater {
|
|||||||
common::RowSetCollection row_set_collection_;
|
common::RowSetCollection row_set_collection_;
|
||||||
|
|
||||||
const xgboost::tree::TrainParam& param_;
|
const xgboost::tree::TrainParam& param_;
|
||||||
|
std::shared_ptr<xgboost::common::ColumnSampler> column_sampler_;
|
||||||
|
|
||||||
|
std::vector<SplitQuery> split_queries_host_;
|
||||||
|
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;
|
||||||
|
|
||||||
|
USMVector<SplitEntry<GradientSumT>, MemoryType::on_device> best_splits_device_;
|
||||||
|
std::vector<SplitEntry<GradientSumT>> best_splits_host_;
|
||||||
|
|
||||||
TreeEvaluator<GradientSumT> tree_evaluator_;
|
TreeEvaluator<GradientSumT> tree_evaluator_;
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
FeatureInteractionConstraintHost interaction_constraints_;
|
FeatureInteractionConstraintHost interaction_constraints_;
|
||||||
@ -137,6 +168,7 @@ class HistUpdater {
|
|||||||
|
|
||||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||||
std::vector<NodeEntry<GradientSumT>> snode_host_;
|
std::vector<NodeEntry<GradientSumT>> snode_host_;
|
||||||
|
USMVector<NodeEntry<GradientSumT>, MemoryType::on_device> snode_device_;
|
||||||
|
|
||||||
xgboost::common::Monitor builder_monitor_;
|
xgboost::common::Monitor builder_monitor_;
|
||||||
xgboost::common::Monitor kernel_monitor_;
|
xgboost::common::Monitor kernel_monitor_;
|
||||||
|
|||||||
@ -54,6 +54,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
|
|||||||
HistUpdater<GradientSumT>::InitNewNode(nid, gmat, gpair, fmat, tree);
|
HistUpdater<GradientSumT>::InitNewNode(nid, gmat, gpair, fmat, tree);
|
||||||
return HistUpdater<GradientSumT>::snode_host_[nid];
|
return HistUpdater<GradientSumT>::snode_host_[nid];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto TestEvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||||
|
const common::GHistIndexMatrix& gmat,
|
||||||
|
const RegTree& tree) {
|
||||||
|
HistUpdater<GradientSumT>::EvaluateSplits(nodes_set, gmat, tree);
|
||||||
|
return HistUpdater<GradientSumT>::snode_host_;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
|
void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
|
||||||
@ -307,6 +314,84 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
|
|||||||
EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess());
|
EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
|
void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
|
||||||
|
const size_t num_rows = 1u << 8;
|
||||||
|
const size_t num_columns = 2;
|
||||||
|
const size_t n_bins = 32;
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||||
|
|
||||||
|
DeviceManager device_manager;
|
||||||
|
auto qu = device_manager.GetQueue(ctx.Device());
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
|
||||||
|
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();
|
||||||
|
|
||||||
|
FeatureInteractionConstraintHost int_constraints;
|
||||||
|
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||||
|
|
||||||
|
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||||
|
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||||
|
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||||
|
|
||||||
|
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||||
|
auto* gpair_ptr = gpair.Data();
|
||||||
|
GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false);
|
||||||
|
|
||||||
|
DeviceMatrix dmat;
|
||||||
|
dmat.Init(qu, p_fmat.get());
|
||||||
|
common::GHistIndexMatrix gmat;
|
||||||
|
gmat.Init(qu, &ctx, dmat, n_bins);
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid));
|
||||||
|
|
||||||
|
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
|
||||||
|
auto& row_idxs = row_set_collection->Data();
|
||||||
|
const size_t* row_idxs_ptr = row_idxs.DataConst();
|
||||||
|
const auto* hist = updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair);
|
||||||
|
const auto snode_init = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree);
|
||||||
|
|
||||||
|
const auto snode_updated = updater.TestEvaluateSplits({node}, gmat, tree);
|
||||||
|
auto best_loss_chg = snode_updated[0].best.loss_chg;
|
||||||
|
auto stats = snode_init.stats;
|
||||||
|
auto root_gain = snode_init.root_gain;
|
||||||
|
|
||||||
|
// Check all splits manually. Save the best one and compare with the ans
|
||||||
|
TreeEvaluator<GradientSumT> tree_evaluator(qu, param, num_columns);
|
||||||
|
auto evaluator = tree_evaluator.GetEvaluator();
|
||||||
|
const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst();
|
||||||
|
const size_t size = gmat.cut_device.Ptrs().Size();
|
||||||
|
int n_better_splits = 0;
|
||||||
|
const auto* hist_ptr = (*hist)[0].DataConst();
|
||||||
|
std::vector<bst_float> best_loss_chg_des(1, -1);
|
||||||
|
{
|
||||||
|
::sycl::buffer<bst_float> best_loss_chg_buff(best_loss_chg_des.data(), 1);
|
||||||
|
qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
auto best_loss_chg_acc = best_loss_chg_buff.template get_access<::sycl::access::mode::read_write>(cgh);
|
||||||
|
cgh.single_task<>([=]() {
|
||||||
|
for (size_t i = 1; i < size; ++i) {
|
||||||
|
GradStats<GradientSumT> left(0, 0);
|
||||||
|
GradStats<GradientSumT> right = stats - left;
|
||||||
|
for (size_t j = cut_ptr[i-1]; j < cut_ptr[i]; ++j) {
|
||||||
|
auto loss_change = evaluator.CalcSplitGain(0, i - 1, left, right) - root_gain;
|
||||||
|
if (loss_change > best_loss_chg_acc[0]) {
|
||||||
|
best_loss_chg_acc[0] = loss_change;
|
||||||
|
}
|
||||||
|
left.Add(hist_ptr[j].GetGrad(), hist_ptr[j].GetHess());
|
||||||
|
right = stats - left;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}).wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SyclHistUpdater, Sampling) {
|
TEST(SyclHistUpdater, Sampling) {
|
||||||
xgboost::tree::TrainParam param;
|
xgboost::tree::TrainParam param;
|
||||||
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
|
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
|
||||||
@ -346,4 +431,12 @@ TEST(SyclHistUpdater, InitNewNode) {
|
|||||||
TestHistUpdaterInitNewNode<double>(param, 0.5);
|
TestHistUpdaterInitNewNode<double>(param, 0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SyclHistUpdater, EvaluateSplits) {
|
||||||
|
xgboost::tree::TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});
|
||||||
|
|
||||||
|
TestHistUpdaterEvaluateSplits<float>(param);
|
||||||
|
TestHistUpdaterEvaluateSplits<double>(param);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xgboost::sycl::tree
|
} // namespace xgboost::sycl::tree
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user