[sycl] add split applications and tests (#10636)

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2024-07-26 09:25:49 +02:00
committed by GitHub
parent 384983ed27
commit 7720272870
4 changed files with 180 additions and 4 deletions

View File

@@ -9,6 +9,8 @@
#include <functional>
#include "../../src/tree/common_row_partitioner.h"
#include "../common/hist_util.h"
#include "../../src/collective/allreduce.h"
@@ -250,6 +252,55 @@ void HistUpdater<GradientSumT>::InitData(
builder_monitor_.Stop("InitData");
}
template <typename GradientSumT>
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
const std::vector<ExpandEntry>& nodes,
RegTree* p_tree) {
const size_t n_nodes = nodes.size();
for (size_t i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i);
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(),
(*p_tree)[nid].RightChild(), n_left, n_right);
}
}
template <typename GradientSumT>
void HistUpdater<GradientSumT>::ApplySplit(
const std::vector<ExpandEntry> nodes,
const common::GHistIndexMatrix& gmat,
RegTree* p_tree) {
using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
builder_monitor_.Start("ApplySplit");
const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions(n_nodes);
CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size();
});
::sycl::event event;
partition_builder_.Partition(gmat, nodes, row_set_collection_,
split_conditions, p_tree, &event);
qu_.wait_and_throw();
for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) {
const int32_t nid = nodes[node_in_set].nid;
size_t* data_result = const_cast<size_t*>(row_set_collection_[nid].begin);
partition_builder_.MergeToArray(node_in_set, data_result, &event);
}
qu_.wait_and_throw();
AddSplitsToRowSet(nodes, p_tree);
builder_monitor_.Stop("ApplySplit");
}
template <typename GradientSumT>
void HistUpdater<GradientSumT>::InitNewNode(int nid,
const common::GHistIndexMatrix& gmat,

View File

@@ -101,6 +101,12 @@ class HistUpdater {
typename TreeEvaluator<GradientSumT>::SplitEvaluator const &evaluator,
float min_child_weight);
void ApplySplit(std::vector<ExpandEntry> nodes,
const common::GHistIndexMatrix& gmat,
RegTree* p_tree);
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);
void InitData(const common::GHistIndexMatrix& gmat,
const USMVector<GradientPair, MemoryType::on_device> &gpair,
const DMatrix& fmat,
@@ -179,6 +185,8 @@ class HistUpdater {
uint64_t seed_ = 0;
common::PartitionBuilder partition_builder_;
// key is the node id which should be calculated by Subtraction Trick, value is the node which
// provides the evidence for substracts
std::vector<ExpandEntry> nodes_for_subtraction_trick_;