[sycl] add loss guided hist building (#10251)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
9b465052ce
commit
f588252481
@@ -46,6 +46,93 @@ template<typename GradientSumT>
|
||||
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
|
||||
size_t size, ::sycl::event event_priv);
|
||||
|
||||
/*!
|
||||
* \brief Histograms of gradient statistics for multiple nodes
|
||||
*/
|
||||
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
|
||||
class HistCollection {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT, memory_type>;
|
||||
|
||||
// Access histogram for i-th node
|
||||
GHistRowT& operator[](bst_uint nid) {
|
||||
return *(data_.at(nid));
|
||||
}
|
||||
|
||||
const GHistRowT& operator[](bst_uint nid) const {
|
||||
return *(data_.at(nid));
|
||||
}
|
||||
|
||||
// Initialize histogram collection
|
||||
void Init(::sycl::queue qu, uint32_t nbins) {
|
||||
qu_ = qu;
|
||||
if (nbins_ != nbins) {
|
||||
nbins_ = nbins;
|
||||
data_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Create an empty histogram for i-th node
|
||||
::sycl::event AddHistRow(bst_uint nid) {
|
||||
::sycl::event event;
|
||||
if (data_.count(nid) == 0) {
|
||||
data_[nid] =
|
||||
std::make_shared<GHistRowT>(&qu_, nbins_,
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
|
||||
&event);
|
||||
} else {
|
||||
data_[nid]->Resize(&qu_, nbins_,
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
|
||||
&event);
|
||||
}
|
||||
return event;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Number of all bins over all features */
|
||||
uint32_t nbins_ = 0;
|
||||
|
||||
std::unordered_map<uint32_t, std::shared_ptr<GHistRowT>> data_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Stores temporary histograms to compute them in parallel
|
||||
*/
|
||||
template<typename GradientSumT>
|
||||
class ParallelGHistBuilder {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT, MemoryType::on_device>;
|
||||
|
||||
void Init(::sycl::queue qu, size_t nbins) {
|
||||
qu_ = qu;
|
||||
if (nbins != nbins_) {
|
||||
hist_buffer_.Init(qu_, nbins);
|
||||
nbins_ = nbins;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset(size_t nblocks) {
|
||||
hist_device_buffer_.Resize(&qu_, nblocks * nbins_ * 2);
|
||||
}
|
||||
|
||||
GHistRowT& GetDeviceBuffer() {
|
||||
return hist_device_buffer_;
|
||||
}
|
||||
|
||||
protected:
|
||||
/*! \brief Number of bins in each histogram */
|
||||
size_t nbins_ = 0;
|
||||
/*! \brief Buffers for histograms for all nodes processed */
|
||||
HistCollection<GradientSumT> hist_buffer_;
|
||||
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
GHistRowT hist_device_buffer_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Builder for histograms of gradient statistics
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user