diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index 7a53d8d1f..30bcef5ba 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -9,6 +9,8 @@ #include +#include "../../src/tree/common_row_partitioner.h" + #include "../common/hist_util.h" #include "../../src/collective/allreduce.h" @@ -250,6 +252,55 @@ void HistUpdater::InitData( builder_monitor_.Stop("InitData"); } +template +void HistUpdater::AddSplitsToRowSet( + const std::vector& nodes, + RegTree* p_tree) { + const size_t n_nodes = nodes.size(); + for (size_t i = 0; i < n_nodes; ++i) { + const int32_t nid = nodes[i].nid; + const size_t n_left = partition_builder_.GetNLeftElems(i); + const size_t n_right = partition_builder_.GetNRightElems(i); + + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), + (*p_tree)[nid].RightChild(), n_left, n_right); + } +} + +template +void HistUpdater::ApplySplit( + const std::vector nodes, + const common::GHistIndexMatrix& gmat, + RegTree* p_tree) { + using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner; + builder_monitor_.Start("ApplySplit"); + + const size_t n_nodes = nodes.size(); + std::vector split_conditions(n_nodes); + CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + + partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + return row_set_collection_[nid].Size(); + }); + + ::sycl::event event; + partition_builder_.Partition(gmat, nodes, row_set_collection_, + split_conditions, p_tree, &event); + qu_.wait_and_throw(); + + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + size_t* data_result = const_cast(row_set_collection_[nid].begin); + partition_builder_.MergeToArray(node_in_set, data_result, &event); + } + qu_.wait_and_throw(); + + AddSplitsToRowSet(nodes, p_tree); + + builder_monitor_.Stop("ApplySplit"); +} + template void HistUpdater::InitNewNode(int nid, const common::GHistIndexMatrix& gmat, diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 4515b24a1..e4df09777 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -101,6 +101,12 @@ class HistUpdater { typename TreeEvaluator::SplitEvaluator const &evaluator, float min_child_weight); + void ApplySplit(std::vector nodes, + const common::GHistIndexMatrix& gmat, + RegTree* p_tree); + + void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); + void InitData(const common::GHistIndexMatrix& gmat, const USMVector &gpair, const DMatrix& fmat, @@ -179,6 +185,8 @@ class HistUpdater { uint64_t seed_ = 0; + common::PartitionBuilder partition_builder_; + // key is the node id which should be calculated by Subtraction Trick, value is the node which // provides the evidence for substracts std::vector nodes_for_subtraction_trick_; diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index ff75000df..159be768c 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -144,9 +144,10 @@ class CommonRowPartitioner { } } - template - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrix& gmat, std::vector* split_conditions) { + /* Making GHistIndexMatrix_t a templete parameter allows reuse this function for sycl-plugin */ + template + static void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrix_t& gmat, std::vector* split_conditions) { auto const& ptrs = gmat.cut.Ptrs(); auto const& vals = gmat.cut.Values(); diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index 325769fe8..19a739308 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -8,6 +8,8 @@ #include "../../../plugin/sycl/tree/hist_updater.h" #include "../../../plugin/sycl/device_manager.h" +#include "../../../src/tree/common_row_partitioner.h" + #include "../helpers.h" namespace xgboost::sycl::tree { @@ -61,6 +63,12 @@ class TestHistUpdater : public HistUpdater { HistUpdater::EvaluateSplits(nodes_set, gmat, tree); return HistUpdater::snode_host_; } + + void TestApplySplit(const std::vector nodes, + const common::GHistIndexMatrix& gmat, + RegTree* p_tree) { + HistUpdater::ApplySplit(nodes, gmat, p_tree); + } }; void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { @@ -131,7 +139,6 @@ void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) { ASSERT_NE(num_diffs, 0); } - } template @@ -392,6 +399,95 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) { ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6); } +template +void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float sparsity, int max_bins) { + const size_t num_rows = 1024; + const size_t num_columns = 2; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, max_bins); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + + std::vector nodes; + nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0))); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + USMVector gpair(&qu, num_rows); + GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false); + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + updater.TestApplySplit(nodes, gmat, &tree); + + // Copy indexes to host + std::vector row_indices_host(num_rows); + qu.memcpy(row_indices_host.data(), row_set_collection->Data().Data(), sizeof(size_t)*num_rows).wait(); + + // Reference Implementation + std::vector row_indices_desired_host(num_rows); + size_t n_left, n_right; + { + std::unique_ptr pruner4verification{TreeUpdater::Create("prune", &ctx, &task)}; + TestHistUpdater updater4verification(&ctx, qu, param, std::move(pruner4verification), int_constraints, p_fmat.get()); + auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree); + + size_t n_nodes = nodes.size(); + std::vector split_conditions(n_nodes); + xgboost::tree::CommonRowPartitioner::FindSplitConditions(nodes, tree, gmat, &split_conditions); + + common::PartitionBuilder partition_builder; + partition_builder.Init(&qu, n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + return (*row_set_collection4verification)[nid].Size(); + }); + + ::sycl::event event; + partition_builder.Partition(gmat, nodes, (*row_set_collection4verification), + split_conditions, &tree, &event); + qu.wait_and_throw(); + + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + size_t* data_result = const_cast((*row_set_collection4verification)[nid].begin); + partition_builder.MergeToArray(node_in_set, data_result, &event); + } + qu.wait_and_throw(); + + const int32_t nid = nodes[0].nid; + n_left = partition_builder.GetNLeftElems(0); + n_right = partition_builder.GetNRightElems(0); + + row_set_collection4verification->AddSplit(nid, tree[nid].LeftChild(), + tree[nid].RightChild(), n_left, n_right); + + qu.memcpy(row_indices_desired_host.data(), row_set_collection4verification->Data().Data(), sizeof(size_t)*num_rows).wait(); + } + + std::sort(row_indices_desired_host.begin(), row_indices_desired_host.begin() + n_left); + std::sort(row_indices_host.begin(), row_indices_host.begin() + n_left); + std::sort(row_indices_desired_host.begin() + n_left, row_indices_desired_host.end()); + std::sort(row_indices_host.begin() + n_left, row_indices_host.end()); + + for (size_t row = 0; row < num_rows; ++row) { + ASSERT_EQ(row_indices_desired_host[row], row_indices_host[row]); + } +} + TEST(SyclHistUpdater, Sampling) { xgboost::tree::TrainParam param; param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); @@ -439,4 +535,24 @@ TEST(SyclHistUpdater, EvaluateSplits) { TestHistUpdaterEvaluateSplits(param); } +TEST(SyclHistUpdater, ApplySplitSparce) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterApplySplit(param, 0.3, 256); + TestHistUpdaterApplySplit(param, 0.3, 256); +} + +TEST(SyclHistUpdater, ApplySplitDence) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterApplySplit(param, 0.0, 256); + TestHistUpdaterApplySplit(param, 0.0, 256+1); + TestHistUpdaterApplySplit(param, 0.0, (1u << 16) + 1); + TestHistUpdaterApplySplit(param, 0.0, 256); + TestHistUpdaterApplySplit(param, 0.0, 256+1); + TestHistUpdaterApplySplit(param, 0.0, (1u << 16) + 1); +} + } // namespace xgboost::sycl::tree