Dispatcher for template parameters of BuildHist Kernels (#8259)
* Intoducing Column Wise Hist Building * linting * more linting * bug fixing * Removing column samping optimization for a while to simplify the review process. * linting * Removing unnecessary changes * Use DispatchBinType in hist_util.cc * Adding force_read_by column flag to buildhist. Adding tests for column wise buiilhist. * Introducing new dispatcher for compile time flags in hist building * fixing bug with using of DispatchBinType * Fixing building * Merging with master branch Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com> Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
8d4038da57
commit
c24e9d712c
@ -134,10 +134,72 @@ struct Prefetch {
|
||||
|
||||
constexpr size_t Prefetch::kNoPrefetchSize;
|
||||
|
||||
template <bool do_prefetch, typename BinIdxType, bool first_page, bool any_missing = true>
|
||||
struct RuntimeFlags {
|
||||
const bool first_page;
|
||||
const bool read_by_column;
|
||||
const BinTypeSize bin_type_size;
|
||||
};
|
||||
|
||||
template <bool _any_missing,
|
||||
bool _first_page = false,
|
||||
bool _read_by_column = false,
|
||||
typename _BinIdxType = uint8_t>
|
||||
class GHistBuildingManager {
|
||||
public:
|
||||
constexpr static bool kAnyMissing = _any_missing;
|
||||
constexpr static bool kFirstPage = _first_page;
|
||||
constexpr static bool kReadByColumn = _read_by_column;
|
||||
using BinIdxType = _BinIdxType;
|
||||
|
||||
private:
|
||||
template<bool new_first_page>
|
||||
struct set_first_page {
|
||||
using type = GHistBuildingManager<kAnyMissing, new_first_page, kReadByColumn, BinIdxType>;
|
||||
};
|
||||
|
||||
template<bool new_read_by_column>
|
||||
struct set_read_by_column {
|
||||
using type = GHistBuildingManager<kAnyMissing, kFirstPage, new_read_by_column, BinIdxType>;
|
||||
};
|
||||
|
||||
template<typename new_bin_idx_type>
|
||||
struct set_bin_idx_type {
|
||||
using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, new_bin_idx_type>;
|
||||
};
|
||||
|
||||
using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, BinIdxType>;
|
||||
|
||||
public:
|
||||
/* Entry point to dispatcher
|
||||
* This function check matching run time flags to compile time flags.
|
||||
* In case of difference, it creates a Manager with different template parameters
|
||||
* and forward the call there.
|
||||
*/
|
||||
template <typename Fn>
|
||||
static void DispatchAndExecute(const RuntimeFlags& flags, Fn&& fn) {
|
||||
if (flags.first_page != kFirstPage) {
|
||||
set_first_page<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||
} else if (flags.read_by_column != kReadByColumn) {
|
||||
set_read_by_column<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||
} else if (flags.bin_type_size != sizeof(BinIdxType)) {
|
||||
DispatchBinType(flags.bin_type_size, [&](auto t) {
|
||||
using NewBinIdxType = decltype(t);
|
||||
set_bin_idx_type<NewBinIdxType>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||
});
|
||||
} else {
|
||||
fn(type());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool do_prefetch, class BuildingManager>
|
||||
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t *rid = row_indices.begin;
|
||||
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
||||
@ -147,10 +209,10 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
auto base_rowid = gmat.base_rowid;
|
||||
const uint32_t *offsets = gmat.index.Offset();
|
||||
auto get_row_ptr = [&](size_t ridx) {
|
||||
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
};
|
||||
auto get_rid = [&](size_t ridx) {
|
||||
return first_page ? ridx : (ridx - base_rowid);
|
||||
return kFirstPage ? ridx : (ridx - base_rowid);
|
||||
};
|
||||
|
||||
const size_t n_features =
|
||||
@ -163,20 +225,20 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const size_t icol_start =
|
||||
any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
|
||||
kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
|
||||
const size_t icol_end =
|
||||
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||
|
||||
const size_t row_size = icol_end - icol_start;
|
||||
const size_t idx_gh = two * rid[i];
|
||||
|
||||
if (do_prefetch) {
|
||||
const size_t icol_start_prefetch =
|
||||
any_missing
|
||||
kAnyMissing
|
||||
? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
|
||||
: get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
|
||||
const size_t icol_end_prefetch =
|
||||
any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
|
||||
kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
|
||||
: icol_start_prefetch + n_features;
|
||||
|
||||
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
|
||||
@ -191,7 +253,7 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]};
|
||||
for (size_t j = 0; j < row_size; ++j) {
|
||||
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
|
||||
(any_missing ? 0 : offsets[j]));
|
||||
(kAnyMissing ? 0 : offsets[j]));
|
||||
auto hist_local = hist_data + idx_bin;
|
||||
*(hist_local) += pgh_t[0];
|
||||
*(hist_local + 1) += pgh_t[1];
|
||||
@ -199,10 +261,13 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BinIdxType, bool first_page, bool any_missing>
|
||||
template <class BuildingManager>
|
||||
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist) {
|
||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t *rid = row_indices.begin;
|
||||
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
||||
@ -212,10 +277,10 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
auto base_rowid = gmat.base_rowid;
|
||||
const uint32_t *offsets = gmat.index.Offset();
|
||||
auto get_row_ptr = [&](size_t ridx) {
|
||||
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
};
|
||||
auto get_rid = [&](size_t ridx) {
|
||||
return first_page ? ridx : (ridx - base_rowid);
|
||||
return kFirstPage ? ridx : (ridx - base_rowid);
|
||||
};
|
||||
|
||||
const size_t n_features = gmat.cut.Ptrs().size() - 1;
|
||||
@ -226,13 +291,13 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
// So we need to multiply each row-index/bin-index by 2
|
||||
// to work with gradient pairs as a singe row FP array
|
||||
for (size_t cid = 0; cid < n_columns; ++cid) {
|
||||
const uint32_t offset = any_missing ? 0 : offsets[cid];
|
||||
const uint32_t offset = kAnyMissing ? 0 : offsets[cid];
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const size_t row_id = rid[i];
|
||||
const size_t icol_start =
|
||||
any_missing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
|
||||
kAnyMissing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
|
||||
const size_t icol_end =
|
||||
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||
|
||||
if (cid < icol_end - icol_start) {
|
||||
const BinIdxType *gr_index_local = gradient_index + icol_start;
|
||||
@ -249,41 +314,13 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool do_prefetch, typename BinIdxType, bool first_page,
|
||||
bool any_missing>
|
||||
void BuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist, bool read_by_column) {
|
||||
if (read_by_column) {
|
||||
ColsWiseBuildHistKernel<BinIdxType, first_page, any_missing>
|
||||
(gpair, row_indices, gmat, hist);
|
||||
} else {
|
||||
RowsWiseBuildHistKernel<do_prefetch, BinIdxType, first_page, any_missing>
|
||||
(gpair, row_indices, gmat, hist);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool do_prefetch, bool any_missing>
|
||||
template <class BuildingManager>
|
||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist, bool read_by_column) {
|
||||
auto first_page = gmat.base_rowid == 0;
|
||||
DispatchBinType(gmat.index.GetBinTypeSize(), [&](auto t) {
|
||||
using BinIdxType = decltype(t);
|
||||
if (first_page) {
|
||||
BuildHistKernel<do_prefetch, BinIdxType, true, any_missing>
|
||||
(gpair, row_indices, gmat, hist, read_by_column);
|
||||
GHistRow hist) {
|
||||
if (BuildingManager::kReadByColumn) {
|
||||
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
} else {
|
||||
BuildHistKernel<do_prefetch, BinIdxType, false, any_missing>
|
||||
(gpair, row_indices, gmat, hist, read_by_column);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <bool any_missing>
|
||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||
GHistRow hist, bool read_by_column) {
|
||||
const size_t nrows = row_indices.Size();
|
||||
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
|
||||
// if need to work with all rows from bin-matrix (e.g. root node)
|
||||
@ -292,16 +329,17 @@ void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||
|
||||
if (contiguousBlock) {
|
||||
// contiguous memory access, built-in HW prefetching is enough
|
||||
BuildHistDispatch<false, any_missing>(gpair, row_indices, gmat, hist, read_by_column);
|
||||
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
} else {
|
||||
const RowSetCollection::Elem span1(row_indices.begin,
|
||||
row_indices.end - no_prefetch_size);
|
||||
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
|
||||
row_indices.end);
|
||||
|
||||
BuildHistDispatch<true, any_missing>(gpair, span1, gmat, hist, read_by_column);
|
||||
RowsWiseBuildHistKernel<true, BuildingManager>(gpair, span1, gmat, hist);
|
||||
// no prefetching to avoid loading extra memory
|
||||
BuildHistDispatch<false, any_missing>(gpair, span2, gmat, hist, read_by_column);
|
||||
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, span2, gmat, hist);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,10 +353,16 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
|
||||
*/
|
||||
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
||||
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
|
||||
const bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
||||
bool first_page = gmat.base_rowid == 0;
|
||||
bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
||||
auto bin_type_size = gmat.index.GetBinTypeSize();
|
||||
|
||||
BuildHistDispatch<any_missing>(gpair, row_indices, gmat, hist, read_by_column ||
|
||||
force_read_by_column);
|
||||
GHistBuildingManager<any_missing>::DispatchAndExecute(
|
||||
{first_page, read_by_column || force_read_by_column, bin_type_size},
|
||||
[&](auto t) {
|
||||
using BuildingManager = decltype(t);
|
||||
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||
});
|
||||
}
|
||||
|
||||
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user