Initial support for external memory in gradient index. (#7183)
* Add hessian to batch param in preparation of new approx impl. * Extract a push method for gradient index matrix. * Use span instead of vector ref for hessian in sketching. * Create a binary format for gradient index.
This commit is contained in:
parent
a0dcf6f5c1
commit
3515931305
@ -38,6 +38,8 @@
|
||||
#include "../src/data/sparse_page_raw_format.cc"
|
||||
#include "../src/data/ellpack_page.cc"
|
||||
#include "../src/data/gradient_index.cc"
|
||||
#include "../src/data/gradient_index_page_source.cc"
|
||||
#include "../src/data/gradient_index_format.cc"
|
||||
#include "../src/data/sparse_page_dmatrix.cc"
|
||||
#include "../src/data/proxy_dmatrix.cc"
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* Copyright (c) 2015-2021 by Contributors
|
||||
* \file data.h
|
||||
* \brief The input data structure of xgboost.
|
||||
* \author Tianqi Chen
|
||||
@ -214,12 +214,27 @@ struct BatchParam {
|
||||
int gpu_id;
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
int max_bin{0};
|
||||
/*! \brief Hessian, used for sketching with future approx implementation. */
|
||||
common::Span<float> hess;
|
||||
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
|
||||
bool regen {false};
|
||||
|
||||
BatchParam() = default;
|
||||
BatchParam(int32_t device, int32_t max_bin)
|
||||
: gpu_id{device}, max_bin{max_bin} {}
|
||||
/**
|
||||
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
|
||||
* the span is changed, so caller should keep the span for each iteration.
|
||||
*/
|
||||
BatchParam(int32_t device, int32_t max_bin, common::Span<float> hessian,
|
||||
bool regenerate = false)
|
||||
: gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||
|
||||
bool operator!=(const BatchParam& other) const {
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin;
|
||||
if (hess.empty() && other.hess.empty()) {
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin;
|
||||
}
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ class HistogramCuts {
|
||||
};
|
||||
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
std::vector<float> const &hessian = {}) {
|
||||
Span<float> const hessian = {}) {
|
||||
HistogramCuts out;
|
||||
auto const& info = m->Info();
|
||||
const auto threads = omp_get_max_threads();
|
||||
@ -136,7 +136,7 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
return out;
|
||||
}
|
||||
|
||||
enum BinTypeSize {
|
||||
enum BinTypeSize : uint32_t {
|
||||
kUint8BinsTypeSize = 1,
|
||||
kUint16BinsTypeSize = 2,
|
||||
kUint32BinsTypeSize = 4
|
||||
@ -207,6 +207,13 @@ struct Index {
|
||||
return data_.end();
|
||||
}
|
||||
|
||||
std::vector<uint8_t>::iterator begin() { // NOLINT
|
||||
return data_.begin();
|
||||
}
|
||||
std::vector<uint8_t>::iterator end() { // NOLINT
|
||||
return data_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
static uint32_t GetValueFromUint8(void *t, size_t i) {
|
||||
return reinterpret_cast<uint8_t*>(t)[i];
|
||||
|
||||
@ -94,26 +94,26 @@ std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
|
||||
namespace {
|
||||
// Function to merge hessian and sample weights
|
||||
std::vector<float> MergeWeights(MetaInfo const &info,
|
||||
std::vector<float> const &hessian,
|
||||
Span<float> const hessian,
|
||||
bool use_group, int32_t n_threads) {
|
||||
CHECK_EQ(hessian.size(), info.num_row_);
|
||||
std::vector<float> results(hessian.size());
|
||||
auto const &group_ptr = info.group_ptr_;
|
||||
auto const& weights = info.weights_.HostVector();
|
||||
auto get_weight = [&](size_t i) { return weights.empty() ? 1.0f : weights[i]; };
|
||||
if (use_group) {
|
||||
auto const &group_weights = info.weights_.HostVector();
|
||||
CHECK_GE(group_ptr.size(), 2);
|
||||
CHECK_EQ(group_ptr.back(), hessian.size());
|
||||
size_t cur_group = 0;
|
||||
for (size_t i = 0; i < hessian.size(); ++i) {
|
||||
results[i] = hessian[i] * group_weights[cur_group];
|
||||
results[i] = hessian[i] * get_weight(cur_group);
|
||||
if (i == group_ptr[cur_group + 1]) {
|
||||
cur_group++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto const &sample_weights = info.weights_.HostVector();
|
||||
ParallelFor(hessian.size(), n_threads, Sched::Auto(),
|
||||
[&](auto i) { results[i] = hessian[i] * sample_weights[i]; });
|
||||
[&](auto i) { results[i] = hessian[i] * get_weight(i); });
|
||||
}
|
||||
return results;
|
||||
}
|
||||
@ -141,7 +141,7 @@ std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
} // anonymous namespace
|
||||
|
||||
void HostSketchContainer::PushRowPage(
|
||||
SparsePage const &page, MetaInfo const &info, std::vector<float> const &hessian) {
|
||||
SparsePage const &page, MetaInfo const &info, Span<float> hessian) {
|
||||
monitor_.Start(__func__);
|
||||
bst_feature_t n_columns = info.num_col_;
|
||||
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
|
||||
|
||||
@ -760,7 +760,7 @@ class HostSketchContainer {
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info,
|
||||
std::vector<float> const &hessian = {});
|
||||
Span<float> const hessian = {});
|
||||
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
};
|
||||
|
||||
@ -32,6 +32,7 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::CSCPage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::GHistIndexMatrix>);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace {
|
||||
@ -1089,5 +1090,6 @@ namespace data {
|
||||
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
|
||||
DMLC_REGISTRY_LINK_TAG(gradient_index_format);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -4,8 +4,9 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "./sparse_page_writer.h"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "sparse_page_writer.h"
|
||||
#include "histogram_cut_format.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@ -17,9 +18,9 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||
public:
|
||||
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
|
||||
auto* impl = page->Impl();
|
||||
fi->Read(&impl->Cuts().cut_values_.HostVector());
|
||||
fi->Read(&impl->Cuts().cut_ptrs_.HostVector());
|
||||
fi->Read(&impl->Cuts().min_vals_.HostVector());
|
||||
if (!ReadHistogramCuts(&impl->Cuts(), fi)) {
|
||||
return false;
|
||||
}
|
||||
fi->Read(&impl->n_rows);
|
||||
fi->Read(&impl->is_dense);
|
||||
fi->Read(&impl->row_stride);
|
||||
@ -33,12 +34,7 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||
size_t Write(const EllpackPage& page, dmlc::Stream* fo) override {
|
||||
size_t bytes = 0;
|
||||
auto* impl = page.Impl();
|
||||
fo->Write(impl->Cuts().cut_values_.ConstHostVector());
|
||||
bytes += impl->Cuts().cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector());
|
||||
bytes += impl->Cuts().cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
fo->Write(impl->Cuts().min_vals_.ConstHostVector());
|
||||
bytes += impl->Cuts().min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
bytes += WriteHistogramCuts(impl->Cuts(), fo);
|
||||
fo->Write(impl->n_rows);
|
||||
bytes += sizeof(impl->n_rows);
|
||||
fo->Write(impl->is_dense);
|
||||
|
||||
@ -32,7 +32,7 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
||||
size_t row_stride, common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
|
||||
is_dense_{is_dense}, row_stride_{row_stride}, param_{param},
|
||||
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
|
||||
feature_types_{feature_types}, cuts_{std::move(cuts)} {
|
||||
this->source_ = source;
|
||||
this->Fetch();
|
||||
|
||||
@ -8,8 +8,125 @@
|
||||
#include "../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins);
|
||||
|
||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin,
|
||||
size_t prev_sum, uint32_t nbins,
|
||||
int32_t n_threads) {
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
const size_t batch_threads =
|
||||
std::max(size_t(1), std::min(batch.Size(),
|
||||
static_cast<size_t>(n_threads)));
|
||||
auto page = batch.GetView();
|
||||
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
size_t *p_part = partial_sums.Get();
|
||||
|
||||
size_t block_size = batch.Size() / batch_threads;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(batch_threads)
|
||||
{
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads - 1) ? batch.Size()
|
||||
: (block_size * (tid + 1)));
|
||||
|
||||
size_t sum = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
sum += page[i].size();
|
||||
row_ptr[rbegin + 1 + i] = sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp single
|
||||
{
|
||||
exc.Run([&]() {
|
||||
p_part[0] = prev_sum;
|
||||
for (size_t i = 1; i < batch_threads; ++i) {
|
||||
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i * block_size];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads - 1) ? batch.Size()
|
||||
: (block_size * (tid + 1)));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
row_ptr[rbegin + 1 + i] += p_part[tid];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
const size_t n_offsets = cut.Ptrs().size() - 1;
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()];
|
||||
ResizeIndex(n_index, isDense_);
|
||||
|
||||
CHECK_GT(cut.Values().size(), 0U);
|
||||
|
||||
uint32_t *offsets = nullptr;
|
||||
if (isDense_) {
|
||||
index.ResizeOffset(n_offsets);
|
||||
offsets = index.Offset();
|
||||
for (size_t i = 0; i < n_offsets; ++i) {
|
||||
offsets[i] = cut.Ptrs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (isDense_) {
|
||||
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
||||
if (curent_bin_size == common::kUint8BinsTypeSize) {
|
||||
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(), n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint8_t>(idx - offsets[j]);
|
||||
});
|
||||
|
||||
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
|
||||
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint16_t>(idx - offsets[j]);
|
||||
});
|
||||
} else {
|
||||
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint32_t>(idx - offsets[j]);
|
||||
});
|
||||
}
|
||||
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is
|
||||
not reduced */
|
||||
} else {
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span<float> hess) {
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, hess);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const int32_t nthread = omp_get_max_threads();
|
||||
@ -32,121 +149,35 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
|
||||
this->isDense_ = isDense;
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
const size_t batch_threads = std::max(
|
||||
size_t(1),
|
||||
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
|
||||
auto page = batch.GetView();
|
||||
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
size_t* p_part = partial_sums.Get();
|
||||
|
||||
size_t block_size = batch.Size() / batch_threads;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(batch_threads)
|
||||
{
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
size_t sum = 0;
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
sum += page[i].size();
|
||||
row_ptr[rbegin + 1 + i] = sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp single
|
||||
{
|
||||
exc.Run([&]() {
|
||||
p_part[0] = prev_sum;
|
||||
for (size_t i = 1; i < batch_threads; ++i) {
|
||||
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
row_ptr[rbegin + 1 + i] += p_part[tid];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
const size_t n_offsets = cut.Ptrs().size() - 1;
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()];
|
||||
ResizeIndex(n_index, isDense);
|
||||
|
||||
CHECK_GT(cut.Values().size(), 0U);
|
||||
|
||||
uint32_t* offsets = nullptr;
|
||||
if (isDense) {
|
||||
index.ResizeOffset(n_offsets);
|
||||
offsets = index.Offset();
|
||||
for (size_t i = 0; i < n_offsets; ++i) {
|
||||
offsets[i] = cut.Ptrs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (isDense) {
|
||||
common::BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
||||
if (curent_bin_size == common::kUint8BinsTypeSize) {
|
||||
common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint8_t>(idx - offsets[j]);
|
||||
});
|
||||
|
||||
} else if (curent_bin_size == common::kUint16BinsTypeSize) {
|
||||
common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint16_t>(idx - offsets[j]);
|
||||
});
|
||||
} else {
|
||||
CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize);
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
|
||||
n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[offsets](auto idx, auto j) {
|
||||
return static_cast<uint32_t>(idx - offsets[j]);
|
||||
});
|
||||
}
|
||||
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is not reduced */
|
||||
} else {
|
||||
common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
|
||||
SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
||||
}
|
||||
});
|
||||
|
||||
this->PushBatch(batch, rbegin, prev_sum, nbins, nthread);
|
||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
||||
rbegin += batch.Size();
|
||||
}
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(SparsePage const &batch,
|
||||
common::HistogramCuts const &cuts,
|
||||
int32_t max_bins_per_feat, bool isDense,
|
||||
int32_t n_threads) {
|
||||
CHECK_GE(n_threads, 1);
|
||||
base_rowid = batch.base_rowid;
|
||||
isDense_ = isDense;
|
||||
cut = cuts;
|
||||
max_num_bins = max_bins_per_feat;
|
||||
CHECK_EQ(row_ptr.size(), 0);
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
row_ptr.resize(batch.Size() + 1, 0);
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||
|
||||
size_t rbegin = 0;
|
||||
size_t prev_sum = 0;
|
||||
|
||||
this->PushBatch(batch, rbegin, prev_sum, nbins, n_threads);
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index,
|
||||
const bool isDense) {
|
||||
|
||||
@ -18,6 +18,9 @@ namespace xgboost {
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
class GHistIndexMatrix {
|
||||
void PushBatch(SparsePage const &batch, size_t rbegin, size_t prev_sum,
|
||||
uint32_t nbins, int32_t n_threads);
|
||||
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
std::vector<size_t> row_ptr;
|
||||
@ -29,12 +32,16 @@ class GHistIndexMatrix {
|
||||
common::HistogramCuts cut;
|
||||
DMatrix* p_fmat;
|
||||
size_t max_num_bins;
|
||||
size_t base_rowid{0};
|
||||
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin) {
|
||||
this->Init(x, max_bin);
|
||||
GHistIndexMatrix() = default;
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span<float> hess = {}) {
|
||||
this->Init(x, max_bin, hess);
|
||||
}
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess);
|
||||
void Init(SparsePage const &page, common::HistogramCuts const &cuts,
|
||||
int32_t max_bins_per_feat, bool is_dense, int32_t n_threads);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
@ -77,6 +84,11 @@ class GHistIndexMatrix {
|
||||
inline bool IsDense() const {
|
||||
return isDense_;
|
||||
}
|
||||
void SetDense(bool is_dense) { isDense_ = is_dense; }
|
||||
|
||||
bst_row_t Size() const {
|
||||
return row_ptr.empty() ? 0 : row_ptr.size() - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
|
||||
107
src/data/gradient_index_format.cc
Normal file
107
src/data/gradient_index_format.cc
Normal file
@ -0,0 +1,107 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#include "sparse_page_writer.h"
|
||||
#include "gradient_index.h"
|
||||
#include "histogram_cut_format.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||
public:
|
||||
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
|
||||
if (!ReadHistogramCuts(&page->cut, fi)) {
|
||||
return false;
|
||||
}
|
||||
// indptr
|
||||
fi->Read(&page->row_ptr);
|
||||
// offset
|
||||
using OffsetT = std::iterator_traits<decltype(page->index.Offset())>::value_type;
|
||||
std::vector<OffsetT> offset;
|
||||
if (!fi->Read(&offset)) {
|
||||
return false;
|
||||
}
|
||||
page->index.ResizeOffset(offset.size());
|
||||
std::copy(offset.begin(), offset.end(), page->index.Offset());
|
||||
// data
|
||||
std::vector<uint8_t> data;
|
||||
if (!fi->Read(&data)) {
|
||||
return false;
|
||||
}
|
||||
page->index.Resize(data.size());
|
||||
std::copy(data.cbegin(), data.cend(), page->index.begin());
|
||||
// bin type
|
||||
// Old gcc doesn't support reading from enum.
|
||||
std::underlying_type_t<common::BinTypeSize> uint_bin_type{0};
|
||||
if (!fi->Read(&uint_bin_type)) {
|
||||
return false;
|
||||
}
|
||||
common::BinTypeSize size_type =
|
||||
static_cast<common::BinTypeSize>(uint_bin_type);
|
||||
page->index.SetBinTypeSize(size_type);
|
||||
// hit count
|
||||
if (!fi->Read(&page->hit_count)) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&page->max_num_bins)) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&page->base_rowid)) {
|
||||
return false;
|
||||
}
|
||||
bool is_dense = false;
|
||||
if (!fi->Read(&is_dense)) {
|
||||
return false;
|
||||
}
|
||||
page->SetDense(is_dense);
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t Write(GHistIndexMatrix const &page, dmlc::Stream *fo) override {
|
||||
size_t bytes = 0;
|
||||
bytes += WriteHistogramCuts(page.cut, fo);
|
||||
// indptr
|
||||
fo->Write(page.row_ptr);
|
||||
bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) +
|
||||
sizeof(uint64_t);
|
||||
// offset
|
||||
using OffsetT = std::iterator_traits<decltype(page.index.Offset())>::value_type;
|
||||
std::vector<OffsetT> offset(page.index.OffsetSize());
|
||||
std::copy(page.index.Offset(),
|
||||
page.index.Offset() + page.index.OffsetSize(), offset.begin());
|
||||
fo->Write(offset);
|
||||
bytes += page.index.OffsetSize() * sizeof(OffsetT) + sizeof(uint64_t);
|
||||
// data
|
||||
std::vector<uint8_t> data(page.index.begin(), page.index.end());
|
||||
fo->Write(data);
|
||||
bytes += data.size() * sizeof(decltype(data)::value_type) + sizeof(uint64_t);
|
||||
// bin type
|
||||
std::underlying_type_t<common::BinTypeSize> uint_bin_type =
|
||||
page.index.GetBinTypeSize();
|
||||
fo->Write(uint_bin_type);
|
||||
bytes += sizeof(page.index.GetBinTypeSize());
|
||||
// hit count
|
||||
fo->Write(page.hit_count);
|
||||
bytes +=
|
||||
page.hit_count.size() * sizeof(decltype(page.hit_count)::value_type) +
|
||||
sizeof(uint64_t);
|
||||
// max_bins, base row, is_dense
|
||||
fo->Write(page.max_num_bins);
|
||||
bytes += sizeof(page.max_num_bins);
|
||||
fo->Write(page.base_rowid);
|
||||
bytes += sizeof(page.base_rowid);
|
||||
fo->Write(page.IsDense());
|
||||
bytes += sizeof(page.IsDense());
|
||||
return bytes;
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gradient_index_format);
|
||||
|
||||
XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw)
|
||||
.describe("Raw GHistIndex binary data format.")
|
||||
.set_body([]() { return new GHistIndexRawFormat(); });
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
18
src/data/gradient_index_page_source.cc
Normal file
18
src/data/gradient_index_page_source.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#include "gradient_index_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
void GradientIndexPageSource::Fetch() {
|
||||
if (!this->ReadCache()) {
|
||||
auto const& csr = source_->Page();
|
||||
this->page_.reset(new GHistIndexMatrix());
|
||||
CHECK_NE(cuts_.Values().size(), 0);
|
||||
this->page_->Init(*csr, cuts_, max_bin_per_feat_, is_dense_, nthreads_);
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
37
src/data/gradient_index_page_source.h
Normal file
37
src/data/gradient_index_page_source.h
Normal file
@ -0,0 +1,37 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "sparse_page_source.h"
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
||||
common::HistogramCuts cuts_;
|
||||
bool is_dense_;
|
||||
int32_t max_bin_per_feat_;
|
||||
|
||||
public:
|
||||
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features,
|
||||
size_t n_batches, std::shared_ptr<Cache> cache,
|
||||
BatchParam param, common::HistogramCuts cuts,
|
||||
bool is_dense, int32_t max_bin_per_feat,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
|
||||
cuts_{std::move(cuts)}, is_dense_{is_dense}, max_bin_per_feat_{
|
||||
max_bin_per_feat} {
|
||||
this->source_ = source;
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
void Fetch() final;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
||||
36
src/data/histogram_cut_format.h
Normal file
36
src/data/histogram_cut_format.h
Normal file
@ -0,0 +1,36 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
inline bool ReadHistogramCuts(common::HistogramCuts *cuts, dmlc::SeekStream *fi) {
|
||||
if (!fi->Read(&cuts->cut_values_.HostVector())) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&cuts->cut_ptrs_.HostVector())) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&cuts->min_vals_.HostVector())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline size_t WriteHistogramCuts(common::HistogramCuts const &cuts, dmlc::Stream *fo) {
|
||||
size_t bytes = 0;
|
||||
fo->Write(cuts.cut_values_.ConstHostVector());
|
||||
bytes += cuts.cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
fo->Write(cuts.cut_ptrs_.ConstHostVector());
|
||||
bytes += cuts.cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
fo->Write(cuts.min_vals_.ConstHostVector());
|
||||
bytes += cuts.min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
||||
return bytes;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||
@ -94,10 +94,12 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin));
|
||||
CHECK_EQ(param.gpu_id, -1);
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess));
|
||||
batch_param_ = param;
|
||||
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_));
|
||||
|
||||
@ -43,7 +43,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int32_t nthreads, std::string cache_prefix)
|
||||
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
|
||||
nthreads_{nthreads}, cache_prefix_{std::move(cache_prefix)} {
|
||||
cache_prefix_{std::move(cache_prefix)} {
|
||||
ctx_.nthread = nthreads;
|
||||
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
||||
if (rabit::IsDistributed()) {
|
||||
cache_prefix_ += ("-r" + std::to_string(rabit::GetRank()));
|
||||
@ -112,7 +113,7 @@ void SparsePageDMatrix::InitializeSparsePage() {
|
||||
DMatrixProxy *proxy = MakeProxy(proxy_);
|
||||
sparse_page_source_.reset(); // clear before creating new one to prevent conflicts.
|
||||
sparse_page_source_ = std::make_shared<SparsePageSource>(
|
||||
iter, proxy, this->missing_, this->nthreads_, this->info_.num_col_,
|
||||
iter, proxy, this->missing_, this->ctx_.Threads(), this->info_.num_col_,
|
||||
this->n_batches_, cache_info_.at(id));
|
||||
}
|
||||
|
||||
@ -132,7 +133,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
||||
this->InitializeSparsePage();
|
||||
if (!column_source_) {
|
||||
column_source_ = std::make_shared<CSCPageSource>(
|
||||
this->missing_, this->nthreads_, this->Info().num_col_,
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
||||
} else {
|
||||
column_source_->Reset();
|
||||
@ -147,7 +148,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
this->InitializeSparsePage();
|
||||
if (!sorted_column_source_) {
|
||||
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
||||
this->missing_, this->nthreads_, this->Info().num_col_,
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
||||
} else {
|
||||
sorted_column_source_->Reset();
|
||||
@ -158,16 +159,41 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
// External memory is not support
|
||||
if (!ghist_index_source_ || (param != batch_param_ && param != BatchParam{})) {
|
||||
this->InitializeSparsePage();
|
||||
ghist_index_source_.reset(new GHistIndexMatrix{this, param.max_bin});
|
||||
batch_param_ = param;
|
||||
if (param.hess.empty()) {
|
||||
// hist method doesn't support full external memory implementation, so we concatenate
|
||||
// all index here.
|
||||
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
|
||||
this->InitializeSparsePage();
|
||||
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin});
|
||||
this->InitializeSparsePage();
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_page_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
this->InitializeSparsePage();
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_source_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) {
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess);
|
||||
this->InitializeSparsePage(); // reset after use.
|
||||
|
||||
batch_param_ = param;
|
||||
ghist_index_source_.reset();
|
||||
CHECK_NE(cuts.Values().size(), 0);
|
||||
ghist_index_source_.reset(new GradientIndexPageSource(
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
||||
this->IsDense(), param.max_bin, sparse_page_source_));
|
||||
} else {
|
||||
CHECK(ghist_index_source_);
|
||||
ghist_index_source_->Reset();
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(ghist_index_source_);
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(begin_iter));
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -31,7 +31,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
||||
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
||||
ellpack_page_source_.reset(); // release resources.
|
||||
ellpack_page_source_.reset(new EllpackPageSource(
|
||||
this->missing_, this->nthreads_, this->Info().num_col_,
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
||||
this->IsDense(), row_stride, ft, sparse_page_source_));
|
||||
} else {
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -16,6 +17,7 @@
|
||||
#include <map>
|
||||
|
||||
#include "ellpack_page_source.h"
|
||||
#include "gradient_index_page_source.h"
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -67,7 +69,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
XGDMatrixCallbackNext *next_;
|
||||
|
||||
float missing_;
|
||||
int nthreads_;
|
||||
GenericParameter ctx_;
|
||||
std::string cache_prefix_;
|
||||
uint32_t n_batches_ {0};
|
||||
// sparse page is the source to other page types, we make a special member function.
|
||||
@ -118,7 +120,8 @@ class SparsePageDMatrix : public DMatrix {
|
||||
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
|
||||
std::shared_ptr<CSCPageSource> column_source_;
|
||||
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||
std::shared_ptr<GHistIndexMatrix> ghist_index_source_;
|
||||
std::shared_ptr<GHistIndexMatrix> ghist_index_page_; // hist
|
||||
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
||||
|
||||
bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_page_source_);
|
||||
@ -143,6 +146,7 @@ MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
||||
auto it = cache_info.find(id);
|
||||
if (it == cache_info.cend()) {
|
||||
cache_info[id].reset(new Cache{false, name, format});
|
||||
LOG(INFO) << "Make cache:" << name << std::endl;
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
@ -98,7 +98,12 @@ struct SparsePageFormatReg
|
||||
|
||||
#define EllpackPageFmt SparsePageFormat<EllpackPage>
|
||||
#define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<EllpackPage>, EllpackPageFm, Name)
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<EllpackPage>, EllpackPageFmt, Name)
|
||||
|
||||
#define GHistIndexPageFmt SparsePageFormat<GHistIndexMatrix>
|
||||
#define XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<GHistIndexMatrix>, \
|
||||
GHistIndexPageFmt, Name)
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "../../common/compressed_iterator.h"
|
||||
#include "../../common/random.h"
|
||||
@ -185,10 +186,10 @@ GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DM
|
||||
|
||||
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl const* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
BatchParam batch_param,
|
||||
float subsample)
|
||||
: original_page_(page),
|
||||
batch_param_(batch_param),
|
||||
batch_param_(std::move(batch_param)),
|
||||
subsample_(subsample),
|
||||
sample_row_index_(n_rows) {}
|
||||
|
||||
@ -259,10 +260,10 @@ GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpa
|
||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
||||
EllpackPageImpl const* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
BatchParam batch_param,
|
||||
float subsample)
|
||||
: original_page_(page),
|
||||
batch_param_(batch_param),
|
||||
batch_param_(std::move(batch_param)),
|
||||
subsample_(subsample),
|
||||
threshold_(n_rows + 1, 0.0f),
|
||||
grad_sum_(n_rows, 0.0f),
|
||||
|
||||
@ -68,7 +68,7 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
|
||||
public:
|
||||
ExternalMemoryUniformSampling(EllpackPageImpl const* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
BatchParam batch_param,
|
||||
float subsample);
|
||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||
|
||||
@ -102,7 +102,7 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||
public:
|
||||
ExternalMemoryGradientBasedSampling(EllpackPageImpl const* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
BatchParam batch_param,
|
||||
float subsample);
|
||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ struct GPUHistMakerDevice {
|
||||
tree_evaluator(param, n_features, _device_id),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
batch_param(_batch_param) {
|
||||
batch_param(std::move(_batch_param)) {
|
||||
sampler.reset(new GradientBasedSampler(
|
||||
page, _n_rows, batch_param, param.subsample, param.sampling_method));
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
|
||||
@ -69,13 +69,13 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<Gradient
|
||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
auto const &gmat =
|
||||
*(dmat->GetBatches<GHistIndexMatrix>(
|
||||
BatchParam{GenericParameter::kCpuId, param_.max_bin})
|
||||
.begin());
|
||||
auto it = dmat->GetBatches<GHistIndexMatrix>(
|
||||
BatchParam{GenericParameter::kCpuId, param_.max_bin})
|
||||
.begin();
|
||||
auto p_gmat = it.Page();
|
||||
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
||||
updater_monitor_.Start("GmatInitialization");
|
||||
column_matrix_.Init(gmat, param_.sparse_threshold);
|
||||
column_matrix_.Init(*p_gmat, param_.sparse_threshold);
|
||||
updater_monitor_.Stop("GmatInitialization");
|
||||
// A proper solution is puting cut matrix in DMatrix, see:
|
||||
// https://github.com/dmlc/xgboost/issues/5143
|
||||
@ -91,12 +91,12 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
if (!float_builder_) {
|
||||
this->SetBuilder(n_trees, &float_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, gmat, trees);
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, *p_gmat, trees);
|
||||
} else {
|
||||
if (!double_builder_) {
|
||||
SetBuilder(n_trees, &double_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, gmat, trees);
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, *p_gmat, trees);
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
|
||||
26
tests/cpp/data/test_gradient_index.cc
Normal file
26
tests/cpp/data/test_gradient_index.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
TEST(GradientIndex, ExternalMemory) {
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
||||
std::vector<size_t> base_rowids;
|
||||
std::vector<float> hessian(dmat->Info().num_row_, 1);
|
||||
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>({0, 64, hessian})) {
|
||||
base_rowids.push_back(page.base_rowid);
|
||||
}
|
||||
size_t i = 0;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
ASSERT_EQ(base_rowids[i], page.base_rowid);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
48
tests/cpp/data/test_gradient_index_page_raw_format.cc
Normal file
48
tests/cpp/data/test_gradient_index_page_raw_format.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../../../src/data/sparse_page_source.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
TEST(GHistIndexPageRawFormat, IO) {
|
||||
std::unique_ptr<SparsePageFormat<GHistIndexMatrix>> format{
|
||||
CreatePageFormat<GHistIndexMatrix>("raw")};
|
||||
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string path = tmpdir.path + "/ghistindex.page";
|
||||
|
||||
{
|
||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
||||
for (auto const &index :
|
||||
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
|
||||
format->Write(index, fo.get());
|
||||
}
|
||||
}
|
||||
|
||||
GHistIndexMatrix page;
|
||||
std::unique_ptr<dmlc::SeekStream> fi{
|
||||
dmlc::SeekStream::CreateForRead(path.c_str())};
|
||||
format->Read(&page, fi.get());
|
||||
|
||||
for (auto const &gidx :
|
||||
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
|
||||
auto const &loaded = gidx;
|
||||
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
|
||||
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
|
||||
ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
|
||||
ASSERT_EQ(loaded.base_rowid, page.base_rowid);
|
||||
ASSERT_EQ(loaded.IsDense(), page.IsDense());
|
||||
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(),
|
||||
page.index.begin()));
|
||||
ASSERT_TRUE(std::equal(loaded.index.Offset(),
|
||||
loaded.index.Offset() + loaded.index.OffsetSize(),
|
||||
page.index.Offset()));
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
@ -370,7 +370,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
|
||||
|
||||
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
|
||||
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, prefix)};
|
||||
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix)};
|
||||
auto row_page_path =
|
||||
data::MakeId(prefix,
|
||||
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user