Small refactor for hist builder. (#8698)
- Use span instead of vector as parameter. No perf change as the builder work on pointer. - Use const pointer for reg tree.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* \file hist_util.cc
|
||||
*/
|
||||
#include <dmlc/timer.h>
|
||||
@@ -193,9 +193,9 @@ class GHistBuildingManager {
|
||||
};
|
||||
|
||||
template <bool do_prefetch, class BuildingManager>
|
||||
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||
@@ -262,9 +262,9 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
}
|
||||
|
||||
template <class BuildingManager>
|
||||
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||
@@ -315,9 +315,8 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
}
|
||||
|
||||
template <class BuildingManager>
|
||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix &gmat, GHistRow hist) {
|
||||
if (BuildingManager::kReadByColumn) {
|
||||
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
} else {
|
||||
@@ -344,33 +343,31 @@ void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||
}
|
||||
|
||||
template <bool any_missing>
|
||||
void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix &gmat,
|
||||
void GHistBuilder::BuildHist(Span<GradientPair const> gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist, bool force_read_by_column) const {
|
||||
/* force_read_by_column is used for testing the columnwise building of histograms.
|
||||
* default force_read_by_column = false
|
||||
*/
|
||||
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
||||
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
|
||||
const bool hist_fit_to_l2 = kAdhocL2Size > 2 * sizeof(float) * gmat.cut.Ptrs().back();
|
||||
bool first_page = gmat.base_rowid == 0;
|
||||
bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
||||
auto bin_type_size = gmat.index.GetBinTypeSize();
|
||||
|
||||
GHistBuildingManager<any_missing>::DispatchAndExecute(
|
||||
{first_page, read_by_column || force_read_by_column, bin_type_size},
|
||||
[&](auto t) {
|
||||
using BuildingManager = decltype(t);
|
||||
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
});
|
||||
{first_page, read_by_column || force_read_by_column, bin_type_size}, [&](auto t) {
|
||||
using BuildingManager = decltype(t);
|
||||
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
});
|
||||
}
|
||||
|
||||
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
||||
template void GHistBuilder::BuildHist<true>(Span<GradientPair const> gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix &gmat, GHistRow hist,
|
||||
bool force_read_by_column) const;
|
||||
|
||||
template void GHistBuilder::BuildHist<false>(const std::vector<GradientPair> &gpair,
|
||||
template void GHistBuilder::BuildHist<false>(Span<GradientPair const> gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix &gmat, GHistRow hist,
|
||||
bool force_read_by_column) const;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* \file hist_util.h
|
||||
* \brief Utility for fast histogram aggregation
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "row_set.h"
|
||||
#include "threading_utils.h"
|
||||
#include "timer.h"
|
||||
#include "xgboost/base.h" // bst_feature_t, bst_bin_t
|
||||
|
||||
namespace xgboost {
|
||||
class GHistIndexMatrix;
|
||||
@@ -320,10 +321,10 @@ struct Index {
|
||||
};
|
||||
|
||||
template <typename GradientIndex>
|
||||
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
||||
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t end,
|
||||
GradientIndex const& data,
|
||||
uint32_t const fidx_begin,
|
||||
uint32_t const fidx_end) {
|
||||
bst_feature_t const fidx_begin,
|
||||
bst_feature_t const fidx_end) {
|
||||
size_t previous_middle = std::numeric_limits<size_t>::max();
|
||||
while (end != begin) {
|
||||
size_t middle = begin + (end - begin) / 2;
|
||||
@@ -635,7 +636,7 @@ class GHistBuilder {
|
||||
|
||||
// construct a histogram via histogram aggregation
|
||||
template <bool any_missing>
|
||||
void BuildHist(const std::vector<GradientPair>& gpair, const RowSetCollection::Elem row_indices,
|
||||
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat, GHistRow hist,
|
||||
bool force_read_by_column = false) const;
|
||||
uint32_t GetNumBins() const {
|
||||
|
||||
Reference in New Issue
Block a user