Dmitry Razdoburdin c7e7ce7569
[SYCL] Add nodes initialisation (#10269)
---------

Co-authored-by: Dmitry Razdoburdin <>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
2024-05-21 23:38:52 +08:00

165 lines
5.0 KiB
C++

/*!
* Copyright 2014-2024 by Contributors
*/
#ifndef PLUGIN_SYCL_TREE_PARAM_H_
#define PLUGIN_SYCL_TREE_PARAM_H_
#include <cmath>
#include <cstring>
#include <limits>
#include <string>
#include <vector>
#include "xgboost/parameter.h"
#include "xgboost/data.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "../src/tree/param.h"
#pragma GCC diagnostic pop
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace tree {
/*! \brief Wrapper for necessary training parameters for regression tree to access on device */
/* The original structure xgboost::tree::TrainParam can't be used,
* since std::vector are not copyable on sycl-devices.
*/
struct TrainParam {
float min_child_weight;
float reg_lambda;
float reg_alpha;
float max_delta_step;
TrainParam() {}
explicit TrainParam(const xgboost::tree::TrainParam& param) {
reg_lambda = param.reg_lambda;
reg_alpha = param.reg_alpha;
min_child_weight = param.min_child_weight;
max_delta_step = param.max_delta_step;
}
};
template <typename GradType>
using GradStats = xgboost::detail::GradientPairInternal<GradType>;
/*!
* \brief SYCL implementation of SplitEntryContainer for device compilation.
* Original structure cannot be used due 'cat_bits' field of type std::vector<uint32_t>,
* which is not device-copyable
*/
template<typename GradientT>
struct SplitEntryContainer {
/*! \brief loss change after split this node */
bst_float loss_chg {0.0f};
/*! \brief split index */
bst_feature_t sindex{0};
bst_float split_value{0.0f};
GradientT left_sum;
GradientT right_sum;
SplitEntryContainer() = default;
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
os << "loss_chg: " << s.loss_chg << ", "
<< "split index: " << s.SplitIndex() << ", "
<< "split value: " << s.split_value << ", "
<< "left_sum: " << s.left_sum << ", "
<< "right_sum: " << s.right_sum;
return os;
}
/*!\return feature index to split on */
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */
bool DefaultLeft() const { return (sindex >> 31) != 0; }
/*!
* \brief decides whether we can replace current entry with the given statistics
*
* This function gives better priority to lower index when loss_chg == new_loss_chg.
* Not the best way, but helps to give consistent result during multi-thread
* execution.
*
* \param new_loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on
*/
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
if (::sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
// for example when lambda = 0 & min_child_weight = 0
// skip value in this case
return false;
} else if (this->SplitIndex() <= split_index) {
return new_loss_chg > this->loss_chg;
} else {
return !(this->loss_chg > new_loss_chg);
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param e candidate split solution
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(const SplitEntryContainer &e) {
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
this->left_sum = e.left_sum;
this->right_sum = e.right_sum;
return true;
} else {
return false;
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param new_loss_chg loss reduction of new candidate
* \param split_index feature index to split on
* \param new_split_value the split point
* \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split
*/
bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left,
const GradientT &left_sum,
const GradientT &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
split_index |= (1U << 31);
}
this->sindex = split_index;
this->split_value = new_split_value;
this->left_sum = left_sum;
this->right_sum = right_sum;
return true;
} else {
return false;
}
}
/*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
const SplitEntryContainer &src) { // NOLINT(*)
dst.Update(src);
}
};
template<typename GradType>
using SplitEntry = SplitEntryContainer<GradStats<GradType>>;
} // namespace tree
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_TREE_PARAM_H_