Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {
|
||||
|
||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||
bool IsCUDA() const { return !IsCPU(); }
|
||||
|
||||
CUDAContext const* CUDACtx() const;
|
||||
// Make a CUDA context based on the current context.
|
||||
Context MakeCUDA(std::int32_t device = 0) const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = device;
|
||||
return ctx;
|
||||
}
|
||||
Context MakeCPU() const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = kCpuId;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(Context) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file data.h
|
||||
* \brief The input data structure of xgboost.
|
||||
* \author Tianqi Chen
|
||||
@@ -238,44 +238,72 @@ struct Entry {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Parameters for constructing batches.
|
||||
/**
|
||||
* \brief Parameters for constructing histogram index batches.
|
||||
*/
|
||||
struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id {-1};
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
/**
|
||||
* \brief Maximum number of bins per feature for histograms.
|
||||
*/
|
||||
bst_bin_t max_bin{0};
|
||||
/*! \brief Hessian, used for sketching with future approx implementation. */
|
||||
/**
|
||||
* \brief Hessian, used for sketching with future approx implementation.
|
||||
*/
|
||||
common::Span<float> hess;
|
||||
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
|
||||
bool regen {false};
|
||||
/*! \brief Parameter used to generate column matrix for hist. */
|
||||
/**
|
||||
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
|
||||
* GHistIndex.
|
||||
*/
|
||||
bool regen{false};
|
||||
/**
|
||||
* \brief Forbid regenerating the gradient index. Used for internal validation.
|
||||
*/
|
||||
bool forbid_regen{false};
|
||||
/**
|
||||
* \brief Parameter used to generate column matrix for hist.
|
||||
*/
|
||||
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
||||
|
||||
/**
|
||||
* \brief Exact or others that don't need histogram.
|
||||
*/
|
||||
BatchParam() = default;
|
||||
// GPU Hist
|
||||
BatchParam(int32_t device, bst_bin_t max_bin)
|
||||
: gpu_id{device}, max_bin{max_bin} {}
|
||||
// Hist
|
||||
/**
|
||||
* \brief Used by the hist tree method.
|
||||
*/
|
||||
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
||||
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
||||
// Approx
|
||||
/**
|
||||
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
|
||||
* the span is changed, so caller should keep the span for each iteration.
|
||||
* \brief Used by the approx tree method.
|
||||
*
|
||||
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
||||
* span is changed, so caller should keep the span for each iteration.
|
||||
*/
|
||||
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
||||
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||
|
||||
bool operator!=(BatchParam const& other) const {
|
||||
if (hess.empty() && other.hess.empty()) {
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin;
|
||||
}
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
|
||||
bool ParamNotEqual(BatchParam const& other) const {
|
||||
// Check non-floating parameters.
|
||||
bool cond = max_bin != other.max_bin;
|
||||
// Check sparse thresh.
|
||||
bool l_nan = std::isnan(sparse_thresh);
|
||||
bool r_nan = std::isnan(other.sparse_thresh);
|
||||
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
|
||||
cond |= st_chg;
|
||||
|
||||
return cond;
|
||||
}
|
||||
bool operator==(BatchParam const& other) const {
|
||||
return !(*this != other);
|
||||
bool Initialized() const { return max_bin != 0; }
|
||||
/**
|
||||
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
||||
*/
|
||||
BatchParam MakeCache() const {
|
||||
auto p = *this;
|
||||
// These parameters have nothing to do with how the gradient index was generated in the
|
||||
// first place.
|
||||
p.regen = false;
|
||||
p.forbid_regen = false;
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -435,7 +463,7 @@ class EllpackPage {
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
|
||||
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
|
||||
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
@@ -551,7 +579,9 @@ class DMatrix {
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches();
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches(const BatchParam& param);
|
||||
BatchSet<T> GetBatches(Context const* ctx);
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
|
||||
template <typename T>
|
||||
bool PageExists() const;
|
||||
|
||||
@@ -658,18 +688,19 @@ class DMatrix {
|
||||
|
||||
protected:
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
|
||||
BatchParam const& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool GHistIndexExists() const = 0;
|
||||
virtual bool SparsePageExists() const = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches() {
|
||||
return GetRowBatches();
|
||||
}
|
||||
@@ -684,34 +715,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
|
||||
return this->GHistIndexExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline bool DMatrix::PageExists<SparsePage>() const {
|
||||
return this->SparsePageExists();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches() {
|
||||
return GetColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||
return GetSortedColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
template <>
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetGradientIndex(param);
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||
return GetColumnBatches(ctx);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
||||
return GetExtBatches(BatchParam{});
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||
return GetSortedColumnBatches(ctx);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetEllpackBatches(ctx, param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetGradientIndex(ctx, param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetExtBatches(ctx, param);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user