Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {
bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }
CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
return ctx;
}
Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
}
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {

View File

@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
@@ -238,44 +238,72 @@ struct Entry {
}
};
/*!
* \brief Parameters for constructing batches.
/**
* \brief Parameters for constructing histogram index batches.
*/
struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id {-1};
/*! \brief Maximum number of bins per feature for histograms. */
/**
* \brief Maximum number of bins per feature for histograms.
*/
bst_bin_t max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */
/**
* \brief Hessian, used for sketching with future approx implementation.
*/
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
/**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex.
*/
bool regen{false};
/**
* \brief Forbid regenerating the gradient index. Used for internal validation.
*/
bool forbid_regen{false};
/**
* \brief Parameter used to generate column matrix for hist.
*/
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
/**
* \brief Exact or others that don't need histogram.
*/
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, bst_bin_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
/**
* \brief Used by the hist tree method.
*/
BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
* \brief Used by the approx tree method.
*
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
bool ParamNotEqual(BatchParam const& other) const {
// Check non-floating parameters.
bool cond = max_bin != other.max_bin;
// Check sparse thresh.
bool l_nan = std::isnan(sparse_thresh);
bool r_nan = std::isnan(other.sparse_thresh);
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
cond |= st_chg;
return cond;
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
bool Initialized() const { return max_bin != 0; }
/**
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
*/
BatchParam MakeCache() const {
auto p = *this;
// These parameters have nothing to do with how the gradient index was generated in the
// first place.
p.regen = false;
p.forbid_regen = false;
return p;
}
};
@@ -435,7 +463,7 @@ class EllpackPage {
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
@@ -551,7 +579,9 @@ class DMatrix {
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
BatchSet<T> GetBatches(Context const* ctx);
template <typename T>
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
template <typename T>
bool PageExists() const;
@@ -658,18 +688,19 @@ class DMatrix {
protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
BatchParam const& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};
template<>
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
@@ -684,34 +715,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}
template<>
template <>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}
template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
return GetRowBatches();
}
template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetColumnBatches(ctx);
}
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetSortedColumnBatches(ctx);
}
template <>
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetEllpackBatches(ctx, param);
}
template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetGradientIndex(ctx, param);
}
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetExtBatches(ctx, param);
}
} // namespace xgboost