xgboost/plugin/sycl/common/hist_util.h
Dmitry Razdoburdin 6a7c6a8ae6
add sycl reaslisation of ghist builder (#10138)
Co-authored-by: Dmitry Razdoburdin <>
2024-03-23 12:55:25 +08:00

90 lines
2.7 KiB
C++

/*!
* Copyright 2017-2023 by Contributors
* \file hist_util.h
*/
#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#include <vector>
#include <unordered_map>
#include <memory>
#include "../data.h"
#include "row_set.h"
#include "../../src/common/hist_util.h"
#include "../data/gradient_index.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace common {
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
using GHistRow = USMVector<xgboost::detail::GradientPairInternal<GradientSumT>, memory_type>;
using BinTypeSize = ::xgboost::common::BinTypeSize;
class ColumnMatrix;
/*!
* \brief Fill histogram with zeroes
*/
template<typename GradientSumT>
void InitHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Compute subtraction: dst = src1 - src2
*/
template<typename GradientSumT>
::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
/*!
* \brief Builder for histograms of gradient statistics
*/
template<typename GradientSumT>
class GHistBuilder {
public:
template<MemoryType memory_type = MemoryType::shared>
using GHistRowT = GHistRow<GradientSumT, memory_type>;
GHistBuilder() = default;
GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}
// Construct a histogram via histogram aggregation
::sycl::event BuildHist(const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRowT<MemoryType::on_device>* HistCollection,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event,
bool force_atomic_use = false);
// Construct a histogram via subtraction trick
void SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
const GHistRowT<MemoryType::on_device>& sibling,
const GHistRowT<MemoryType::on_device>& parent);
uint32_t GetNumBins() const {
return nbins_;
}
private:
/*! \brief Number of all bins over all features */
uint32_t nbins_ { 0 };
::sycl::queue qu_;
};
} // namespace common
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_