[SYCL] Add nodes initialisation (#10269)
--------- Co-authored-by: Dmitry Razdoburdin <> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
7a54ca41c9
commit
c7e7ce7569
@@ -2,11 +2,6 @@
|
||||
* Copyright 2017-2023 by Contributors
|
||||
* \file device_manager.cc
|
||||
*/
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../sycl/device_manager.h"
|
||||
|
||||
#include "../../src/collective/communicator-inl.h"
|
||||
|
||||
@@ -12,7 +12,11 @@
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include "xgboost/context.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
|
||||
@@ -3,19 +3,15 @@
|
||||
* \file multiclass_obj.cc
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
*/
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../../src/common/math.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <oneapi/dpl/random>
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
#include "../../src/collective/allreduce.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
@@ -111,7 +112,6 @@ void HistUpdater<GradientSumT>::InitSampling(
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitData(
|
||||
Context const * ctx,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
@@ -215,6 +215,101 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
data_layout_ = kSparseData;
|
||||
}
|
||||
}
|
||||
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
/* specialized code for dense data:
|
||||
choose the column that has a least positive number of discrete bins.
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid] */
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
||||
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
|
||||
uint32_t min_nbins_per_feature = 0;
|
||||
for (bst_uint i = 0; i < nfeature; ++i) {
|
||||
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
|
||||
if (nbins > 0) {
|
||||
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
|
||||
min_nbins_per_feature = nbins;
|
||||
fid_least_bins_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_GT(min_nbins_per_feature, 0U);
|
||||
}
|
||||
|
||||
std::fill(snode_host_.begin(), snode_host_.end(), NodeEntry<GradientSumT>(param_));
|
||||
builder_monitor_.Stop("InitData");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitNewNode(int nid,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair,
|
||||
MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
builder_monitor_.Start("InitNewNode");
|
||||
|
||||
snode_host_.resize(tree.NumNodes(), NodeEntry<GradientSumT>(param_));
|
||||
{
|
||||
if (tree[nid].IsRoot()) {
|
||||
GradStats<GradientSumT> grad_stat;
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
const auto* hist = reinterpret_cast<GradStats<GradientSumT>*>(hist_[nid].Data());
|
||||
|
||||
std::vector<GradStats<GradientSumT>> ets(iend - ibegin);
|
||||
qu_.memcpy(ets.data(), hist + ibegin,
|
||||
(iend - ibegin) * sizeof(GradStats<GradientSumT>)).wait_and_throw();
|
||||
for (const auto& et : ets) {
|
||||
grad_stat += et;
|
||||
}
|
||||
} else {
|
||||
const common::RowSetCollection::Elem e = row_set_collection_[nid];
|
||||
const size_t* row_idxs = e.begin;
|
||||
const size_t size = e.Size();
|
||||
const GradientPair* gpair_ptr = gpair.DataConst();
|
||||
|
||||
::sycl::buffer<GradStats<GradientSumT>> buff(&grad_stat, 1);
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
|
||||
cgh.parallel_for<>(::sycl::range<1>(size), reduction,
|
||||
[=](::sycl::item<1> pid, auto& sum) {
|
||||
size_t i = pid.get_id(0);
|
||||
size_t row_idx = row_idxs[i];
|
||||
if constexpr (std::is_same<GradientPair::ValueT, GradientSumT>::value) {
|
||||
sum += gpair_ptr[row_idx];
|
||||
} else {
|
||||
sum += GradStats<GradientSumT>(gpair_ptr[row_idx].GetGrad(),
|
||||
gpair_ptr[row_idx].GetHess());
|
||||
}
|
||||
});
|
||||
}).wait_and_throw();
|
||||
}
|
||||
auto rc = collective::Allreduce(
|
||||
ctx_, linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
|
||||
collective::Op::kSum);
|
||||
SafeColl(rc);
|
||||
snode_host_[nid].stats = grad_stat;
|
||||
} else {
|
||||
int parent_id = tree[nid].Parent();
|
||||
if (tree[nid].IsLeftChild()) {
|
||||
snode_host_[nid].stats = snode_host_[parent_id].best.left_sum;
|
||||
} else {
|
||||
snode_host_[nid].stats = snode_host_[parent_id].best.right_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculating the weights
|
||||
{
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
bst_uint parentid = tree[nid].Parent();
|
||||
snode_host_[nid].weight = evaluator.CalcWeight(parentid, snode_host_[nid].stats);
|
||||
snode_host_[nid].root_gain = evaluator.CalcGain(parentid, snode_host_[nid].stats);
|
||||
}
|
||||
builder_monitor_.Stop("InitNewNode");
|
||||
}
|
||||
|
||||
template class HistUpdater<float>;
|
||||
|
||||
@@ -26,6 +26,22 @@ namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
// data structure
|
||||
template<typename GradType>
|
||||
struct NodeEntry {
|
||||
/*! \brief statics for node entry */
|
||||
GradStats<GradType> stats;
|
||||
/*! \brief loss of this node, without split */
|
||||
GradType root_gain;
|
||||
/*! \brief weight calculated related to current data */
|
||||
GradType weight;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry<GradType> best;
|
||||
// constructor
|
||||
explicit NodeEntry(const xgboost::tree::TrainParam& param)
|
||||
: root_gain(0.0f), weight(0.0f) {}
|
||||
};
|
||||
|
||||
template<typename GradientSumT>
|
||||
class HistUpdater {
|
||||
public:
|
||||
@@ -33,12 +49,13 @@ class HistUpdater {
|
||||
using GHistRowT = common::GHistRow<GradientSumT, memory_type>;
|
||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||
|
||||
explicit HistUpdater(::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: qu_(qu), param_(param),
|
||||
explicit HistUpdater(const Context* ctx,
|
||||
::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: ctx_(ctx), qu_(qu), param_(param),
|
||||
tree_evaluator_(qu, param, fmat->Info().num_col_),
|
||||
pruner_(std::move(pruner)),
|
||||
interaction_constraints_{std::move(int_constraints_)},
|
||||
@@ -61,8 +78,7 @@ class HistUpdater {
|
||||
USMVector<size_t, MemoryType::on_device>* row_indices);
|
||||
|
||||
|
||||
void InitData(Context const * ctx,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
void InitData(const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
@@ -78,6 +94,12 @@ class HistUpdater {
|
||||
data_layout_ != kSparseData, hist_buffer, event_priv);
|
||||
}
|
||||
|
||||
void InitNewNode(int nid,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
void BuildLocalHistograms(const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair);
|
||||
@@ -89,6 +111,7 @@ class HistUpdater {
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair);
|
||||
|
||||
// --data fields--
|
||||
const Context* ctx_;
|
||||
size_t sub_group_size_;
|
||||
|
||||
// the internal row sets
|
||||
@@ -113,9 +136,16 @@ class HistUpdater {
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
common::HistCollection<GradientSumT, MemoryType::on_device> hist_;
|
||||
|
||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||
std::vector<NodeEntry<GradientSumT>> snode_host_;
|
||||
|
||||
xgboost::common::Monitor builder_monitor_;
|
||||
xgboost::common::Monitor kernel_monitor_;
|
||||
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
uint32_t fid_least_bins_;
|
||||
|
||||
uint64_t seed_ = 0;
|
||||
|
||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||
|
||||
@@ -49,6 +49,115 @@ struct TrainParam {
|
||||
template <typename GradType>
|
||||
using GradStats = xgboost::detail::GradientPairInternal<GradType>;
|
||||
|
||||
/*!
|
||||
* \brief SYCL implementation of SplitEntryContainer for device compilation.
|
||||
* Original structure cannot be used due 'cat_bits' field of type std::vector<uint32_t>,
|
||||
* which is not device-copyable
|
||||
*/
|
||||
template<typename GradientT>
|
||||
struct SplitEntryContainer {
|
||||
/*! \brief loss change after split this node */
|
||||
bst_float loss_chg {0.0f};
|
||||
/*! \brief split index */
|
||||
bst_feature_t sindex{0};
|
||||
bst_float split_value{0.0f};
|
||||
|
||||
|
||||
GradientT left_sum;
|
||||
GradientT right_sum;
|
||||
|
||||
|
||||
SplitEntryContainer() = default;
|
||||
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
|
||||
os << "loss_chg: " << s.loss_chg << ", "
|
||||
<< "split index: " << s.SplitIndex() << ", "
|
||||
<< "split value: " << s.split_value << ", "
|
||||
<< "left_sum: " << s.left_sum << ", "
|
||||
<< "right_sum: " << s.right_sum;
|
||||
return os;
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||
/*!\return whether missing value goes to left branch */
|
||||
bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||
/*!
|
||||
* \brief decides whether we can replace current entry with the given statistics
|
||||
*
|
||||
* This function gives better priority to lower index when loss_chg == new_loss_chg.
|
||||
* Not the best way, but helps to give consistent result during multi-thread
|
||||
* execution.
|
||||
*
|
||||
* \param new_loss_chg the loss reduction get through the split
|
||||
* \param split_index the feature index where the split is on
|
||||
*/
|
||||
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
if (::sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
|
||||
// for example when lambda = 0 & min_child_weight = 0
|
||||
// skip value in this case
|
||||
return false;
|
||||
} else if (this->SplitIndex() <= split_index) {
|
||||
return new_loss_chg > this->loss_chg;
|
||||
} else {
|
||||
return !(this->loss_chg > new_loss_chg);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief update the split entry, replace it if e is better
|
||||
* \param e candidate split solution
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
inline bool Update(const SplitEntryContainer &e) {
|
||||
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
this->left_sum = e.left_sum;
|
||||
this->right_sum = e.right_sum;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief update the split entry, replace it if e is better
|
||||
* \param new_loss_chg loss reduction of new candidate
|
||||
* \param split_index feature index to split on
|
||||
* \param new_split_value the split point
|
||||
* \param default_left whether the missing value goes to left
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||
bst_float new_split_value, bool default_left,
|
||||
const GradientT &left_sum,
|
||||
const GradientT &right_sum) {
|
||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||
this->loss_chg = new_loss_chg;
|
||||
if (default_left) {
|
||||
split_index |= (1U << 31);
|
||||
}
|
||||
this->sindex = split_index;
|
||||
this->split_value = new_split_value;
|
||||
this->left_sum = left_sum;
|
||||
this->right_sum = right_sum;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/*! \brief same as update, used by AllReduce*/
|
||||
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
|
||||
const SplitEntryContainer &src) { // NOLINT(*)
|
||||
dst.Update(src);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename GradType>
|
||||
using SplitEntry = SplitEntryContainer<GradStats<GradType>>;
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user