/*! * Copyright 2017-2024 by Contributors * \file hist_updater.h */ #ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_ #define PLUGIN_SYCL_TREE_HIST_UPDATER_H_ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" #include #pragma GCC diagnostic pop #include #include #include #include #include "../common/partition_builder.h" #include "split_evaluator.h" #include "hist_synchronizer.h" #include "hist_row_adder.h" #include "../../src/common/random.h" #include "../data.h" namespace xgboost { namespace sycl { namespace tree { // data structure template struct NodeEntry { /*! \brief statics for node entry */ GradStats stats; /*! \brief loss of this node, without split */ GradType root_gain; /*! \brief weight calculated related to current data */ GradType weight; /*! \brief current best solution */ SplitEntry best; // constructor explicit NodeEntry(const xgboost::tree::TrainParam& param) : root_gain(0.0f), weight(0.0f) {} }; template class HistUpdater { public: template using GHistRowT = common::GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; explicit HistUpdater(const Context* ctx, ::sycl::queue qu, const xgboost::tree::TrainParam& param, std::unique_ptr pruner, FeatureInteractionConstraintHost int_constraints_, DMatrix const* fmat) : ctx_(ctx), qu_(qu), param_(param), tree_evaluator_(qu, param, fmat->Info().num_col_), pruner_(std::move(pruner)), interaction_constraints_{std::move(int_constraints_)}, p_last_tree_(nullptr), p_last_fmat_(fmat) { builder_monitor_.Init("SYCL::Quantile::HistUpdater"); kernel_monitor_.Init("SYCL::Quantile::HistUpdater"); if (param.max_depth > 0) { snode_device_.Resize(&qu, 1u << (param.max_depth + 1)); } const auto sub_group_sizes = qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); sub_group_size_ = sub_group_sizes.back(); } void SetHistSynchronizer(HistSynchronizer* sync); void SetHistRowsAdder(HistRowsAdder* adder); protected: friend class BatchHistSynchronizer; friend class BatchHistRowsAdder; struct SplitQuery { bst_node_t nid; size_t fid; const GradientPairT* hist; }; void InitSampling(const USMVector &gpair, USMVector* row_indices); void EvaluateSplits(const std::vector& nodes_set, const common::GHistIndexMatrix& gmat, const RegTree& tree); // Enumerate the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains a non-missing // value for the particular feature fid. static void EnumerateSplit(const ::sycl::sub_group& sg, const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data, const NodeEntry &snode, SplitEntry* p_best, bst_uint fid, bst_uint nodeID, typename TreeEvaluator::SplitEvaluator const &evaluator, float min_child_weight); void ApplySplit(std::vector nodes, const common::GHistIndexMatrix& gmat, RegTree* p_tree); void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); void InitData(const common::GHistIndexMatrix& gmat, const USMVector &gpair, const DMatrix& fmat, const RegTree& tree); inline ::sycl::event BuildHist( const USMVector& gpair_device, const common::RowSetCollection::Elem row_indices, const common::GHistIndexMatrix& gmat, GHistRowT* hist, GHistRowT* hist_buffer, ::sycl::event event_priv) { return hist_builder_.BuildHist(gpair_device, row_indices, gmat, hist, data_layout_ != kSparseData, hist_buffer, event_priv); } void InitNewNode(int nid, const common::GHistIndexMatrix& gmat, const USMVector &gpair, const RegTree& tree); void BuildLocalHistograms(const common::GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair); void BuildHistogramsLossGuide( ExpandEntry entry, const common::GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair); void ExpandWithLossGuide(const common::GHistIndexMatrix& gmat, RegTree* p_tree, const USMVector& gpair); inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { if (lhs.GetLossChange() == rhs.GetLossChange()) { return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp } else { return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg } } // --data fields-- const Context* ctx_; size_t sub_group_size_; // the internal row sets common::RowSetCollection row_set_collection_; const xgboost::tree::TrainParam& param_; std::shared_ptr column_sampler_; std::vector split_queries_host_; USMVector split_queries_device_; USMVector, MemoryType::on_device> best_splits_device_; std::vector> best_splits_host_; TreeEvaluator tree_evaluator_; std::unique_ptr pruner_; FeatureInteractionConstraintHost interaction_constraints_; // back pointers to tree and data matrix const RegTree* p_last_tree_; DMatrix const* const p_last_fmat_; using ExpandQueue = std::priority_queue, std::function>; std::unique_ptr qexpand_loss_guided_; enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; constexpr static size_t kBufferSize = 2048; common::GHistBuilder hist_builder_; common::ParallelGHistBuilder hist_buffer_; /*! \brief culmulative histogram of gradients. */ common::HistCollection hist_; /*! \brief TreeNode Data: statistics for each constructed node */ std::vector> snode_host_; USMVector, MemoryType::on_device> snode_device_; xgboost::common::Monitor builder_monitor_; xgboost::common::Monitor kernel_monitor_; /*! \brief feature with least # of bins. to be used for dense specialization of InitNewNode() */ uint32_t fid_least_bins_; uint64_t seed_ = 0; common::PartitionBuilder partition_builder_; // key is the node id which should be calculated by Subtraction Trick, value is the node which // provides the evidence for substracts std::vector nodes_for_subtraction_trick_; // list of nodes whose histograms would be built explicitly. std::vector nodes_for_explicit_hist_build_; std::unique_ptr> hist_synchronizer_; std::unique_ptr> hist_rows_adder_; ::sycl::queue qu_; }; } // namespace tree } // namespace sycl } // namespace xgboost #endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_