Prepare gradient index for Quantile DMatrix. (#8103)
* Prepare gradient index for Quantile DMatrix. - Implement push batch with adapter batch. - Implement `GetFvalue` for prediction.
This commit is contained in:
parent
1be09848a7
commit
4a4e5c7c18
@ -7,6 +7,7 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility> // std::forward
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
@ -43,7 +44,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->PushBatch(batch, ft, nbins, n_threads);
|
||||
this->PushBatch(batch, ft, n_threads);
|
||||
}
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
|
||||
@ -57,60 +58,27 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
}
|
||||
}
|
||||
|
||||
GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &&cuts,
|
||||
bst_bin_t max_bin_per_feat)
|
||||
: row_ptr(info.num_row_ + 1, 0),
|
||||
hit_count(cuts.TotalBins(), 0),
|
||||
cut{std::forward<common::HistogramCuts>(cuts)},
|
||||
max_num_bins(max_bin_per_feat),
|
||||
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
||||
|
||||
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||
|
||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
bst_bin_t n_total_bins, int32_t n_threads) {
|
||||
int32_t n_threads) {
|
||||
auto page = batch.GetView();
|
||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
||||
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
|
||||
// The number of threads is pegged to the batch size. If the OMP block is parallelized
|
||||
// on anything other than the batch/block size, it should be reassigned
|
||||
const size_t batch_threads =
|
||||
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
||||
|
||||
const size_t n_index = row_ptr[batch.Size()]; // number of entries in this page
|
||||
ResizeIndex(n_index, isDense_);
|
||||
|
||||
CHECK_GT(cut.Values().size(), 0U);
|
||||
|
||||
if (isDense_) {
|
||||
index.SetBinOffset(cut.Ptrs());
|
||||
}
|
||||
uint32_t const *offsets = index.Offset();
|
||||
|
||||
auto n_bins_total = cut.TotalBins();
|
||||
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
||||
data::SparsePageAdapterBatch adapter_batch{page};
|
||||
if (isDense_) {
|
||||
// Inside the lambda functions, bin_idx is the index for cut value across all
|
||||
// features. By subtracting it with starting pointer of each feature, we can reduce
|
||||
// it to smaller value and compress it to smaller types.
|
||||
common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
common::Span<T> index_data_span = {index.data<T>(), index.Size()};
|
||||
SetIndexData(
|
||||
index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total,
|
||||
[offsets](auto bin_idx, auto fidx) { return static_cast<T>(bin_idx - offsets[fidx]); });
|
||||
});
|
||||
} else {
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is
|
||||
not reduced */
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(n_total_bins, n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * n_total_bins + idx];
|
||||
hit_count_tloc_[tid * n_total_bins + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
|
||||
PushBatchImpl(n_threads, adapter_batch, 0, is_valid, ft);
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
|
||||
bool isDense, double sparse_thresh, int32_t n_threads) {
|
||||
CHECK_GE(n_threads, 1);
|
||||
@ -127,13 +95,30 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||
|
||||
this->PushBatch(batch, ft, nbins, n_threads);
|
||||
this->PushBatch(batch, ft, n_threads);
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
if (!std::isnan(sparse_thresh)) {
|
||||
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Batch>
|
||||
void GHistIndexMatrix::PushAdapterBatchColumns(Context const *ctx, Batch const &batch,
|
||||
float missing, size_t rbegin) {
|
||||
CHECK(columns_);
|
||||
this->columns_->PushBatch(ctx->Threads(), batch, missing, *this, rbegin);
|
||||
}
|
||||
|
||||
#define INSTANTIATION_PUSH(BatchT) \
|
||||
template void GHistIndexMatrix::PushAdapterBatchColumns<BatchT>( \
|
||||
Context const *ctx, BatchT const &batch, float missing, size_t rbegin);
|
||||
|
||||
INSTANTIATION_PUSH(data::CSRArrayAdapterBatch)
|
||||
INSTANTIATION_PUSH(data::ArrayAdapterBatch)
|
||||
INSTANTIATION_PUSH(data::SparsePageAdapterBatch)
|
||||
|
||||
#undef INSTANTIATION_PUSH
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
||||
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
|
||||
// compress dense index to uint8
|
||||
@ -156,6 +141,57 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
|
||||
return *columns_;
|
||||
}
|
||||
|
||||
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
auto const &values = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
if (is_cat) {
|
||||
auto f_begin = ptrs[fidx];
|
||||
auto f_end = ptrs[fidx + 1];
|
||||
auto begin = RowIdx(ridx);
|
||||
auto end = RowIdx(ridx + 1);
|
||||
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||
if (gidx == -1) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return values[gidx];
|
||||
}
|
||||
|
||||
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]);
|
||||
auto get_bin_idx = [&](auto &column) {
|
||||
auto bin_idx = column[ridx];
|
||||
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (bin_idx == lower) {
|
||||
return mins[fidx];
|
||||
}
|
||||
return values[bin_idx - 1];
|
||||
};
|
||||
|
||||
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
|
||||
if (columns_->AnyMissing()) {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||
return get_bin_idx(column);
|
||||
});
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||
return get_bin_idx(column);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||
return get_bin_idx(column);
|
||||
});
|
||||
}
|
||||
|
||||
SPAN_CHECK(false);
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
|
||||
return this->columns_->Read(fi, this->cut.Ptrs().data());
|
||||
}
|
||||
|
||||
@ -4,13 +4,17 @@
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -18,7 +22,6 @@ namespace xgboost {
|
||||
namespace common {
|
||||
class ColumnMatrix;
|
||||
} // namespace common
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
*
|
||||
@ -26,24 +29,39 @@ class ColumnMatrix;
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
class GHistIndexMatrix {
|
||||
// Get the size of each row
|
||||
template <typename AdapterBatchT>
|
||||
auto GetRowCounts(AdapterBatchT const& batch, float missing, int32_t n_threads) {
|
||||
std::vector<size_t> valid_counts(batch.Size(), 0);
|
||||
common::ParallelFor(batch.Size(), n_threads, [&](size_t i) {
|
||||
auto line = batch.GetLine(i);
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
data::COOTuple elem = line.GetElement(j);
|
||||
if (data::IsValidFunctor {missing}(elem)) {
|
||||
valid_counts[i]++;
|
||||
}
|
||||
}
|
||||
});
|
||||
return valid_counts;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Push a page into index matrix, the function is only necessary because hist has
|
||||
* partial support for external memory.
|
||||
*/
|
||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
|
||||
bst_bin_t n_total_bins, int32_t n_threads);
|
||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, int32_t n_threads);
|
||||
|
||||
template <typename Batch, typename BinIdxType, typename GetOffset, typename IsValid>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span, common::Span<FeatureType const> ft,
|
||||
size_t batch_threads, Batch const& batch, IsValid&& is_valid, size_t nbins,
|
||||
GetOffset&& get_offset) {
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span, size_t rbegin,
|
||||
common::Span<FeatureType const> ft, size_t batch_threads, Batch const& batch,
|
||||
IsValid&& is_valid, size_t nbins, GetOffset&& get_offset) {
|
||||
auto batch_size = batch.Size();
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
|
||||
auto line = batch.GetLine(i);
|
||||
size_t ibegin = row_ptr[i]; // index of first entry for current block
|
||||
size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block
|
||||
size_t k = 0;
|
||||
auto tid = omp_get_thread_num();
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
@ -63,6 +81,49 @@ class GHistIndexMatrix {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid,
|
||||
common::Span<FeatureType const> ft) {
|
||||
// The number of threads is pegged to the batch size. If the OMP block is parallelized
|
||||
// on anything other than the batch/block size, it should be reassigned
|
||||
size_t batch_threads =
|
||||
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
||||
|
||||
auto n_bins_total = cut.TotalBins();
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
|
||||
ResizeIndex(n_index, isDense_);
|
||||
if (isDense_) {
|
||||
index.SetBinOffset(cut.Ptrs());
|
||||
}
|
||||
uint32_t const* offsets = index.Offset();
|
||||
if (isDense_) {
|
||||
// Inside the lambda functions, bin_idx is the index for cut value across all
|
||||
// features. By subtracting it with starting pointer of each feature, we can reduce
|
||||
// it to smaller value and compress it to smaller types.
|
||||
common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
common::Span<T> index_data_span = {index.data<T>(), index.Size()};
|
||||
SetIndexData(
|
||||
index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
|
||||
[offsets](auto bin_idx, auto fidx) { return static_cast<T>(bin_idx - offsets[fidx]); });
|
||||
});
|
||||
} else {
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is
|
||||
not reduced */
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
|
||||
hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
std::vector<size_t> row_ptr;
|
||||
@ -77,15 +138,53 @@ class GHistIndexMatrix {
|
||||
/*! \brief base row index for current page (used by external memory) */
|
||||
size_t base_rowid{0};
|
||||
|
||||
GHistIndexMatrix();
|
||||
~GHistIndexMatrix();
|
||||
/**
|
||||
* \brief Constrcutor for SimpleDMatrix.
|
||||
*/
|
||||
GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh,
|
||||
bool sorted_sketch, int32_t n_threads, common::Span<float> hess = {});
|
||||
~GHistIndexMatrix();
|
||||
|
||||
// Create a global histogram matrix, given cut. Used by external memory
|
||||
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||
/**
|
||||
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
|
||||
* for push batch.
|
||||
*/
|
||||
GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat);
|
||||
/**
|
||||
* \brief Constructor for external memory.
|
||||
*/
|
||||
GHistIndexMatrix(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
||||
double sparse_thresh, int32_t n_threads);
|
||||
GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back.
|
||||
|
||||
template <typename Batch>
|
||||
void PushAdapterBatch(Context const* ctx, size_t rbegin, size_t prev_sum, Batch const& batch,
|
||||
float missing, common::Span<FeatureType const> ft, double sparse_thresh,
|
||||
size_t n_samples_total) {
|
||||
auto n_bins_total = cut.TotalBins();
|
||||
hit_count_tloc_.clear();
|
||||
hit_count_tloc_.resize(ctx->Threads() * n_bins_total, 0);
|
||||
|
||||
auto n_threads = ctx->Threads();
|
||||
auto valid_counts = GetRowCounts(batch, missing, n_threads);
|
||||
|
||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return valid_counts[ridx]; });
|
||||
common::PartialSum(n_threads, it, it + batch.Size(), prev_sum, row_ptr.begin() + rbegin);
|
||||
auto is_valid = data::IsValidFunctor{missing};
|
||||
|
||||
PushBatchImpl(ctx->Threads(), batch, rbegin, is_valid, ft);
|
||||
|
||||
if (rbegin + batch.Size() == n_samples_total) {
|
||||
// finished
|
||||
CHECK(!std::isnan(sparse_thresh));
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, sparse_thresh);
|
||||
}
|
||||
}
|
||||
|
||||
// Call ColumnMatrix::PushBatch
|
||||
template <typename Batch>
|
||||
void PushAdapterBatchColumns(Context const* ctx, Batch const& batch, float missing,
|
||||
size_t rbegin);
|
||||
|
||||
void ResizeIndex(const size_t n_index, const bool isDense);
|
||||
|
||||
@ -117,6 +216,8 @@ class GHistIndexMatrix {
|
||||
|
||||
common::ColumnMatrix const& Transpose() const;
|
||||
|
||||
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<common::ColumnMatrix> columns_;
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
|
||||
@ -15,10 +15,9 @@ void GradientIndexPageSource::Fetch() {
|
||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||
CHECK_EQ(count_, source_->Iter());
|
||||
auto const& csr = source_->Page();
|
||||
this->page_.reset(new GHistIndexMatrix());
|
||||
CHECK_NE(cuts_.Values().size(), 0);
|
||||
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_,
|
||||
nthreads_);
|
||||
this->page_.reset(new GHistIndexMatrix(*csr, feature_types_, cuts_, max_bin_per_feat_,
|
||||
is_dense_, sparse_thresh_, nthreads_));
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "../../../src/common/column_matrix.h"
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
@ -65,5 +66,46 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
||||
ASSERT_EQ(common::AsCat(x[i]), common::AsCat(bin_value));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientIndex, PushBatch) {
|
||||
size_t constexpr kRows = 64, kCols = 4;
|
||||
bst_bin_t max_bins = 64;
|
||||
float st = 0.5;
|
||||
|
||||
auto test = [&](float sparisty) {
|
||||
auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true);
|
||||
auto cuts = common::SketchOnDMatrix(m.get(), max_bins, common::OmpGetNumThreads(0), false, {});
|
||||
common::HistogramCuts copy_cuts = cuts;
|
||||
|
||||
ASSERT_EQ(m->Info().num_row_, kRows);
|
||||
ASSERT_EQ(m->Info().num_col_, kCols);
|
||||
GHistIndexMatrix gmat{m->Info(), std::move(copy_cuts), max_bins};
|
||||
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
SparsePageAdapterBatch batch{page.GetView()};
|
||||
gmat.PushAdapterBatch(m->Ctx(), 0, 0, batch, std::numeric_limits<float>::quiet_NaN(), {}, st,
|
||||
m->Info().num_row_);
|
||||
gmat.PushAdapterBatchColumns(m->Ctx(), batch, std::numeric_limits<float>::quiet_NaN(), 0);
|
||||
}
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(BatchParam{max_bins, st})) {
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
for (size_t j = 0; j < kCols; ++j) {
|
||||
auto v0 = gmat.GetFvalue(i, j, false);
|
||||
auto v1 = page.GetFvalue(i, j, false);
|
||||
if (sparisty == 0.0) {
|
||||
ASSERT_FALSE(std::isnan(v0));
|
||||
}
|
||||
if (!std::isnan(v0)) {
|
||||
ASSERT_EQ(v0, v1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
test(0.0f);
|
||||
test(0.5f);
|
||||
test(0.9f);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -66,6 +66,14 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
learner->UpdateOneIter(i, p_hist);
|
||||
}
|
||||
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
learner.reset(Learner::Create({}));
|
||||
learner->LoadModel(model);
|
||||
learner->SetParam("predictor", predictor);
|
||||
learner->Configure();
|
||||
|
||||
HostDeviceVector<float> from_full;
|
||||
learner->Predict(p_full, false, &from_full, 0, 0);
|
||||
|
||||
|
||||
@ -419,9 +419,8 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
|
||||
|
||||
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0),
|
||||
false, hess);
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits<double>::quiet_NaN(),
|
||||
common::OmpGetNumThreads(0));
|
||||
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
||||
std::numeric_limits<double>::quiet_NaN(), common::OmpGetNumThreads(0));
|
||||
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair);
|
||||
single_page = single_build.Histogram()[0];
|
||||
}
|
||||
|
||||
@ -34,8 +34,7 @@ TEST(QuantileHist, Partitioner) {
|
||||
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
|
||||
|
||||
for (auto const& page : Xy->GetBatches<SparsePage>()) {
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(page, {}, cuts, 64, true, 0.5, ctx.Threads());
|
||||
GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads());
|
||||
bst_feature_t const split_ind = 0;
|
||||
common::ColumnMatrix column_indices;
|
||||
column_indices.Init(page, gmat, 0.5, ctx.Threads());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user