Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -28,5 +28,10 @@ constexpr StringView InfInData() {
|
||||
constexpr StringView NoF128() {
|
||||
return "128-bit floating point is not supported on current platform.";
|
||||
}
|
||||
|
||||
constexpr StringView InconsistentMaxBin() {
|
||||
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
|
||||
"and consistent with the Booster being trained.";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@@ -2,15 +2,18 @@
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* \file hist_util.cc
|
||||
*/
|
||||
#include "hist_util.h"
|
||||
|
||||
#include <dmlc/timer.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "../common/common.h"
|
||||
#include "hist_util.h"
|
||||
#include "column_matrix.h"
|
||||
#include "quantile.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // SparsePage, SortedCSCPage
|
||||
|
||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||
#include <xmmintrin.h>
|
||||
@@ -28,10 +31,11 @@ HistogramCuts::HistogramCuts() {
|
||||
cut_ptrs_.HostVector().emplace_back(0);
|
||||
}
|
||||
|
||||
HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, bool use_sorted,
|
||||
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
|
||||
Span<float> const hessian) {
|
||||
HistogramCuts out;
|
||||
auto const& info = m->Info();
|
||||
auto const &info = m->Info();
|
||||
auto n_threads = ctx->Threads();
|
||||
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
auto const &entries_per_column =
|
||||
@@ -44,16 +48,19 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
}
|
||||
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info), n_threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
HostSketchContainer container(ctx, max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info));
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(m->Info(), &out);
|
||||
} else {
|
||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info), n_threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
SortedSketchContainer container{ctx,
|
||||
max_bins,
|
||||
m->Info().feature_types.ConstHostSpan(),
|
||||
reduced,
|
||||
HostSketchContainer::UseGroup(info)};
|
||||
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(m->Info(), &out);
|
||||
|
||||
@@ -170,7 +170,7 @@ class HistogramCuts {
|
||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||
* but consumes more memory.
|
||||
*/
|
||||
HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
|
||||
HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins,
|
||||
bool use_sorted = false, Span<float> const hessian = {});
|
||||
|
||||
enum BinTypeSize : uint8_t {
|
||||
|
||||
@@ -16,16 +16,16 @@ namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
template <typename WQSketch>
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||
std::vector<bst_row_t> columns_size,
|
||||
int32_t max_bins,
|
||||
Span<FeatureType const> feature_types,
|
||||
bool use_group,
|
||||
int32_t n_threads)
|
||||
bool use_group)
|
||||
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
||||
columns_size_{std::move(columns_size)},
|
||||
max_bins_{max_bins},
|
||||
use_group_ind_{use_group},
|
||||
n_threads_{n_threads} {
|
||||
n_threads_{ctx->Threads()} {
|
||||
monitor_.Init(__func__);
|
||||
CHECK_NE(columns_size_.size(), 0);
|
||||
sketches_.resize(columns_size_.size());
|
||||
@@ -380,13 +380,13 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts* cuts) {
|
||||
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<typename WQSketch::SummaryContainer> reduced;
|
||||
std::vector<int32_t> num_cuts;
|
||||
this->AllReduce(info, &reduced, &num_cuts);
|
||||
|
||||
cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||
p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||
|
||||
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
|
||||
@@ -401,48 +401,48 @@ void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts
|
||||
a.SetPrune(reduced[fidx], max_num_bins + 1);
|
||||
CHECK(a.data && reduced[fidx].data);
|
||||
const bst_float mval = a.data[0].value;
|
||||
cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
|
||||
p_cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
|
||||
} else {
|
||||
// Empty column.
|
||||
const float mval = 1e-5f;
|
||||
cuts->min_vals_.HostVector()[fidx] = mval;
|
||||
p_cuts->min_vals_.HostVector()[fidx] = mval;
|
||||
}
|
||||
});
|
||||
|
||||
float max_cat{-1.f};
|
||||
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
||||
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
||||
typename WQSketch::SummaryContainer const& a = final_summaries[fid];
|
||||
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
|
||||
if (IsCat(feature_types_, fid)) {
|
||||
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), cuts));
|
||||
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
|
||||
} else {
|
||||
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
||||
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt =
|
||||
(a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
|
||||
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
|
||||
// this must be bigger than last value in a scale
|
||||
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
||||
cuts->cut_values_.HostVector().push_back(last);
|
||||
p_cuts->cut_values_.HostVector().push_back(last);
|
||||
}
|
||||
|
||||
// Ensure that every feature gets at least one quantile point
|
||||
CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
|
||||
auto cut_size = static_cast<uint32_t>(cuts->cut_values_.HostVector().size());
|
||||
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
|
||||
cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
||||
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
|
||||
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
|
||||
CHECK_GT(cut_size, p_cuts->cut_ptrs_.HostVector().back());
|
||||
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
||||
}
|
||||
|
||||
cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
template class SketchContainerImpl<WQuantileSketch<float, float>>;
|
||||
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
||||
HostSketchContainer::HostSketchContainer(Context const *ctx, bst_bin_t max_bins,
|
||||
common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group)
|
||||
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||
monitor_.Init(__func__);
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
|
||||
@@ -800,9 +800,8 @@ class SketchContainerImpl {
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads);
|
||||
SketchContainerImpl(Context const *ctx, std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group);
|
||||
|
||||
static bool UseGroup(MetaInfo const &info) {
|
||||
size_t const num_groups =
|
||||
@@ -894,8 +893,8 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
public:
|
||||
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
|
||||
HostSketchContainer(Context const *ctx, bst_bin_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group);
|
||||
|
||||
template <typename Batch>
|
||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
||||
@@ -990,10 +989,10 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
|
||||
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
public:
|
||||
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
||||
explicit SortedSketchContainer(Context const *ctx, int32_t max_bins,
|
||||
common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group)
|
||||
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||
monitor_.Init(__func__);
|
||||
sketches_.resize(columns_size.size());
|
||||
size_t i = 0;
|
||||
|
||||
Reference in New Issue
Block a user