[sycl] add loss guided hist building (#10251)

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2024-05-10 16:35:13 +02:00
committed by GitHub
parent 9b465052ce
commit f588252481
7 changed files with 459 additions and 30 deletions

View File

@@ -46,6 +46,93 @@ template<typename GradientSumT>
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
/*!
* \brief Histograms of gradient statistics for multiple nodes
*/
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
class HistCollection {
public:
using GHistRowT = GHistRow<GradientSumT, memory_type>;
// Access histogram for i-th node
GHistRowT& operator[](bst_uint nid) {
return *(data_.at(nid));
}
const GHistRowT& operator[](bst_uint nid) const {
return *(data_.at(nid));
}
// Initialize histogram collection
void Init(::sycl::queue qu, uint32_t nbins) {
qu_ = qu;
if (nbins_ != nbins) {
nbins_ = nbins;
data_.clear();
}
}
// Create an empty histogram for i-th node
::sycl::event AddHistRow(bst_uint nid) {
::sycl::event event;
if (data_.count(nid) == 0) {
data_[nid] =
std::make_shared<GHistRowT>(&qu_, nbins_,
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
&event);
} else {
data_[nid]->Resize(&qu_, nbins_,
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
&event);
}
return event;
}
private:
/*! \brief Number of all bins over all features */
uint32_t nbins_ = 0;
std::unordered_map<uint32_t, std::shared_ptr<GHistRowT>> data_;
::sycl::queue qu_;
};
/*!
* \brief Stores temporary histograms to compute them in parallel
*/
template<typename GradientSumT>
class ParallelGHistBuilder {
public:
using GHistRowT = GHistRow<GradientSumT, MemoryType::on_device>;
void Init(::sycl::queue qu, size_t nbins) {
qu_ = qu;
if (nbins != nbins_) {
hist_buffer_.Init(qu_, nbins);
nbins_ = nbins;
}
}
void Reset(size_t nblocks) {
hist_device_buffer_.Resize(&qu_, nblocks * nbins_ * 2);
}
GHistRowT& GetDeviceBuffer() {
return hist_device_buffer_;
}
protected:
/*! \brief Number of bins in each histogram */
size_t nbins_ = 0;
/*! \brief Buffers for histograms for all nodes processed */
HistCollection<GradientSumT> hist_buffer_;
/*! \brief Buffer for additional histograms for Parallel processing */
GHistRowT hist_device_buffer_;
::sycl::queue qu_;
};
/*!
* \brief Builder for histograms of gradient statistics
*/