Dispatcher for template parameters of BuildHist Kernels (#8259)
* Intoducing Column Wise Hist Building * linting * more linting * bug fixing * Removing column samping optimization for a while to simplify the review process. * linting * Removing unnecessary changes * Use DispatchBinType in hist_util.cc * Adding force_read_by column flag to buildhist. Adding tests for column wise buiilhist. * Introducing new dispatcher for compile time flags in hist building * fixing bug with using of DispatchBinType * Fixing building * Merging with master branch Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com> Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
8d4038da57
commit
c24e9d712c
@ -134,10 +134,72 @@ struct Prefetch {
|
|||||||
|
|
||||||
constexpr size_t Prefetch::kNoPrefetchSize;
|
constexpr size_t Prefetch::kNoPrefetchSize;
|
||||||
|
|
||||||
template <bool do_prefetch, typename BinIdxType, bool first_page, bool any_missing = true>
|
struct RuntimeFlags {
|
||||||
|
const bool first_page;
|
||||||
|
const bool read_by_column;
|
||||||
|
const BinTypeSize bin_type_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <bool _any_missing,
|
||||||
|
bool _first_page = false,
|
||||||
|
bool _read_by_column = false,
|
||||||
|
typename _BinIdxType = uint8_t>
|
||||||
|
class GHistBuildingManager {
|
||||||
|
public:
|
||||||
|
constexpr static bool kAnyMissing = _any_missing;
|
||||||
|
constexpr static bool kFirstPage = _first_page;
|
||||||
|
constexpr static bool kReadByColumn = _read_by_column;
|
||||||
|
using BinIdxType = _BinIdxType;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template<bool new_first_page>
|
||||||
|
struct set_first_page {
|
||||||
|
using type = GHistBuildingManager<kAnyMissing, new_first_page, kReadByColumn, BinIdxType>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<bool new_read_by_column>
|
||||||
|
struct set_read_by_column {
|
||||||
|
using type = GHistBuildingManager<kAnyMissing, kFirstPage, new_read_by_column, BinIdxType>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename new_bin_idx_type>
|
||||||
|
struct set_bin_idx_type {
|
||||||
|
using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, new_bin_idx_type>;
|
||||||
|
};
|
||||||
|
|
||||||
|
using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, BinIdxType>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/* Entry point to dispatcher
|
||||||
|
* This function check matching run time flags to compile time flags.
|
||||||
|
* In case of difference, it creates a Manager with different template parameters
|
||||||
|
* and forward the call there.
|
||||||
|
*/
|
||||||
|
template <typename Fn>
|
||||||
|
static void DispatchAndExecute(const RuntimeFlags& flags, Fn&& fn) {
|
||||||
|
if (flags.first_page != kFirstPage) {
|
||||||
|
set_first_page<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||||
|
} else if (flags.read_by_column != kReadByColumn) {
|
||||||
|
set_read_by_column<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||||
|
} else if (flags.bin_type_size != sizeof(BinIdxType)) {
|
||||||
|
DispatchBinType(flags.bin_type_size, [&](auto t) {
|
||||||
|
using NewBinIdxType = decltype(t);
|
||||||
|
set_bin_idx_type<NewBinIdxType>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
fn(type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <bool do_prefetch, class BuildingManager>
|
||||||
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
GHistRow hist) {
|
GHistRow hist) {
|
||||||
|
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||||
|
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||||
|
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||||
|
|
||||||
const size_t size = row_indices.Size();
|
const size_t size = row_indices.Size();
|
||||||
const size_t *rid = row_indices.begin;
|
const size_t *rid = row_indices.begin;
|
||||||
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
||||||
@ -147,10 +209,10 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
auto base_rowid = gmat.base_rowid;
|
auto base_rowid = gmat.base_rowid;
|
||||||
const uint32_t *offsets = gmat.index.Offset();
|
const uint32_t *offsets = gmat.index.Offset();
|
||||||
auto get_row_ptr = [&](size_t ridx) {
|
auto get_row_ptr = [&](size_t ridx) {
|
||||||
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||||
};
|
};
|
||||||
auto get_rid = [&](size_t ridx) {
|
auto get_rid = [&](size_t ridx) {
|
||||||
return first_page ? ridx : (ridx - base_rowid);
|
return kFirstPage ? ridx : (ridx - base_rowid);
|
||||||
};
|
};
|
||||||
|
|
||||||
const size_t n_features =
|
const size_t n_features =
|
||||||
@ -163,20 +225,20 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
|
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
const size_t icol_start =
|
const size_t icol_start =
|
||||||
any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
|
kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
|
||||||
const size_t icol_end =
|
const size_t icol_end =
|
||||||
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||||
|
|
||||||
const size_t row_size = icol_end - icol_start;
|
const size_t row_size = icol_end - icol_start;
|
||||||
const size_t idx_gh = two * rid[i];
|
const size_t idx_gh = two * rid[i];
|
||||||
|
|
||||||
if (do_prefetch) {
|
if (do_prefetch) {
|
||||||
const size_t icol_start_prefetch =
|
const size_t icol_start_prefetch =
|
||||||
any_missing
|
kAnyMissing
|
||||||
? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
|
? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
|
||||||
: get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
|
: get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
|
||||||
const size_t icol_end_prefetch =
|
const size_t icol_end_prefetch =
|
||||||
any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
|
kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
|
||||||
: icol_start_prefetch + n_features;
|
: icol_start_prefetch + n_features;
|
||||||
|
|
||||||
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
|
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
|
||||||
@ -191,7 +253,7 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]};
|
const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]};
|
||||||
for (size_t j = 0; j < row_size; ++j) {
|
for (size_t j = 0; j < row_size; ++j) {
|
||||||
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
|
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
|
||||||
(any_missing ? 0 : offsets[j]));
|
(kAnyMissing ? 0 : offsets[j]));
|
||||||
auto hist_local = hist_data + idx_bin;
|
auto hist_local = hist_data + idx_bin;
|
||||||
*(hist_local) += pgh_t[0];
|
*(hist_local) += pgh_t[0];
|
||||||
*(hist_local + 1) += pgh_t[1];
|
*(hist_local + 1) += pgh_t[1];
|
||||||
@ -199,10 +261,13 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename BinIdxType, bool first_page, bool any_missing>
|
template <class BuildingManager>
|
||||||
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
GHistRow hist) {
|
GHistRow hist) {
|
||||||
|
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||||
|
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||||
|
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||||
const size_t size = row_indices.Size();
|
const size_t size = row_indices.Size();
|
||||||
const size_t *rid = row_indices.begin;
|
const size_t *rid = row_indices.begin;
|
||||||
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
|
||||||
@ -212,10 +277,10 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
auto base_rowid = gmat.base_rowid;
|
auto base_rowid = gmat.base_rowid;
|
||||||
const uint32_t *offsets = gmat.index.Offset();
|
const uint32_t *offsets = gmat.index.Offset();
|
||||||
auto get_row_ptr = [&](size_t ridx) {
|
auto get_row_ptr = [&](size_t ridx) {
|
||||||
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||||
};
|
};
|
||||||
auto get_rid = [&](size_t ridx) {
|
auto get_rid = [&](size_t ridx) {
|
||||||
return first_page ? ridx : (ridx - base_rowid);
|
return kFirstPage ? ridx : (ridx - base_rowid);
|
||||||
};
|
};
|
||||||
|
|
||||||
const size_t n_features = gmat.cut.Ptrs().size() - 1;
|
const size_t n_features = gmat.cut.Ptrs().size() - 1;
|
||||||
@ -226,13 +291,13 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
// So we need to multiply each row-index/bin-index by 2
|
// So we need to multiply each row-index/bin-index by 2
|
||||||
// to work with gradient pairs as a singe row FP array
|
// to work with gradient pairs as a singe row FP array
|
||||||
for (size_t cid = 0; cid < n_columns; ++cid) {
|
for (size_t cid = 0; cid < n_columns; ++cid) {
|
||||||
const uint32_t offset = any_missing ? 0 : offsets[cid];
|
const uint32_t offset = kAnyMissing ? 0 : offsets[cid];
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
const size_t row_id = rid[i];
|
const size_t row_id = rid[i];
|
||||||
const size_t icol_start =
|
const size_t icol_start =
|
||||||
any_missing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
|
kAnyMissing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
|
||||||
const size_t icol_end =
|
const size_t icol_end =
|
||||||
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
|
||||||
|
|
||||||
if (cid < icol_end - icol_start) {
|
if (cid < icol_end - icol_start) {
|
||||||
const BinIdxType *gr_index_local = gradient_index + icol_start;
|
const BinIdxType *gr_index_local = gradient_index + icol_start;
|
||||||
@ -249,59 +314,32 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool do_prefetch, typename BinIdxType, bool first_page,
|
template <class BuildingManager>
|
||||||
bool any_missing>
|
|
||||||
void BuildHistKernel(const std::vector<GradientPair> &gpair,
|
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
|
||||||
GHistRow hist, bool read_by_column) {
|
|
||||||
if (read_by_column) {
|
|
||||||
ColsWiseBuildHistKernel<BinIdxType, first_page, any_missing>
|
|
||||||
(gpair, row_indices, gmat, hist);
|
|
||||||
} else {
|
|
||||||
RowsWiseBuildHistKernel<do_prefetch, BinIdxType, first_page, any_missing>
|
|
||||||
(gpair, row_indices, gmat, hist);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool do_prefetch, bool any_missing>
|
|
||||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
GHistRow hist, bool read_by_column) {
|
GHistRow hist) {
|
||||||
auto first_page = gmat.base_rowid == 0;
|
if (BuildingManager::kReadByColumn) {
|
||||||
DispatchBinType(gmat.index.GetBinTypeSize(), [&](auto t) {
|
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||||
using BinIdxType = decltype(t);
|
} else {
|
||||||
if (first_page) {
|
const size_t nrows = row_indices.Size();
|
||||||
BuildHistKernel<do_prefetch, BinIdxType, true, any_missing>
|
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
|
||||||
(gpair, row_indices, gmat, hist, read_by_column);
|
// if need to work with all rows from bin-matrix (e.g. root node)
|
||||||
|
const bool contiguousBlock =
|
||||||
|
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);
|
||||||
|
|
||||||
|
if (contiguousBlock) {
|
||||||
|
// contiguous memory access, built-in HW prefetching is enough
|
||||||
|
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, row_indices, gmat, hist);
|
||||||
} else {
|
} else {
|
||||||
BuildHistKernel<do_prefetch, BinIdxType, false, any_missing>
|
const RowSetCollection::Elem span1(row_indices.begin,
|
||||||
(gpair, row_indices, gmat, hist, read_by_column);
|
row_indices.end - no_prefetch_size);
|
||||||
|
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
|
||||||
|
row_indices.end);
|
||||||
|
|
||||||
|
RowsWiseBuildHistKernel<true, BuildingManager>(gpair, span1, gmat, hist);
|
||||||
|
// no prefetching to avoid loading extra memory
|
||||||
|
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, span2, gmat, hist);
|
||||||
}
|
}
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool any_missing>
|
|
||||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
|
||||||
GHistRow hist, bool read_by_column) {
|
|
||||||
const size_t nrows = row_indices.Size();
|
|
||||||
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
|
|
||||||
// if need to work with all rows from bin-matrix (e.g. root node)
|
|
||||||
const bool contiguousBlock =
|
|
||||||
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);
|
|
||||||
|
|
||||||
if (contiguousBlock) {
|
|
||||||
// contiguous memory access, built-in HW prefetching is enough
|
|
||||||
BuildHistDispatch<false, any_missing>(gpair, row_indices, gmat, hist, read_by_column);
|
|
||||||
} else {
|
|
||||||
const RowSetCollection::Elem span1(row_indices.begin,
|
|
||||||
row_indices.end - no_prefetch_size);
|
|
||||||
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
|
|
||||||
row_indices.end);
|
|
||||||
|
|
||||||
BuildHistDispatch<true, any_missing>(gpair, span1, gmat, hist, read_by_column);
|
|
||||||
// no prefetching to avoid loading extra memory
|
|
||||||
BuildHistDispatch<false, any_missing>(gpair, span2, gmat, hist, read_by_column);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,10 +353,16 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
|
|||||||
*/
|
*/
|
||||||
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
||||||
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
|
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
|
||||||
const bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
bool first_page = gmat.base_rowid == 0;
|
||||||
|
bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
||||||
|
auto bin_type_size = gmat.index.GetBinTypeSize();
|
||||||
|
|
||||||
BuildHistDispatch<any_missing>(gpair, row_indices, gmat, hist, read_by_column ||
|
GHistBuildingManager<any_missing>::DispatchAndExecute(
|
||||||
force_read_by_column);
|
{first_page, read_by_column || force_read_by_column, bin_type_size},
|
||||||
|
[&](auto t) {
|
||||||
|
using BuildingManager = decltype(t);
|
||||||
|
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user