Refactor for GHistIndex. (#7923)
* Pass sparse page as adapter, which prepares for quantile dmatrix. * Remove old external memory code like `rbegin` and extra `Init` function. * Simplify type dispatch.
This commit is contained in:
parent
d314680a15
commit
18a38f7ca0
@ -125,26 +125,25 @@ class HistogramCuts {
|
|||||||
/**
|
/**
|
||||||
* \brief Search the bin index for numerical feature.
|
* \brief Search the bin index for numerical feature.
|
||||||
*/
|
*/
|
||||||
bst_bin_t SearchBin(Entry const& e) const {
|
bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); }
|
||||||
return SearchBin(e.fvalue, e.index);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Search the bin index for categorical feature.
|
* \brief Search the bin index for categorical feature.
|
||||||
*/
|
*/
|
||||||
bst_bin_t SearchCatBin(Entry const &e) const {
|
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
|
||||||
auto const &ptrs = this->Ptrs();
|
auto const &ptrs = this->Ptrs();
|
||||||
auto const &vals = this->Values();
|
auto const &vals = this->Values();
|
||||||
auto end = ptrs.at(e.index + 1) + vals.cbegin();
|
auto end = ptrs.at(fidx + 1) + vals.cbegin();
|
||||||
auto beg = ptrs[e.index] + vals.cbegin();
|
auto beg = ptrs[fidx] + vals.cbegin();
|
||||||
// Truncates the value in case it's not perfectly rounded.
|
// Truncates the value in case it's not perfectly rounded.
|
||||||
auto v = static_cast<float>(common::AsCat(e.fvalue));
|
auto v = static_cast<float>(common::AsCat(value));
|
||||||
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
|
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
|
||||||
if (bin_idx == ptrs.at(e.index + 1)) {
|
if (bin_idx == ptrs.at(fidx + 1)) {
|
||||||
bin_idx -= 1;
|
bin_idx -= 1;
|
||||||
}
|
}
|
||||||
return bin_idx;
|
return bin_idx;
|
||||||
}
|
}
|
||||||
|
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -17,85 +17,15 @@ namespace xgboost {
|
|||||||
|
|
||||||
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
|
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
|
||||||
|
|
||||||
GHistIndexMatrix::GHistIndexMatrix(DMatrix *x, int32_t max_bin, double sparse_thresh,
|
GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||||
bool sorted_sketch, int32_t n_threads,
|
double sparse_thresh, bool sorted_sketch, int32_t n_threads,
|
||||||
common::Span<float> hess) {
|
common::Span<float> hess) {
|
||||||
this->Init(x, max_bin, sparse_thresh, sorted_sketch, n_threads, hess);
|
CHECK(p_fmat->SingleColBlock());
|
||||||
}
|
|
||||||
|
|
||||||
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
|
||||||
|
|
||||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
|
||||||
common::Span<FeatureType const> ft,
|
|
||||||
size_t rbegin, size_t prev_sum, uint32_t nbins,
|
|
||||||
int32_t n_threads) {
|
|
||||||
auto page = batch.GetView();
|
|
||||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
|
||||||
common::PartialSum(n_threads, it, it + page.Size(), prev_sum, row_ptr.begin() + rbegin);
|
|
||||||
// The number of threads is pegged to the batch size. If the OMP block is parallelized
|
|
||||||
// on anything other than the batch/block size, it should be reassigned
|
|
||||||
const size_t batch_threads =
|
|
||||||
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
|
||||||
|
|
||||||
const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
|
|
||||||
ResizeIndex(n_index, isDense_);
|
|
||||||
|
|
||||||
CHECK_GT(cut.Values().size(), 0U);
|
|
||||||
|
|
||||||
if (isDense_) {
|
|
||||||
index.SetBinOffset(cut.Ptrs());
|
|
||||||
}
|
|
||||||
uint32_t const *offsets = index.Offset();
|
|
||||||
|
|
||||||
if (isDense_) {
|
|
||||||
// Inside the lambda functions, bin_idx is the index for cut value across all
|
|
||||||
// features. By subtracting it with starting pointer of each feature, we can reduce
|
|
||||||
// it to smaller value and compress it to smaller types.
|
|
||||||
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
|
||||||
if (curent_bin_size == common::kUint8BinsTypeSize) {
|
|
||||||
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index};
|
|
||||||
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
|
|
||||||
[offsets](auto bin_idx, auto fidx) {
|
|
||||||
return static_cast<uint8_t>(bin_idx - offsets[fidx]);
|
|
||||||
});
|
|
||||||
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
|
|
||||||
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(), n_index};
|
|
||||||
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
|
|
||||||
[offsets](auto bin_idx, auto fidx) {
|
|
||||||
return static_cast<uint16_t>(bin_idx - offsets[fidx]);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
|
|
||||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
|
||||||
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
|
|
||||||
[offsets](auto bin_idx, auto fidx) {
|
|
||||||
return static_cast<uint32_t>(bin_idx - offsets[fidx]);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
/* For sparse DMatrix we have to store index of feature for each bin
|
|
||||||
in index field to chose right offset. So offset is nullptr and index is
|
|
||||||
not reduced */
|
|
||||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
|
||||||
SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins,
|
|
||||||
[](auto idx, auto) { return idx; });
|
|
||||||
}
|
|
||||||
|
|
||||||
common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) {
|
|
||||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
|
||||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
|
||||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch,
|
|
||||||
int32_t n_threads, common::Span<float> hess) {
|
|
||||||
// We use sorted sketching for approx tree method since it's more efficient in
|
// We use sorted sketching for approx tree method since it's more efficient in
|
||||||
// computation time (but higher memory usage).
|
// computation time (but higher memory usage).
|
||||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess);
|
cut = common::SketchOnDMatrix(p_fmat, max_bins_per_feat, n_threads, sorted_sketch, hess);
|
||||||
|
|
||||||
max_num_bins = max_bins;
|
max_num_bins = max_bins_per_feat;
|
||||||
const uint32_t nbins = cut.Ptrs().back();
|
const uint32_t nbins = cut.Ptrs().back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count.resize(nbins, 0);
|
||||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||||
@ -108,16 +38,12 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
|
|||||||
row_ptr.resize(new_size);
|
row_ptr.resize(new_size);
|
||||||
row_ptr[0] = 0;
|
row_ptr[0] = 0;
|
||||||
|
|
||||||
size_t rbegin = 0;
|
|
||||||
size_t prev_sum = 0;
|
|
||||||
const bool isDense = p_fmat->IsDense();
|
const bool isDense = p_fmat->IsDense();
|
||||||
this->isDense_ = isDense;
|
this->isDense_ = isDense;
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
|
this->PushBatch(batch, ft, nbins, n_threads);
|
||||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
|
||||||
rbegin += batch.Size();
|
|
||||||
}
|
}
|
||||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||||
|
|
||||||
@ -131,6 +57,59 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||||
|
|
||||||
|
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||||
|
bst_bin_t n_total_bins, int32_t n_threads) {
|
||||||
|
auto page = batch.GetView();
|
||||||
|
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
||||||
|
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
|
||||||
|
// The number of threads is pegged to the batch size. If the OMP block is parallelized
|
||||||
|
// on anything other than the batch/block size, it should be reassigned
|
||||||
|
const size_t batch_threads =
|
||||||
|
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
||||||
|
|
||||||
|
const size_t n_index = row_ptr[batch.Size()]; // number of entries in this page
|
||||||
|
ResizeIndex(n_index, isDense_);
|
||||||
|
|
||||||
|
CHECK_GT(cut.Values().size(), 0U);
|
||||||
|
|
||||||
|
if (isDense_) {
|
||||||
|
index.SetBinOffset(cut.Ptrs());
|
||||||
|
}
|
||||||
|
uint32_t const *offsets = index.Offset();
|
||||||
|
|
||||||
|
auto n_bins_total = cut.TotalBins();
|
||||||
|
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
||||||
|
data::SparsePageAdapterBatch adapter_batch{page};
|
||||||
|
if (isDense_) {
|
||||||
|
// Inside the lambda functions, bin_idx is the index for cut value across all
|
||||||
|
// features. By subtracting it with starting pointer of each feature, we can reduce
|
||||||
|
// it to smaller value and compress it to smaller types.
|
||||||
|
common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) {
|
||||||
|
using T = decltype(dtype);
|
||||||
|
common::Span<T> index_data_span = {index.data<T>(), index.Size()};
|
||||||
|
SetIndexData(
|
||||||
|
index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total,
|
||||||
|
[offsets](auto bin_idx, auto fidx) { return static_cast<T>(bin_idx - offsets[fidx]); });
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
/* For sparse DMatrix we have to store index of feature for each bin
|
||||||
|
in index field to chose right offset. So offset is nullptr and index is
|
||||||
|
not reduced */
|
||||||
|
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||||
|
SetIndexData(index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total,
|
||||||
|
[](auto idx, auto) { return idx; });
|
||||||
|
}
|
||||||
|
|
||||||
|
common::ParallelFor(n_total_bins, n_threads, [&](bst_omp_uint idx) {
|
||||||
|
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||||
|
hit_count[idx] += hit_count_tloc_[tid * n_total_bins + idx];
|
||||||
|
hit_count_tloc_[tid * n_total_bins + idx] = 0; // reset for next batch
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
|
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||||
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
|
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
|
||||||
bool isDense, double sparse_thresh, int32_t n_threads) {
|
bool isDense, double sparse_thresh, int32_t n_threads) {
|
||||||
@ -148,10 +127,7 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
|
|||||||
hit_count.resize(nbins, 0);
|
hit_count.resize(nbins, 0);
|
||||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||||
|
|
||||||
size_t rbegin = 0;
|
this->PushBatch(batch, ft, nbins, n_threads);
|
||||||
size_t prev_sum = 0;
|
|
||||||
|
|
||||||
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
|
|
||||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||||
if (!std::isnan(sparse_thresh)) {
|
if (!std::isnan(sparse_thresh)) {
|
||||||
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
|
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "adapter.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
@ -32,8 +33,38 @@ class GHistIndexMatrix {
|
|||||||
* \param rbegin The beginning row index of current page. (total rows in previous pages)
|
* \param rbegin The beginning row index of current page. (total rows in previous pages)
|
||||||
* \param prev_sum Total number of entries in previous pages.
|
* \param prev_sum Total number of entries in previous pages.
|
||||||
*/
|
*/
|
||||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
|
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
|
||||||
size_t prev_sum, uint32_t nbins, int32_t n_threads);
|
bst_bin_t n_total_bins, int32_t n_threads);
|
||||||
|
|
||||||
|
template <typename Batch, typename BinIdxType, typename GetOffset, typename IsValid>
|
||||||
|
void SetIndexData(common::Span<BinIdxType> index_data_span, common::Span<FeatureType const> ft,
|
||||||
|
size_t batch_threads, Batch const& batch, IsValid&& is_valid, size_t nbins,
|
||||||
|
GetOffset&& get_offset) {
|
||||||
|
auto batch_size = batch.Size();
|
||||||
|
BinIdxType* index_data = index_data_span.data();
|
||||||
|
auto const& ptrs = cut.Ptrs();
|
||||||
|
auto const& values = cut.Values();
|
||||||
|
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
|
||||||
|
auto line = batch.GetLine(i);
|
||||||
|
size_t ibegin = row_ptr[i]; // index of first entry for current block
|
||||||
|
size_t k = 0;
|
||||||
|
auto tid = omp_get_thread_num();
|
||||||
|
for (size_t j = 0; j < line.Size(); ++j) {
|
||||||
|
data::COOTuple elem = line.GetElement(j);
|
||||||
|
if (is_valid(elem)) {
|
||||||
|
bst_bin_t bin_idx{-1};
|
||||||
|
if (common::IsCat(ft, elem.column_idx)) {
|
||||||
|
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx);
|
||||||
|
} else {
|
||||||
|
bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values);
|
||||||
|
}
|
||||||
|
index_data[ibegin + k] = get_offset(bin_idx, j);
|
||||||
|
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||||
|
++k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief row pointer to rows by element position */
|
/*! \brief row pointer to rows by element position */
|
||||||
@ -50,52 +81,15 @@ class GHistIndexMatrix {
|
|||||||
size_t base_rowid{0};
|
size_t base_rowid{0};
|
||||||
|
|
||||||
GHistIndexMatrix();
|
GHistIndexMatrix();
|
||||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, double sparse_thresh, bool sorted_sketch,
|
GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh,
|
||||||
int32_t n_threads, common::Span<float> hess = {});
|
bool sorted_sketch, int32_t n_threads, common::Span<float> hess = {});
|
||||||
~GHistIndexMatrix();
|
~GHistIndexMatrix();
|
||||||
|
|
||||||
// Create a global histogram matrix, given cut
|
// Create a global histogram matrix, given cut. Used by external memory
|
||||||
void Init(DMatrix* p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch,
|
|
||||||
int32_t n_threads, common::Span<float> hess);
|
|
||||||
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||||
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
||||||
double sparse_thresh, int32_t n_threads);
|
double sparse_thresh, int32_t n_threads);
|
||||||
|
|
||||||
// specific method for sparse data as no possibility to reduce allocated memory
|
|
||||||
template <typename BinIdxType, typename GetOffset>
|
|
||||||
void SetIndexData(common::Span<BinIdxType> index_data_span,
|
|
||||||
common::Span<FeatureType const> ft,
|
|
||||||
size_t batch_threads, const SparsePage &batch,
|
|
||||||
size_t rbegin, size_t nbins, GetOffset get_offset) {
|
|
||||||
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
|
|
||||||
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
|
|
||||||
const size_t batch_size = batch.Size();
|
|
||||||
CHECK_LT(batch_size, offset_vec.size());
|
|
||||||
BinIdxType* index_data = index_data_span.data();
|
|
||||||
auto const& ptrs = cut.Ptrs();
|
|
||||||
auto const& values = cut.Values();
|
|
||||||
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) {
|
|
||||||
const int tid = omp_get_thread_num();
|
|
||||||
size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block
|
|
||||||
size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block
|
|
||||||
const size_t size = offset_vec[ridx + 1] - offset_vec[ridx];
|
|
||||||
SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size};
|
|
||||||
CHECK_EQ(ibegin + inst.size(), iend);
|
|
||||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
|
||||||
auto e = inst[j];
|
|
||||||
if (common::IsCat(ft, e.index)) {
|
|
||||||
bst_bin_t bin_idx = cut.SearchCatBin(e);
|
|
||||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
|
||||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
|
||||||
} else {
|
|
||||||
bst_bin_t bin_idx = cut.SearchBin(e.fvalue, e.index, ptrs, values);
|
|
||||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
|
||||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void ResizeIndex(const size_t n_index, const bool isDense);
|
void ResizeIndex(const size_t n_index, const bool isDense);
|
||||||
|
|
||||||
void GetFeatureCounts(size_t* counts) const {
|
void GetFeatureCounts(size_t* counts) const {
|
||||||
|
|||||||
@ -107,22 +107,5 @@ TEST(DenseColumnWithMissing, Test) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestGHistIndexMatrixCreation(size_t nthreads) {
|
|
||||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
|
||||||
/* This should create multiple sparse pages */
|
|
||||||
std::unique_ptr<DMatrix> dmat{CreateSparsePageDMatrix(kEntries)};
|
|
||||||
GHistIndexMatrix gmat(dmat.get(), 256, 0.5f, false, common::OmpGetNumThreads(nthreads));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HistIndexCreationWithExternalMemory, Test) {
|
|
||||||
// Vary the number of threads to make sure that the last batch
|
|
||||||
// is distributed properly to the available number of threads
|
|
||||||
// in the thread pool
|
|
||||||
TestGHistIndexMatrixCreation(20);
|
|
||||||
TestGHistIndexMatrixCreation(30);
|
|
||||||
TestGHistIndexMatrixCreation(40);
|
|
||||||
}
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -44,9 +44,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
|||||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||||
|
|
||||||
BatchParam p(max_bins, 0.8);
|
BatchParam p(max_bins, 0.8);
|
||||||
GHistIndexMatrix gidx;
|
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
|
||||||
|
|
||||||
gidx.Init(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
|
|
||||||
|
|
||||||
auto x_copy = x;
|
auto x_copy = x;
|
||||||
std::sort(x_copy.begin(), x_copy.end());
|
std::sort(x_copy.begin(), x_copy.end());
|
||||||
|
|||||||
@ -413,10 +413,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
|
|||||||
|
|
||||||
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
|
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
|
||||||
SparsePage concat;
|
SparsePage concat;
|
||||||
GHistIndexMatrix gmat;
|
|
||||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||||
gmat.Init(m.get(), batch_param.max_bin, std::numeric_limits<double>::quiet_NaN(), false,
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
common::OmpGetNumThreads(0), hess);
|
concat.Push(page);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0),
|
||||||
|
false, hess);
|
||||||
|
GHistIndexMatrix gmat;
|
||||||
|
gmat.Init(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits<double>::quiet_NaN(),
|
||||||
|
common::OmpGetNumThreads(0));
|
||||||
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair);
|
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair);
|
||||||
single_page = single_build.Histogram()[0];
|
single_page = single_build.Histogram()[0];
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user