Prepare external memory support for hist. (#7638)

This PR prepares the GHistIndexMatrix to host the column matrix which is used by the hist tree method by accepting sparse_threshold parameter.

Some cleanups are made to ensure the correct batch param is being passed into DMatrix along with some additional tests for correctness of SimpleDMatrix.
This commit is contained in:
Jiaming Yuan
2022-02-10 16:58:02 +08:00
committed by GitHub
parent 87c01f49d8
commit 2775c2a1ab
24 changed files with 368 additions and 201 deletions

View File

@@ -18,6 +18,7 @@
#include <xgboost/string_view.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
@@ -217,24 +218,33 @@ struct BatchParam {
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, int32_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
BatchParam(int32_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
*/
BatchParam(int32_t device, int32_t max_bin, common::Span<float> hessian,
bool regenerate = false)
: gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
BatchParam(int32_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool operator!=(const BatchParam& other) const {
bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
}
};
struct HostSparsePageView {
@@ -477,8 +487,10 @@ class DMatrix {
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
template <typename T>
bool PageExists() const;
@@ -592,7 +604,7 @@ class DMatrix {
};
template<>
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
@@ -607,12 +619,12 @@ inline bool DMatrix::PageExists<SparsePage>() const {
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(const BatchParam&) {
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}