[sycl] add split applications and tests (#10636)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
384983ed27
commit
7720272870
@@ -9,6 +9,8 @@
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "../../src/tree/common_row_partitioner.h"
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
#include "../../src/collective/allreduce.h"
|
||||
|
||||
@@ -250,6 +252,55 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
builder_monitor_.Stop("InitData");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
RegTree* p_tree) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
||||
const size_t n_right = partition_builder_.GetNRightElems(i);
|
||||
|
||||
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(),
|
||||
(*p_tree)[nid].RightChild(), n_left, n_right);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::ApplySplit(
|
||||
const std::vector<ExpandEntry> nodes,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
RegTree* p_tree) {
|
||||
using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
|
||||
builder_monitor_.Start("ApplySplit");
|
||||
|
||||
const size_t n_nodes = nodes.size();
|
||||
std::vector<int32_t> split_conditions(n_nodes);
|
||||
CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
||||
|
||||
partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
return row_set_collection_[nid].Size();
|
||||
});
|
||||
|
||||
::sycl::event event;
|
||||
partition_builder_.Partition(gmat, nodes, row_set_collection_,
|
||||
split_conditions, p_tree, &event);
|
||||
qu_.wait_and_throw();
|
||||
|
||||
for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
size_t* data_result = const_cast<size_t*>(row_set_collection_[nid].begin);
|
||||
partition_builder_.MergeToArray(node_in_set, data_result, &event);
|
||||
}
|
||||
qu_.wait_and_throw();
|
||||
|
||||
AddSplitsToRowSet(nodes, p_tree);
|
||||
|
||||
builder_monitor_.Stop("ApplySplit");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitNewNode(int nid,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
|
||||
@@ -101,6 +101,12 @@ class HistUpdater {
|
||||
typename TreeEvaluator<GradientSumT>::SplitEvaluator const &evaluator,
|
||||
float min_child_weight);
|
||||
|
||||
void ApplySplit(std::vector<ExpandEntry> nodes,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
RegTree* p_tree);
|
||||
|
||||
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);
|
||||
|
||||
void InitData(const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
@@ -179,6 +185,8 @@ class HistUpdater {
|
||||
|
||||
uint64_t seed_ = 0;
|
||||
|
||||
common::PartitionBuilder partition_builder_;
|
||||
|
||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||
// provides the evidence for substracts
|
||||
std::vector<ExpandEntry> nodes_for_subtraction_trick_;
|
||||
|
||||
Reference in New Issue
Block a user