Remove omp_get_max_threads in data. (#7588)
This commit is contained in:
parent
f84291c1e1
commit
5817840858
@ -300,7 +300,7 @@ class SparsePage {
|
||||
base_rowid = row_id;
|
||||
}
|
||||
|
||||
SparsePage GetTranspose(int num_columns) const;
|
||||
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
|
||||
|
||||
void SortRows() {
|
||||
auto ncol = static_cast<bst_omp_uint>(this->Size());
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
* \file hist_util.h
|
||||
* \brief Utility for fast histogram aggregation
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@ -137,19 +137,18 @@ class HistogramCuts {
|
||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||
* but consumes more memory.
|
||||
*/
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false,
|
||||
Span<float> const hessian = {}) {
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
|
||||
bool use_sorted = false, Span<float> const hessian = {}) {
|
||||
HistogramCuts out;
|
||||
auto const& info = m->Info();
|
||||
const auto threads = omp_get_max_threads();
|
||||
std::vector<std::vector<bst_row_t>> column_sizes(threads);
|
||||
std::vector<std::vector<bst_row_t>> column_sizes(n_threads);
|
||||
for (auto& column : column_sizes) {
|
||||
column.resize(info.num_col_, 0);
|
||||
}
|
||||
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
auto const &entries_per_column =
|
||||
HostSketchContainer::CalcColumnSize(page, info.num_col_, threads);
|
||||
auto const& entries_per_column =
|
||||
HostSketchContainer::CalcColumnSize(page, info.num_col_, n_threads);
|
||||
for (size_t i = 0; i < entries_per_column.size(); ++i) {
|
||||
reduced[i] += entries_per_column[i];
|
||||
}
|
||||
@ -157,14 +156,14 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sort
|
||||
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
||||
hessian, threads);
|
||||
hessian, n_threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
} else {
|
||||
SortedSketchContainer container{
|
||||
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads};
|
||||
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, n_threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
|
||||
@ -263,17 +263,6 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
||||
return nthread_original;
|
||||
}
|
||||
|
||||
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
|
||||
auto& threads = *p_threads;
|
||||
int32_t nthread_original = omp_get_max_threads();
|
||||
if (threads <= 0) {
|
||||
threads = nthread_original;
|
||||
}
|
||||
threads = std::min(threads, OmpGetThreadLimit());
|
||||
omp_set_num_threads(threads);
|
||||
return nthread_original;
|
||||
}
|
||||
|
||||
inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
@ -1001,15 +1001,14 @@ DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||
XGBoostBatchCSR> *adapter,
|
||||
float missing, int nthread, const std::string &cache_prefix);
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
|
||||
&transpose.data.HostVector());
|
||||
const int nthread = omp_get_max_threads();
|
||||
builder.InitBudget(num_columns, nthread);
|
||||
builder.InitBudget(num_columns, n_threads);
|
||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||
auto page = this->GetView();
|
||||
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
|
||||
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = page[i];
|
||||
for (const auto& entry : inst) {
|
||||
@ -1017,7 +1016,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||
}
|
||||
});
|
||||
builder.InitStorage();
|
||||
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
|
||||
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = page[i];
|
||||
for (const auto& entry : inst) {
|
||||
@ -1059,8 +1058,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
|
||||
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
|
||||
nthread = kIsRowMajor ? nthread : 1;
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
|
||||
if (!kIsRowMajor) {
|
||||
CHECK_EQ(nthread, 1);
|
||||
}
|
||||
@ -1085,7 +1082,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
expected_rows = kIsRowMajor ? batch_size : expected_rows;
|
||||
uint64_t max_columns = 0;
|
||||
if (batch_size == 0) {
|
||||
omp_set_num_threads(nthread_original);
|
||||
return max_columns;
|
||||
}
|
||||
const size_t thread_size = batch_size / nthread;
|
||||
@ -1154,7 +1150,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
});
|
||||
}
|
||||
exec.Rethrow();
|
||||
omp_set_num_threads(nthread_original);
|
||||
|
||||
return max_columns;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#include <algorithm>
|
||||
@ -126,17 +126,16 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
||||
});
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
|
||||
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, int32_t n_threads,
|
||||
common::Span<float> hess) {
|
||||
// We use sorted sketching for approx tree method since it's more efficient in
|
||||
// computation time (but higher memory usage).
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess);
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const int32_t nthread = omp_get_max_threads();
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||
|
||||
this->p_fmat = p_fmat;
|
||||
size_t new_size = 1;
|
||||
@ -154,7 +153,7 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, nthread);
|
||||
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
|
||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
||||
rbegin += batch.Size();
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
@ -19,9 +19,8 @@ namespace xgboost {
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
class GHistIndexMatrix {
|
||||
void PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
size_t rbegin, size_t prev_sum, uint32_t nbins,
|
||||
int32_t n_threads);
|
||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
|
||||
size_t prev_sum, uint32_t nbins, int32_t n_threads);
|
||||
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
@ -37,11 +36,13 @@ class GHistIndexMatrix {
|
||||
size_t base_rowid{0};
|
||||
|
||||
GHistIndexMatrix() = default;
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span<float> hess = {}) {
|
||||
this->Init(x, max_bin, sorted_sketch, hess);
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, int32_t n_threads,
|
||||
common::Span<float> hess = {}) {
|
||||
this->Init(x, max_bin, sorted_sketch, n_threads, hess);
|
||||
}
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span<float> hess);
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, int32_t n_threads,
|
||||
common::Span<float> hess);
|
||||
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
||||
int32_t n_threads);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014~2021 by Contributors
|
||||
* Copyright 2014~2022 by XGBoost Contributors
|
||||
* \file simple_dmatrix.cc
|
||||
* \brief the input data structure for gradient boosting
|
||||
* \author Tianqi Chen
|
||||
@ -55,7 +55,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_)));
|
||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
||||
@ -66,7 +66,7 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
sorted_column_page_.reset(
|
||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_)));
|
||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
||||
sorted_column_page_->SortRows();
|
||||
}
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
@ -99,7 +99,8 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
CHECK_EQ(param.gpu_id, -1);
|
||||
// Used only by approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess));
|
||||
gradient_index_.reset(
|
||||
new GHistIndexMatrix(this, param.max_bin, sorted_sketch, this->ctx_.Threads(), param.hess));
|
||||
batch_param_ = param;
|
||||
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
||||
}
|
||||
@ -110,6 +111,8 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
this->ctx_.nthread = nthread;
|
||||
|
||||
std::vector<uint64_t> qids;
|
||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
@ -124,7 +127,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
// Iterate over batches of input data
|
||||
while (adapter->Next()) {
|
||||
auto& batch = adapter->Value();
|
||||
auto batch_max_columns = sparse_page_->Push(batch, missing, nthread);
|
||||
auto batch_max_columns = sparse_page_->Push(batch, missing, ctx_.Threads());
|
||||
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
||||
total_batch_size += batch.Size();
|
||||
// Append meta information if available
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file simple_dmatrix.h
|
||||
* \brief In-memory version of DMatrix.
|
||||
* \author Tianqi Chen
|
||||
@ -61,6 +61,9 @@ class SimpleDMatrix : public DMatrix {
|
||||
bool SparsePageExists() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
GenericParameter ctx_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by Contributors
|
||||
* \file sparse_page_dmatrix.cc
|
||||
* \brief The external memory version of Page Iterator.
|
||||
* \author Tianqi Chen
|
||||
@ -164,7 +164,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
|
||||
// all index here.
|
||||
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
|
||||
this->InitializeSparsePage();
|
||||
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen});
|
||||
ghist_index_page_.reset(
|
||||
new GHistIndexMatrix{this, param.max_bin, param.regen, ctx_.Threads()});
|
||||
this->InitializeSparsePage();
|
||||
batch_param_ = param;
|
||||
}
|
||||
@ -181,7 +182,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
|
||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
// Use sorted sketch for approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess);
|
||||
auto cuts =
|
||||
common::SketchOnDMatrix(this, param.max_bin, ctx_.Threads(), sorted_sketch, param.hess);
|
||||
this->InitializeSparsePage(); // reset after use.
|
||||
|
||||
batch_param_ = param;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* \file sparse_page_source.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||
@ -311,7 +311,7 @@ class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
|
||||
auto const &csr = source_->Page();
|
||||
this->page_.reset(new CSCPage{});
|
||||
// we might be able to optimize this by merging transpose and pushcsc
|
||||
this->page_->PushCSC(csr->GetTranspose(n_features_));
|
||||
this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
|
||||
page_->SetBaseRowId(csr->base_rowid);
|
||||
this->WriteCache();
|
||||
}
|
||||
@ -336,7 +336,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
|
||||
auto const &csr = this->source_->Page();
|
||||
this->page_.reset(new SortedCSCPage{});
|
||||
// we might be able to optimize this by merging transpose and pushcsc
|
||||
this->page_->PushCSC(csr->GetTranspose(n_features_));
|
||||
this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
|
||||
CHECK_EQ(this->page_->Size(), n_features_);
|
||||
CHECK_EQ(this->page_->data.Size(), csr->data.Size());
|
||||
this->page_->SortRows();
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2018-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -14,7 +17,7 @@ TEST(DenseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
|
||||
@ -61,7 +64,7 @@ TEST(SparseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.5);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
@ -101,7 +104,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
@ -130,7 +133,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
|
||||
/* This should create multiple sparse pages */
|
||||
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) };
|
||||
omp_set_num_threads(nthreads);
|
||||
GHistIndexMatrix gmat(dmat.get(), 256, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), 256, false, common::OmpGetNumThreads(0));
|
||||
}
|
||||
|
||||
TEST(HistIndexCreationWithExternalMemory, Test) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
@ -188,7 +188,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
||||
std::vector<float> x_sorted(x);
|
||||
std::sort(x_sorted.begin(), x_sorted.end());
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
auto cuts_from_sketch = cuts.Values();
|
||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||
@ -207,7 +207,7 @@ TEST(HistUtil, DenseCutsAccuracyTest) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -224,11 +224,13 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
for (auto num_bins : bin_sizes) {
|
||||
{
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, true);
|
||||
HistogramCuts cuts =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), true);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
{
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, false);
|
||||
HistogramCuts cuts =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), false);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -249,13 +251,15 @@ void TestQuantileWithHessian(bool use_sorted) {
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, use_sorted, hessian);
|
||||
HistogramCuts cuts_hess =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted, hessian);
|
||||
for (size_t i = 0; i < w.size(); ++i) {
|
||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||
}
|
||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||
|
||||
HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins, use_sorted);
|
||||
HistogramCuts cuts_wh =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted);
|
||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
@ -283,7 +287,7 @@ TEST(HistUtil, DenseCutsExternalMemory) {
|
||||
auto dmat =
|
||||
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -303,7 +307,7 @@ TEST(HistUtil, IndexBinBound) {
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false);
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||
}
|
||||
@ -326,7 +330,7 @@ TEST(HistUtil, IndexBinData) {
|
||||
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false);
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
|
||||
uint32_t* offsets = hmat.index.Offset();
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
switch (max_bin) {
|
||||
@ -351,7 +355,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m =
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins);
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
@ -385,7 +389,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
ValidateCuts(cuts, m.get(), kBins);
|
||||
|
||||
if (with_group) {
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins);
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
@ -404,14 +408,12 @@ TEST(HistUtil, SketchFromWeights) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, SketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
@ -28,7 +28,7 @@ namespace common {
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, common::OmpGetNumThreads(0));
|
||||
return cuts;
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ TEST(HistUtil, DeviceSketch) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../../../src/common/quantile.h"
|
||||
@ -201,7 +204,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
.MaxCategory(17)
|
||||
.Seed(rank + seed)
|
||||
.GenerateDMatrix();
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins);
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, common::OmpGetNumThreads(0));
|
||||
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
||||
std::vector<
|
||||
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <fstream>
|
||||
@ -66,7 +69,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
||||
SparsePage page; // Consolidated sparse page
|
||||
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||
// Transpose each batch and push
|
||||
SparsePage tmp = batch.GetTranspose(ncols);
|
||||
SparsePage tmp = batch.GetTranspose(ncols, common::OmpGetNumThreads(0));
|
||||
page.PushCSC(tmp);
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
* Copyright 2021-2022 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
@ -36,7 +36,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
||||
BatchParam p(0, max_bins);
|
||||
GHistIndexMatrix gidx;
|
||||
|
||||
gidx.Init(m.get(), max_bins, false, {});
|
||||
gidx.Init(m.get(), max_bins, false, common::OmpGetNumThreads(0), {});
|
||||
|
||||
auto x_copy = x;
|
||||
std::sort(x_copy.begin(), x_copy.end());
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||
@ -29,7 +32,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
// dense, no missing values
|
||||
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
common::RowSetCollection row_set_collection;
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kRows);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018-2021 by Contributors
|
||||
* Copyright 2018-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// kNRows samples with kNCols features
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
|
||||
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false);
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
ColumnMatrix cm;
|
||||
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
@ -253,7 +253,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestInitData() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@ -270,7 +270,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
void TestInitDataSampling() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false);
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user