Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -2,15 +2,18 @@
* Copyright 2017-2023 by XGBoost Contributors
* \file hist_util.cc
*/
#include "hist_util.h"
#include <dmlc/timer.h>
#include <vector>
#include "xgboost/base.h"
#include "../common/common.h"
#include "hist_util.h"
#include "column_matrix.h"
#include "quantile.h"
#include "xgboost/base.h"
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // SparsePage, SortedCSCPage
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h>
@@ -28,10 +31,11 @@ HistogramCuts::HistogramCuts() {
cut_ptrs_.HostVector().emplace_back(0);
}
HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, bool use_sorted,
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
Span<float> const hessian) {
HistogramCuts out;
auto const& info = m->Info();
auto const &info = m->Info();
auto n_threads = ctx->Threads();
std::vector<bst_row_t> reduced(info.num_col_, 0);
for (auto const &page : m->GetBatches<SparsePage>()) {
auto const &entries_per_column =
@@ -44,16 +48,19 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
}
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
HostSketchContainer container(ctx, max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info));
for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
SortedSketchContainer container{ctx,
max_bins,
m->Info().feature_types.ConstHostSpan(),
reduced,
HostSketchContainer::UseGroup(info)};
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);