Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
parent
1f9a57d17b
commit
08ce495b5d
@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {
|
|||||||
|
|
||||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||||
bool IsCUDA() const { return !IsCPU(); }
|
bool IsCUDA() const { return !IsCPU(); }
|
||||||
|
|
||||||
CUDAContext const* CUDACtx() const;
|
CUDAContext const* CUDACtx() const;
|
||||||
|
// Make a CUDA context based on the current context.
|
||||||
|
Context MakeCUDA(std::int32_t device = 0) const {
|
||||||
|
Context ctx = *this;
|
||||||
|
ctx.gpu_id = device;
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
Context MakeCPU() const {
|
||||||
|
Context ctx = *this;
|
||||||
|
ctx.gpu_id = kCpuId;
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(Context) {
|
DMLC_DECLARE_PARAMETER(Context) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file data.h
|
* \file data.h
|
||||||
* \brief The input data structure of xgboost.
|
* \brief The input data structure of xgboost.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -238,44 +238,72 @@ struct Entry {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/**
|
||||||
* \brief Parameters for constructing batches.
|
* \brief Parameters for constructing histogram index batches.
|
||||||
*/
|
*/
|
||||||
struct BatchParam {
|
struct BatchParam {
|
||||||
/*! \brief The GPU device to use. */
|
/**
|
||||||
int gpu_id {-1};
|
* \brief Maximum number of bins per feature for histograms.
|
||||||
/*! \brief Maximum number of bins per feature for histograms. */
|
*/
|
||||||
bst_bin_t max_bin{0};
|
bst_bin_t max_bin{0};
|
||||||
/*! \brief Hessian, used for sketching with future approx implementation. */
|
/**
|
||||||
|
* \brief Hessian, used for sketching with future approx implementation.
|
||||||
|
*/
|
||||||
common::Span<float> hess;
|
common::Span<float> hess;
|
||||||
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
|
/**
|
||||||
bool regen {false};
|
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
|
||||||
/*! \brief Parameter used to generate column matrix for hist. */
|
* GHistIndex.
|
||||||
|
*/
|
||||||
|
bool regen{false};
|
||||||
|
/**
|
||||||
|
* \brief Forbid regenerating the gradient index. Used for internal validation.
|
||||||
|
*/
|
||||||
|
bool forbid_regen{false};
|
||||||
|
/**
|
||||||
|
* \brief Parameter used to generate column matrix for hist.
|
||||||
|
*/
|
||||||
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Exact or others that don't need histogram.
|
||||||
|
*/
|
||||||
BatchParam() = default;
|
BatchParam() = default;
|
||||||
// GPU Hist
|
/**
|
||||||
BatchParam(int32_t device, bst_bin_t max_bin)
|
* \brief Used by the hist tree method.
|
||||||
: gpu_id{device}, max_bin{max_bin} {}
|
*/
|
||||||
// Hist
|
|
||||||
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
||||||
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
||||||
// Approx
|
|
||||||
/**
|
/**
|
||||||
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
|
* \brief Used by the approx tree method.
|
||||||
* the span is changed, so caller should keep the span for each iteration.
|
*
|
||||||
|
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
||||||
|
* span is changed, so caller should keep the span for each iteration.
|
||||||
*/
|
*/
|
||||||
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
||||||
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||||
|
|
||||||
bool operator!=(BatchParam const& other) const {
|
bool ParamNotEqual(BatchParam const& other) const {
|
||||||
if (hess.empty() && other.hess.empty()) {
|
// Check non-floating parameters.
|
||||||
return gpu_id != other.gpu_id || max_bin != other.max_bin;
|
bool cond = max_bin != other.max_bin;
|
||||||
}
|
// Check sparse thresh.
|
||||||
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
|
bool l_nan = std::isnan(sparse_thresh);
|
||||||
|
bool r_nan = std::isnan(other.sparse_thresh);
|
||||||
|
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
|
||||||
|
cond |= st_chg;
|
||||||
|
|
||||||
|
return cond;
|
||||||
}
|
}
|
||||||
bool operator==(BatchParam const& other) const {
|
bool Initialized() const { return max_bin != 0; }
|
||||||
return !(*this != other);
|
/**
|
||||||
|
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
||||||
|
*/
|
||||||
|
BatchParam MakeCache() const {
|
||||||
|
auto p = *this;
|
||||||
|
// These parameters have nothing to do with how the gradient index was generated in the
|
||||||
|
// first place.
|
||||||
|
p.regen = false;
|
||||||
|
p.forbid_regen = false;
|
||||||
|
return p;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -435,7 +463,7 @@ class EllpackPage {
|
|||||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||||
* in CSR format.
|
* in CSR format.
|
||||||
*/
|
*/
|
||||||
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
|
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
|
||||||
|
|
||||||
/*! \brief Destructor. */
|
/*! \brief Destructor. */
|
||||||
~EllpackPage();
|
~EllpackPage();
|
||||||
@ -551,7 +579,9 @@ class DMatrix {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
BatchSet<T> GetBatches();
|
BatchSet<T> GetBatches();
|
||||||
template <typename T>
|
template <typename T>
|
||||||
BatchSet<T> GetBatches(const BatchParam& param);
|
BatchSet<T> GetBatches(Context const* ctx);
|
||||||
|
template <typename T>
|
||||||
|
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool PageExists() const;
|
bool PageExists() const;
|
||||||
|
|
||||||
@ -658,18 +688,19 @@ class DMatrix {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
|
||||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
|
||||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
|
||||||
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
|
BatchParam const& param) = 0;
|
||||||
|
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||||
|
|
||||||
virtual bool EllpackExists() const = 0;
|
virtual bool EllpackExists() const = 0;
|
||||||
virtual bool GHistIndexExists() const = 0;
|
virtual bool GHistIndexExists() const = 0;
|
||||||
virtual bool SparsePageExists() const = 0;
|
virtual bool SparsePageExists() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline BatchSet<SparsePage> DMatrix::GetBatches() {
|
inline BatchSet<SparsePage> DMatrix::GetBatches() {
|
||||||
return GetRowBatches();
|
return GetRowBatches();
|
||||||
}
|
}
|
||||||
@ -684,34 +715,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
|
|||||||
return this->GHistIndexExists();
|
return this->GHistIndexExists();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline bool DMatrix::PageExists<SparsePage>() const {
|
inline bool DMatrix::PageExists<SparsePage>() const {
|
||||||
return this->SparsePageExists();
|
return this->SparsePageExists();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline BatchSet<CSCPage> DMatrix::GetBatches() {
|
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
|
||||||
return GetColumnBatches();
|
return GetRowBatches();
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
|
||||||
return GetSortedColumnBatches();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
|
||||||
return GetEllpackBatches(param);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||||
return GetGradientIndex(param);
|
return GetColumnBatches(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||||
return GetExtBatches(BatchParam{});
|
return GetSortedColumnBatches(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||||
|
return GetEllpackBatches(ctx, param);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||||
|
return GetGradientIndex(ctx, param);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||||
|
return GetExtBatches(ctx, param);
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -317,13 +317,15 @@ class TestDataset:
|
|||||||
enable_categorical=True,
|
enable_categorical=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_device_dmat(self) -> xgb.QuantileDMatrix:
|
def get_device_dmat(self, max_bin: Optional[int]) -> xgb.QuantileDMatrix:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
w = None if self.w is None else cp.array(self.w)
|
w = None if self.w is None else cp.array(self.w)
|
||||||
X = cp.array(self.X, dtype=np.float32)
|
X = cp.array(self.X, dtype=np.float32)
|
||||||
y = cp.array(self.y, dtype=np.float32)
|
y = cp.array(self.y, dtype=np.float32)
|
||||||
return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
|
return xgb.QuantileDMatrix(
|
||||||
|
X, y, weight=w, base_margin=self.margin, max_bin=max_bin
|
||||||
|
)
|
||||||
|
|
||||||
def get_external_dmat(self) -> xgb.DMatrix:
|
def get_external_dmat(self) -> xgb.DMatrix:
|
||||||
n_samples = self.X.shape[0]
|
n_samples = self.X.shape[0]
|
||||||
|
|||||||
@ -3,30 +3,50 @@
|
|||||||
*/
|
*/
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
|
|
||||||
#include <rabit/c_api.h>
|
#include <algorithm> // for copy
|
||||||
|
#include <cinttypes> // for strtoimax
|
||||||
|
#include <cmath> // for nan
|
||||||
|
#include <cstring> // for strcmp
|
||||||
|
#include <fstream> // for operator<<, basic_ostream, ios, stringstream
|
||||||
|
#include <functional> // for less
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
|
||||||
|
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
|
||||||
|
#include <string> // for char_traits, basic_string, operator==, string
|
||||||
|
#include <system_error> // for errc
|
||||||
|
#include <utility> // for pair
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include <cstring>
|
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||||
#include <fstream>
|
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||||
#include <memory>
|
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||||
#include <string>
|
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||||
#include <vector>
|
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||||
|
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
|
||||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||||
#include "../common/charconv.h"
|
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||||
#include "../common/io.h"
|
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
|
||||||
#include "../data/adapter.h"
|
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
|
||||||
#include "../data/simple_dmatrix.h"
|
#include "dmlc/io.h" // for Stream
|
||||||
#include "c_api_utils.h"
|
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||||
#include "xgboost/base.h"
|
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||||
#include "xgboost/data.h"
|
#include "rabit/c_api.h" // for RabitLinkTag
|
||||||
#include "xgboost/global_config.h"
|
#include "rabit/rabit.h" // for CheckPoint, LoadCheckPoint
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/feature_map.h" // for FeatureMap
|
||||||
#include "xgboost/string_view.h" // StringView
|
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
|
||||||
#include "xgboost/version_config.h"
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
|
#include "xgboost/intrusive_ptr.h" // for xgboost
|
||||||
|
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
|
||||||
|
#include "xgboost/learner.h" // for Learner, PredictionType
|
||||||
|
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
|
||||||
|
#include "xgboost/predictor.h" // for PredictionCacheEntry
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||||
|
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
#include "../../plugin/federated/federated_server.h"
|
#include "../../plugin/federated/federated_server.h"
|
||||||
@ -341,10 +361,10 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -746,7 +766,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
|
|||||||
|
|
||||||
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
||||||
|
|
||||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
|
for (auto const &page : p_m->GetBatches<ExtSparsePage>(p_m->Ctx(), BatchParam{})) {
|
||||||
CHECK(page.page);
|
CHECK(page.page);
|
||||||
auto const &h_offset = page.page->offset.ConstHostVector();
|
auto const &h_offset = page.page->offset.ConstHostVector();
|
||||||
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
||||||
|
|||||||
@ -28,5 +28,10 @@ constexpr StringView InfInData() {
|
|||||||
constexpr StringView NoF128() {
|
constexpr StringView NoF128() {
|
||||||
return "128-bit floating point is not supported on current platform.";
|
return "128-bit floating point is not supported on current platform.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr StringView InconsistentMaxBin() {
|
||||||
|
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
|
||||||
|
"and consistent with the Booster being trained.";
|
||||||
|
}
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
@ -2,15 +2,18 @@
|
|||||||
* Copyright 2017-2023 by XGBoost Contributors
|
* Copyright 2017-2023 by XGBoost Contributors
|
||||||
* \file hist_util.cc
|
* \file hist_util.cc
|
||||||
*/
|
*/
|
||||||
|
#include "hist_util.h"
|
||||||
|
|
||||||
#include <dmlc/timer.h>
|
#include <dmlc/timer.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "hist_util.h"
|
|
||||||
#include "column_matrix.h"
|
#include "column_matrix.h"
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/data.h" // SparsePage, SortedCSCPage
|
||||||
|
|
||||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||||
#include <xmmintrin.h>
|
#include <xmmintrin.h>
|
||||||
@ -28,10 +31,11 @@ HistogramCuts::HistogramCuts() {
|
|||||||
cut_ptrs_.HostVector().emplace_back(0);
|
cut_ptrs_.HostVector().emplace_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, bool use_sorted,
|
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
|
||||||
Span<float> const hessian) {
|
Span<float> const hessian) {
|
||||||
HistogramCuts out;
|
HistogramCuts out;
|
||||||
auto const& info = m->Info();
|
auto const &info = m->Info();
|
||||||
|
auto n_threads = ctx->Threads();
|
||||||
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
||||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||||
auto const &entries_per_column =
|
auto const &entries_per_column =
|
||||||
@ -44,16 +48,19 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!use_sorted) {
|
if (!use_sorted) {
|
||||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
HostSketchContainer container(ctx, max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||||
HostSketchContainer::UseGroup(info), n_threads);
|
HostSketchContainer::UseGroup(info));
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||||
container.PushRowPage(page, info, hessian);
|
container.PushRowPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
container.MakeCuts(m->Info(), &out);
|
container.MakeCuts(m->Info(), &out);
|
||||||
} else {
|
} else {
|
||||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
SortedSketchContainer container{ctx,
|
||||||
HostSketchContainer::UseGroup(info), n_threads};
|
max_bins,
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
m->Info().feature_types.ConstHostSpan(),
|
||||||
|
reduced,
|
||||||
|
HostSketchContainer::UseGroup(info)};
|
||||||
|
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
|
||||||
container.PushColPage(page, info, hessian);
|
container.PushColPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
container.MakeCuts(m->Info(), &out);
|
container.MakeCuts(m->Info(), &out);
|
||||||
|
|||||||
@ -170,7 +170,7 @@ class HistogramCuts {
|
|||||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||||
* but consumes more memory.
|
* but consumes more memory.
|
||||||
*/
|
*/
|
||||||
HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
|
HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins,
|
||||||
bool use_sorted = false, Span<float> const hessian = {});
|
bool use_sorted = false, Span<float> const hessian = {});
|
||||||
|
|
||||||
enum BinTypeSize : uint8_t {
|
enum BinTypeSize : uint8_t {
|
||||||
|
|||||||
@ -16,16 +16,16 @@ namespace xgboost {
|
|||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
|
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||||
|
std::vector<bst_row_t> columns_size,
|
||||||
int32_t max_bins,
|
int32_t max_bins,
|
||||||
Span<FeatureType const> feature_types,
|
Span<FeatureType const> feature_types,
|
||||||
bool use_group,
|
bool use_group)
|
||||||
int32_t n_threads)
|
|
||||||
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
||||||
columns_size_{std::move(columns_size)},
|
columns_size_{std::move(columns_size)},
|
||||||
max_bins_{max_bins},
|
max_bins_{max_bins},
|
||||||
use_group_ind_{use_group},
|
use_group_ind_{use_group},
|
||||||
n_threads_{n_threads} {
|
n_threads_{ctx->Threads()} {
|
||||||
monitor_.Init(__func__);
|
monitor_.Init(__func__);
|
||||||
CHECK_NE(columns_size_.size(), 0);
|
CHECK_NE(columns_size_.size(), 0);
|
||||||
sketches_.resize(columns_size_.size());
|
sketches_.resize(columns_size_.size());
|
||||||
@ -380,13 +380,13 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts* cuts) {
|
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
|
||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
std::vector<typename WQSketch::SummaryContainer> reduced;
|
std::vector<typename WQSketch::SummaryContainer> reduced;
|
||||||
std::vector<int32_t> num_cuts;
|
std::vector<int32_t> num_cuts;
|
||||||
this->AllReduce(info, &reduced, &num_cuts);
|
this->AllReduce(info, &reduced, &num_cuts);
|
||||||
|
|
||||||
cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||||
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||||
|
|
||||||
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
|
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
|
||||||
@ -401,48 +401,48 @@ void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts
|
|||||||
a.SetPrune(reduced[fidx], max_num_bins + 1);
|
a.SetPrune(reduced[fidx], max_num_bins + 1);
|
||||||
CHECK(a.data && reduced[fidx].data);
|
CHECK(a.data && reduced[fidx].data);
|
||||||
const bst_float mval = a.data[0].value;
|
const bst_float mval = a.data[0].value;
|
||||||
cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
|
p_cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
|
||||||
} else {
|
} else {
|
||||||
// Empty column.
|
// Empty column.
|
||||||
const float mval = 1e-5f;
|
const float mval = 1e-5f;
|
||||||
cuts->min_vals_.HostVector()[fidx] = mval;
|
p_cuts->min_vals_.HostVector()[fidx] = mval;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
float max_cat{-1.f};
|
float max_cat{-1.f};
|
||||||
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
||||||
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
||||||
typename WQSketch::SummaryContainer const& a = final_summaries[fid];
|
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
|
||||||
if (IsCat(feature_types_, fid)) {
|
if (IsCat(feature_types_, fid)) {
|
||||||
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), cuts));
|
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
|
||||||
} else {
|
} else {
|
||||||
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
|
||||||
// push a value that is greater than anything
|
// push a value that is greater than anything
|
||||||
const bst_float cpt =
|
const bst_float cpt =
|
||||||
(a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
|
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
|
||||||
// this must be bigger than last value in a scale
|
// this must be bigger than last value in a scale
|
||||||
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
||||||
cuts->cut_values_.HostVector().push_back(last);
|
p_cuts->cut_values_.HostVector().push_back(last);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that every feature gets at least one quantile point
|
// Ensure that every feature gets at least one quantile point
|
||||||
CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
|
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
|
||||||
auto cut_size = static_cast<uint32_t>(cuts->cut_values_.HostVector().size());
|
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
|
||||||
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
|
CHECK_GT(cut_size, p_cuts->cut_ptrs_.HostVector().back());
|
||||||
cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
cuts->SetCategorical(this->has_categorical_, max_cat);
|
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
template class SketchContainerImpl<WQuantileSketch<float, float>>;
|
template class SketchContainerImpl<WQuantileSketch<float, float>>;
|
||||||
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||||
|
|
||||||
HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
HostSketchContainer::HostSketchContainer(Context const *ctx, bst_bin_t max_bins,
|
||||||
std::vector<size_t> columns_size, bool use_group,
|
common::Span<FeatureType const> ft,
|
||||||
int32_t n_threads)
|
std::vector<size_t> columns_size, bool use_group)
|
||||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||||
monitor_.Init(__func__);
|
monitor_.Init(__func__);
|
||||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||||
|
|||||||
@ -800,9 +800,8 @@ class SketchContainerImpl {
|
|||||||
* \param max_bins maximum number of bins for each feature.
|
* \param max_bins maximum number of bins for each feature.
|
||||||
* \param use_group whether is assigned to group to data instance.
|
* \param use_group whether is assigned to group to data instance.
|
||||||
*/
|
*/
|
||||||
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
SketchContainerImpl(Context const *ctx, std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||||
common::Span<FeatureType const> feature_types, bool use_group,
|
common::Span<FeatureType const> feature_types, bool use_group);
|
||||||
int32_t n_threads);
|
|
||||||
|
|
||||||
static bool UseGroup(MetaInfo const &info) {
|
static bool UseGroup(MetaInfo const &info) {
|
||||||
size_t const num_groups =
|
size_t const num_groups =
|
||||||
@ -894,8 +893,8 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
|||||||
using WQSketch = WQuantileSketch<float, float>;
|
using WQSketch = WQuantileSketch<float, float>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
HostSketchContainer(Context const *ctx, bst_bin_t max_bins, common::Span<FeatureType const> ft,
|
||||||
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
|
std::vector<size_t> columns_size, bool use_group);
|
||||||
|
|
||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
||||||
@ -990,10 +989,10 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
|
|||||||
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
explicit SortedSketchContainer(Context const *ctx, int32_t max_bins,
|
||||||
std::vector<size_t> columns_size, bool use_group,
|
common::Span<FeatureType const> ft,
|
||||||
int32_t n_threads)
|
std::vector<size_t> columns_size, bool use_group)
|
||||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||||
monitor_.Init(__func__);
|
monitor_.Init(__func__);
|
||||||
sketches_.resize(columns_size.size());
|
sketches_.resize(columns_size.size());
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|||||||
33
src/data/batch_utils.h
Normal file
33
src/data/batch_utils.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_DATA_BATCH_UTILS_H_
|
||||||
|
#define XGBOOST_DATA_BATCH_UTILS_H_
|
||||||
|
|
||||||
|
#include "xgboost/data.h" // for BatchParam
|
||||||
|
|
||||||
|
namespace xgboost::data::detail {
|
||||||
|
// At least one batch parameter is initialized.
|
||||||
|
inline void CheckEmpty(BatchParam const& l, BatchParam const& r) {
|
||||||
|
if (!l.Initialized()) {
|
||||||
|
CHECK(r.Initialized()) << "Batch parameter is not initialized.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Should we regenerate the gradient index?
|
||||||
|
*
|
||||||
|
* \param old Parameter stored in DMatrix.
|
||||||
|
* \param p New parameter passed in by caller.
|
||||||
|
*/
|
||||||
|
inline bool RegenGHist(BatchParam old, BatchParam p) {
|
||||||
|
// Parameter is renewed or caller requests a regen
|
||||||
|
if (!p.Initialized()) {
|
||||||
|
// Empty parameter is passed in, don't regenerate so that we can use gindex in
|
||||||
|
// predictor, which doesn't have any training parameter.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return p.regen || old.ParamNotEqual(p);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::data::detail
|
||||||
|
#endif // XGBOOST_DATA_BATCH_UTILS_H_
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019 XGBoost contributors
|
* Copyright 2019-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_USE_CUDA
|
#ifndef XGBOOST_USE_CUDA
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ class EllpackPageImpl {};
|
|||||||
|
|
||||||
EllpackPage::EllpackPage() = default;
|
EllpackPage::EllpackPage() = default;
|
||||||
|
|
||||||
EllpackPage::EllpackPage(DMatrix*, const BatchParam&) {
|
EllpackPage::EllpackPage(Context const*, DMatrix*, const BatchParam&) {
|
||||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||||
"EllpackPage is required";
|
"EllpackPage is required";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,8 +17,8 @@ namespace xgboost {
|
|||||||
|
|
||||||
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
||||||
|
|
||||||
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
|
EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
|
||||||
: impl_{new EllpackPageImpl(dmat, param)} {}
|
: impl_{new EllpackPageImpl{ctx, dmat, param}} {}
|
||||||
|
|
||||||
EllpackPage::~EllpackPage() = default;
|
EllpackPage::~EllpackPage() = default;
|
||||||
|
|
||||||
@ -105,29 +105,29 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Construct an ELLPACK matrix in memory.
|
// Construct an ELLPACK matrix in memory.
|
||||||
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
|
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param)
|
||||||
: is_dense(dmat->IsDense()) {
|
: is_dense(dmat->IsDense()) {
|
||||||
monitor_.Init("ellpack_page");
|
monitor_.Init("ellpack_page");
|
||||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||||
|
|
||||||
n_rows = dmat->Info().num_row_;
|
n_rows = dmat->Info().num_row_;
|
||||||
|
|
||||||
monitor_.Start("Quantiles");
|
monitor_.Start("Quantiles");
|
||||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||||
row_stride = GetRowStride(dmat);
|
row_stride = GetRowStride(dmat);
|
||||||
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
|
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
|
||||||
monitor_.Stop("Quantiles");
|
monitor_.Stop("Quantiles");
|
||||||
|
|
||||||
monitor_.Start("InitCompressedData");
|
monitor_.Start("InitCompressedData");
|
||||||
this->InitCompressedData(param.gpu_id);
|
this->InitCompressedData(ctx->gpu_id);
|
||||||
monitor_.Stop("InitCompressedData");
|
monitor_.Stop("InitCompressedData");
|
||||||
|
|
||||||
dmat->Info().feature_types.SetDevice(param.gpu_id);
|
dmat->Info().feature_types.SetDevice(ctx->gpu_id);
|
||||||
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
|
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
|
||||||
monitor_.Start("BinningCompression");
|
monitor_.Start("BinningCompression");
|
||||||
CHECK(dmat->SingleColBlock());
|
CHECK(dmat->SingleColBlock());
|
||||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||||
CreateHistIndices(param.gpu_id, batch, ft);
|
CreateHistIndices(ctx->gpu_id, batch, ft);
|
||||||
}
|
}
|
||||||
monitor_.Stop("BinningCompression");
|
monitor_.Stop("BinningCompression");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class EllpackPageImpl {
|
|||||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||||
* in CSR format.
|
* in CSR format.
|
||||||
*/
|
*/
|
||||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
explicit EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& parm);
|
||||||
|
|
||||||
template <typename AdapterBatch>
|
template <typename AdapterBatch>
|
||||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense,
|
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2022 XGBoost contributors
|
* Copyright 2019-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -10,7 +10,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
void EllpackPageSource::Fetch() {
|
void EllpackPageSource::Fetch() {
|
||||||
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
if (!this->ReadCache()) {
|
if (!this->ReadCache()) {
|
||||||
if (count_ != 0 && !sync_) {
|
if (count_ != 0 && !sync_) {
|
||||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||||
@ -22,8 +22,7 @@ void EllpackPageSource::Fetch() {
|
|||||||
auto const &csr = source_->Page();
|
auto const &csr = source_->Page();
|
||||||
this->page_.reset(new EllpackPage{});
|
this->page_.reset(new EllpackPage{});
|
||||||
auto *impl = this->page_->Impl();
|
auto *impl = this->page_->Impl();
|
||||||
*impl = EllpackPageImpl(param_.gpu_id, *cuts_, *csr, is_dense_, row_stride_,
|
*impl = EllpackPageImpl(device_, *cuts_, *csr, is_dense_, row_stride_, feature_types_);
|
||||||
feature_types_);
|
|
||||||
page_->SetBaseRowId(csr->base_rowid);
|
page_->SetBaseRowId(csr->base_rowid);
|
||||||
this->WriteCache();
|
this->WriteCache();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2022 by XGBoost Contributors
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||||
@ -23,19 +23,21 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
|||||||
BatchParam param_;
|
BatchParam param_;
|
||||||
common::Span<FeatureType const> feature_types_;
|
common::Span<FeatureType const> feature_types_;
|
||||||
std::unique_ptr<common::HistogramCuts> cuts_;
|
std::unique_ptr<common::HistogramCuts> cuts_;
|
||||||
|
std::int32_t device_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
||||||
std::shared_ptr<Cache> cache, BatchParam param,
|
std::shared_ptr<Cache> cache, BatchParam param,
|
||||||
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
|
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
|
||||||
common::Span<FeatureType const> feature_types,
|
common::Span<FeatureType const> feature_types,
|
||||||
std::shared_ptr<SparsePageSource> source)
|
std::shared_ptr<SparsePageSource> source, std::int32_t device)
|
||||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
||||||
is_dense_{is_dense},
|
is_dense_{is_dense},
|
||||||
row_stride_{row_stride},
|
row_stride_{row_stride},
|
||||||
param_{std::move(param)},
|
param_{std::move(param)},
|
||||||
feature_types_{feature_types},
|
feature_types_{feature_types},
|
||||||
cuts_{std::move(cuts)} {
|
cuts_{std::move(cuts)},
|
||||||
|
device_{device} {
|
||||||
this->source_ = source;
|
this->source_ = source;
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by XGBoost Contributors
|
* Copyright 2017-2023, XGBoost Contributors
|
||||||
* \brief Data type for fast histogram aggregation.
|
* \brief Data type for fast histogram aggregation.
|
||||||
*/
|
*/
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
@ -19,18 +19,18 @@ namespace xgboost {
|
|||||||
|
|
||||||
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
|
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
|
||||||
|
|
||||||
GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||||
double sparse_thresh, bool sorted_sketch, int32_t n_threads,
|
double sparse_thresh, bool sorted_sketch,
|
||||||
common::Span<float> hess)
|
common::Span<float> hess)
|
||||||
: max_numeric_bins_per_feat{max_bins_per_feat} {
|
: max_numeric_bins_per_feat{max_bins_per_feat} {
|
||||||
CHECK(p_fmat->SingleColBlock());
|
CHECK(p_fmat->SingleColBlock());
|
||||||
// We use sorted sketching for approx tree method since it's more efficient in
|
// We use sorted sketching for approx tree method since it's more efficient in
|
||||||
// computation time (but higher memory usage).
|
// computation time (but higher memory usage).
|
||||||
cut = common::SketchOnDMatrix(p_fmat, max_bins_per_feat, n_threads, sorted_sketch, hess);
|
cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess);
|
||||||
|
|
||||||
const uint32_t nbins = cut.Ptrs().back();
|
const uint32_t nbins = cut.Ptrs().back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count.resize(nbins, 0);
|
||||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
hit_count_tloc_.resize(ctx->Threads() * nbins, 0);
|
||||||
|
|
||||||
size_t new_size = 1;
|
size_t new_size = 1;
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
@ -45,7 +45,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
|||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
this->PushBatch(batch, ft, n_threads);
|
this->PushBatch(batch, ft, ctx->Threads());
|
||||||
}
|
}
|
||||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
|||||||
// hist
|
// hist
|
||||||
CHECK(!sorted_sketch);
|
CHECK(!sorted_sketch);
|
||||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
this->columns_->InitFromSparse(page, *this, sparse_thresh, n_threads);
|
this->columns_->InitFromSparse(page, *this, sparse_thresh, ctx->Threads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,6 @@
|
|||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
#include "proxy_dmatrix.h"
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
@ -155,8 +154,8 @@ class GHistIndexMatrix {
|
|||||||
/**
|
/**
|
||||||
* \brief Constrcutor for SimpleDMatrix.
|
* \brief Constrcutor for SimpleDMatrix.
|
||||||
*/
|
*/
|
||||||
GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh,
|
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
|
||||||
bool sorted_sketch, int32_t n_threads, common::Span<float> hess = {});
|
double sparse_thresh, bool sorted_sketch, common::Span<float> hess = {});
|
||||||
/**
|
/**
|
||||||
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
|
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
|
||||||
* for push batch.
|
* for push batch.
|
||||||
@ -295,28 +294,5 @@ void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Should we regenerate the gradient index?
|
|
||||||
*
|
|
||||||
* \param old Parameter stored in DMatrix.
|
|
||||||
* \param p New parameter passed in by caller.
|
|
||||||
*/
|
|
||||||
inline bool RegenGHist(BatchParam old, BatchParam p) {
|
|
||||||
// parameter is renewed or caller requests a regen
|
|
||||||
if (p == BatchParam{}) {
|
|
||||||
// empty parameter is passed in, don't regenerate so that we can use gindex in
|
|
||||||
// predictor, which doesn't have any training parameter.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Avoid comparing nan values.
|
|
||||||
bool l_nan = std::isnan(old.sparse_thresh);
|
|
||||||
bool r_nan = std::isnan(p.sparse_thresh);
|
|
||||||
// regenerate if parameter is changed.
|
|
||||||
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (old.sparse_thresh != p.sparse_thresh));
|
|
||||||
bool param_chg = old.gpu_id != p.gpu_id || old.max_bin != p.max_bin;
|
|
||||||
return p.regen || param_chg || st_chg;
|
|
||||||
}
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||||
|
|||||||
@ -1,25 +1,26 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "iterative_dmatrix.h"
|
#include "iterative_dmatrix.h"
|
||||||
|
|
||||||
#include <algorithm> // std::copy
|
#include <algorithm> // for copy
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // for size_t
|
||||||
#include <type_traits> // std::underlying_type_t
|
#include <memory> // for shared_ptr
|
||||||
#include <vector> // std::vector
|
#include <type_traits> // for underlying_type_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/categorical.h" // common::IsCat
|
#include "../common/categorical.h" // common::IsCat
|
||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||||
|
#include "batch_utils.h" // for RegenGHist
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
#include "simple_batch_iterator.h"
|
#include "simple_batch_iterator.h"
|
||||||
#include "xgboost/data.h" // FeatureType
|
#include "xgboost/data.h" // for FeatureType, DMatrix
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||||
@ -34,60 +35,61 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
|
|||||||
|
|
||||||
auto d = MakeProxy(proxy_)->DeviceIdx();
|
auto d = MakeProxy(proxy_)->DeviceIdx();
|
||||||
|
|
||||||
StringView msg{"All batch should be on the same device."};
|
Context ctx;
|
||||||
if (batch_param_.gpu_id != Context::kCpuId) {
|
ctx.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||||
CHECK_EQ(d, batch_param_.gpu_id) << msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
batch_param_ = BatchParam{d, max_bin};
|
|
||||||
// hardcoded parameter.
|
// hardcoded parameter.
|
||||||
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
|
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
ctx_.UpdateAllowUnknown(
|
if (ctx.IsCPU()) {
|
||||||
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
this->InitFromCPU(&ctx, p, iter_handle, missing, ref);
|
||||||
if (ctx_.IsCPU()) {
|
|
||||||
this->InitFromCPU(iter_handle, missing, ref);
|
|
||||||
} else {
|
} else {
|
||||||
this->InitFromCUDA(iter_handle, missing, ref);
|
this->InitFromCUDA(&ctx, p, iter_handle, missing, ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this->fmat_ctx_ = ctx;
|
||||||
|
this->batch_ = p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
|
void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
|
||||||
common::HistogramCuts* p_cuts) {
|
BatchParam p, common::HistogramCuts* p_cuts) {
|
||||||
CHECK(ref_);
|
CHECK(ref);
|
||||||
CHECK(p_cuts);
|
CHECK(p_cuts);
|
||||||
auto csr = [&]() {
|
p.forbid_regen = true;
|
||||||
for (auto const& page : ref_->GetBatches<GHistIndexMatrix>(p)) {
|
// Fetch cuts from GIDX
|
||||||
|
auto csr = [&] {
|
||||||
|
for (auto const& page : ref->GetBatches<GHistIndexMatrix>(ctx, p)) {
|
||||||
*p_cuts = page.cut;
|
*p_cuts = page.cut;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto ellpack = [&]() {
|
// Fetch cuts from Ellpack.
|
||||||
// workaround ellpack being initialized from CPU.
|
auto ellpack = [&] {
|
||||||
if (p.gpu_id == Context::kCpuId) {
|
for (auto const& page : ref->GetBatches<EllpackPage>(ctx, p)) {
|
||||||
p.gpu_id = ref_->Ctx()->gpu_id;
|
|
||||||
}
|
|
||||||
if (p.gpu_id == Context::kCpuId) {
|
|
||||||
p.gpu_id = 0;
|
|
||||||
}
|
|
||||||
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
|
|
||||||
GetCutsFromEllpack(page, p_cuts);
|
GetCutsFromEllpack(page, p_cuts);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (ref_->PageExists<GHistIndexMatrix>()) {
|
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
|
||||||
|
// Both exists
|
||||||
|
if (ctx->IsCPU()) {
|
||||||
|
csr();
|
||||||
|
} else {
|
||||||
|
ellpack();
|
||||||
|
}
|
||||||
|
} else if (ref->PageExists<GHistIndexMatrix>()) {
|
||||||
csr();
|
csr();
|
||||||
} else if (ref_->PageExists<EllpackPage>()) {
|
} else if (ref->PageExists<EllpackPage>()) {
|
||||||
ellpack();
|
ellpack();
|
||||||
} else {
|
} else {
|
||||||
if (p.gpu_id == Context::kCpuId) {
|
// None exist
|
||||||
|
if (ctx->IsCPU()) {
|
||||||
csr();
|
csr();
|
||||||
} else {
|
} else {
|
||||||
ellpack();
|
ellpack();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_EQ(ref_->Info().num_col_, n_features)
|
CHECK_EQ(ref->Info().num_col_, n_features)
|
||||||
<< "Invalid ref DMatrix, different number of features.";
|
<< "Invalid ref DMatrix, different number of features.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +114,8 @@ void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
||||||
|
DataIterHandle iter_handle, float missing,
|
||||||
std::shared_ptr<DMatrix> ref) {
|
std::shared_ptr<DMatrix> ref) {
|
||||||
DMatrixProxy* proxy = MakeProxy(proxy_);
|
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||||
CHECK(proxy);
|
CHECK(proxy);
|
||||||
@ -133,7 +136,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
auto const is_valid = data::IsValidFunctor{missing};
|
auto const is_valid = data::IsValidFunctor{missing};
|
||||||
auto nnz_cnt = [&]() {
|
auto nnz_cnt = [&]() {
|
||||||
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
||||||
size_t n_threads = ctx_.Threads();
|
size_t n_threads = ctx->Threads();
|
||||||
size_t n_features = column_sizes.size();
|
size_t n_features = column_sizes.size();
|
||||||
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
||||||
column_sizes_tloc.Data()->Fill(0ul);
|
column_sizes_tloc.Data()->Fill(0ul);
|
||||||
@ -158,10 +161,10 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t n_features = 0;
|
std::uint64_t n_features = 0;
|
||||||
size_t n_batches = 0;
|
std::size_t n_batches = 0;
|
||||||
size_t accumulated_rows{0};
|
std::uint64_t accumulated_rows{0};
|
||||||
size_t nnz{0};
|
std::uint64_t nnz{0};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* CPU impl needs an additional loop for accumulating the column size.
|
* CPU impl needs an additional loop for accumulating the column size.
|
||||||
@ -203,7 +206,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
accumulated_rows = 0;
|
accumulated_rows = 0;
|
||||||
std::vector<FeatureType> h_ft;
|
std::vector<FeatureType> h_ft;
|
||||||
if (ref) {
|
if (ref) {
|
||||||
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
|
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||||
h_ft = ref->Info().feature_types.HostVector();
|
h_ft = ref->Info().feature_types.HostVector();
|
||||||
} else {
|
} else {
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
@ -211,9 +214,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
if (!p_sketch) {
|
if (!p_sketch) {
|
||||||
h_ft = proxy->Info().feature_types.ConstHostVector();
|
h_ft = proxy->Info().feature_types.ConstHostVector();
|
||||||
SyncFeatureType(&h_ft);
|
SyncFeatureType(&h_ft);
|
||||||
p_sketch.reset(new common::HostSketchContainer{
|
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
|
||||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
!proxy->Info().group_ptr_.empty()});
|
||||||
ctx_.Threads()});
|
|
||||||
}
|
}
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
@ -237,15 +239,15 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
/**
|
/**
|
||||||
* Generate gradient index.
|
* Generate gradient index.
|
||||||
*/
|
*/
|
||||||
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), batch_param_.max_bin);
|
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
|
||||||
size_t rbegin = 0;
|
size_t rbegin = 0;
|
||||||
size_t prev_sum = 0;
|
size_t prev_sum = 0;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing, h_ft,
|
this->ghist_->PushAdapterBatch(ctx, rbegin, prev_sum, batch, missing, h_ft, p.sparse_thresh,
|
||||||
batch_param_.sparse_thresh, Info().num_row_);
|
Info().num_row_);
|
||||||
});
|
});
|
||||||
if (n_batches != 1) {
|
if (n_batches != 1) {
|
||||||
this->info_.Extend(std::move(proxy->Info()), false, true);
|
this->info_.Extend(std::move(proxy->Info()), false, true);
|
||||||
@ -265,7 +267,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
accumulated_rows = 0;
|
accumulated_rows = 0;
|
||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows);
|
this->ghist_->PushAdapterBatchColumns(ctx, batch, missing, accumulated_rows);
|
||||||
});
|
});
|
||||||
accumulated_rows += num_rows();
|
accumulated_rows += num_rows();
|
||||||
}
|
}
|
||||||
@ -282,11 +284,27 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
Info().feature_types.HostVector() = h_ft;
|
Info().feature_types.HostVector() = h_ft;
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(Context const* ctx,
|
||||||
CheckParam(param);
|
BatchParam const& param) {
|
||||||
|
if (param.Initialized()) {
|
||||||
|
CheckParam(param);
|
||||||
|
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
|
||||||
|
}
|
||||||
|
if (!ellpack_ && !ghist_) {
|
||||||
|
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
||||||
|
}
|
||||||
|
|
||||||
if (!ghist_) {
|
if (!ghist_) {
|
||||||
CHECK(ellpack_);
|
if (ctx->IsCPU()) {
|
||||||
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
|
ghist_ = std::make_shared<GHistIndexMatrix>(ctx, Info(), *ellpack_, param);
|
||||||
|
} else if (fmat_ctx_.IsCPU()) {
|
||||||
|
ghist_ = std::make_shared<GHistIndexMatrix>(&fmat_ctx_, Info(), *ellpack_, param);
|
||||||
|
} else {
|
||||||
|
// Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM
|
||||||
|
// for cut reference.
|
||||||
|
auto cpu_ctx = ctx->MakeCPU();
|
||||||
|
ghist_ = std::make_shared<GHistIndexMatrix>(&cpu_ctx, Info(), *ellpack_, param);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!std::isnan(param.sparse_thresh) &&
|
if (!std::isnan(param.sparse_thresh) &&
|
||||||
@ -300,8 +318,9 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
|||||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
|
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(Context const* ctx,
|
||||||
for (auto const& page : this->GetGradientIndex(param)) {
|
BatchParam const& param) {
|
||||||
|
for (auto const& page : this->GetGradientIndex(ctx, param)) {
|
||||||
auto p_out = std::make_shared<SparsePage>();
|
auto p_out = std::make_shared<SparsePage>();
|
||||||
p_out->data.Resize(this->Info().num_nonzero_);
|
p_out->data.Resize(this->Info().num_nonzero_);
|
||||||
p_out->offset.Resize(this->Info().num_row_ + 1);
|
p_out->offset.Resize(this->Info().num_row_ + 1);
|
||||||
@ -336,5 +355,26 @@ BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param)
|
|||||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
|
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
|
||||||
return BatchSet<ExtSparsePage>(begin_iter);
|
return BatchSet<ExtSparsePage>(begin_iter);
|
||||||
}
|
}
|
||||||
} // namespace data
|
|
||||||
} // namespace xgboost
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, DataIterHandle, float,
|
||||||
|
std::shared_ptr<DMatrix>) {
|
||||||
|
// silent the warning about unused variables.
|
||||||
|
(void)(proxy_);
|
||||||
|
(void)(reset_);
|
||||||
|
(void)(next_);
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
|
BatchParam const& param) {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||||
|
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void GetCutsFromEllpack(EllpackPage const&, common::HistogramCuts*) {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,22 +1,24 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2022 XGBoost contributors
|
* Copyright 2020-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "../common/hist_util.cuh"
|
#include "../common/hist_util.cuh"
|
||||||
|
#include "batch_utils.h" // for RegenGHist
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
|
#include "gradient_index.h"
|
||||||
#include "iterative_dmatrix.h"
|
#include "iterative_dmatrix.h"
|
||||||
#include "proxy_dmatrix.cuh"
|
#include "proxy_dmatrix.cuh"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
#include "simple_batch_iterator.h"
|
#include "simple_batch_iterator.h"
|
||||||
#include "sparse_page_source.h"
|
#include "sparse_page_source.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||||
void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
DataIterHandle iter_handle, float missing,
|
||||||
std::shared_ptr<DMatrix> ref) {
|
std::shared_ptr<DMatrix> ref) {
|
||||||
// A handle passed to external iterator.
|
// A handle passed to external iterator.
|
||||||
DMatrixProxy* proxy = MakeProxy(proxy_);
|
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||||
@ -46,7 +48,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
int32_t current_device;
|
int32_t current_device;
|
||||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||||
auto get_device = [&]() -> int32_t {
|
auto get_device = [&]() -> int32_t {
|
||||||
int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
|
std::int32_t d = (ctx->gpu_id == Context::kCpuId) ? current_device : ctx->gpu_id;
|
||||||
CHECK_NE(d, Context::kCpuId);
|
CHECK_NE(d, Context::kCpuId);
|
||||||
return d;
|
return d;
|
||||||
};
|
};
|
||||||
@ -57,8 +59,8 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
do {
|
do {
|
||||||
// We use do while here as the first batch is fetched in ctor
|
// We use do while here as the first batch is fetched in ctor
|
||||||
ctx_.gpu_id = proxy->DeviceIdx();
|
// ctx_.gpu_id = proxy->DeviceIdx();
|
||||||
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
|
CHECK_LT(ctx->gpu_id, common::AllVisibleGPUs());
|
||||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||||
if (cols == 0) {
|
if (cols == 0) {
|
||||||
cols = num_cols();
|
cols = num_cols();
|
||||||
@ -68,12 +70,12 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||||
}
|
}
|
||||||
if (!ref) {
|
if (!ref) {
|
||||||
sketch_containers.emplace_back(proxy->Info().feature_types, batch_param_.max_bin, cols,
|
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, cols, num_rows(),
|
||||||
num_rows(), get_device());
|
get_device());
|
||||||
auto* p_sketch = &sketch_containers.back();
|
auto* p_sketch = &sketch_containers.back();
|
||||||
proxy->Info().weights_.SetDevice(get_device());
|
proxy->Info().weights_.SetDevice(get_device());
|
||||||
Dispatch(proxy, [&](auto const& value) {
|
Dispatch(proxy, [&](auto const& value) {
|
||||||
common::AdapterDeviceSketch(value, batch_param_.max_bin, proxy->Info(), missing, p_sketch);
|
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
auto batch_rows = num_rows();
|
auto batch_rows = num_rows();
|
||||||
@ -95,8 +97,8 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
if (!ref) {
|
if (!ref) {
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
common::SketchContainer final_sketch(
|
common::SketchContainer final_sketch(
|
||||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(),
|
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, cols,
|
||||||
batch_param_.max_bin, cols, accumulated_rows, get_device());
|
accumulated_rows, get_device());
|
||||||
for (auto const& sketch : sketch_containers) {
|
for (auto const& sketch : sketch_containers) {
|
||||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||||
final_sketch.FixError();
|
final_sketch.FixError();
|
||||||
@ -106,7 +108,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
|
|
||||||
final_sketch.MakeCuts(&cuts);
|
final_sketch.MakeCuts(&cuts);
|
||||||
} else {
|
} else {
|
||||||
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
|
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||||
}
|
}
|
||||||
|
|
||||||
this->info_.num_row_ = accumulated_rows;
|
this->info_.num_row_ = accumulated_rows;
|
||||||
@ -169,24 +171,34 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns();
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
CheckParam(param);
|
BatchParam const& param) {
|
||||||
|
if (param.Initialized()) {
|
||||||
|
CheckParam(param);
|
||||||
|
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
|
||||||
|
}
|
||||||
if (!ellpack_ && !ghist_) {
|
if (!ellpack_ && !ghist_) {
|
||||||
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
||||||
}
|
}
|
||||||
if (!ellpack_ && ghist_) {
|
|
||||||
|
if (!ellpack_) {
|
||||||
ellpack_.reset(new EllpackPage());
|
ellpack_.reset(new EllpackPage());
|
||||||
// Evaluation QuantileDMatrix initialized from CPU data might not have the correct GPU
|
if (ctx->IsCUDA()) {
|
||||||
// ID.
|
this->Info().feature_types.SetDevice(ctx->gpu_id);
|
||||||
if (this->ctx_.IsCPU()) {
|
*ellpack_->Impl() =
|
||||||
this->ctx_.gpu_id = param.gpu_id;
|
EllpackPageImpl(ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||||
|
} else if (fmat_ctx_.IsCUDA()) {
|
||||||
|
this->Info().feature_types.SetDevice(fmat_ctx_.gpu_id);
|
||||||
|
*ellpack_->Impl() =
|
||||||
|
EllpackPageImpl(&fmat_ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||||
|
} else {
|
||||||
|
// Can happen when QDM is initialized on CPU, but a GPU version is queried by a different QDM
|
||||||
|
// for cut reference.
|
||||||
|
auto cuda_ctx = ctx->MakeCUDA();
|
||||||
|
this->Info().feature_types.SetDevice(cuda_ctx.gpu_id);
|
||||||
|
*ellpack_->Impl() =
|
||||||
|
EllpackPageImpl(&cuda_ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||||
}
|
}
|
||||||
if (this->ctx_.IsCPU()) {
|
|
||||||
this->ctx_.gpu_id = dh::CurrentDevice();
|
|
||||||
}
|
|
||||||
this->Info().feature_types.SetDevice(this->ctx_.gpu_id);
|
|
||||||
*ellpack_->Impl() =
|
|
||||||
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
|
||||||
}
|
}
|
||||||
CHECK(ellpack_);
|
CHECK(ellpack_);
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||||
@ -196,5 +208,4 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& para
|
|||||||
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
|
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
|
||||||
*cuts = page.Impl()->Cuts();
|
*cuts = page.Impl()->Cuts();
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2022 by Contributors
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
* \file iterative_dmatrix.h
|
* \file iterative_dmatrix.h
|
||||||
|
*
|
||||||
|
* \brief Implementation of the higher-level `QuantileDMatrix`.
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_ITERATIVE_DMATRIX_H_
|
#ifndef XGBOOST_DATA_ITERATIVE_DMATRIX_H_
|
||||||
#define XGBOOST_DATA_ITERATIVE_DMATRIX_H_
|
#define XGBOOST_DATA_ITERATIVE_DMATRIX_H_
|
||||||
@ -10,10 +12,12 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/error_msg.h"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
#include "simple_batch_iterator.h"
|
#include "simple_batch_iterator.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -43,21 +47,17 @@ namespace data {
|
|||||||
*/
|
*/
|
||||||
class IterativeDMatrix : public DMatrix {
|
class IterativeDMatrix : public DMatrix {
|
||||||
MetaInfo info_;
|
MetaInfo info_;
|
||||||
Context ctx_;
|
|
||||||
BatchParam batch_param_;
|
|
||||||
std::shared_ptr<EllpackPage> ellpack_;
|
std::shared_ptr<EllpackPage> ellpack_;
|
||||||
std::shared_ptr<GHistIndexMatrix> ghist_;
|
std::shared_ptr<GHistIndexMatrix> ghist_;
|
||||||
|
BatchParam batch_;
|
||||||
|
|
||||||
DMatrixHandle proxy_;
|
DMatrixHandle proxy_;
|
||||||
DataIterResetCallback *reset_;
|
DataIterResetCallback *reset_;
|
||||||
XGDMatrixCallbackNext *next_;
|
XGDMatrixCallbackNext *next_;
|
||||||
|
Context fmat_ctx_;
|
||||||
|
|
||||||
void CheckParam(BatchParam const ¶m) {
|
void CheckParam(BatchParam const ¶m) {
|
||||||
// FIXME(Jiamingy): https://github.com/dmlc/xgboost/issues/7976
|
CHECK_EQ(param.max_bin, batch_.max_bin) << error::InconsistentMaxBin();
|
||||||
if (param.max_bin != batch_param_.max_bin && param.max_bin != 0) {
|
|
||||||
LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin
|
|
||||||
<< " vs. " << batch_param_.max_bin;
|
|
||||||
}
|
|
||||||
CHECK(!param.regen && param.hess.empty())
|
CHECK(!param.regen && param.hess.empty())
|
||||||
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
|
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
|
||||||
}
|
}
|
||||||
@ -68,8 +68,10 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref);
|
void InitFromCUDA(Context const *ctx, BatchParam const &p, DataIterHandle iter_handle,
|
||||||
void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref);
|
float missing, std::shared_ptr<DMatrix> ref);
|
||||||
|
void InitFromCPU(Context const *ctx, BatchParam const &p, DataIterHandle iter_handle,
|
||||||
|
float missing, std::shared_ptr<DMatrix> ref);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||||
@ -94,51 +96,40 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
LOG(FATAL) << "Not implemented.";
|
LOG(FATAL) << "Not implemented.";
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||||
}
|
}
|
||||||
BatchSet<CSCPage> GetColumnBatches() override { return InvalidTreeMethod<CSCPage>(); }
|
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
return InvalidTreeMethod<CSCPage>();
|
||||||
|
}
|
||||||
|
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const *) override {
|
||||||
return InvalidTreeMethod<SortedCSCPage>();
|
return InvalidTreeMethod<SortedCSCPage>();
|
||||||
}
|
}
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const ¶m) override;
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, BatchParam const ¶m) override;
|
||||||
|
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam ¶m) override;
|
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
||||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
|
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const ¶m) override;
|
||||||
|
|
||||||
bool SingleColBlock() const override { return true; }
|
bool SingleColBlock() const override { return true; }
|
||||||
|
|
||||||
MetaInfo &Info() override { return info_; }
|
MetaInfo &Info() override { return info_; }
|
||||||
MetaInfo const &Info() const override { return info_; }
|
MetaInfo const &Info() const override { return info_; }
|
||||||
|
|
||||||
Context const *Ctx() const override { return &ctx_; }
|
Context const *Ctx() const override { return &fmat_ctx_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get quantile cuts from reference Quantile DMatrix.
|
* \brief Get quantile cuts from reference (Quantile)DMatrix.
|
||||||
|
*
|
||||||
|
* \param ctx The context of the new DMatrix.
|
||||||
|
* \param ref The reference DMatrix.
|
||||||
|
* \param n_features Number of features, used for validation only.
|
||||||
|
* \param p Batch parameter for the new DMatrix.
|
||||||
|
* \param p_cuts Output quantile cuts.
|
||||||
*/
|
*/
|
||||||
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
|
void GetCutsFromRef(Context const *ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
|
||||||
common::HistogramCuts *p_cuts);
|
BatchParam p, common::HistogramCuts *p_cuts);
|
||||||
/**
|
/**
|
||||||
* \brief Get quantile cuts from ellpack page.
|
* \brief Get quantile cuts from ellpack page.
|
||||||
*/
|
*/
|
||||||
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
|
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
|
||||||
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle, float, std::shared_ptr<DMatrix>) {
|
|
||||||
// silent the warning about unused variables.
|
|
||||||
(void)(proxy_);
|
|
||||||
(void)(reset_);
|
|
||||||
(void)(next_);
|
|
||||||
common::AssertGPUSupport();
|
|
||||||
}
|
|
||||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &) {
|
|
||||||
common::AssertGPUSupport();
|
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void GetCutsFromEllpack(EllpackPage const &, common::HistogramCuts *) {
|
|
||||||
common::AssertGPUSupport();
|
|
||||||
}
|
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -25,16 +25,11 @@ class DataIterProxy {
|
|||||||
NextFn* next_;
|
NextFn* next_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next) :
|
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next)
|
||||||
iter_{iter},
|
: iter_{iter}, reset_{reset}, next_{next} {}
|
||||||
reset_{reset}, next_{next} {}
|
|
||||||
|
|
||||||
bool Next() {
|
bool Next() { return next_(iter_); }
|
||||||
return next_(iter_);
|
void Reset() { reset_(iter_); }
|
||||||
}
|
|
||||||
void Reset() {
|
|
||||||
reset_(iter_);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -68,9 +63,8 @@ class DMatrixProxy : public DMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetArrayData(char const* c_interface);
|
void SetArrayData(char const* c_interface);
|
||||||
void SetCSRData(char const *c_indptr, char const *c_indices,
|
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
|
||||||
char const *c_values, bst_feature_t n_features,
|
bst_feature_t n_features, bool on_host);
|
||||||
bool on_host);
|
|
||||||
|
|
||||||
MetaInfo& Info() override { return info_; }
|
MetaInfo& Info() override { return info_; }
|
||||||
MetaInfo const& Info() const override { return info_; }
|
MetaInfo const& Info() const override { return info_; }
|
||||||
@ -81,6 +75,12 @@ class DMatrixProxy : public DMatrix {
|
|||||||
bool GHistIndexExists() const override { return false; }
|
bool GHistIndexExists() const override { return false; }
|
||||||
bool SparsePageExists() const override { return false; }
|
bool SparsePageExists() const override { return false; }
|
||||||
|
|
||||||
|
template <typename Page>
|
||||||
|
BatchSet<Page> NoBatch() {
|
||||||
|
LOG(FATAL) << "Proxy DMatrix cannot return data batch.";
|
||||||
|
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
DMatrix* Slice(common::Span<int32_t const> /*ridxs*/) override {
|
DMatrix* Slice(common::Span<int32_t const> /*ridxs*/) override {
|
||||||
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -89,29 +89,19 @@ class DMatrixProxy : public DMatrix {
|
|||||||
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
|
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
BatchSet<SparsePage> GetRowBatches() override {
|
BatchSet<SparsePage> GetRowBatches() override { return NoBatch<SparsePage>(); }
|
||||||
LOG(FATAL) << "Not implemented.";
|
BatchSet<CSCPage> GetColumnBatches(Context const*) override { return NoBatch<CSCPage>(); }
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const*) override {
|
||||||
|
return NoBatch<SortedCSCPage>();
|
||||||
}
|
}
|
||||||
BatchSet<CSCPage> GetColumnBatches() override {
|
BatchSet<EllpackPage> GetEllpackBatches(Context const*, BatchParam const&) override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
return NoBatch<EllpackPage>();
|
||||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
|
||||||
}
|
}
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const*, BatchParam const&) override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
return NoBatch<GHistIndexMatrix>();
|
||||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
|
||||||
}
|
}
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam&) override {
|
BatchSet<ExtSparsePage> GetExtBatches(Context const*, BatchParam const&) override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
return NoBatch<ExtSparsePage>();
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
|
|
||||||
}
|
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
|
|
||||||
LOG(FATAL) << "Not implemented.";
|
|
||||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
|
||||||
}
|
|
||||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const&) override {
|
|
||||||
LOG(FATAL) << "Not implemented.";
|
|
||||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
|
||||||
}
|
}
|
||||||
std::any Adapter() const { return batch_; }
|
std::any Adapter() const { return batch_; }
|
||||||
};
|
};
|
||||||
@ -144,8 +134,7 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
|
|||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||||
}
|
}
|
||||||
return std::result_of_t<Fn(
|
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||||
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -11,10 +11,12 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/error_msg.h" // for InconsistentMaxBin
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "./simple_batch_iterator.h"
|
#include "./simple_batch_iterator.h"
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
|
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -28,7 +30,7 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
|||||||
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||||
auto out = new SimpleDMatrix;
|
auto out = new SimpleDMatrix;
|
||||||
SparsePage& out_page = *out->sparse_page_;
|
SparsePage& out_page = *out->sparse_page_;
|
||||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
for (auto const& page : this->GetBatches<SparsePage>()) {
|
||||||
auto batch = page.GetView();
|
auto batch = page.GetView();
|
||||||
auto& h_data = out_page.data.HostVector();
|
auto& h_data = out_page.data.HostVector();
|
||||||
auto& h_offset = out_page.offset.HostVector();
|
auto& h_offset = out_page.offset.HostVector();
|
||||||
@ -42,7 +44,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
|||||||
out->Info() = this->Info().Slice(ridxs);
|
out->Info() = this->Info().Slice(ridxs);
|
||||||
out->Info().num_nonzero_ = h_offset.back();
|
out->Info().num_nonzero_ = h_offset.back();
|
||||||
}
|
}
|
||||||
out->ctx_ = this->ctx_;
|
out->fmat_ctx_ = this->fmat_ctx_;
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +54,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
|||||||
auto const slice_size = info_.num_col_ / num_slices;
|
auto const slice_size = info_.num_col_ / num_slices;
|
||||||
auto const slice_start = slice_size * slice_id;
|
auto const slice_start = slice_size * slice_id;
|
||||||
auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size;
|
auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size;
|
||||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
for (auto const& page : this->GetBatches<SparsePage>()) {
|
||||||
auto batch = page.GetView();
|
auto batch = page.GetView();
|
||||||
auto& h_data = out_page.data.HostVector();
|
auto& h_data = out_page.data.HostVector();
|
||||||
auto& h_offset = out_page.offset.HostVector();
|
auto& h_offset = out_page.offset.HostVector();
|
||||||
@ -60,9 +62,8 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
|||||||
for (bst_row_t i = 0; i < this->Info().num_row_; i++) {
|
for (bst_row_t i = 0; i < this->Info().num_row_; i++) {
|
||||||
auto inst = batch[i];
|
auto inst = batch[i];
|
||||||
auto prev_size = h_data.size();
|
auto prev_size = h_data.size();
|
||||||
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
|
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data),
|
||||||
return e.index >= slice_start && e.index < slice_end;
|
[&](Entry e) { return e.index >= slice_start && e.index < slice_end; });
|
||||||
});
|
|
||||||
rptr += h_data.size() - prev_size;
|
rptr += h_data.size() - prev_size;
|
||||||
h_offset.emplace_back(rptr);
|
h_offset.emplace_back(rptr);
|
||||||
}
|
}
|
||||||
@ -73,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleDMatrix::ReindexFeatures() {
|
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||||
if (info_.IsVerticalFederated()) {
|
if (info_.IsVerticalFederated()) {
|
||||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||||
buffer[collective::GetRank()] = info_.num_col_;
|
buffer[collective::GetRank()] = info_.num_col_;
|
||||||
@ -82,72 +83,115 @@ void SimpleDMatrix::ReindexFeatures() {
|
|||||||
if (offset == 0) {
|
if (offset == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
sparse_page_->Reindex(offset, ctx_.Threads());
|
sparse_page_->Reindex(offset, ctx->Threads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||||
// since csr is the default data structure so `source_` is always available.
|
// since csr is the default data structure so `source_` is always available.
|
||||||
auto begin_iter = BatchIterator<SparsePage>(
|
auto begin_iter =
|
||||||
new SimpleBatchIteratorImpl<SparsePage>(sparse_page_));
|
BatchIterator<SparsePage>(new SimpleBatchIteratorImpl<SparsePage>(sparse_page_));
|
||||||
return BatchSet<SparsePage>(begin_iter);
|
return BatchSet<SparsePage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
|
||||||
// column page doesn't exist, generate it
|
// column page doesn't exist, generate it
|
||||||
if (!column_page_) {
|
if (!column_page_) {
|
||||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||||
}
|
}
|
||||||
auto begin_iter =
|
auto begin_iter = BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
||||||
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
|
||||||
return BatchSet<CSCPage>(begin_iter);
|
return BatchSet<CSCPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches(Context const* ctx) {
|
||||||
// Sorted column page doesn't exist, generate it
|
// Sorted column page doesn't exist, generate it
|
||||||
if (!sorted_column_page_) {
|
if (!sorted_column_page_) {
|
||||||
sorted_column_page_.reset(
|
sorted_column_page_.reset(
|
||||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||||
sorted_column_page_->SortRows(ctx_.Threads());
|
sorted_column_page_->SortRows(ctx->Threads());
|
||||||
}
|
}
|
||||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
auto begin_iter =
|
||||||
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
|
BatchIterator<SortedCSCPage>(new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
|
||||||
return BatchSet<SortedCSCPage>(begin_iter);
|
return BatchSet<SortedCSCPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
void CheckEmpty(BatchParam const& l, BatchParam const& r) {
|
const BatchParam& param) {
|
||||||
if (l == BatchParam{}) {
|
detail::CheckEmpty(batch_param_, param);
|
||||||
CHECK(r != BatchParam{}) << "Batch parameter is not initialized.";
|
if (ellpack_page_ && param.Initialized() && param.forbid_regen) {
|
||||||
|
if (detail::RegenGHist(batch_param_, param)) {
|
||||||
|
CHECK_EQ(batch_param_.max_bin, param.max_bin) << error::InconsistentMaxBin();
|
||||||
|
}
|
||||||
|
CHECK(!detail::RegenGHist(batch_param_, param));
|
||||||
}
|
}
|
||||||
}
|
if (!ellpack_page_ || detail::RegenGHist(batch_param_, param)) {
|
||||||
} // anonymous namespace
|
// ELLPACK page doesn't exist, generate it
|
||||||
|
LOG(INFO) << "Generating new Ellpack page.";
|
||||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
// These places can ask for a ellpack page:
|
||||||
// ELLPACK page doesn't exist, generate it
|
// - GPU hist: the ctx must be on CUDA.
|
||||||
CheckEmpty(batch_param_, param);
|
// - IterativeDMatrix::InitFromCUDA: The ctx must be on CUDA.
|
||||||
if (!ellpack_page_ || RegenGHist(batch_param_, param)) {
|
// - IterativeDMatrix::InitFromCPU: It asks for ellpack only if it exists. It should
|
||||||
CHECK_GE(param.gpu_id, 0);
|
// not regen, otherwise it indicates a mismatched parameter like max_bin.
|
||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
ellpack_page_.reset(new EllpackPage(this, param));
|
if (ctx->IsCUDA()) {
|
||||||
batch_param_ = param;
|
// The context passed in is on GPU, we pick it first since we prioritize the context
|
||||||
|
// in Booster.
|
||||||
|
ellpack_page_.reset(new EllpackPage(ctx, this, param));
|
||||||
|
} else if (fmat_ctx_.IsCUDA()) {
|
||||||
|
// DMatrix was initialized on GPU, we use the context from initialization.
|
||||||
|
ellpack_page_.reset(new EllpackPage(&fmat_ctx_, this, param));
|
||||||
|
} else {
|
||||||
|
// Mismatched parameter, user set a new max_bin during training.
|
||||||
|
auto cuda_ctx = ctx->MakeCUDA();
|
||||||
|
ellpack_page_.reset(new EllpackPage(&cuda_ctx, this, param));
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_param_ = param.MakeCache();
|
||||||
}
|
}
|
||||||
auto begin_iter =
|
auto begin_iter =
|
||||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_));
|
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_));
|
||||||
return BatchSet<EllpackPage>(begin_iter);
|
return BatchSet<EllpackPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
|
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(Context const* ctx,
|
||||||
CheckEmpty(batch_param_, param);
|
const BatchParam& param) {
|
||||||
if (!gradient_index_ || RegenGHist(batch_param_, param)) {
|
detail::CheckEmpty(batch_param_, param);
|
||||||
|
// Check whether we can regenerate the gradient index. This is to keep the consistency
|
||||||
|
// between evaluation data and training data.
|
||||||
|
if (gradient_index_ && param.Initialized() && param.forbid_regen) {
|
||||||
|
if (detail::RegenGHist(batch_param_, param)) {
|
||||||
|
CHECK_EQ(batch_param_.max_bin, param.max_bin) << error::InconsistentMaxBin();
|
||||||
|
}
|
||||||
|
CHECK(!detail::RegenGHist(batch_param_, param)) << "Inconsistent sparse threshold.";
|
||||||
|
}
|
||||||
|
if (!gradient_index_ || detail::RegenGHist(batch_param_, param)) {
|
||||||
|
// GIDX page doesn't exist, generate it
|
||||||
LOG(INFO) << "Generating new Gradient Index.";
|
LOG(INFO) << "Generating new Gradient Index.";
|
||||||
|
// These places can ask for a CSR gidx:
|
||||||
|
// - CPU Hist: the ctx must be on CPU.
|
||||||
|
// - IterativeDMatrix::InitFromCPU: The ctx must be on CPU.
|
||||||
|
// - IterativeDMatrix::InitFromCUDA: It asks for gidx only if it exists. It should not
|
||||||
|
// regen, otherwise it indicates a mismatched parameter like max_bin.
|
||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
CHECK_EQ(param.gpu_id, -1);
|
|
||||||
// Used only by approx.
|
// Used only by approx.
|
||||||
auto sorted_sketch = param.regen;
|
auto sorted_sketch = param.regen;
|
||||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.sparse_thresh,
|
if (ctx->IsCPU()) {
|
||||||
sorted_sketch, this->ctx_.Threads(), param.hess));
|
// The context passed in is on CPU, we pick it first since we prioritize the context
|
||||||
batch_param_ = param;
|
// in Booster.
|
||||||
|
gradient_index_.reset(new GHistIndexMatrix{ctx, this, param.max_bin, param.sparse_thresh,
|
||||||
|
sorted_sketch, param.hess});
|
||||||
|
} else if (fmat_ctx_.IsCPU()) {
|
||||||
|
// DMatrix was initialized on CPU, we use the context from initialization.
|
||||||
|
gradient_index_.reset(new GHistIndexMatrix{&fmat_ctx_, this, param.max_bin,
|
||||||
|
param.sparse_thresh, sorted_sketch, param.hess});
|
||||||
|
} else {
|
||||||
|
// Mismatched parameter, user set a new max_bin during training.
|
||||||
|
auto cpu_ctx = ctx->MakeCPU();
|
||||||
|
gradient_index_.reset(new GHistIndexMatrix{&cpu_ctx, this, param.max_bin, param.sparse_thresh,
|
||||||
|
sorted_sketch, param.hess});
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_param_ = param.MakeCache();
|
||||||
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
||||||
}
|
}
|
||||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||||
@ -155,7 +199,7 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
|||||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(Context const*, BatchParam const&) {
|
||||||
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
|
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
|
||||||
CHECK(casted);
|
CHECK(casted);
|
||||||
auto begin_iter =
|
auto begin_iter =
|
||||||
@ -166,7 +210,8 @@ BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
|||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||||
DataSplitMode data_split_mode) {
|
DataSplitMode data_split_mode) {
|
||||||
this->ctx_.nthread = nthread;
|
Context ctx;
|
||||||
|
ctx.Init(Args{{"nthread", std::to_string(nthread)}});
|
||||||
|
|
||||||
std::vector<uint64_t> qids;
|
std::vector<uint64_t> qids;
|
||||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
@ -176,13 +221,13 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
|||||||
auto& data_vec = sparse_page_->data.HostVector();
|
auto& data_vec = sparse_page_->data.HostVector();
|
||||||
uint64_t inferred_num_columns = 0;
|
uint64_t inferred_num_columns = 0;
|
||||||
uint64_t total_batch_size = 0;
|
uint64_t total_batch_size = 0;
|
||||||
// batch_size is either number of rows or cols, depending on data layout
|
// batch_size is either number of rows or cols, depending on data layout
|
||||||
|
|
||||||
adapter->BeforeFirst();
|
adapter->BeforeFirst();
|
||||||
// Iterate over batches of input data
|
// Iterate over batches of input data
|
||||||
while (adapter->Next()) {
|
while (adapter->Next()) {
|
||||||
auto& batch = adapter->Value();
|
auto& batch = adapter->Value();
|
||||||
auto batch_max_columns = sparse_page_->Push(batch, missing, ctx_.Threads());
|
auto batch_max_columns = sparse_page_->Push(batch, missing, ctx.Threads());
|
||||||
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
||||||
total_batch_size += batch.Size();
|
total_batch_size += batch.Size();
|
||||||
// Append meta information if available
|
// Append meta information if available
|
||||||
@ -229,19 +274,18 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
|||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.data_split_mode = data_split_mode;
|
info_.data_split_mode = data_split_mode;
|
||||||
ReindexFeatures();
|
ReindexFeatures(&ctx);
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||||
using IteratorAdapterT
|
using IteratorAdapterT =
|
||||||
= IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
|
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
|
||||||
// If AdapterT is either IteratorAdapter or FileAdapter type, use the total batch size to
|
// If AdapterT is either IteratorAdapter or FileAdapter type, use the total batch size to
|
||||||
// determine the correct number of rows, as offset_vec may be too short
|
// determine the correct number of rows, as offset_vec may be too short
|
||||||
if (std::is_same<AdapterT, IteratorAdapterT>::value
|
if (std::is_same<AdapterT, IteratorAdapterT>::value ||
|
||||||
|| std::is_same<AdapterT, FileAdapter>::value) {
|
std::is_same<AdapterT, FileAdapter>::value) {
|
||||||
info_.num_row_ = total_batch_size;
|
info_.num_row_ = total_batch_size;
|
||||||
// Ensure offset_vec.size() - 1 == [number of rows]
|
// Ensure offset_vec.size() - 1 == [number of rows]
|
||||||
while (offset_vec.size() - 1 < total_batch_size) {
|
while (offset_vec.size() - 1 < total_batch_size) {
|
||||||
@ -265,9 +309,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
|||||||
info_.num_nonzero_ = data_vec.size();
|
info_.num_nonzero_ = data_vec.size();
|
||||||
|
|
||||||
// Sort the index for row partitioners used by variuos tree methods.
|
// Sort the index for row partitioners used by variuos tree methods.
|
||||||
if (!sparse_page_->IsIndicesSorted(this->ctx_.Threads())) {
|
if (!sparse_page_->IsIndicesSorted(ctx.Threads())) {
|
||||||
sparse_page_->SortIndices(this->ctx_.Threads());
|
sparse_page_->SortIndices(ctx.Threads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this->fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
||||||
@ -280,12 +326,12 @@ SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
|
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(tmagic);
|
fo->Write(tmagic);
|
||||||
info_.SaveBinary(fo.get());
|
info_.SaveBinary(fo.get());
|
||||||
fo->Write(sparse_page_->offset.HostVector());
|
fo->Write(sparse_page_->offset.HostVector());
|
||||||
fo->Write(sparse_page_->data.HostVector());
|
fo->Write(sparse_page_->data.HostVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread,
|
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread,
|
||||||
@ -305,14 +351,14 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
|||||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
|
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
|
||||||
DataSplitMode data_split_mode);
|
DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(
|
template SimpleDMatrix::SimpleDMatrix(
|
||||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||||
*adapter,
|
|
||||||
float missing, int nthread, DataSplitMode data_split_mode);
|
float missing, int nthread, DataSplitMode data_split_mode);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
||||||
DataSplitMode data_split_mode) {
|
DataSplitMode data_split_mode) {
|
||||||
ctx_.nthread = nthread;
|
Context ctx;
|
||||||
|
ctx.nthread = nthread;
|
||||||
|
|
||||||
auto& offset_vec = sparse_page_->offset.HostVector();
|
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||||
auto& data_vec = sparse_page_->data.HostVector();
|
auto& data_vec = sparse_page_->data.HostVector();
|
||||||
@ -326,7 +372,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
|||||||
size_t num_elements = 0;
|
size_t num_elements = 0;
|
||||||
size_t num_rows = 0;
|
size_t num_rows = 0;
|
||||||
// Import Arrow RecordBatches
|
// Import Arrow RecordBatches
|
||||||
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx_.Threads())
|
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx.Threads())
|
||||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||||
num_elements += batches[i]->Import(missing);
|
num_elements += batches[i]->Import(missing);
|
||||||
num_rows += batches[i]->Size();
|
num_rows += batches[i]->Size();
|
||||||
@ -348,7 +394,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
|||||||
data_vec.resize(total_elements);
|
data_vec.resize(total_elements);
|
||||||
offset_vec.resize(total_batch_size + 1);
|
offset_vec.resize(total_batch_size + 1);
|
||||||
// Copy data into DMatrix
|
// Copy data into DMatrix
|
||||||
#pragma omp parallel num_threads(ctx_.Threads())
|
#pragma omp parallel num_threads(ctx.Threads())
|
||||||
{
|
{
|
||||||
#pragma omp for nowait
|
#pragma omp for nowait
|
||||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||||
@ -372,12 +418,14 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
|||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.data_split_mode = data_split_mode;
|
info_.data_split_mode = data_split_mode;
|
||||||
ReindexFeatures();
|
ReindexFeatures(&ctx);
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
info_.num_row_ = total_batch_size;
|
info_.num_row_ = total_batch_size;
|
||||||
info_.num_nonzero_ = data_vec.size();
|
info_.num_nonzero_ = data_vec.size();
|
||||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||||
|
|
||||||
|
fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2021 by XGBoost Contributors
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
* \file simple_dmatrix.cu
|
* \file simple_dmatrix.cu
|
||||||
*/
|
*/
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <xgboost/data.h>
|
|
||||||
|
#include "device_adapter.cuh" // for CurrentDevice
|
||||||
#include "simple_dmatrix.cuh"
|
#include "simple_dmatrix.cuh"
|
||||||
#include "simple_dmatrix.h"
|
#include "simple_dmatrix.h"
|
||||||
#include "device_adapter.cuh"
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -15,7 +17,7 @@ namespace data {
|
|||||||
// Current implementation assumes a single batch. More batches can
|
// Current implementation assumes a single batch. More batches can
|
||||||
// be supported in future. Does not currently support inferring row/column size
|
// be supported in future. Does not currently support inferring row/column size
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/,
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthread,
|
||||||
DataSplitMode data_split_mode) {
|
DataSplitMode data_split_mode) {
|
||||||
CHECK(data_split_mode != DataSplitMode::kCol)
|
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||||
<< "Column-wise data split is currently not supported on the GPU.";
|
<< "Column-wise data split is currently not supported on the GPU.";
|
||||||
@ -24,6 +26,9 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
|||||||
CHECK_GE(device, 0);
|
CHECK_GE(device, 0);
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(device)}});
|
||||||
|
|
||||||
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
||||||
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
||||||
|
|
||||||
@ -33,13 +38,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
|||||||
// Enforce single batch
|
// Enforce single batch
|
||||||
CHECK(!adapter->Next());
|
CHECK(!adapter->Next());
|
||||||
|
|
||||||
info_.num_nonzero_ =
|
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
|
||||||
CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
|
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.data_split_mode = data_split_mode;
|
info_.data_split_mode = data_split_mode;
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
|
this->fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
|
|
||||||
MetaInfo& Info() override;
|
MetaInfo& Info() override;
|
||||||
const MetaInfo& Info() const override;
|
const MetaInfo& Info() const override;
|
||||||
Context const* Ctx() const override { return &ctx_; }
|
Context const* Ctx() const override { return &fmat_ctx_; }
|
||||||
|
|
||||||
bool SingleColBlock() const override { return true; }
|
bool SingleColBlock() const override { return true; }
|
||||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||||
@ -43,11 +43,11 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
BatchSet<CSCPage> GetColumnBatches() override;
|
BatchSet<CSCPage> GetColumnBatches(Context const* ctx) override;
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) override;
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, const BatchParam& param) override;
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx, const BatchParam& param) override;
|
||||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
|
BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) override;
|
||||||
|
|
||||||
MetaInfo info_;
|
MetaInfo info_;
|
||||||
// Primary storage type
|
// Primary storage type
|
||||||
@ -69,10 +69,11 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
|
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
|
||||||
* reindex the features based on the offset needed to obtain the global view.
|
* reindex the features based on the offset needed to obtain the global view.
|
||||||
*/
|
*/
|
||||||
void ReindexFeatures();
|
void ReindexFeatures(Context const* ctx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Context ctx_;
|
// Context used only for DMatrix initialization.
|
||||||
|
Context fmat_ctx_;
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file sparse_page_dmatrix.cc
|
* \file sparse_page_dmatrix.cc
|
||||||
|
*
|
||||||
* \brief The external memory version of Page Iterator.
|
* \brief The external memory version of Page Iterator.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
@ -8,11 +9,10 @@
|
|||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "./simple_batch_iterator.h"
|
#include "./simple_batch_iterator.h"
|
||||||
|
#include "batch_utils.h" // for RegenGHist
|
||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
|
|
||||||
MetaInfo &SparsePageDMatrix::Info() { return info_; }
|
MetaInfo &SparsePageDMatrix::Info() { return info_; }
|
||||||
|
|
||||||
const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
|
const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
|
||||||
@ -46,7 +46,9 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
int32_t nthreads, std::string cache_prefix)
|
int32_t nthreads, std::string cache_prefix)
|
||||||
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
|
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
|
||||||
cache_prefix_{std::move(cache_prefix)} {
|
cache_prefix_{std::move(cache_prefix)} {
|
||||||
ctx_.nthread = nthreads;
|
Context ctx;
|
||||||
|
ctx.nthread = nthreads;
|
||||||
|
|
||||||
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
||||||
if (collective::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
|
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
|
||||||
@ -81,7 +83,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
|
|
||||||
// the proxy is iterated together with the sparse page source so we can obtain all
|
// the proxy is iterated together with the sparse page source so we can obtain all
|
||||||
// information in 1 pass.
|
// information in 1 pass.
|
||||||
for (auto const &page : this->GetRowBatchesImpl()) {
|
for (auto const &page : this->GetRowBatchesImpl(&ctx)) {
|
||||||
this->info_.Extend(std::move(proxy->Info()), false, false);
|
this->info_.Extend(std::move(proxy->Info()), false, false);
|
||||||
n_features = std::max(n_features, num_cols());
|
n_features = std::max(n_features, num_cols());
|
||||||
n_samples += num_rows();
|
n_samples += num_rows();
|
||||||
@ -98,9 +100,11 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
|
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns();
|
||||||
CHECK_NE(info_.num_col_, 0);
|
CHECK_NE(info_.num_col_, 0);
|
||||||
|
|
||||||
|
fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SparsePageDMatrix::InitializeSparsePage() {
|
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
|
||||||
// Don't use proxy DMatrix once this is already initialized, this allows users to
|
// Don't use proxy DMatrix once this is already initialized, this allows users to
|
||||||
// release the iterator and data.
|
// release the iterator and data.
|
||||||
@ -110,33 +114,33 @@ void SparsePageDMatrix::InitializeSparsePage() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_, reset_, next_};
|
||||||
iter_, reset_, next_};
|
|
||||||
DMatrixProxy *proxy = MakeProxy(proxy_);
|
DMatrixProxy *proxy = MakeProxy(proxy_);
|
||||||
sparse_page_source_.reset(); // clear before creating new one to prevent conflicts.
|
sparse_page_source_.reset(); // clear before creating new one to prevent conflicts.
|
||||||
sparse_page_source_ = std::make_shared<SparsePageSource>(
|
sparse_page_source_ = std::make_shared<SparsePageSource>(iter, proxy, this->missing_,
|
||||||
iter, proxy, this->missing_, this->ctx_.Threads(), this->info_.num_col_,
|
ctx->Threads(), this->info_.num_col_,
|
||||||
this->n_batches_, cache_info_.at(id));
|
this->n_batches_, cache_info_.at(id));
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl() {
|
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl(Context const* ctx) {
|
||||||
this->InitializeSparsePage();
|
this->InitializeSparsePage(ctx);
|
||||||
auto begin_iter = BatchIterator<SparsePage>(sparse_page_source_);
|
auto begin_iter = BatchIterator<SparsePage>(sparse_page_source_);
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(begin_iter));
|
return BatchSet<SparsePage>(BatchIterator<SparsePage>(begin_iter));
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||||
return this->GetRowBatchesImpl();
|
// Use context from initialization for the default row page.
|
||||||
|
return this->GetRowBatchesImpl(&fmat_ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
|
||||||
CHECK_NE(this->Info().num_col_, 0);
|
CHECK_NE(this->Info().num_col_, 0);
|
||||||
this->InitializeSparsePage();
|
this->InitializeSparsePage(ctx);
|
||||||
if (!column_source_) {
|
if (!column_source_) {
|
||||||
column_source_ = std::make_shared<CSCPageSource>(
|
column_source_ =
|
||||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
std::make_shared<CSCPageSource>(this->missing_, ctx->Threads(), this->Info().num_col_,
|
||||||
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
||||||
} else {
|
} else {
|
||||||
column_source_->Reset();
|
column_source_->Reset();
|
||||||
}
|
}
|
||||||
@ -144,14 +148,14 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
|||||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(begin_iter));
|
return BatchSet<CSCPage>(BatchIterator<CSCPage>(begin_iter));
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
|
||||||
CHECK_NE(this->Info().num_col_, 0);
|
CHECK_NE(this->Info().num_col_, 0);
|
||||||
this->InitializeSparsePage();
|
this->InitializeSparsePage(ctx);
|
||||||
if (!sorted_column_source_) {
|
if (!sorted_column_source_) {
|
||||||
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
||||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||||
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
sparse_page_source_);
|
||||||
} else {
|
} else {
|
||||||
sorted_column_source_->Reset();
|
sorted_column_source_->Reset();
|
||||||
}
|
}
|
||||||
@ -159,27 +163,27 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
|||||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
|
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) {
|
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
|
||||||
|
const BatchParam ¶m) {
|
||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||||
this->InitializeSparsePage();
|
this->InitializeSparsePage(ctx);
|
||||||
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||||
LOG(INFO) << "Generating new Gradient Index.";
|
LOG(INFO) << "Generating new Gradient Index.";
|
||||||
// Use sorted sketch for approx.
|
// Use sorted sketch for approx.
|
||||||
auto sorted_sketch = param.regen;
|
auto sorted_sketch = param.regen;
|
||||||
auto cuts =
|
auto cuts = common::SketchOnDMatrix(ctx, this, param.max_bin, sorted_sketch, param.hess);
|
||||||
common::SketchOnDMatrix(this, param.max_bin, ctx_.Threads(), sorted_sketch, param.hess);
|
this->InitializeSparsePage(ctx); // reset after use.
|
||||||
this->InitializeSparsePage(); // reset after use.
|
|
||||||
|
|
||||||
batch_param_ = param;
|
batch_param_ = param;
|
||||||
ghist_index_source_.reset();
|
ghist_index_source_.reset();
|
||||||
CHECK_NE(cuts.Values().size(), 0);
|
CHECK_NE(cuts.Values().size(), 0);
|
||||||
auto ft = this->info_.feature_types.ConstHostSpan();
|
auto ft = this->info_.feature_types.ConstHostSpan();
|
||||||
ghist_index_source_.reset(new GradientIndexPageSource(
|
ghist_index_source_.reset(new GradientIndexPageSource(
|
||||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_,
|
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||||
cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
|
param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
|
||||||
} else {
|
} else {
|
||||||
CHECK(ghist_index_source_);
|
CHECK(ghist_index_source_);
|
||||||
ghist_index_source_->Reset();
|
ghist_index_source_->Reset();
|
||||||
@ -189,11 +193,10 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam &) {
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,42 +1,40 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "sparse_page_source.h"
|
|
||||||
#include "../common/hist_util.cuh"
|
#include "../common/hist_util.cuh"
|
||||||
|
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
#include "sparse_page_dmatrix.h"
|
#include "sparse_page_dmatrix.h"
|
||||||
|
#include "sparse_page_source.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
const BatchParam& param) {
|
||||||
CHECK_GE(param.gpu_id, 0);
|
CHECK(ctx->IsCUDA());
|
||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
if (!(batch_param_ != BatchParam{})) {
|
detail::CheckEmpty(batch_param_, param);
|
||||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
|
||||||
}
|
|
||||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||||
size_t row_stride = 0;
|
size_t row_stride = 0;
|
||||||
this->InitializeSparsePage();
|
this->InitializeSparsePage(ctx);
|
||||||
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
// reinitialize the cache
|
// reinitialize the cache
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||||
std::unique_ptr<common::HistogramCuts> cuts;
|
std::unique_ptr<common::HistogramCuts> cuts;
|
||||||
cuts.reset(new common::HistogramCuts{
|
cuts.reset(
|
||||||
common::DeviceSketch(param.gpu_id, this, param.max_bin, 0)});
|
new common::HistogramCuts{common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)});
|
||||||
this->InitializeSparsePage(); // reset after use.
|
this->InitializeSparsePage(ctx); // reset after use.
|
||||||
|
|
||||||
row_stride = GetRowStride(this);
|
row_stride = GetRowStride(this);
|
||||||
this->InitializeSparsePage(); // reset after use.
|
this->InitializeSparsePage(ctx); // reset after use.
|
||||||
CHECK_NE(row_stride, 0);
|
CHECK_NE(row_stride, 0);
|
||||||
batch_param_ = param;
|
batch_param_ = param;
|
||||||
|
|
||||||
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
||||||
ellpack_page_source_.reset(); // release resources.
|
ellpack_page_source_.reset(); // release resources.
|
||||||
ellpack_page_source_.reset(new EllpackPageSource(
|
ellpack_page_source_.reset(new EllpackPageSource(
|
||||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id));
|
||||||
this->IsDense(), row_stride, ft, sparse_page_source_));
|
|
||||||
} else {
|
} else {
|
||||||
CHECK(sparse_page_source_);
|
CHECK(sparse_page_source_);
|
||||||
ellpack_page_source_->Reset();
|
ellpack_page_source_->Reset();
|
||||||
@ -45,5 +43,4 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
|
|||||||
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2021 by Contributors
|
* Copyright 2015-2023, XGBoost Contributors
|
||||||
* \file sparse_page_dmatrix.h
|
* \file sparse_page_dmatrix.h
|
||||||
* \brief External-memory version of DMatrix.
|
* \brief External-memory version of DMatrix.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -9,12 +9,13 @@
|
|||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h"
|
||||||
#include "gradient_index_page_source.h"
|
#include "gradient_index_page_source.h"
|
||||||
@ -69,19 +70,18 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
XGDMatrixCallbackNext *next_;
|
XGDMatrixCallbackNext *next_;
|
||||||
|
|
||||||
float missing_;
|
float missing_;
|
||||||
Context ctx_;
|
Context fmat_ctx_;
|
||||||
std::string cache_prefix_;
|
std::string cache_prefix_;
|
||||||
uint32_t n_batches_ {0};
|
uint32_t n_batches_{0};
|
||||||
// sparse page is the source to other page types, we make a special member function.
|
// sparse page is the source to other page types, we make a special member function.
|
||||||
void InitializeSparsePage();
|
void InitializeSparsePage(Context const *ctx);
|
||||||
// Non-virtual version that can be used in constructor
|
// Non-virtual version that can be used in constructor
|
||||||
BatchSet<SparsePage> GetRowBatchesImpl();
|
BatchSet<SparsePage> GetRowBatchesImpl(Context const *ctx);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy,
|
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||||
DataIterResetCallback *reset,
|
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
std::string cache_prefix);
|
||||||
int32_t nthreads, std::string cache_prefix);
|
|
||||||
|
|
||||||
~SparsePageDMatrix() override {
|
~SparsePageDMatrix() override {
|
||||||
// Clear out all resources before deleting the cache file.
|
// Clear out all resources before deleting the cache file.
|
||||||
@ -98,9 +98,9 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MetaInfo& Info() override;
|
MetaInfo &Info() override;
|
||||||
const MetaInfo& Info() const override;
|
const MetaInfo &Info() const override;
|
||||||
Context const* Ctx() const override { return &ctx_; }
|
Context const *Ctx() const override { return &fmat_ctx_; }
|
||||||
|
|
||||||
bool SingleColBlock() const override { return false; }
|
bool SingleColBlock() const override { return false; }
|
||||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
DMatrix *Slice(common::Span<int32_t const>) override {
|
||||||
@ -114,11 +114,11 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
BatchSet<CSCPage> GetColumnBatches() override;
|
BatchSet<CSCPage> GetColumnBatches(Context const *ctx) override;
|
||||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const *ctx) override;
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override;
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, const BatchParam &) override;
|
||||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const &) override {
|
BatchSet<ExtSparsePage> GetExtBatches(Context const *, BatchParam const &) override {
|
||||||
LOG(FATAL) << "Can not obtain a single CSR page for external memory DMatrix";
|
LOG(FATAL) << "Can not obtain a single CSR page for external memory DMatrix";
|
||||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||||
}
|
}
|
||||||
@ -141,9 +141,8 @@ inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
|
|||||||
return prefix + "-" + ss.str();
|
return prefix + "-" + ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string
|
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
||||||
MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
||||||
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
|
||||||
auto &cache_info = *out;
|
auto &cache_info = *out;
|
||||||
auto name = MakeId(prefix, ptr);
|
auto name = MakeId(prefix, ptr);
|
||||||
auto id = name + format;
|
auto id = name + format;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018 by Contributors
|
* Copyright 2018-2023 by XGBoost Contributors
|
||||||
* \author Rory Mitchell
|
* \author Rory Mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
@ -78,11 +78,12 @@ inline double CoordinateDeltaBias(double sum_grad, double sum_hess) {
|
|||||||
*
|
*
|
||||||
* \return The gradient and diagonal Hessian entry for a given feature.
|
* \return The gradient and diagonal Hessian entry for a given feature.
|
||||||
*/
|
*/
|
||||||
inline std::pair<double, double> GetGradient(int group_idx, int num_group, int fidx,
|
inline std::pair<double, double> GetGradient(Context const *ctx, int group_idx, int num_group,
|
||||||
const std::vector<GradientPair> &gpair,
|
bst_feature_t fidx,
|
||||||
|
std::vector<GradientPair> const &gpair,
|
||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto col = page[fidx];
|
auto col = page[fidx];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
@ -115,7 +116,7 @@ inline std::pair<double, double> GetGradientParallel(Context const *ctx, int gro
|
|||||||
std::vector<double> sum_grad_tloc(ctx->Threads(), 0.0);
|
std::vector<double> sum_grad_tloc(ctx->Threads(), 0.0);
|
||||||
std::vector<double> sum_hess_tloc(ctx->Threads(), 0.0);
|
std::vector<double> sum_hess_tloc(ctx->Threads(), 0.0);
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto col = page[fidx];
|
auto col = page[fidx];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
@ -177,16 +178,16 @@ inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_
|
|||||||
* \param in_gpair The gradient vector to be updated.
|
* \param in_gpair The gradient vector to be updated.
|
||||||
* \param p_fmat The input feature matrix.
|
* \param p_fmat The input feature matrix.
|
||||||
*/
|
*/
|
||||||
inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
inline void UpdateResidualParallel(Context const *ctx, bst_feature_t fidx, int group_idx,
|
||||||
float dw, std::vector<GradientPair> *in_gpair,
|
int num_group, float dw, std::vector<GradientPair> *in_gpair,
|
||||||
DMatrix *p_fmat, int32_t n_threads) {
|
DMatrix *p_fmat) {
|
||||||
if (dw == 0.0f) return;
|
if (dw == 0.0f) return;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto col = page[fidx];
|
auto col = page[fidx];
|
||||||
// update grad value
|
// update grad value
|
||||||
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
||||||
common::ParallelFor(num_row, n_threads, [&](auto j) {
|
common::ParallelFor(num_row, ctx->Threads(), [&](auto j) {
|
||||||
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
||||||
if (p.GetHess() < 0.0f) return;
|
if (p.GetHess() < 0.0f) return;
|
||||||
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
||||||
@ -203,12 +204,12 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
|||||||
* \param in_gpair The gradient vector to be updated.
|
* \param in_gpair The gradient vector to be updated.
|
||||||
* \param p_fmat The input feature matrix.
|
* \param p_fmat The input feature matrix.
|
||||||
*/
|
*/
|
||||||
inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias,
|
inline void UpdateBiasResidualParallel(Context const *ctx, int group_idx, int num_group,
|
||||||
std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
float dbias, std::vector<GradientPair> *in_gpair,
|
||||||
int32_t n_threads) {
|
DMatrix *p_fmat) {
|
||||||
if (dbias == 0.0f) return;
|
if (dbias == 0.0f) return;
|
||||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||||
common::ParallelFor(ndata, n_threads, [&](auto i) {
|
common::ParallelFor(ndata, ctx->Threads(), [&](auto i) {
|
||||||
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
||||||
if (g.GetHess() < 0.0f) return;
|
if (g.GetHess() < 0.0f) return;
|
||||||
g += GradientPair(g.GetHess() * dbias, 0);
|
g += GradientPair(g.GetHess() * dbias, 0);
|
||||||
@ -220,18 +221,16 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias
|
|||||||
* in coordinate descent algorithms.
|
* in coordinate descent algorithms.
|
||||||
*/
|
*/
|
||||||
class FeatureSelector {
|
class FeatureSelector {
|
||||||
protected:
|
|
||||||
int32_t n_threads_{-1};
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit FeatureSelector(int32_t n_threads) : n_threads_{n_threads} {}
|
FeatureSelector() = default;
|
||||||
/*! \brief factory method */
|
/*! \brief factory method */
|
||||||
static FeatureSelector *Create(int choice, int32_t n_threads);
|
static FeatureSelector *Create(int choice);
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~FeatureSelector() = default;
|
virtual ~FeatureSelector() = default;
|
||||||
/**
|
/**
|
||||||
* \brief Setting up the selector state prior to looping through features.
|
* \brief Setting up the selector state prior to looping through features.
|
||||||
*
|
*
|
||||||
|
* \param ctx The booster context.
|
||||||
* \param model The model.
|
* \param model The model.
|
||||||
* \param gpair The gpair.
|
* \param gpair The gpair.
|
||||||
* \param p_fmat The feature matrix.
|
* \param p_fmat The feature matrix.
|
||||||
@ -239,13 +238,12 @@ class FeatureSelector {
|
|||||||
* \param lambda Regularisation lambda.
|
* \param lambda Regularisation lambda.
|
||||||
* \param param A parameter with algorithm-dependent use.
|
* \param param A parameter with algorithm-dependent use.
|
||||||
*/
|
*/
|
||||||
virtual void Setup(const gbm::GBLinearModel &,
|
virtual void Setup(Context const *, const gbm::GBLinearModel &,
|
||||||
const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &, DMatrix *, float, float, int) {}
|
||||||
DMatrix *,
|
|
||||||
float , float , int ) {}
|
|
||||||
/**
|
/**
|
||||||
* \brief Select next coordinate to update.
|
* \brief Select next coordinate to update.
|
||||||
*
|
*
|
||||||
|
* \param ctx Booster context
|
||||||
* \param iteration The iteration in a loop through features
|
* \param iteration The iteration in a loop through features
|
||||||
* \param model The model.
|
* \param model The model.
|
||||||
* \param group_idx Zero-based index of the group.
|
* \param group_idx Zero-based index of the group.
|
||||||
@ -256,11 +254,9 @@ class FeatureSelector {
|
|||||||
*
|
*
|
||||||
* \return The index of the selected feature. -1 indicates none selected.
|
* \return The index of the selected feature. -1 indicates none selected.
|
||||||
*/
|
*/
|
||||||
virtual int NextFeature(int iteration,
|
virtual int NextFeature(Context const *ctx, int iteration, const gbm::GBLinearModel &model,
|
||||||
const gbm::GBLinearModel &model,
|
int group_idx, const std::vector<GradientPair> &gpair, DMatrix *p_fmat,
|
||||||
int group_idx,
|
float alpha, float lambda) = 0;
|
||||||
const std::vector<GradientPair> &gpair,
|
|
||||||
DMatrix *p_fmat, float alpha, float lambda) = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -269,9 +265,8 @@ class FeatureSelector {
|
|||||||
class CyclicFeatureSelector : public FeatureSelector {
|
class CyclicFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
using FeatureSelector::FeatureSelector;
|
using FeatureSelector::FeatureSelector;
|
||||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
int NextFeature(Context const *, int iteration, const gbm::GBLinearModel &model, int,
|
||||||
int , const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &, DMatrix *, float, float) override {
|
||||||
DMatrix *, float, float) override {
|
|
||||||
return iteration % model.learner_model_param->num_feature;
|
return iteration % model.learner_model_param->num_feature;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -283,8 +278,7 @@ class CyclicFeatureSelector : public FeatureSelector {
|
|||||||
class ShuffleFeatureSelector : public FeatureSelector {
|
class ShuffleFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
using FeatureSelector::FeatureSelector;
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
void Setup(Context const *, const gbm::GBLinearModel &model, const std::vector<GradientPair> &,
|
||||||
const std::vector<GradientPair>&,
|
|
||||||
DMatrix *, float, float, int) override {
|
DMatrix *, float, float, int) override {
|
||||||
if (feat_index_.size() == 0) {
|
if (feat_index_.size() == 0) {
|
||||||
feat_index_.resize(model.learner_model_param->num_feature);
|
feat_index_.resize(model.learner_model_param->num_feature);
|
||||||
@ -293,9 +287,8 @@ class ShuffleFeatureSelector : public FeatureSelector {
|
|||||||
std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom());
|
std::shuffle(feat_index_.begin(), feat_index_.end(), common::GlobalRandom());
|
||||||
}
|
}
|
||||||
|
|
||||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
int NextFeature(Context const *, int iteration, const gbm::GBLinearModel &model, int,
|
||||||
int, const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &, DMatrix *, float, float) override {
|
||||||
DMatrix *, float, float) override {
|
|
||||||
return feat_index_[iteration % model.learner_model_param->num_feature];
|
return feat_index_[iteration % model.learner_model_param->num_feature];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,9 +303,8 @@ class ShuffleFeatureSelector : public FeatureSelector {
|
|||||||
class RandomFeatureSelector : public FeatureSelector {
|
class RandomFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
using FeatureSelector::FeatureSelector;
|
using FeatureSelector::FeatureSelector;
|
||||||
int NextFeature(int, const gbm::GBLinearModel &model,
|
int NextFeature(Context const *, int, const gbm::GBLinearModel &model, int,
|
||||||
int, const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &, DMatrix *, float, float) override {
|
||||||
DMatrix *, float, float) override {
|
|
||||||
return common::GlobalRandom()() % model.learner_model_param->num_feature;
|
return common::GlobalRandom()() % model.learner_model_param->num_feature;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -329,8 +321,7 @@ class RandomFeatureSelector : public FeatureSelector {
|
|||||||
class GreedyFeatureSelector : public FeatureSelector {
|
class GreedyFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
using FeatureSelector::FeatureSelector;
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
void Setup(Context const *, const gbm::GBLinearModel &model, const std::vector<GradientPair> &,
|
||||||
const std::vector<GradientPair> &,
|
|
||||||
DMatrix *, float, float, int param) override {
|
DMatrix *, float, float, int param) override {
|
||||||
top_k_ = static_cast<bst_uint>(param);
|
top_k_ = static_cast<bst_uint>(param);
|
||||||
const bst_uint ngroup = model.learner_model_param->num_output_group;
|
const bst_uint ngroup = model.learner_model_param->num_output_group;
|
||||||
@ -344,7 +335,7 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int NextFeature(int, const gbm::GBLinearModel &model,
|
int NextFeature(Context const* ctx, int, const gbm::GBLinearModel &model,
|
||||||
int group_idx, const std::vector<GradientPair> &gpair,
|
int group_idx, const std::vector<GradientPair> &gpair,
|
||||||
DMatrix *p_fmat, float alpha, float lambda) override {
|
DMatrix *p_fmat, float alpha, float lambda) override {
|
||||||
// k-th selected feature for a group
|
// k-th selected feature for a group
|
||||||
@ -356,9 +347,9 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
const bst_omp_uint nfeat = model.learner_model_param->num_feature;
|
const bst_omp_uint nfeat = model.learner_model_param->num_feature;
|
||||||
// Calculate univariate gradient sums
|
// Calculate univariate gradient sums
|
||||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
common::ParallelFor(nfeat, this->n_threads_, [&](bst_omp_uint i) {
|
common::ParallelFor(nfeat, ctx->Threads(), [&](bst_omp_uint i) {
|
||||||
const auto col = page[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
||||||
@ -406,9 +397,10 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
class ThriftyFeatureSelector : public FeatureSelector {
|
class ThriftyFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
using FeatureSelector::FeatureSelector;
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
|
||||||
const std::vector<GradientPair> &gpair,
|
void Setup(Context const *ctx, const gbm::GBLinearModel &model,
|
||||||
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
const std::vector<GradientPair> &gpair, DMatrix *p_fmat, float alpha, float lambda,
|
||||||
|
int param) override {
|
||||||
top_k_ = static_cast<bst_uint>(param);
|
top_k_ = static_cast<bst_uint>(param);
|
||||||
if (param <= 0) top_k_ = std::numeric_limits<bst_uint>::max();
|
if (param <= 0) top_k_ = std::numeric_limits<bst_uint>::max();
|
||||||
const bst_uint ngroup = model.learner_model_param->num_output_group;
|
const bst_uint ngroup = model.learner_model_param->num_output_group;
|
||||||
@ -422,10 +414,10 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
}
|
}
|
||||||
// Calculate univariate gradient sums
|
// Calculate univariate gradient sums
|
||||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
// column-parallel is usually fastaer than row-parallel
|
// column-parallel is usually fastaer than row-parallel
|
||||||
common::ParallelFor(nfeat, this->n_threads_, [&](auto i) {
|
common::ParallelFor(nfeat, ctx->Threads(), [&](auto i) {
|
||||||
const auto col = page[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||||
@ -462,9 +454,8 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int NextFeature(int, const gbm::GBLinearModel &model,
|
int NextFeature(Context const *, int, const gbm::GBLinearModel &model, int group_idx,
|
||||||
int group_idx, const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &, DMatrix *, float, float) override {
|
||||||
DMatrix *, float, float) override {
|
|
||||||
// k-th selected feature for a group
|
// k-th selected feature for a group
|
||||||
auto k = counter_[group_idx]++;
|
auto k = counter_[group_idx]++;
|
||||||
// stop after either reaching top-N or going through all the features in a group
|
// stop after either reaching top-N or going through all the features in a group
|
||||||
@ -482,18 +473,18 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
std::vector<std::pair<double, double>> gpair_sums_;
|
std::vector<std::pair<double, double>> gpair_sums_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline FeatureSelector *FeatureSelector::Create(int choice, int32_t n_threads) {
|
inline FeatureSelector *FeatureSelector::Create(int choice) {
|
||||||
switch (choice) {
|
switch (choice) {
|
||||||
case kCyclic:
|
case kCyclic:
|
||||||
return new CyclicFeatureSelector(n_threads);
|
return new CyclicFeatureSelector;
|
||||||
case kShuffle:
|
case kShuffle:
|
||||||
return new ShuffleFeatureSelector(n_threads);
|
return new ShuffleFeatureSelector;
|
||||||
case kThrifty:
|
case kThrifty:
|
||||||
return new ThriftyFeatureSelector(n_threads);
|
return new ThriftyFeatureSelector;
|
||||||
case kGreedy:
|
case kGreedy:
|
||||||
return new GreedyFeatureSelector(n_threads);
|
return new GreedyFeatureSelector;
|
||||||
case kRandom:
|
case kRandom:
|
||||||
return new RandomFeatureSelector(n_threads);
|
return new RandomFeatureSelector;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unknown coordinate selector: " << choice;
|
LOG(FATAL) << "unknown coordinate selector: " << choice;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018 by Contributors
|
* Copyright 2018-2023 by XGBoost Contributors
|
||||||
* \author Rory Mitchell
|
* \author Rory Mitchell
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
tparam_.UpdateAllowUnknown(args)
|
tparam_.UpdateAllowUnknown(args)
|
||||||
};
|
};
|
||||||
cparam_.UpdateAllowUnknown(rest);
|
cparam_.UpdateAllowUnknown(rest);
|
||||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads()));
|
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||||
monitor_.Init("CoordinateUpdater");
|
monitor_.Init("CoordinateUpdater");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,19 +56,17 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
auto dbias = static_cast<float>(tparam_.learning_rate *
|
auto dbias = static_cast<float>(tparam_.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->Bias()[group_idx] += dbias;
|
model->Bias()[group_idx] += dbias;
|
||||||
UpdateBiasResidualParallel(group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat,
|
UpdateBiasResidualParallel(ctx_, group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat);
|
||||||
ctx_->Threads());
|
|
||||||
}
|
}
|
||||||
// prepare for updating the weights
|
// prepare for updating the weights
|
||||||
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm,
|
||||||
tparam_.reg_alpha_denorm,
|
tparam_.reg_lambda_denorm, cparam_.top_k);
|
||||||
tparam_.reg_lambda_denorm, cparam_.top_k);
|
|
||||||
// update weights
|
// update weights
|
||||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||||
for (unsigned i = 0U; i < model->learner_model_param->num_feature; i++) {
|
for (unsigned i = 0U; i < model->learner_model_param->num_feature; i++) {
|
||||||
int fidx = selector_->NextFeature
|
int fidx =
|
||||||
(i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
||||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
||||||
if (fidx < 0) break;
|
if (fidx < 0) break;
|
||||||
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model);
|
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model);
|
||||||
}
|
}
|
||||||
@ -76,8 +74,8 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
monitor_.Stop("UpdateFeature");
|
monitor_.Stop("UpdateFeature");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void UpdateFeature(int fidx, int group_idx, std::vector<GradientPair> *in_gpair,
|
void UpdateFeature(int fidx, int group_idx, std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
gbm::GBLinearModel *model) {
|
||||||
const int ngroup = model->learner_model_param->num_output_group;
|
const int ngroup = model->learner_model_param->num_output_group;
|
||||||
bst_float &w = (*model)[fidx][group_idx];
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
auto gradient = GetGradientParallel(ctx_, group_idx, ngroup, fidx,
|
auto gradient = GetGradientParallel(ctx_, group_idx, ngroup, fidx,
|
||||||
@ -87,8 +85,7 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
||||||
tparam_.reg_lambda_denorm));
|
tparam_.reg_lambda_denorm));
|
||||||
w += dw;
|
w += dw;
|
||||||
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat,
|
UpdateResidualParallel(ctx_, fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
||||||
ctx_->Threads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
void Configure(Args const& args) override {
|
void Configure(Args const& args) override {
|
||||||
tparam_.UpdateAllowUnknown(args);
|
tparam_.UpdateAllowUnknown(args);
|
||||||
coord_param_.UpdateAllowUnknown(args);
|
coord_param_.UpdateAllowUnknown(args);
|
||||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads()));
|
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||||
monitor_.Init("GPUCoordinateUpdater");
|
monitor_.Init("GPUCoordinateUpdater");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
num_row_ = static_cast<size_t>(p_fmat->Info().num_row_);
|
num_row_ = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||||
|
|
||||||
CHECK(p_fmat->SingleColBlock());
|
CHECK(p_fmat->SingleColBlock());
|
||||||
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
SparsePage const &batch = *(p_fmat->GetBatches<CSCPage>(ctx_).begin());
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
|
|
||||||
if (IsEmpty()) {
|
if (IsEmpty()) {
|
||||||
@ -112,16 +112,15 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
this->UpdateBias(model);
|
this->UpdateBias(model);
|
||||||
monitor_.Stop("UpdateBias");
|
monitor_.Stop("UpdateBias");
|
||||||
// prepare for updating the weights
|
// prepare for updating the weights
|
||||||
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm,
|
||||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm,
|
tparam_.reg_lambda_denorm, coord_param_.top_k);
|
||||||
coord_param_.top_k);
|
|
||||||
monitor_.Start("UpdateFeature");
|
monitor_.Start("UpdateFeature");
|
||||||
for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group;
|
for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group;
|
||||||
++group_idx) {
|
++group_idx) {
|
||||||
for (auto i = 0U; i < model->learner_model_param->num_feature; i++) {
|
for (auto i = 0U; i < model->learner_model_param->num_feature; i++) {
|
||||||
auto fidx = selector_->NextFeature(
|
auto fidx =
|
||||||
i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
||||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
||||||
if (fidx < 0) break;
|
if (fidx < 0) break;
|
||||||
this->UpdateFeature(fidx, group_idx, model);
|
this->UpdateFeature(fidx, group_idx, model);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018 by Contributors
|
* Copyright 2018-2023 by XGBoost Contributors
|
||||||
* \author Tianqi Chen, Rory Mitchell
|
* \author Tianqi Chen, Rory Mitchell
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n"
|
LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n"
|
||||||
<< "Supported options are: {cyclic, shuffle}";
|
<< "Supported options are: {cyclic, shuffle}";
|
||||||
}
|
}
|
||||||
selector_.reset(FeatureSelector::Create(param_.feature_selector, ctx_->Threads()));
|
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
||||||
}
|
}
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
@ -45,18 +45,17 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->Bias()[gid] += dbias;
|
model->Bias()[gid] += dbias;
|
||||||
UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat,
|
UpdateBiasResidualParallel(ctx_, gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat);
|
||||||
ctx_->Threads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock-free parallel updates of weights
|
// lock-free parallel updates of weights
|
||||||
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm,
|
||||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
|
param_.reg_lambda_denorm, 0);
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx_)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
||||||
common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) {
|
common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) {
|
||||||
int ii = selector_->NextFeature(i, *model, 0, in_gpair->ConstHostVector(), p_fmat,
|
int ii = selector_->NextFeature(ctx_, i, *model, 0, in_gpair->ConstHostVector(), p_fmat,
|
||||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
||||||
if (ii < 0) return;
|
if (ii < 0) return;
|
||||||
const bst_uint fid = ii;
|
const bst_uint fid = ii;
|
||||||
|
|||||||
@ -634,7 +634,7 @@ class CPUPredictor : public Predictor {
|
|||||||
if (!p_fmat->PageExists<SparsePage>()) {
|
if (!p_fmat->PageExists<SparsePage>()) {
|
||||||
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostVector();
|
auto ft = p_fmat->Info().feature_types.ConstHostVector();
|
||||||
for (auto const &batch : p_fmat->GetBatches<GHistIndexMatrix>({})) {
|
for (auto const &batch : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, {})) {
|
||||||
if (blocked) {
|
if (blocked) {
|
||||||
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, kBlockOfRowsSize>(
|
PredictBatchByBlockOfRowsKernel<GHistIndexMatrixView, kBlockOfRowsSize>(
|
||||||
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model,
|
GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model,
|
||||||
|
|||||||
@ -706,7 +706,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t batch_offset = 0;
|
size_t batch_offset = 0;
|
||||||
for (auto const& page : dmat->GetBatches<EllpackPage>(BatchParam{})) {
|
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||||
dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
|
dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
|
||||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||||
this->PredictInternal(
|
this->PredictInternal(
|
||||||
@ -983,7 +983,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(BatchParam{})) {
|
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||||
bst_row_t batch_offset = 0;
|
bst_row_t batch_offset = 0;
|
||||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
|
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
|
||||||
size_t num_rows = batch.Size();
|
size_t num_rows = batch.Size();
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2021 by XGBoost Contributors
|
* Copyright 2019-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/functional.h>
|
#include <thrust/functional.h>
|
||||||
#include <thrust/random.h>
|
#include <thrust/random.h>
|
||||||
@ -12,6 +12,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../../common/compressed_iterator.h"
|
#include "../../common/compressed_iterator.h"
|
||||||
|
#include "../../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../../common/random.h"
|
#include "../../common/random.h"
|
||||||
#include "../param.h"
|
#include "../param.h"
|
||||||
#include "gradient_based_sampler.cuh"
|
#include "gradient_based_sampler.cuh"
|
||||||
@ -147,25 +148,26 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
|
|||||||
|
|
||||||
NoSampling::NoSampling(EllpackPageImpl const* page) : page_(page) {}
|
NoSampling::NoSampling(EllpackPageImpl const* page) : page_(page) {}
|
||||||
|
|
||||||
GradientBasedSample NoSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
|
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) {
|
||||||
return {dmat->Info().num_row_, page_, gpair};
|
return {dmat->Info().num_row_, page_, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl const* page,
|
ExternalMemoryNoSampling::ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page,
|
||||||
size_t n_rows,
|
size_t n_rows, BatchParam batch_param)
|
||||||
const BatchParam& batch_param)
|
: batch_param_{std::move(batch_param)},
|
||||||
: batch_param_(batch_param),
|
page_(new EllpackPageImpl(ctx->gpu_id, page->Cuts(), page->is_dense, page->row_stride,
|
||||||
page_(new EllpackPageImpl(batch_param.gpu_id, page->Cuts(), page->is_dense,
|
n_rows)) {}
|
||||||
page->row_stride, n_rows)) {}
|
|
||||||
|
|
||||||
GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span<GradientPair> gpair,
|
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
||||||
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
if (!page_concatenated_) {
|
if (!page_concatenated_) {
|
||||||
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
||||||
auto page = batch.Impl();
|
auto page = batch.Impl();
|
||||||
size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset);
|
size_t num_elements = page_->Copy(ctx->gpu_id, page, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
}
|
}
|
||||||
page_concatenated_ = true;
|
page_concatenated_ = true;
|
||||||
@ -176,12 +178,13 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span<GradientPair>
|
|||||||
UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample)
|
UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample)
|
||||||
: page_(page), subsample_(subsample) {}
|
: page_(page), subsample_(subsample) {}
|
||||||
|
|
||||||
GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DMatrix* dmat) {
|
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) {
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
|
auto cuctx = ctx->CUDACtx();
|
||||||
thrust::counting_iterator<size_t>(0),
|
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
BernoulliTrial(common::GlobalRandom()(), subsample_),
|
thrust::counting_iterator<std::size_t>(0),
|
||||||
GradientPair());
|
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
||||||
return {dmat->Info().num_row_, page_, gpair};
|
return {dmat->Info().num_row_, page_, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,7 +195,8 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
|||||||
subsample_(subsample),
|
subsample_(subsample),
|
||||||
sample_row_index_(n_rows) {}
|
sample_row_index_(n_rows) {}
|
||||||
|
|
||||||
GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
|
GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
||||||
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
|
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
|
||||||
@ -216,18 +220,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
|
|||||||
sample_row_index_.begin(),
|
sample_row_index_.begin(),
|
||||||
ClearEmptyRows());
|
ClearEmptyRows());
|
||||||
|
|
||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
page_.reset(); // Release the device memory first before reallocating
|
||||||
page_.reset(new EllpackPageImpl(
|
page_.reset(new EllpackPageImpl(ctx->gpu_id, first_page->Cuts(), first_page->is_dense,
|
||||||
batch_param_.gpu_id, first_page->Cuts(), first_page->is_dense,
|
first_page->row_stride, sample_rows));
|
||||||
first_page->row_stride, sample_rows));
|
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : batch_iterator) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
@ -242,18 +245,17 @@ GradientBasedSampling::GradientBasedSampling(EllpackPageImpl const* page,
|
|||||||
threshold_(n_rows + 1, 0.0f),
|
threshold_(n_rows + 1, 0.0f),
|
||||||
grad_sum_(n_rows, 0.0f) {}
|
grad_sum_(n_rows, 0.0f) {}
|
||||||
|
|
||||||
GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
|
GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
||||||
DMatrix* dmat) {
|
common::Span<GradientPair> gpair, DMatrix* dmat) {
|
||||||
|
auto cuctx = ctx->CUDACtx();
|
||||||
size_t n_rows = dmat->Info().num_row_;
|
size_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
// Perform Poisson sampling in place.
|
// Perform Poisson sampling in place.
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<size_t>(0),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
dh::tbegin(gpair),
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
PoissonSampling(dh::ToSpan(threshold_),
|
|
||||||
threshold_index,
|
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
return {n_rows, page_, gpair};
|
return {n_rows, page_, gpair};
|
||||||
}
|
}
|
||||||
@ -268,7 +270,8 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
|||||||
grad_sum_(n_rows, 0.0f),
|
grad_sum_(n_rows, 0.0f),
|
||||||
sample_row_index_(n_rows) {}
|
sample_row_index_(n_rows) {}
|
||||||
|
|
||||||
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
|
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx,
|
||||||
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
size_t n_rows = dmat->Info().num_row_;
|
size_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
@ -298,28 +301,25 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
|
|||||||
sample_row_index_.begin(),
|
sample_row_index_.begin(),
|
||||||
ClearEmptyRows());
|
ClearEmptyRows());
|
||||||
|
|
||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
page_.reset(); // Release the device memory first before reallocating
|
||||||
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, first_page->Cuts(),
|
page_.reset(new EllpackPageImpl(ctx->gpu_id, first_page->Cuts(), first_page->is_dense,
|
||||||
first_page->is_dense,
|
|
||||||
first_page->row_stride, sample_rows));
|
first_page->row_stride, sample_rows));
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : batch_iterator) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
|
GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page,
|
||||||
size_t n_rows,
|
size_t n_rows, const BatchParam& batch_param,
|
||||||
const BatchParam& batch_param,
|
float subsample, int sampling_method) {
|
||||||
float subsample,
|
|
||||||
int sampling_method) {
|
|
||||||
monitor_.Init("gradient_based_sampler");
|
monitor_.Init("gradient_based_sampler");
|
||||||
|
|
||||||
bool is_sampling = subsample < 1.0;
|
bool is_sampling = subsample < 1.0;
|
||||||
@ -346,7 +346,7 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param));
|
strategy_.reset(new ExternalMemoryNoSampling(ctx, page, n_rows, batch_param));
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new NoSampling(page));
|
strategy_.reset(new NoSampling(page));
|
||||||
}
|
}
|
||||||
@ -354,10 +354,10 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sample a DMatrix based on the given gradient pairs.
|
// Sample a DMatrix based on the given gradient pairs.
|
||||||
GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpair,
|
GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
|
||||||
DMatrix* dmat) {
|
common::Span<GradientPair> gpair, DMatrix* dmat) {
|
||||||
monitor_.Start("Sample");
|
monitor_.Start("Sample");
|
||||||
GradientBasedSample sample = strategy_->Sample(gpair, dmat);
|
GradientBasedSample sample = strategy_->Sample(ctx, gpair, dmat);
|
||||||
monitor_.Stop("Sample");
|
monitor_.Stop("Sample");
|
||||||
return sample;
|
return sample;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,7 +24,8 @@ struct GradientBasedSample {
|
|||||||
class SamplingStrategy {
|
class SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
||||||
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
|
virtual GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) = 0;
|
||||||
virtual ~SamplingStrategy() = default;
|
virtual ~SamplingStrategy() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -32,7 +33,8 @@ class SamplingStrategy {
|
|||||||
class NoSampling : public SamplingStrategy {
|
class NoSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
explicit NoSampling(EllpackPageImpl const* page);
|
explicit NoSampling(EllpackPageImpl const* page);
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* page_;
|
EllpackPageImpl const* page_;
|
||||||
@ -41,10 +43,10 @@ class NoSampling : public SamplingStrategy {
|
|||||||
/*! \brief No sampling in external memory mode. */
|
/*! \brief No sampling in external memory mode. */
|
||||||
class ExternalMemoryNoSampling : public SamplingStrategy {
|
class ExternalMemoryNoSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
ExternalMemoryNoSampling(EllpackPageImpl const* page,
|
ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
|
||||||
size_t n_rows,
|
BatchParam batch_param);
|
||||||
const BatchParam& batch_param);
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
@ -56,7 +58,8 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
|
|||||||
class UniformSampling : public SamplingStrategy {
|
class UniformSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
UniformSampling(EllpackPageImpl const* page, float subsample);
|
UniformSampling(EllpackPageImpl const* page, float subsample);
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* page_;
|
EllpackPageImpl const* page_;
|
||||||
@ -66,10 +69,9 @@ class UniformSampling : public SamplingStrategy {
|
|||||||
/*! \brief No sampling in external memory mode. */
|
/*! \brief No sampling in external memory mode. */
|
||||||
class ExternalMemoryUniformSampling : public SamplingStrategy {
|
class ExternalMemoryUniformSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
ExternalMemoryUniformSampling(size_t n_rows,
|
ExternalMemoryUniformSampling(size_t n_rows, BatchParam batch_param, float subsample);
|
||||||
BatchParam batch_param,
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
float subsample);
|
DMatrix* dmat) override;
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
@ -82,11 +84,10 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
|
|||||||
/*! \brief Gradient-based sampling in in-memory mode.. */
|
/*! \brief Gradient-based sampling in in-memory mode.. */
|
||||||
class GradientBasedSampling : public SamplingStrategy {
|
class GradientBasedSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
GradientBasedSampling(EllpackPageImpl const* page,
|
GradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, const BatchParam& batch_param,
|
||||||
size_t n_rows,
|
|
||||||
const BatchParam& batch_param,
|
|
||||||
float subsample);
|
float subsample);
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* page_;
|
EllpackPageImpl const* page_;
|
||||||
@ -98,10 +99,9 @@ class GradientBasedSampling : public SamplingStrategy {
|
|||||||
/*! \brief Gradient-based sampling in external memory mode.. */
|
/*! \brief Gradient-based sampling in external memory mode.. */
|
||||||
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
ExternalMemoryGradientBasedSampling(size_t n_rows,
|
ExternalMemoryGradientBasedSampling(size_t n_rows, BatchParam batch_param, float subsample);
|
||||||
BatchParam batch_param,
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
float subsample);
|
DMatrix* dmat) override;
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
@ -124,14 +124,11 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
|||||||
*/
|
*/
|
||||||
class GradientBasedSampler {
|
class GradientBasedSampler {
|
||||||
public:
|
public:
|
||||||
GradientBasedSampler(EllpackPageImpl const* page,
|
GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
|
||||||
size_t n_rows,
|
const BatchParam& batch_param, float subsample, int sampling_method);
|
||||||
const BatchParam& batch_param,
|
|
||||||
float subsample,
|
|
||||||
int sampling_method);
|
|
||||||
|
|
||||||
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat);
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
||||||
|
|
||||||
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
|
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
|
||||||
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class GloablApproxBuilder {
|
|||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||||
for (auto const &page :
|
for (auto const &page :
|
||||||
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
|
p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess, *task_))) {
|
||||||
if (n_total_bins == 0) {
|
if (n_total_bins == 0) {
|
||||||
n_total_bins = page.cut.TotalBins();
|
n_total_bins = page.cut.TotalBins();
|
||||||
feature_values_ = page.cut;
|
feature_values_ = page.cut;
|
||||||
@ -97,7 +97,7 @@ class GloablApproxBuilder {
|
|||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
||||||
{}, gpair);
|
{}, gpair);
|
||||||
i++;
|
i++;
|
||||||
@ -148,7 +148,7 @@ class GloablApproxBuilder {
|
|||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
|
||||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||||
nodes_to_build, nodes_to_sub, gpair);
|
nodes_to_build, nodes_to_sub, gpair);
|
||||||
i++;
|
i++;
|
||||||
@ -214,7 +214,8 @@ class GloablApproxBuilder {
|
|||||||
|
|
||||||
monitor_->Start("UpdatePosition");
|
monitor_->Start("UpdatePosition");
|
||||||
size_t page_id = 0;
|
size_t page_id = 0;
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess))) {
|
for (auto const &page :
|
||||||
|
p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
|
||||||
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
|
partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
|
||||||
page_id++;
|
page_id++;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// Finds densities if we don't already have them
|
// Finds densities if we don't already have them
|
||||||
if (column_densities_.empty()) {
|
if (column_densities_.empty()) {
|
||||||
std::vector<size_t> column_size(dmat->Info().num_col_);
|
std::vector<size_t> column_size(dmat->Info().num_col_);
|
||||||
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : dmat->GetBatches<SortedCSCPage>(ctx_)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
for (auto i = 0u; i < batch.Size(); i++) {
|
for (auto i = 0u; i < batch.Size(); i++) {
|
||||||
column_size[i] += page[i].size();
|
column_size[i] += page[i].size();
|
||||||
@ -467,7 +467,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
|
|
||||||
auto feat_set = column_sampler_.GetFeatureSet(depth);
|
auto feat_set = column_sampler_.GetFeatureSet(depth);
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
|
||||||
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
|
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
|
||||||
}
|
}
|
||||||
// after this each thread's stemp will get the best candidates, aggregate results
|
// after this each thread's stemp will get the best candidates, aggregate results
|
||||||
@ -546,7 +546,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
std::sort(fsplits.begin(), fsplits.end());
|
std::sort(fsplits.begin(), fsplits.end());
|
||||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
for (auto fid : fsplits) {
|
for (auto fid : fsplits) {
|
||||||
auto col = page[fid];
|
auto col = page[fid];
|
||||||
|
|||||||
@ -218,7 +218,7 @@ struct GPUHistMakerDevice {
|
|||||||
column_sampler(column_sampler_seed),
|
column_sampler(column_sampler_seed),
|
||||||
interaction_constraints(param, n_features),
|
interaction_constraints(param, n_features),
|
||||||
batch_param(std::move(_batch_param)) {
|
batch_param(std::move(_batch_param)) {
|
||||||
sampler.reset(new GradientBasedSampler(page, _n_rows, batch_param, param.subsample,
|
sampler.reset(new GradientBasedSampler(ctx, page, _n_rows, batch_param, param.subsample,
|
||||||
param.sampling_method));
|
param.sampling_method));
|
||||||
if (!param.monotone_constraints.empty()) {
|
if (!param.monotone_constraints.empty()) {
|
||||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||||
@ -258,7 +258,7 @@ struct GPUHistMakerDevice {
|
|||||||
dh::safe_cuda(cudaMemcpyAsync(
|
dh::safe_cuda(cudaMemcpyAsync(
|
||||||
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
||||||
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
|
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
|
||||||
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
|
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat);
|
||||||
page = sample.page;
|
page = sample.page;
|
||||||
gpair = sample.gpair;
|
gpair = sample.gpair;
|
||||||
|
|
||||||
@ -808,11 +808,8 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||||
|
|
||||||
BatchParam batch_param{
|
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
|
||||||
ctx_->gpu_id,
|
auto page = (*dmat->GetBatches<EllpackPage>(ctx_, batch_param).begin()).Impl();
|
||||||
param->max_bin,
|
|
||||||
};
|
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
info_->feature_types.SetDevice(ctx_->gpu_id);
|
info_->feature_types.SetDevice(ctx_->gpu_id);
|
||||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class MultiTargetHistBuilder {
|
|||||||
std::vector<MultiExpandEntry> const &applied) {
|
std::vector<MultiExpandEntry> const &applied) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::size_t page_id{0};
|
std::size_t page_id{0};
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(this->param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(this->param_))) {
|
||||||
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
||||||
page_id++;
|
page_id++;
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ class MultiTargetHistBuilder {
|
|||||||
std::size_t page_id = 0;
|
std::size_t page_id = 0;
|
||||||
bst_bin_t n_total_bins = 0;
|
bst_bin_t n_total_bins = 0;
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
if (n_total_bins == 0) {
|
if (n_total_bins == 0) {
|
||||||
n_total_bins = page.cut.TotalBins();
|
n_total_bins = page.cut.TotalBins();
|
||||||
} else {
|
} else {
|
||||||
@ -206,7 +206,7 @@ class MultiTargetHistBuilder {
|
|||||||
std::vector<MultiExpandEntry> nodes{best};
|
std::vector<MultiExpandEntry> nodes{best};
|
||||||
std::size_t i = 0;
|
std::size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||||
@ -225,7 +225,7 @@ class MultiTargetHistBuilder {
|
|||||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
hists.push_back(&histogram_builder_[t].Histogram());
|
hists.push_back(&histogram_builder_[t].Histogram());
|
||||||
}
|
}
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -263,7 +263,7 @@ class MultiTargetHistBuilder {
|
|||||||
|
|
||||||
std::size_t i = 0;
|
std::size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
|
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
// Make sure the gradient matrix is f-order.
|
// Make sure the gradient matrix is f-order.
|
||||||
@ -283,7 +283,7 @@ class MultiTargetHistBuilder {
|
|||||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
hists.push_back(&histogram_builder_[t].Histogram());
|
hists.push_back(&histogram_builder_[t].Histogram());
|
||||||
}
|
}
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -383,7 +383,7 @@ class HistBuilder {
|
|||||||
std::size_t page_id{0};
|
std::size_t page_id{0};
|
||||||
bst_bin_t n_total_bins{0};
|
bst_bin_t n_total_bins{0};
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
if (n_total_bins == 0) {
|
if (n_total_bins == 0) {
|
||||||
n_total_bins = page.cut.TotalBins();
|
n_total_bins = page.cut.TotalBins();
|
||||||
} else {
|
} else {
|
||||||
@ -406,7 +406,7 @@ class HistBuilder {
|
|||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
auto const &histograms = histogram_builder_->Histogram();
|
auto const &histograms = histogram_builder_->Histogram();
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits);
|
evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -423,7 +423,7 @@ class HistBuilder {
|
|||||||
|
|
||||||
std::size_t page_id = 0;
|
std::size_t page_id = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, {node});
|
auto space = ConstructHistSpace(partitioner_, {node});
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
std::vector<CPUExpandEntry> nodes_to_build{node};
|
std::vector<CPUExpandEntry> nodes_to_build{node};
|
||||||
std::vector<CPUExpandEntry> nodes_to_sub;
|
std::vector<CPUExpandEntry> nodes_to_sub;
|
||||||
this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
||||||
@ -439,7 +439,7 @@ class HistBuilder {
|
|||||||
* Specialized code for dense data: For dense data (with no missing value), the sum
|
* Specialized code for dense data: For dense data (with no missing value), the sum
|
||||||
* of gradient histogram is equal to snode[nid]
|
* of gradient histogram is equal to snode[nid]
|
||||||
*/
|
*/
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_)).begin());
|
||||||
std::vector<std::uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
std::vector<std::uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
||||||
CHECK_GE(row_ptr.size(), 2);
|
CHECK_GE(row_ptr.size(), 2);
|
||||||
std::uint32_t const ibegin = row_ptr[0];
|
std::uint32_t const ibegin = row_ptr[0];
|
||||||
@ -467,7 +467,7 @@ class HistBuilder {
|
|||||||
std::vector<CPUExpandEntry> entries{node};
|
std::vector<CPUExpandEntry> entries{node};
|
||||||
monitor_->Start("EvaluateSplits");
|
monitor_->Start("EvaluateSplits");
|
||||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree,
|
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree,
|
||||||
&entries);
|
&entries);
|
||||||
break;
|
break;
|
||||||
@ -503,7 +503,7 @@ class HistBuilder {
|
|||||||
|
|
||||||
std::size_t page_id{0};
|
std::size_t page_id{0};
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
||||||
partitioner_.at(page_id).Partitions(), nodes_to_build,
|
partitioner_.at(page_id).Partitions(), nodes_to_build,
|
||||||
nodes_to_sub, gpair.Values());
|
nodes_to_sub, gpair.Values());
|
||||||
@ -515,7 +515,7 @@ class HistBuilder {
|
|||||||
std::vector<CPUExpandEntry> const &applied) {
|
std::vector<CPUExpandEntry> const &applied) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
std::size_t page_id{0};
|
std::size_t page_id{0};
|
||||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(this->param_))) {
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
|
||||||
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
||||||
page_id++;
|
page_id++;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,11 +14,12 @@ TEST(DenseColumn, Test) {
|
|||||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
BinTypeSize last{kUint8BinsTypeSize};
|
BinTypeSize last{kUint8BinsTypeSize};
|
||||||
for (int32_t max_num_bin : max_num_bins) {
|
for (int32_t max_num_bin : max_num_bins) {
|
||||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||||
auto sparse_thresh = 0.2;
|
auto sparse_thresh = 0.2;
|
||||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, AllThreadsForTest()};
|
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false};
|
||||||
ColumnMatrix column_matrix;
|
ColumnMatrix column_matrix;
|
||||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
|
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
|
||||||
@ -62,9 +63,10 @@ TEST(SparseColumn, Test) {
|
|||||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
for (int32_t max_num_bin : max_num_bins) {
|
for (int32_t max_num_bin : max_num_bins) {
|
||||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, AllThreadsForTest()};
|
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false};
|
||||||
ColumnMatrix column_matrix;
|
ColumnMatrix column_matrix;
|
||||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||||
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
|
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
|
||||||
@ -90,9 +92,10 @@ TEST(DenseColumnWithMissing, Test) {
|
|||||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
for (int32_t max_num_bin : max_num_bins) {
|
for (int32_t max_num_bin : max_num_bins) {
|
||||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, AllThreadsForTest());
|
GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false);
|
||||||
ColumnMatrix column_matrix;
|
ColumnMatrix column_matrix;
|
||||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||||
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
|
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
|
||||||
|
|||||||
@ -156,6 +156,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DenseCutsCategorical) {
|
TEST(HistUtil, DenseCutsCategorical) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
int categorical_sizes[] = {2, 6, 8, 12};
|
int categorical_sizes[] = {2, 6, 8, 12};
|
||||||
int num_bins = 256;
|
int num_bins = 256;
|
||||||
int sizes[] = {25, 100, 1000};
|
int sizes[] = {25, 100, 1000};
|
||||||
@ -165,7 +166,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
|||||||
std::vector<float> x_sorted(x);
|
std::vector<float> x_sorted(x);
|
||||||
std::sort(x_sorted.begin(), x_sorted.end());
|
std::sort(x_sorted.begin(), x_sorted.end());
|
||||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||||
auto cuts_from_sketch = cuts.Values();
|
auto cuts_from_sketch = cuts.Values();
|
||||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||||
@ -176,6 +177,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DenseCutsAccuracyTest) {
|
TEST(HistUtil, DenseCutsAccuracyTest) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
int bin_sizes[] = {2, 16, 256, 512};
|
int bin_sizes[] = {2, 16, 256, 512};
|
||||||
int sizes[] = {100};
|
int sizes[] = {100};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
@ -183,7 +185,7 @@ TEST(HistUtil, DenseCutsAccuracyTest) {
|
|||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
for (auto num_bins : bin_sizes) {
|
for (auto num_bins : bin_sizes) {
|
||||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,6 +195,7 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
|||||||
int bin_sizes[] = {2, 16, 256, 512};
|
int bin_sizes[] = {2, 16, 256, 512};
|
||||||
int sizes[] = {100, 1000, 1500};
|
int sizes[] = {100, 1000, 1500};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
for (auto num_rows : sizes) {
|
for (auto num_rows : sizes) {
|
||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
@ -200,11 +203,11 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
|||||||
dmat->Info().weights_.HostVector() = w;
|
dmat->Info().weights_.HostVector() = w;
|
||||||
for (auto num_bins : bin_sizes) {
|
for (auto num_bins : bin_sizes) {
|
||||||
{
|
{
|
||||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), true);
|
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins, true);
|
||||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), false);
|
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins, false);
|
||||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,6 +218,7 @@ void TestQuantileWithHessian(bool use_sorted) {
|
|||||||
int bin_sizes[] = {2, 16, 256, 512};
|
int bin_sizes[] = {2, 16, 256, 512};
|
||||||
int sizes[] = {1000, 1500};
|
int sizes[] = {1000, 1500};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
for (auto num_rows : sizes) {
|
for (auto num_rows : sizes) {
|
||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
@ -225,15 +229,13 @@ void TestQuantileWithHessian(bool use_sorted) {
|
|||||||
dmat->Info().weights_.HostVector() = w;
|
dmat->Info().weights_.HostVector() = w;
|
||||||
|
|
||||||
for (auto num_bins : bin_sizes) {
|
for (auto num_bins : bin_sizes) {
|
||||||
HistogramCuts cuts_hess =
|
HistogramCuts cuts_hess = SketchOnDMatrix(&ctx, dmat.get(), num_bins, use_sorted, hessian);
|
||||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted, hessian);
|
|
||||||
for (size_t i = 0; i < w.size(); ++i) {
|
for (size_t i = 0; i < w.size(); ++i) {
|
||||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||||
}
|
}
|
||||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||||
|
|
||||||
HistogramCuts cuts_wh =
|
HistogramCuts cuts_wh = SketchOnDMatrix(&ctx, dmat.get(), num_bins, use_sorted);
|
||||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted);
|
|
||||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||||
|
|
||||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||||
@ -255,12 +257,13 @@ TEST(HistUtil, DenseCutsExternalMemory) {
|
|||||||
int bin_sizes[] = {2, 16, 256, 512};
|
int bin_sizes[] = {2, 16, 256, 512};
|
||||||
int sizes[] = {100, 1000, 1500};
|
int sizes[] = {100, 1000, 1500};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
for (auto num_rows : sizes) {
|
for (auto num_rows : sizes) {
|
||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, tmpdir);
|
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, tmpdir);
|
||||||
for (auto num_bins : bin_sizes) {
|
for (auto num_bins : bin_sizes) {
|
||||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
HistogramCuts cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,12 +278,12 @@ TEST(HistUtil, IndexBinBound) {
|
|||||||
kUint32BinsTypeSize};
|
kUint32BinsTypeSize};
|
||||||
size_t constexpr kRows = 100;
|
size_t constexpr kRows = 100;
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
size_t bin_id = 0;
|
size_t bin_id = 0;
|
||||||
for (auto max_bin : bin_sizes) {
|
for (auto max_bin : bin_sizes) {
|
||||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
GHistIndexMatrix hmat(&ctx, p_fmat.get(), max_bin, 0.5, false);
|
||||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||||
}
|
}
|
||||||
@ -300,10 +303,11 @@ TEST(HistUtil, IndexBinData) {
|
|||||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||||
size_t constexpr kRows = 100;
|
size_t constexpr kRows = 100;
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
for (auto max_bin : kBinSizes) {
|
for (auto max_bin : kBinSizes) {
|
||||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
GHistIndexMatrix hmat(&ctx, p_fmat.get(), max_bin, 0.5, false);
|
||||||
uint32_t const* offsets = hmat.index.Offset();
|
uint32_t const* offsets = hmat.index.Offset();
|
||||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||||
switch (max_bin) {
|
switch (max_bin) {
|
||||||
@ -327,10 +331,10 @@ void TestSketchFromWeights(bool with_group) {
|
|||||||
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
||||||
size_t constexpr kGroups = 10;
|
size_t constexpr kGroups = 10;
|
||||||
auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
||||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
common::HistogramCuts cuts = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
Context ctx;
|
|
||||||
auto& h_weights = info.weights_.HostVector();
|
auto& h_weights = info.weights_.HostVector();
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
h_weights.resize(kGroups);
|
h_weights.resize(kGroups);
|
||||||
@ -363,7 +367,7 @@ void TestSketchFromWeights(bool with_group) {
|
|||||||
|
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
||||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
HistogramCuts non_weighted = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||||
}
|
}
|
||||||
@ -382,7 +386,7 @@ void TestSketchFromWeights(bool with_group) {
|
|||||||
for (size_t i = 0; i < h_weights.size(); ++i) {
|
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||||
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
||||||
}
|
}
|
||||||
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
HistogramCuts weighted = SketchOnDMatrix(&ctx, m.get(), kBins);
|
||||||
ValidateCuts(weighted, m.get(), kBins);
|
ValidateCuts(weighted, m.get(), kBins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -393,11 +397,12 @@ TEST(HistUtil, SketchFromWeights) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, SketchCategoricalFeatures) {
|
TEST(HistUtil, SketchCategoricalFeatures) {
|
||||||
TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) {
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
TestCategoricalSketch(1000, 256, 32, false, [&ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||||
|
return SketchOnDMatrix(&ctx, p_fmat, num_bins);
|
||||||
});
|
});
|
||||||
TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) {
|
TestCategoricalSketch(1000, 256, 32, true, [&ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
return SketchOnDMatrix(&ctx, p_fmat, num_bins);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -25,9 +25,9 @@ namespace xgboost {
|
|||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) {
|
||||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, AllThreadsForTest());
|
HistogramCuts cuts = SketchOnDMatrix(ctx, &dmat, num_bins);
|
||||||
return cuts;
|
return cuts;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,7 +39,9 @@ TEST(HistUtil, DeviceSketch) {
|
|||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
|
|
||||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
|
||||||
|
Context ctx;
|
||||||
|
HistogramCuts host_cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||||
|
|
||||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||||
@ -308,7 +310,8 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
|||||||
data::CupyAdapter adapter(str);
|
data::CupyAdapter adapter(str);
|
||||||
|
|
||||||
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
||||||
auto host_cuts = GetHostCuts(&adapter, num_bins, missing);
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
auto host_cuts = GetHostCuts(&ctx, &adapter, num_bins, missing);
|
||||||
|
|
||||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||||
|
|||||||
@ -16,7 +16,8 @@ TEST(Quantile, LoadBalance) {
|
|||||||
size_t constexpr kRows = 1000, kCols = 100;
|
size_t constexpr kRows = 1000, kCols = 100;
|
||||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||||
std::vector<bst_feature_t> cols_ptr;
|
std::vector<bst_feature_t> cols_ptr;
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
Context ctx;
|
||||||
|
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||||
data::SparsePageAdapterBatch adapter{page.GetView()};
|
data::SparsePageAdapterBatch adapter{page.GetView()};
|
||||||
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
|
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
|
||||||
}
|
}
|
||||||
@ -43,6 +44,7 @@ void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo c
|
|||||||
|
|
||||||
template <bool use_column>
|
template <bool use_column>
|
||||||
void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||||
|
Context ctx;
|
||||||
auto const world = collective::GetWorldSize();
|
auto const world = collective::GetWorldSize();
|
||||||
std::vector<MetaInfo> infos(2);
|
std::vector<MetaInfo> infos(2);
|
||||||
auto& h_weights = infos.front().weights_.HostVector();
|
auto& h_weights = infos.front().weights_.HostVector();
|
||||||
@ -51,7 +53,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
SimpleRealUniformDistribution<float> dist(3, 1000);
|
SimpleRealUniformDistribution<float> dist(3, 1000);
|
||||||
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
|
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
|
||||||
std::vector<bst_row_t> column_size(cols, rows);
|
std::vector<bst_row_t> column_size(cols, rows);
|
||||||
size_t n_bins = 64;
|
bst_bin_t n_bins = 64;
|
||||||
|
|
||||||
// Generate cuts for distributed environment.
|
// Generate cuts for distributed environment.
|
||||||
auto sparsity = 0.5f;
|
auto sparsity = 0.5f;
|
||||||
@ -72,15 +74,15 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
std::vector<float> hessian(rows, 1.0);
|
std::vector<float> hessian(rows, 1.0);
|
||||||
auto hess = Span<float const>{hessian};
|
auto hess = Span<float const>{hessian};
|
||||||
|
|
||||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
ContainerType<use_column> sketch_distributed(
|
||||||
column_size, false, AllThreadsForTest());
|
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||||
|
|
||||||
if (use_column) {
|
if (use_column) {
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,8 +95,8 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||||
m->Info().num_row_ = world * rows;
|
m->Info().num_row_ = world * rows;
|
||||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
ContainerType<use_column> sketch_on_single_node(
|
||||||
column_size, false, AllThreadsForTest());
|
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||||
m->Info().num_row_ = rows;
|
m->Info().num_row_ = rows;
|
||||||
|
|
||||||
for (auto rank = 0; rank < world; ++rank) {
|
for (auto rank = 0; rank < world; ++rank) {
|
||||||
@ -106,7 +108,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
.Upper(1.0f)
|
.Upper(1.0f)
|
||||||
.GenerateDMatrix();
|
.GenerateDMatrix();
|
||||||
if (use_column) {
|
if (use_column) {
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -172,6 +174,7 @@ TEST(Quantile, SortedDistributed) {
|
|||||||
namespace {
|
namespace {
|
||||||
template <bool use_column>
|
template <bool use_column>
|
||||||
void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
||||||
|
Context ctx;
|
||||||
auto const world = collective::GetWorldSize();
|
auto const world = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
|
|
||||||
@ -204,17 +207,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
|||||||
// Generate cuts for distributed environment.
|
// Generate cuts for distributed environment.
|
||||||
HistogramCuts distributed_cuts;
|
HistogramCuts distributed_cuts;
|
||||||
{
|
{
|
||||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
ContainerType<use_column> sketch_distributed(
|
||||||
column_size, false, AllThreadsForTest());
|
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||||
|
|
||||||
std::vector<float> hessian(rows, 1.0);
|
std::vector<float> hessian(rows, 1.0);
|
||||||
auto hess = Span<float const>{hessian};
|
auto hess = Span<float const>{hessian};
|
||||||
if (use_column) {
|
if (use_column) {
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -227,17 +230,17 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
|||||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||||
HistogramCuts single_node_cuts;
|
HistogramCuts single_node_cuts;
|
||||||
{
|
{
|
||||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
ContainerType<use_column> sketch_on_single_node(
|
||||||
column_size, false, AllThreadsForTest());
|
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);
|
||||||
|
|
||||||
std::vector<float> hessian(rows, 1.0);
|
std::vector<float> hessian(rows, 1.0);
|
||||||
auto hess = Span<float const>{hessian};
|
auto hess = Span<float const>{hessian};
|
||||||
if (use_column) {
|
if (use_column) {
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
|
||||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
|
||||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -299,8 +302,10 @@ namespace {
|
|||||||
void TestSameOnAllWorkers() {
|
void TestSameOnAllWorkers() {
|
||||||
auto const world = collective::GetWorldSize();
|
auto const world = collective::GetWorldSize();
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
RunWithSeedsAndBins(
|
RunWithSeedsAndBins(
|
||||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
kRows, [=, &ctx](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||||
auto rank = collective::GetRank();
|
auto rank = collective::GetRank();
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::vector<FeatureType> ft(kCols);
|
std::vector<FeatureType> ft(kCols);
|
||||||
@ -314,7 +319,7 @@ void TestSameOnAllWorkers() {
|
|||||||
.MaxCategory(17)
|
.MaxCategory(17)
|
||||||
.Seed(rank + seed)
|
.Seed(rank + seed)
|
||||||
.GenerateDMatrix();
|
.GenerateDMatrix();
|
||||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, AllThreadsForTest());
|
auto cuts = SketchOnDMatrix(&ctx, m.get(), n_bins);
|
||||||
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
||||||
std::vector<
|
std::vector<
|
||||||
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2020 XGBoost contributors
|
* Copyright 2019-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../helpers.h"
|
|
||||||
#include "../histogram_helpers.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
#include "../../../src/common/categorical.h"
|
#include "../../../src/common/categorical.h"
|
||||||
#include "../../../src/common/hist_util.h"
|
#include "../../../src/common/hist_util.h"
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "../histogram_helpers.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -19,7 +19,10 @@ TEST(EllpackPage, EmptyDMatrix) {
|
|||||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
||||||
constexpr float kSparsity = 0;
|
constexpr float kSparsity = 0;
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
|
||||||
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin}).begin();
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto& page = *dmat->GetBatches<EllpackPage>(
|
||||||
|
&ctx, BatchParam{kMaxBin, tree::TrainParam::DftSparseThreshold()})
|
||||||
|
.begin();
|
||||||
auto impl = page.Impl();
|
auto impl = page.Impl();
|
||||||
ASSERT_EQ(impl->row_stride, 0);
|
ASSERT_EQ(impl->row_stride, 0);
|
||||||
ASSERT_EQ(impl->Cuts().TotalBins(), 0);
|
ASSERT_EQ(impl->Cuts().TotalBins(), 0);
|
||||||
@ -87,8 +90,9 @@ TEST(EllpackPage, FromCategoricalBasic) {
|
|||||||
auto& h_ft = m->Info().feature_types.HostVector();
|
auto& h_ft = m->Info().feature_types.HostVector();
|
||||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||||
|
|
||||||
BatchParam p{0, max_bins};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto ellpack = EllpackPage(m.get(), p);
|
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto ellpack = EllpackPage(&ctx, m.get(), p);
|
||||||
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
|
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
|
||||||
ASSERT_EQ(kCats, accessor.NumBins());
|
ASSERT_EQ(kCats, accessor.NumBins());
|
||||||
|
|
||||||
@ -142,8 +146,9 @@ TEST(EllpackPage, Copy) {
|
|||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::unique_ptr<DMatrix>
|
std::unique_ptr<DMatrix>
|
||||||
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||||
BatchParam param{0, 256};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
|
|
||||||
// Create an empty result page.
|
// Create an empty result page.
|
||||||
EllpackPageImpl result(0, page->Cuts(), page->is_dense, page->row_stride,
|
EllpackPageImpl result(0, page->Cuts(), page->is_dense, page->row_stride,
|
||||||
@ -151,7 +156,7 @@ TEST(EllpackPage, Copy) {
|
|||||||
|
|
||||||
// Copy batch pages into the result page.
|
// Copy batch pages into the result page.
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
size_t num_elements = result.Copy(0, batch.Impl(), offset);
|
size_t num_elements = result.Copy(0, batch.Impl(), offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
}
|
}
|
||||||
@ -161,7 +166,7 @@ TEST(EllpackPage, Copy) {
|
|||||||
thrust::device_vector<bst_float> row_result_d(kCols);
|
thrust::device_vector<bst_float> row_result_d(kCols);
|
||||||
std::vector<bst_float> row(kCols);
|
std::vector<bst_float> row(kCols);
|
||||||
std::vector<bst_float> row_result(kCols);
|
std::vector<bst_float> row_result(kCols);
|
||||||
for (auto& page : dmat->GetBatches<EllpackPage>(param)) {
|
for (auto& page : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto impl = page.Impl();
|
auto impl = page.Impl();
|
||||||
EXPECT_EQ(impl->base_rowid, current_row);
|
EXPECT_EQ(impl->base_rowid, current_row);
|
||||||
|
|
||||||
@ -186,10 +191,11 @@ TEST(EllpackPage, Compact) {
|
|||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::unique_ptr<DMatrix>
|
std::unique_ptr<DMatrix> dmat(
|
||||||
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||||
BatchParam param{0, 256};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
|
|
||||||
// Create an empty result page.
|
// Create an empty result page.
|
||||||
EllpackPageImpl result(0, page->Cuts(), page->is_dense, page->row_stride,
|
EllpackPageImpl result(0, page->Cuts(), page->is_dense, page->row_stride,
|
||||||
@ -201,7 +207,7 @@ TEST(EllpackPage, Compact) {
|
|||||||
SIZE_MAX};
|
SIZE_MAX};
|
||||||
thrust::device_vector<size_t> row_indexes_d = row_indexes_h;
|
thrust::device_vector<size_t> row_indexes_d = row_indexes_h;
|
||||||
common::Span<size_t> row_indexes_span(row_indexes_d.data().get(), kRows);
|
common::Span<size_t> row_indexes_span(row_indexes_d.data().get(), kRows);
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
result.Compact(0, batch.Impl(), row_indexes_span);
|
result.Compact(0, batch.Impl(), row_indexes_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +216,7 @@ TEST(EllpackPage, Compact) {
|
|||||||
thrust::device_vector<bst_float> row_result_d(kCols);
|
thrust::device_vector<bst_float> row_result_d(kCols);
|
||||||
std::vector<bst_float> row(kCols);
|
std::vector<bst_float> row(kCols);
|
||||||
std::vector<bst_float> row_result(kCols);
|
std::vector<bst_float> row_result(kCols);
|
||||||
for (auto& page : dmat->GetBatches<EllpackPage>(param)) {
|
for (auto& page : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto impl = page.Impl();
|
auto impl = page.Impl();
|
||||||
ASSERT_EQ(impl->base_rowid, current_row);
|
ASSERT_EQ(impl->base_rowid, current_row);
|
||||||
|
|
||||||
@ -245,15 +251,17 @@ class EllpackPageTest : public testing::TestWithParam<float> {
|
|||||||
// device.
|
// device.
|
||||||
size_t n_samples{128}, n_features{13};
|
size_t n_samples{128}, n_features{13};
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.gpu_id = 0;
|
Context gpu_ctx{MakeCUDACtx(0)};
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
|
||||||
std::unique_ptr<EllpackPageImpl> from_ghist;
|
std::unique_ptr<EllpackPageImpl> from_ghist;
|
||||||
ASSERT_TRUE(Xy->SingleColBlock());
|
ASSERT_TRUE(Xy->SingleColBlock());
|
||||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(BatchParam{17, 0.6})) {
|
|
||||||
from_ghist.reset(new EllpackPageImpl{&ctx, page, {}});
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{17, 0.6})) {
|
||||||
|
from_ghist.reset(new EllpackPageImpl{&gpu_ctx, page, {}});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto const& page : Xy->GetBatches<EllpackPage>(BatchParam{0, 17})) {
|
for (auto const& page : Xy->GetBatches<EllpackPage>(
|
||||||
|
&gpu_ctx, BatchParam{17, tree::TrainParam::DftSparseThreshold()})) {
|
||||||
auto from_sparse_page = page.Impl();
|
auto from_sparse_page = page.Impl();
|
||||||
ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense);
|
ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense);
|
||||||
ASSERT_EQ(from_sparse_page->base_rowid, 0);
|
ASSERT_EQ(from_sparse_page->base_rowid, 0);
|
||||||
|
|||||||
@ -1,17 +1,21 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_source.h"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
TEST(EllpackPageRawFormat, IO) {
|
TEST(EllpackPageRawFormat, IO) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
std::unique_ptr<SparsePageFormat<EllpackPage>> format{CreatePageFormat<EllpackPage>("raw")};
|
std::unique_ptr<SparsePageFormat<EllpackPage>> format{CreatePageFormat<EllpackPage>("raw")};
|
||||||
|
|
||||||
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
||||||
@ -20,7 +24,7 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>({0, 256})) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
format->Write(ellpack, fo.get());
|
format->Write(ellpack, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -29,7 +33,7 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
||||||
format->Read(&page, fi.get());
|
format->Read(&page, fi.get());
|
||||||
|
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>({0, 256})) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto loaded = page.Impl();
|
auto loaded = page.Impl();
|
||||||
auto orig = ellpack.Impl();
|
auto orig = ellpack.Impl();
|
||||||
ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs());
|
ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs());
|
||||||
|
|||||||
@ -2,20 +2,38 @@
|
|||||||
* Copyright 2021-2023 by XGBoost contributors
|
* Copyright 2021-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, BatchParam
|
||||||
|
|
||||||
#include "../../../src/common/column_matrix.h"
|
#include <algorithm> // for sort, unique
|
||||||
#include "../../../src/common/io.h" // MemoryBufferStream
|
#include <cmath> // for isnan
|
||||||
#include "../../../src/data/gradient_index.h"
|
#include <cstddef> // for size_t
|
||||||
#include "../helpers.h"
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for shared_ptr, __shared_ptr_access, unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <tuple> // for make_tuple, tie, tuple
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/common/categorical.h" // for AsCat
|
||||||
|
#include "../../../src/common/column_matrix.h" // for ColumnMatrix
|
||||||
|
#include "../../../src/common/hist_util.h" // for Index, HistogramCuts, SketchOnDMatrix
|
||||||
|
#include "../../../src/common/io.h" // for MemoryBufferStream
|
||||||
|
#include "../../../src/data/adapter.h" // for SparsePageAdapterBatch
|
||||||
|
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
|
#include "../helpers.h" // for CreateEmptyGenericParam, GenerateRandomCa...
|
||||||
|
#include "xgboost/base.h" // for bst_bin_t
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
TEST(GradientIndex, ExternalMemory) {
|
TEST(GradientIndex, ExternalMemory) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
||||||
std::vector<size_t> base_rowids;
|
std::vector<size_t> base_rowids;
|
||||||
std::vector<float> hessian(dmat->Info().num_row_, 1);
|
std::vector<float> hessian(dmat->Info().num_row_, 1);
|
||||||
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>({64, hessian, true})) {
|
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
|
||||||
base_rowids.push_back(page.base_rowid);
|
base_rowids.push_back(page.base_rowid);
|
||||||
}
|
}
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
@ -24,9 +42,8 @@ TEST(GradientIndex, ExternalMemory) {
|
|||||||
++i;
|
++i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
base_rowids.clear();
|
base_rowids.clear();
|
||||||
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>({64, hessian, false})) {
|
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
|
||||||
base_rowids.push_back(page.base_rowid);
|
base_rowids.push_back(page.base_rowid);
|
||||||
}
|
}
|
||||||
i = 0;
|
i = 0;
|
||||||
@ -41,12 +58,13 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
|||||||
size_t max_bins = 8;
|
size_t max_bins = 8;
|
||||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
|
||||||
auto m = GetDMatrixFromData(x, kRows, 1);
|
auto m = GetDMatrixFromData(x, kRows, 1);
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
auto &h_ft = m->Info().feature_types.HostVector();
|
auto &h_ft = m->Info().feature_types.HostVector();
|
||||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||||
|
|
||||||
BatchParam p(max_bins, 0.8);
|
BatchParam p(max_bins, 0.8);
|
||||||
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, AllThreadsForTest(), {});
|
GHistIndexMatrix gidx(&ctx, m.get(), max_bins, p.sparse_thresh, false, {});
|
||||||
|
|
||||||
auto x_copy = x;
|
auto x_copy = x;
|
||||||
std::sort(x_copy.begin(), x_copy.end());
|
std::sort(x_copy.begin(), x_copy.end());
|
||||||
@ -80,11 +98,11 @@ TEST(GradientIndex, FromCategoricalLarge) {
|
|||||||
|
|
||||||
BatchParam p{max_bins, 0.8};
|
BatchParam p{max_bins, 0.8};
|
||||||
{
|
{
|
||||||
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, AllThreadsForTest(), {});
|
GHistIndexMatrix gidx{&ctx, m.get(), max_bins, p.sparse_thresh, false, {}};
|
||||||
ASSERT_TRUE(gidx.index.GetBinTypeSize() == common::kUint16BinsTypeSize);
|
ASSERT_TRUE(gidx.index.GetBinTypeSize() == common::kUint16BinsTypeSize);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(p)) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
||||||
common::HistogramCuts cut = page.cut;
|
common::HistogramCuts cut = page.cut;
|
||||||
GHistIndexMatrix gidx{m->Info(), std::move(cut), max_bins};
|
GHistIndexMatrix gidx{m->Info(), std::move(cut), max_bins};
|
||||||
ASSERT_EQ(gidx.MaxNumBinPerFeat(), kCats);
|
ASSERT_EQ(gidx.MaxNumBinPerFeat(), kCats);
|
||||||
@ -96,10 +114,11 @@ TEST(GradientIndex, PushBatch) {
|
|||||||
size_t constexpr kRows = 64, kCols = 4;
|
size_t constexpr kRows = 64, kCols = 4;
|
||||||
bst_bin_t max_bins = 64;
|
bst_bin_t max_bins = 64;
|
||||||
float st = 0.5;
|
float st = 0.5;
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
auto test = [&](float sparisty) {
|
auto test = [&](float sparisty) {
|
||||||
auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true);
|
auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true);
|
||||||
auto cuts = common::SketchOnDMatrix(m.get(), max_bins, AllThreadsForTest(), false, {});
|
auto cuts = common::SketchOnDMatrix(&ctx, m.get(), max_bins, false, {});
|
||||||
common::HistogramCuts copy_cuts = cuts;
|
common::HistogramCuts copy_cuts = cuts;
|
||||||
|
|
||||||
ASSERT_EQ(m->Info().num_row_, kRows);
|
ASSERT_EQ(m->Info().num_row_, kRows);
|
||||||
@ -112,7 +131,7 @@ TEST(GradientIndex, PushBatch) {
|
|||||||
m->Info().num_row_);
|
m->Info().num_row_);
|
||||||
gmat.PushAdapterBatchColumns(m->Ctx(), batch, std::numeric_limits<float>::quiet_NaN(), 0);
|
gmat.PushAdapterBatchColumns(m->Ctx(), batch, std::numeric_limits<float>::quiet_NaN(), 0);
|
||||||
}
|
}
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(BatchParam{max_bins, st})) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{max_bins, st})) {
|
||||||
for (size_t i = 0; i < kRows; ++i) {
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
for (size_t j = 0; j < kCols; ++j) {
|
for (size_t j = 0; j < kCols; ++j) {
|
||||||
auto v0 = gmat.GetFvalue(i, j, false);
|
auto v0 = gmat.GetFvalue(i, j, false);
|
||||||
@ -143,17 +162,19 @@ class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, flo
|
|||||||
// device.
|
// device.
|
||||||
size_t n_samples{128}, n_features{13};
|
size_t n_samples{128}, n_features{13};
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.gpu_id = 0;
|
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, 1 - density}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, 1 - density}.GenerateDMatrix(true);
|
||||||
std::unique_ptr<GHistIndexMatrix> from_ellpack;
|
std::unique_ptr<GHistIndexMatrix> from_ellpack;
|
||||||
ASSERT_TRUE(Xy->SingleColBlock());
|
ASSERT_TRUE(Xy->SingleColBlock());
|
||||||
bst_bin_t constexpr kBins{17};
|
bst_bin_t constexpr kBins{17};
|
||||||
auto p = BatchParam{kBins, threshold};
|
auto p = BatchParam{kBins, threshold};
|
||||||
for (auto const &page : Xy->GetBatches<EllpackPage>(BatchParam{0, kBins})) {
|
Context gpu_ctx;
|
||||||
|
gpu_ctx.gpu_id = 0;
|
||||||
|
for (auto const &page : Xy->GetBatches<EllpackPage>(
|
||||||
|
&gpu_ctx, BatchParam{kBins, tree::TrainParam::DftSparseThreshold()})) {
|
||||||
from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p});
|
from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(p)) {
|
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
||||||
ASSERT_EQ(from_sparse_page.IsDense(), from_ellpack->IsDense());
|
ASSERT_EQ(from_sparse_page.IsDense(), from_ellpack->IsDense());
|
||||||
ASSERT_EQ(from_sparse_page.base_rowid, 0);
|
ASSERT_EQ(from_sparse_page.base_rowid, 0);
|
||||||
ASSERT_EQ(from_sparse_page.base_rowid, from_ellpack->base_rowid);
|
ASSERT_EQ(from_sparse_page.base_rowid, from_ellpack->base_rowid);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -11,6 +11,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
TEST(GHistIndexPageRawFormat, IO) {
|
TEST(GHistIndexPageRawFormat, IO) {
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
std::unique_ptr<SparsePageFormat<GHistIndexMatrix>> format{
|
std::unique_ptr<SparsePageFormat<GHistIndexMatrix>> format{
|
||||||
CreatePageFormat<GHistIndexMatrix>("raw")};
|
CreatePageFormat<GHistIndexMatrix>("raw")};
|
||||||
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
||||||
@ -20,7 +22,7 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
||||||
for (auto const &index : m->GetBatches<GHistIndexMatrix>(batch)) {
|
for (auto const &index : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
format->Write(index, fo.get());
|
format->Write(index, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -29,7 +31,7 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
||||||
format->Read(&page, fi.get());
|
format->Read(&page, fi.get());
|
||||||
|
|
||||||
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(batch)) {
|
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
auto const &loaded = gidx;
|
auto const &loaded = gidx;
|
||||||
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
|
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
|
||||||
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
|
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
|
||||||
@ -43,5 +45,5 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
|
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -15,8 +15,9 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
TEST(IterativeDMatrix, Ref) {
|
TEST(IterativeDMatrix, Ref) {
|
||||||
|
Context ctx;
|
||||||
TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>(
|
TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>(
|
||||||
[&](GHistIndexMatrix const& page) { return page.cut; });
|
&ctx, [&](GHistIndexMatrix const& page) { return page.cut; });
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(IterativeDMatrix, IsDense) {
|
TEST(IterativeDMatrix, IsDense) {
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2022 XGBoost contributors
|
* Copyright 2020-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/data/device_adapter.cuh"
|
#include "../../../src/data/device_adapter.cuh"
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
#include "../../../src/data/iterative_dmatrix.h"
|
#include "../../../src/data/iterative_dmatrix.h"
|
||||||
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "test_iterative_dmatrix.h"
|
#include "test_iterative_dmatrix.h"
|
||||||
|
|
||||||
@ -13,15 +14,17 @@ namespace xgboost {
|
|||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
void TestEquivalent(float sparsity) {
|
void TestEquivalent(float sparsity) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
|
||||||
CudaArrayIterForTest iter{sparsity};
|
CudaArrayIterForTest iter{sparsity};
|
||||||
IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next,
|
IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
||||||
size_t offset = 0;
|
std::size_t offset = 0;
|
||||||
auto first = (*m.GetEllpackBatches({}).begin()).Impl();
|
auto first = (*m.GetEllpackBatches(&ctx, {}).begin()).Impl();
|
||||||
std::unique_ptr<EllpackPageImpl> page_concatenated {
|
std::unique_ptr<EllpackPageImpl> page_concatenated {
|
||||||
new EllpackPageImpl(0, first->Cuts(), first->is_dense,
|
new EllpackPageImpl(0, first->Cuts(), first->is_dense,
|
||||||
first->row_stride, 1000 * 100)};
|
first->row_stride, 1000 * 100)};
|
||||||
for (auto& batch : m.GetBatches<EllpackPage>({})) {
|
for (auto& batch : m.GetBatches<EllpackPage>(&ctx, {})) {
|
||||||
auto page = batch.Impl();
|
auto page = batch.Impl();
|
||||||
size_t num_elements = page_concatenated->Copy(0, page, offset);
|
size_t num_elements = page_concatenated->Copy(0, page, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
@ -34,8 +37,8 @@ void TestEquivalent(float sparsity) {
|
|||||||
auto adapter = CupyAdapter(interface_str);
|
auto adapter = CupyAdapter(interface_str);
|
||||||
std::unique_ptr<DMatrix> dm{
|
std::unique_ptr<DMatrix> dm{
|
||||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)};
|
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)};
|
||||||
BatchParam bp {0, 256};
|
auto bp = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
for (auto& ellpack : dm->GetBatches<EllpackPage>(bp)) {
|
for (auto& ellpack : dm->GetBatches<EllpackPage>(&ctx, bp)) {
|
||||||
auto from_data = ellpack.Impl()->GetDeviceAccessor(0);
|
auto from_data = ellpack.Impl()->GetDeviceAccessor(0);
|
||||||
|
|
||||||
std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size());
|
std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size());
|
||||||
@ -92,7 +95,8 @@ TEST(IterativeDeviceDMatrix, RowMajor) {
|
|||||||
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
||||||
size_t n_batches = 0;
|
size_t n_batches = 0;
|
||||||
std::string interface_str = iter.AsArray();
|
std::string interface_str = iter.AsArray();
|
||||||
for (auto& ellpack : m.GetBatches<EllpackPage>({})) {
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
for (auto& ellpack : m.GetBatches<EllpackPage>(&ctx, {})) {
|
||||||
n_batches ++;
|
n_batches ++;
|
||||||
auto impl = ellpack.Impl();
|
auto impl = ellpack.Impl();
|
||||||
common::CompressedIterator<uint32_t> iterator(
|
common::CompressedIterator<uint32_t> iterator(
|
||||||
@ -140,7 +144,10 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) {
|
|||||||
|
|
||||||
IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next,
|
IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
std::numeric_limits<float>::quiet_NaN(), 0, 256);
|
||||||
auto &ellpack = *m.GetBatches<EllpackPage>({0, 256}).begin();
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto& ellpack =
|
||||||
|
*m.GetBatches<EllpackPage>(&ctx, BatchParam{256, tree::TrainParam::DftSparseThreshold()})
|
||||||
|
.begin();
|
||||||
auto impl = ellpack.Impl();
|
auto impl = ellpack.Impl();
|
||||||
common::CompressedIterator<uint32_t> iterator(
|
common::CompressedIterator<uint32_t> iterator(
|
||||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||||
@ -171,8 +178,9 @@ TEST(IterativeDeviceDMatrix, IsDense) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(IterativeDeviceDMatrix, Ref) {
|
TEST(IterativeDeviceDMatrix, Ref) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
TestRefDMatrix<EllpackPage, CudaArrayIterForTest>(
|
TestRefDMatrix<EllpackPage, CudaArrayIterForTest>(
|
||||||
[](EllpackPage const& page) { return page.Impl()->Cuts(); });
|
&ctx, [](EllpackPage const& page) { return page.Impl()->Cuts(); });
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <memory> // std::make_shared
|
#include <xgboost/context.h> // for Context
|
||||||
|
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for make_shared
|
||||||
|
|
||||||
#include "../../../src/data/iterative_dmatrix.h"
|
#include "../../../src/data/iterative_dmatrix.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
@ -10,7 +13,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
template <typename Page, typename Iter, typename Cuts>
|
template <typename Page, typename Iter, typename Cuts>
|
||||||
void TestRefDMatrix(Cuts&& get_cuts) {
|
void TestRefDMatrix(Context const* ctx, Cuts&& get_cuts) {
|
||||||
int n_bins = 256;
|
int n_bins = 256;
|
||||||
Iter iter(0.3, 2048);
|
Iter iter(0.3, 2048);
|
||||||
auto m = std::make_shared<IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
auto m = std::make_shared<IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||||
@ -20,8 +23,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
|||||||
auto m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), m, Reset, Next,
|
auto m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), m, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||||
|
|
||||||
for (auto const& page_0 : m->template GetBatches<Page>({})) {
|
for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
|
||||||
for (auto const& page_1 : m_1->template GetBatches<Page>({})) {
|
for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
|
||||||
auto const& cuts_0 = get_cuts(page_0);
|
auto const& cuts_0 = get_cuts(page_0);
|
||||||
auto const& cuts_1 = get_cuts(page_1);
|
auto const& cuts_1 = get_cuts(page_1);
|
||||||
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
||||||
@ -32,8 +35,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
|||||||
|
|
||||||
m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), nullptr, Reset, Next,
|
m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), nullptr, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||||
for (auto const& page_0 : m->template GetBatches<Page>({})) {
|
for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
|
||||||
for (auto const& page_1 : m_1->template GetBatches<Page>({})) {
|
for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
|
||||||
auto const& cuts_0 = get_cuts(page_0);
|
auto const& cuts_0 = get_cuts(page_0);
|
||||||
auto const& cuts_1 = get_cuts(page_1);
|
auto const& cuts_1 = get_cuts(page_1);
|
||||||
ASSERT_NE(cuts_0.Values(), cuts_1.Values());
|
ASSERT_NE(cuts_0.Values(), cuts_1.Values());
|
||||||
@ -45,8 +48,8 @@ void TestRefDMatrix(Cuts&& get_cuts) {
|
|||||||
auto dm = RandomDataGenerator(2048, Iter::Cols(), 0.5).GenerateDMatrix(true);
|
auto dm = RandomDataGenerator(2048, Iter::Cols(), 0.5).GenerateDMatrix(true);
|
||||||
auto dqm = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), dm, Reset, Next,
|
auto dqm = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), dm, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||||
for (auto const& page_0 : dm->template GetBatches<Page>({})) {
|
for (auto const& page_0 : dm->template GetBatches<Page>(ctx, {})) {
|
||||||
for (auto const& page_1 : dqm->template GetBatches<Page>({})) {
|
for (auto const& page_1 : dqm->template GetBatches<Page>(ctx, {})) {
|
||||||
auto const& cuts_0 = get_cuts(page_0);
|
auto const& cuts_0 = get_cuts(page_0);
|
||||||
auto const& cuts_1 = get_cuts(page_1);
|
auto const& cuts_1 = get_cuts(page_1);
|
||||||
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
|
||||||
|
|||||||
@ -61,6 +61,7 @@ TEST(SimpleDMatrix, RowAccess) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
||||||
|
Context ctx;
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
CreateSimpleTestData(tmp_file);
|
CreateSimpleTestData(tmp_file);
|
||||||
@ -70,7 +71,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
|||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
int64_t num_col_batch = 0;
|
int64_t num_col_batch = 0;
|
||||||
for (const auto &batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
for (const auto &batch : dmat->GetBatches<xgboost::SortedCSCPage>(&ctx)) {
|
||||||
num_col_batch += 1;
|
num_col_batch += 1;
|
||||||
EXPECT_EQ(batch.Size(), dmat->Info().num_col_)
|
EXPECT_EQ(batch.Size(), dmat->Info().num_col_)
|
||||||
<< "Expected batch size = number of cells as #batches is 1.";
|
<< "Expected batch size = number of cells as #batches is 1.";
|
||||||
|
|||||||
@ -23,7 +23,7 @@ std::string UriSVM(std::string name, std::string cache) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename Page>
|
template <typename Page>
|
||||||
void TestSparseDMatrixLoadFile() {
|
void TestSparseDMatrixLoadFile(Context const* ctx) {
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
auto opath = tmpdir.path + "/1-based.svm";
|
auto opath = tmpdir.path + "/1-based.svm";
|
||||||
CreateBigTestData(opath, 3 * 64, false);
|
CreateBigTestData(opath, 3 * 64, false);
|
||||||
@ -48,7 +48,7 @@ void TestSparseDMatrixLoadFile() {
|
|||||||
data::SimpleDMatrix simple{&adapter, std::numeric_limits<float>::quiet_NaN(),
|
data::SimpleDMatrix simple{&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||||
1};
|
1};
|
||||||
Page out;
|
Page out;
|
||||||
for (auto const& page : m.GetBatches<Page>()) {
|
for (auto const &page : m.GetBatches<Page>(ctx)) {
|
||||||
if (std::is_same<Page, SparsePage>::value) {
|
if (std::is_same<Page, SparsePage>::value) {
|
||||||
out.Push(page);
|
out.Push(page);
|
||||||
} else {
|
} else {
|
||||||
@ -58,7 +58,7 @@ void TestSparseDMatrixLoadFile() {
|
|||||||
ASSERT_EQ(m.Info().num_col_, simple.Info().num_col_);
|
ASSERT_EQ(m.Info().num_col_, simple.Info().num_col_);
|
||||||
ASSERT_EQ(m.Info().num_row_, simple.Info().num_row_);
|
ASSERT_EQ(m.Info().num_row_, simple.Info().num_row_);
|
||||||
|
|
||||||
for (auto const& page : simple.GetBatches<Page>()) {
|
for (auto const& page : simple.GetBatches<Page>(ctx)) {
|
||||||
ASSERT_EQ(page.offset.HostVector(), out.offset.HostVector());
|
ASSERT_EQ(page.offset.HostVector(), out.offset.HostVector());
|
||||||
for (size_t i = 0; i < page.data.Size(); ++i) {
|
for (size_t i = 0; i < page.data.Size(); ++i) {
|
||||||
ASSERT_EQ(page.data.HostVector()[i].fvalue, out.data.HostVector()[i].fvalue);
|
ASSERT_EQ(page.data.HostVector()[i].fvalue, out.data.HostVector()[i].fvalue);
|
||||||
@ -67,16 +67,18 @@ void TestSparseDMatrixLoadFile() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, LoadFile) {
|
TEST(SparsePageDMatrix, LoadFile) {
|
||||||
TestSparseDMatrixLoadFile<SparsePage>();
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
TestSparseDMatrixLoadFile<CSCPage>();
|
TestSparseDMatrixLoadFile<SparsePage>(&ctx);
|
||||||
TestSparseDMatrixLoadFile<SortedCSCPage>();
|
TestSparseDMatrixLoadFile<CSCPage>(&ctx);
|
||||||
|
TestSparseDMatrixLoadFile<SortedCSCPage>(&ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow caller to retain pages so they can process multiple pages at the same time.
|
// allow caller to retain pages so they can process multiple pages at the same time.
|
||||||
template <typename Page>
|
template <typename Page>
|
||||||
void TestRetainPage() {
|
void TestRetainPage() {
|
||||||
auto m = CreateSparsePageDMatrix(10000);
|
auto m = CreateSparsePageDMatrix(10000);
|
||||||
auto batches = m->GetBatches<Page>();
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
auto batches = m->GetBatches<Page>(&ctx);
|
||||||
auto begin = batches.begin();
|
auto begin = batches.begin();
|
||||||
auto end = batches.end();
|
auto end = batches.end();
|
||||||
|
|
||||||
@ -100,7 +102,7 @@ void TestRetainPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure it's const and the caller can not modify the content of page.
|
// make sure it's const and the caller can not modify the content of page.
|
||||||
for (auto& page : m->GetBatches<Page>()) {
|
for (auto &page : m->GetBatches<Page>({&ctx})) {
|
||||||
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -143,10 +145,11 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
CreateSimpleTestData(tmp_file);
|
CreateSimpleTestData(tmp_file);
|
||||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file, tmp_file));
|
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file, tmp_file));
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
size_t iter = 0;
|
size_t iter = 0;
|
||||||
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>(&ctx)) {
|
||||||
auto col_page = col_batch.GetView();
|
auto col_page = col_batch.GetView();
|
||||||
ASSERT_EQ(col_page.Size(), dmat->Info().num_col_);
|
ASSERT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||||
if (iter == 1) {
|
if (iter == 1) {
|
||||||
@ -164,7 +167,7 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
iter = 0;
|
iter = 0;
|
||||||
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
||||||
auto col_page = col_batch.GetView();
|
auto col_page = col_batch.GetView();
|
||||||
EXPECT_EQ(col_page.Size(), dmat->Info().num_col_);
|
EXPECT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||||
if (iter == 0) {
|
if (iter == 0) {
|
||||||
@ -182,9 +185,9 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||||
size_t constexpr kEntriesPerCol = 3;
|
size_t constexpr kEntriesPerCol = 3;
|
||||||
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
|
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(kEntries);
|
||||||
xgboost::CreateSparsePageDMatrix(kEntries);
|
|
||||||
|
|
||||||
int threads = 1000;
|
int threads = 1000;
|
||||||
|
|
||||||
@ -221,7 +224,8 @@ TEST(SparsePageDMatrix, ColAccessBatches) {
|
|||||||
// Create multiple sparse pages
|
// Create multiple sparse pages
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
|
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
|
||||||
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
||||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
||||||
ASSERT_EQ(dmat->Info().num_col_, page.Size());
|
ASSERT_EQ(dmat->Info().num_col_, page.Size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,15 +1,20 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <xgboost/data.h> // for DMatrix
|
||||||
|
|
||||||
#include "../../../src/common/compressed_iterator.h"
|
#include "../../../src/common/compressed_iterator.h"
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, EllpackPage) {
|
TEST(SparsePageDMatrix, EllpackPage) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
CreateSimpleTestData(tmp_file);
|
CreateSimpleTestData(tmp_file);
|
||||||
@ -17,7 +22,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
|||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
size_t n = 0;
|
size_t n = 0;
|
||||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256})) {
|
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
n += batch.Size();
|
n += batch.Size();
|
||||||
}
|
}
|
||||||
EXPECT_EQ(n, dmat->Info().num_row_);
|
EXPECT_EQ(n, dmat->Info().num_row_);
|
||||||
@ -37,6 +42,8 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string filename = tmpdir.path + "/big.libsvm";
|
std::string filename = tmpdir.path + "/big.libsvm";
|
||||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||||
@ -46,7 +53,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
|||||||
// Loop over the batches and count the records
|
// Loop over the batches and count the records
|
||||||
int64_t batch_count = 0;
|
int64_t batch_count = 0;
|
||||||
int64_t row_count = 0;
|
int64_t row_count = 0;
|
||||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256})) {
|
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||||
batch_count++;
|
batch_count++;
|
||||||
row_count += batch.Size();
|
row_count += batch.Size();
|
||||||
@ -61,8 +68,11 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, RetainEllpackPage) {
|
TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||||
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto m = CreateSparsePageDMatrix(10000);
|
auto m = CreateSparsePageDMatrix(10000);
|
||||||
auto batches = m->GetBatches<EllpackPage>({0, 32});
|
|
||||||
|
auto batches = m->GetBatches<EllpackPage>(&ctx, param);
|
||||||
auto begin = batches.begin();
|
auto begin = batches.begin();
|
||||||
auto end = batches.end();
|
auto end = batches.end();
|
||||||
|
|
||||||
@ -87,7 +97,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure it's const and the caller can not modify the content of page.
|
// make sure it's const and the caller can not modify the content of page.
|
||||||
for (auto& page : m->GetBatches<EllpackPage>({0, 32})) {
|
for (auto& page : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,6 +108,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, EllpackPageContent) {
|
TEST(SparsePageDMatrix, EllpackPageContent) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(0);
|
||||||
constexpr size_t kRows = 6;
|
constexpr size_t kRows = 6;
|
||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr size_t kPageSize = 1;
|
constexpr size_t kPageSize = 1;
|
||||||
@ -110,8 +121,8 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
|||||||
std::unique_ptr<DMatrix>
|
std::unique_ptr<DMatrix>
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||||
|
|
||||||
BatchParam param{0, 2};
|
auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_EQ(impl->base_rowid, 0);
|
EXPECT_EQ(impl->base_rowid, 0);
|
||||||
EXPECT_EQ(impl->n_rows, kRows);
|
EXPECT_EQ(impl->n_rows, kRows);
|
||||||
EXPECT_FALSE(impl->is_dense);
|
EXPECT_FALSE(impl->is_dense);
|
||||||
@ -120,7 +131,7 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
|||||||
|
|
||||||
std::unique_ptr<EllpackPageImpl> impl_ext;
|
std::unique_ptr<EllpackPageImpl> impl_ext;
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat_ext->GetBatches<EllpackPage>(param)) {
|
for (auto& batch : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
if (!impl_ext) {
|
if (!impl_ext) {
|
||||||
impl_ext.reset(new EllpackPageImpl(
|
impl_ext.reset(new EllpackPageImpl(
|
||||||
batch.Impl()->gidx_buffer.DeviceIdx(), batch.Impl()->Cuts(),
|
batch.Impl()->gidx_buffer.DeviceIdx(), batch.Impl()->Cuts(),
|
||||||
@ -170,8 +181,9 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
|||||||
std::unique_ptr<DMatrix>
|
std::unique_ptr<DMatrix>
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||||
|
|
||||||
BatchParam param{0, kMaxBins};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_EQ(impl->base_rowid, 0);
|
EXPECT_EQ(impl->base_rowid, 0);
|
||||||
EXPECT_EQ(impl->n_rows, kRows);
|
EXPECT_EQ(impl->n_rows, kRows);
|
||||||
|
|
||||||
@ -180,7 +192,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
|||||||
thrust::device_vector<bst_float> row_ext_d(kCols);
|
thrust::device_vector<bst_float> row_ext_d(kCols);
|
||||||
std::vector<bst_float> row(kCols);
|
std::vector<bst_float> row(kCols);
|
||||||
std::vector<bst_float> row_ext(kCols);
|
std::vector<bst_float> row_ext(kCols);
|
||||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
|
for (auto& page : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto impl_ext = page.Impl();
|
auto impl_ext = page.Impl();
|
||||||
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
||||||
|
|
||||||
@ -211,10 +223,11 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
|
|||||||
std::unique_ptr<DMatrix>
|
std::unique_ptr<DMatrix>
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||||
|
|
||||||
BatchParam param{0, kMaxBins};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
|
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
size_t current_row = 0;
|
size_t current_row = 0;
|
||||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
|
for (auto& page : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto impl_ext = page.Impl();
|
auto impl_ext = page.Impl();
|
||||||
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
||||||
current_row += impl_ext->n_rows;
|
current_row += impl_ext->n_rows;
|
||||||
|
|||||||
@ -1,17 +1,24 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h> // for CSCPage, SortedCSCPage, SparsePage
|
||||||
|
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include <memory> // for allocator, unique_ptr, __shared_ptr_ac...
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include <string> // for char_traits, operator+, basic_string
|
||||||
#include "../helpers.h"
|
|
||||||
|
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
||||||
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
|
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
||||||
|
#include "dmlc/io.h" // for SeekStream, Stream
|
||||||
|
#include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
template <typename S> void TestSparsePageRawFormat() {
|
template <typename S> void TestSparsePageRawFormat() {
|
||||||
std::unique_ptr<SparsePageFormat<S>> format{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> format{CreatePageFormat<S>("raw")};
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
||||||
ASSERT_TRUE(m->SingleColBlock());
|
ASSERT_TRUE(m->SingleColBlock());
|
||||||
@ -21,7 +28,7 @@ template <typename S> void TestSparsePageRawFormat() {
|
|||||||
{
|
{
|
||||||
// block code to flush the stream
|
// block code to flush the stream
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
||||||
for (auto const &page : m->GetBatches<S>()) {
|
for (auto const &page : m->GetBatches<S>(&ctx)) {
|
||||||
orig.Push(page);
|
orig.Push(page);
|
||||||
format->Write(page, fo.get());
|
format->Write(page, fo.get());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -388,6 +388,11 @@ inline Context CreateEmptyGenericParam(int gpu_id) {
|
|||||||
return tparam;
|
return tparam;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Make a context that uses CUDA.
|
||||||
|
*/
|
||||||
|
inline Context MakeCUDACtx(std::int32_t device) { return Context{}.MakeCUDA(device); }
|
||||||
|
|
||||||
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows,
|
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows,
|
||||||
float lower= 0.0f, float upper = 1.0f) {
|
float lower= 0.0f, float upper = 1.0f) {
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
|
|||||||
@ -203,7 +203,11 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
|
|||||||
learner->Save(&mem_out);
|
learner->Save(&mem_out);
|
||||||
ASSERT_EQ(model_at_kiter, serialised_model_tmp);
|
ASSERT_EQ(model_at_kiter, serialised_model_tmp);
|
||||||
|
|
||||||
learner->SetParam("gpu_id", "0");
|
for (auto const& [key, value] : args) {
|
||||||
|
if (key == "tree_method" && value == "gpu_hist") {
|
||||||
|
learner->SetParam("gpu_id", "0");
|
||||||
|
}
|
||||||
|
}
|
||||||
// Pull data to device
|
// Pull data to device
|
||||||
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||||
batch.data.SetDevice(0);
|
batch.data.SetDevice(0);
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2021 by XGBoost Contributors
|
* Copyright 2020-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../../src/data/ellpack_page.cuh"
|
#include "../../../../src/data/ellpack_page.cuh"
|
||||||
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -31,14 +32,15 @@ void VerifySampling(size_t page_size,
|
|||||||
}
|
}
|
||||||
gpair.SetDevice(0);
|
gpair.SetDevice(0);
|
||||||
|
|
||||||
BatchParam param{0, 256};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
if (page_size != 0) {
|
if (page_size != 0) {
|
||||||
EXPECT_NE(page->n_rows, kRows);
|
EXPECT_NE(page->n_rows, kRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler sampler(page, kRows, param, subsample, sampling_method);
|
GradientBasedSampler sampler(&ctx, page, kRows, param, subsample, sampling_method);
|
||||||
auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
|
|
||||||
if (fixed_size_sampling) {
|
if (fixed_size_sampling) {
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.sample_rows, kRows);
|
||||||
@ -86,12 +88,13 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
auto gpair = GenerateRandomGradients(kRows);
|
auto gpair = GenerateRandomGradients(kRows);
|
||||||
gpair.SetDevice(0);
|
gpair.SetDevice(0);
|
||||||
|
|
||||||
BatchParam param{0, 256};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_NE(page->n_rows, kRows);
|
EXPECT_NE(page->n_rows, kRows);
|
||||||
|
|
||||||
GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform);
|
GradientBasedSampler sampler(&ctx, page, kRows, param, kSubsample, TrainParam::kUniform);
|
||||||
auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
auto sampled_page = sample.page;
|
auto sampled_page = sample.page;
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.sample_rows, kRows);
|
||||||
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
||||||
@ -103,7 +106,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
ci(buffer.data(), sampled_page->NumSymbols());
|
ci(buffer.data(), sampled_page->NumSymbols());
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(param)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
auto page = batch.Impl();
|
auto page = batch.Impl();
|
||||||
std::vector<common::CompressedByteT> page_buffer(page->gidx_buffer.HostVector());
|
std::vector<common::CompressedByteT> page_buffer(page->gidx_buffer.HostVector());
|
||||||
common::CompressedIterator<common::CompressedByteT>
|
common::CompressedIterator<common::CompressedByteT>
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../../src/common/categorical.h"
|
#include "../../../../src/common/categorical.h"
|
||||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||||
|
#include "../../../../src/tree/param.h" // TrainParam
|
||||||
#include "../../categorical_helpers.h"
|
#include "../../categorical_helpers.h"
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
@ -11,15 +16,15 @@ namespace xgboost {
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||||
Context ctx = CreateEmptyGenericParam(0);
|
Context ctx = MakeCUDACtx(0);
|
||||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||||
|
|
||||||
float sparsity = is_dense ? 0.0f : 0.5f;
|
float sparsity = is_dense ? 0.0f : 0.5f;
|
||||||
auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix();
|
auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix();
|
||||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
|
|
||||||
tree::RowPartitioner row_partitioner(0, kRows);
|
tree::RowPartitioner row_partitioner(0, kRows);
|
||||||
@ -114,13 +119,13 @@ void ValidateCategoricalHistogram(size_t n_categories, common::Span<GradientPair
|
|||||||
|
|
||||||
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
||||||
void TestGPUHistogramCategorical(size_t num_categories) {
|
void TestGPUHistogramCategorical(size_t num_categories) {
|
||||||
auto ctx = CreateEmptyGenericParam(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
size_t constexpr kRows = 340;
|
size_t constexpr kRows = 340;
|
||||||
size_t constexpr kBins = 256;
|
size_t constexpr kBins = 256;
|
||||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
|
||||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
tree::RowPartitioner row_partitioner(0, kRows);
|
tree::RowPartitioner row_partitioner(0, kRows);
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||||
@ -130,7 +135,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
/**
|
/**
|
||||||
* Generate hist with cat data.
|
* Generate hist with cat data.
|
||||||
*/
|
*/
|
||||||
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
@ -144,7 +149,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
auto x_encoded = OneHotEncodeFeature(x, num_categories);
|
auto x_encoded = OneHotEncodeFeature(x, num_categories);
|
||||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
|
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
|
||||||
dh::device_vector<GradientPairInt64> encode_hist(2 * num_categories);
|
dh::device_vector<GradientPairInt64> encode_hist(2 * num_categories);
|
||||||
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
|
|||||||
@ -41,7 +41,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
|||||||
|
|
||||||
size_t constexpr kMaxBins = 4;
|
size_t constexpr kMaxBins = 4;
|
||||||
// dense, no missing values
|
// dense, no missing values
|
||||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, AllThreadsForTest());
|
GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false);
|
||||||
common::RowSetCollection row_set_collection;
|
common::RowSetCollection row_set_collection;
|
||||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||||
row_indices.resize(kRows);
|
row_indices.resize(kRows);
|
||||||
@ -228,7 +228,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
|||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|
||||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>(&ctx, {32, param.sparse_threshold})) {
|
||||||
common::HistCollection hist;
|
common::HistCollection hist;
|
||||||
|
|
||||||
entries.front().nid = 0;
|
entries.front().nid = 0;
|
||||||
|
|||||||
@ -25,6 +25,7 @@ void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples
|
|||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void TestAddHistRows(bool is_distributed) {
|
void TestAddHistRows(bool is_distributed) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||||
int starting_index = std::numeric_limits<int>::max();
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
@ -32,9 +33,9 @@ void TestAddHistRows(bool is_distributed) {
|
|||||||
|
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
auto p_fmat =
|
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
auto const &gmat =
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
@ -73,6 +74,7 @@ TEST(CPUHistogram, AddRows) {
|
|||||||
void TestSyncHist(bool is_distributed) {
|
void TestSyncHist(bool is_distributed) {
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||||
@ -80,9 +82,9 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
int sync_count = 0;
|
int sync_count = 0;
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
auto p_fmat =
|
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
auto const &gmat =
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram;
|
HistogramBuilder<CPUExpandEntry> histogram;
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
@ -227,12 +229,15 @@ TEST(CPUHistogram, SyncHist) {
|
|||||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
auto p_fmat =
|
||||||
|
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
if (is_col_split) {
|
if (is_col_split) {
|
||||||
p_fmat = std::shared_ptr<DMatrix>{
|
p_fmat = std::shared_ptr<DMatrix>{
|
||||||
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
}
|
}
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
auto const &gmat =
|
||||||
|
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
|
|
||||||
static double constexpr kEps = 1e-6;
|
static double constexpr kEps = 1e-6;
|
||||||
@ -257,9 +262,9 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
||||||
nodes_for_explicit_hist_build.push_back(node);
|
nodes_for_explicit_hist_build.push_back(node);
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>({kMaxBins, 0.5})) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {kMaxBins, 0.5})) {
|
||||||
histogram.BuildHist(0, gidx, &tree, row_set_collection,
|
histogram.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||||
nodes_for_explicit_hist_build, {}, gpair, force_read_by_column);
|
gpair, force_read_by_column);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
@ -325,6 +330,8 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
||||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
|
|
||||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
@ -345,12 +352,11 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
* Generate hist with cat data.
|
* Generate hist with cat data.
|
||||||
*/
|
*/
|
||||||
HistogramBuilder<CPUExpandEntry> cat_hist;
|
HistogramBuilder<CPUExpandEntry> cat_hist;
|
||||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
cat_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||||
nodes_for_explicit_hist_build, {}, gpair.HostVector(),
|
gpair.HostVector(), force_read_by_column);
|
||||||
force_read_by_column);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -359,12 +365,11 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
||||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
||||||
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
||||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(&ctx, {kBins, 0.5})) {
|
||||||
auto total_bins = gidx.cut.TotalBins();
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||||
gpair.HostVector(),
|
gpair.HostVector(), force_read_by_column);
|
||||||
force_read_by_column);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cat = cat_hist.Histogram()[0];
|
auto cat = cat_hist.Histogram()[0];
|
||||||
@ -382,8 +387,8 @@ TEST(CPUHistogram, Categorical) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
namespace {
|
namespace {
|
||||||
void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool force_read_by_column) {
|
void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, bool is_approx,
|
||||||
Context ctx;
|
bool force_read_by_column) {
|
||||||
size_t constexpr kEntries = 1 << 16;
|
size_t constexpr kEntries = 1 << 16;
|
||||||
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
||||||
|
|
||||||
@ -410,7 +415,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
* Multi page
|
* Multi page
|
||||||
*/
|
*/
|
||||||
std::vector<common::RowSetCollection> rows_set;
|
std::vector<common::RowSetCollection> rows_set;
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(ctx, batch_param)) {
|
||||||
CHECK_LT(page.base_rowid, m->Info().num_row_);
|
CHECK_LT(page.base_rowid, m->Info().num_row_);
|
||||||
auto n_rows_in_node = page.Size();
|
auto n_rows_in_node = page.Size();
|
||||||
partition_size[0] = std::max(partition_size[0], n_rows_in_node);
|
partition_size[0] = std::max(partition_size[0], n_rows_in_node);
|
||||||
@ -426,12 +431,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
||||||
256};
|
256};
|
||||||
|
|
||||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false);
|
multi_build.Reset(total_bins, batch_param, ctx->Threads(), rows_set.size(), false, false);
|
||||||
|
|
||||||
size_t page_idx{0};
|
size_t page_idx{0};
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(ctx, batch_param)) {
|
||||||
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {},
|
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {}, h_gpair,
|
||||||
h_gpair, force_read_by_column);
|
force_read_by_column);
|
||||||
++page_idx;
|
++page_idx;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(page_idx, 2);
|
ASSERT_EQ(page_idx, 2);
|
||||||
@ -447,16 +452,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
common::RowSetCollection row_set_collection;
|
common::RowSetCollection row_set_collection;
|
||||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||||
|
|
||||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false);
|
single_build.Reset(total_bins, batch_param, ctx->Threads(), 1, false, false);
|
||||||
SparsePage concat;
|
SparsePage concat;
|
||||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
concat.Push(page);
|
concat.Push(page);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, ctx.Threads(), false, hess);
|
auto cut = common::SketchOnDMatrix(ctx, m.get(), batch_param.max_bin, false, hess);
|
||||||
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
||||||
std::numeric_limits<double>::quiet_NaN(), ctx.Threads());
|
std::numeric_limits<double>::quiet_NaN(), ctx->Threads());
|
||||||
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column);
|
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column);
|
||||||
single_page = single_build.Histogram()[0];
|
single_page = single_build.Histogram()[0];
|
||||||
}
|
}
|
||||||
@ -470,16 +475,17 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
|
|
||||||
TEST(CPUHistogram, ExternalMemory) {
|
TEST(CPUHistogram, ExternalMemory) {
|
||||||
int32_t constexpr kBins = 256;
|
int32_t constexpr kBins = 256;
|
||||||
TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true, false);
|
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||||
TestHistogramExternalMemory(BatchParam{kBins, common::Span<float>{}, false}, true, true);
|
|
||||||
|
TestHistogramExternalMemory(&ctx, BatchParam{kBins, common::Span<float>{}, false}, true, false);
|
||||||
|
TestHistogramExternalMemory(&ctx, BatchParam{kBins, common::Span<float>{}, false}, true, true);
|
||||||
|
|
||||||
float sparse_thresh{0.5};
|
float sparse_thresh{0.5};
|
||||||
TestHistogramExternalMemory({kBins, sparse_thresh}, false, false);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, false);
|
||||||
TestHistogramExternalMemory({kBins, sparse_thresh}, false, true);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
||||||
sparse_thresh = std::numeric_limits<float>::quiet_NaN();
|
sparse_thresh = std::numeric_limits<float>::quiet_NaN();
|
||||||
TestHistogramExternalMemory({kBins, sparse_thresh}, false, false);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, false);
|
||||||
TestHistogramExternalMemory({kBins, sparse_thresh}, false, true);
|
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
||||||
|
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -34,7 +34,7 @@ TEST(Approx, Partitioner) {
|
|||||||
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
candidates.front().split.loss_chg = 0.4;
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, {64, hess, true})) {
|
||||||
bst_feature_t const split_ind = 0;
|
bst_feature_t const split_ind = 0;
|
||||||
{
|
{
|
||||||
auto min_value = page.cut.MinValues()[split_ind];
|
auto min_value = page.cut.MinValues()[split_ind];
|
||||||
@ -84,7 +84,7 @@ void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared
|
|||||||
|
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.InitAllowUnknown(Args{});
|
ctx.InitAllowUnknown(Args{});
|
||||||
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>({64, *hess, true})) {
|
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, *hess, true})) {
|
||||||
{
|
{
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, true};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, true};
|
||||||
@ -133,7 +133,7 @@ TEST(Approx, PartitionerColSplit) {
|
|||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.InitAllowUnknown(Args{});
|
ctx.InitAllowUnknown(Args{});
|
||||||
CommonRowPartitioner mid_partitioner{&ctx, n_samples, base_rowid, false};
|
CommonRowPartitioner mid_partitioner{&ctx, n_samples, base_rowid, false};
|
||||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, {64, hess, true})) {
|
||||||
bst_feature_t const split_ind = 0;
|
bst_feature_t const split_ind = 0;
|
||||||
min_value = page.cut.MinValues()[split_ind];
|
min_value = page.cut.MinValues()[split_ind];
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,7 @@ void TestLeafPartition(size_t n_samples) {
|
|||||||
|
|
||||||
std::vector<size_t> h_nptr;
|
std::vector<size_t> h_nptr;
|
||||||
float split_value{0};
|
float split_value{0};
|
||||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({Context::kCpuId, 64})) {
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
|
||||||
bst_feature_t const split_ind = 0;
|
bst_feature_t const split_ind = 0;
|
||||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||||
split_value = page.cut.Values().at(ptr / 2);
|
split_value = page.cut.Values().at(ptr / 2);
|
||||||
|
|||||||
@ -208,17 +208,16 @@ TEST(GpuHist, TestHistogramIndex) {
|
|||||||
TestHistogramIndexImpl();
|
TestHistogramIndexImpl();
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void UpdateTree(Context const* ctx, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
size_t gpu_page_size, RegTree* tree,
|
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
|
||||||
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,
|
float subsample = 1.0f, const std::string& sampling_method = "uniform",
|
||||||
const std::string& sampling_method = "uniform",
|
|
||||||
int max_bin = 2) {
|
int max_bin = 2) {
|
||||||
|
|
||||||
if (gpu_page_size > 0) {
|
if (gpu_page_size > 0) {
|
||||||
// Loop over the batches and count the records
|
// Loop over the batches and count the records
|
||||||
int64_t batch_count = 0;
|
int64_t batch_count = 0;
|
||||||
int64_t row_count = 0;
|
int64_t row_count = 0;
|
||||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin})) {
|
for (const auto& batch : dmat->GetBatches<EllpackPage>(
|
||||||
|
ctx, BatchParam{max_bin, TrainParam::DftSparseThreshold()})) {
|
||||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||||
batch_count++;
|
batch_count++;
|
||||||
row_count += batch.Size();
|
row_count += batch.Size();
|
||||||
@ -239,14 +238,13 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
|||||||
TrainParam param;
|
TrainParam param;
|
||||||
param.UpdateAllowUnknown(args);
|
param.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
Context ctx(CreateEmptyGenericParam(0));
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
tree::GPUHistMaker hist_maker{&ctx, &task};
|
tree::GPUHistMaker hist_maker{ctx, &task};
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||||
{tree});
|
{tree});
|
||||||
auto cache = linalg::MakeTensorView(&ctx, preds->DeviceSpan(), preds->Size(), 1);
|
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||||
hist_maker.UpdatePredictionCache(dmat, cache);
|
hist_maker.UpdatePredictionCache(dmat, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,12 +262,13 @@ TEST(GpuHist, UniformSampling) {
|
|||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||||
// Build another tree using sampling.
|
// Build another tree using sampling.
|
||||||
RegTree tree_sampling;
|
RegTree tree_sampling;
|
||||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||||
"uniform", kRows);
|
kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
auto preds_h = preds.ConstHostVector();
|
auto preds_h = preds.ConstHostVector();
|
||||||
@ -293,12 +292,13 @@ TEST(GpuHist, GradientBasedSampling) {
|
|||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||||
|
|
||||||
// Build another tree using sampling.
|
// Build another tree using sampling.
|
||||||
RegTree tree_sampling;
|
RegTree tree_sampling;
|
||||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||||
"gradient_based", kRows);
|
"gradient_based", kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
@ -327,12 +327,13 @@ TEST(GpuHist, ExternalMemory) {
|
|||||||
|
|
||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||||
// Build another tree using multiple ELLPACK pages.
|
// Build another tree using multiple ELLPACK pages.
|
||||||
RegTree tree_ext;
|
RegTree tree_ext;
|
||||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
auto preds_h = preds.ConstHostVector();
|
auto preds_h = preds.ConstHostVector();
|
||||||
@ -364,17 +365,17 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
|||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
auto rng = common::GlobalRandom();
|
auto rng = common::GlobalRandom();
|
||||||
|
|
||||||
|
Context ctx(CreateEmptyGenericParam(0));
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod,
|
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||||
kRows);
|
|
||||||
|
|
||||||
// Build another tree using multiple ELLPACK pages.
|
// Build another tree using multiple ELLPACK pages.
|
||||||
common::GlobalRandom() = rng;
|
common::GlobalRandom() = rng;
|
||||||
RegTree tree_ext;
|
RegTree tree_ext;
|
||||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext,
|
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample,
|
||||||
kSubsample, kSamplingMethod, kRows);
|
kSamplingMethod, kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
auto preds_h = preds.ConstHostVector();
|
auto preds_h = preds.ConstHostVector();
|
||||||
|
|||||||
@ -36,7 +36,7 @@ void TestPartitioner(bst_target_t n_targets) {
|
|||||||
std::vector<ExpandEntry> candidates{{0, 0}};
|
std::vector<ExpandEntry> candidates{{0, 0}};
|
||||||
candidates.front().split.loss_chg = 0.4;
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
|
auto cuts = common::SketchOnDMatrix(&ctx, Xy.get(), 64);
|
||||||
|
|
||||||
for (auto const& page : Xy->GetBatches<SparsePage>()) {
|
for (auto const& page : Xy->GetBatches<SparsePage>()) {
|
||||||
GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads());
|
GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads());
|
||||||
|
|||||||
@ -15,16 +15,17 @@ class DMatrixForTest : public data::SimpleDMatrix {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
using SimpleDMatrix::SimpleDMatrix;
|
using SimpleDMatrix::SimpleDMatrix;
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override {
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
|
||||||
|
const BatchParam& param) override {
|
||||||
auto backup = this->gradient_index_;
|
auto backup = this->gradient_index_;
|
||||||
auto iter = SimpleDMatrix::GetGradientIndex(param);
|
auto iter = SimpleDMatrix::GetGradientIndex(ctx, param);
|
||||||
n_regen_ += (backup != this->gradient_index_);
|
n_regen_ += (backup != this->gradient_index_);
|
||||||
return iter;
|
return iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, const BatchParam& param) override {
|
||||||
auto backup = this->ellpack_page_;
|
auto backup = this->ellpack_page_;
|
||||||
auto iter = SimpleDMatrix::GetEllpackBatches(param);
|
auto iter = SimpleDMatrix::GetEllpackBatches(ctx, param);
|
||||||
n_regen_ += (backup != this->ellpack_page_);
|
n_regen_ += (backup != this->ellpack_page_);
|
||||||
return iter;
|
return iter;
|
||||||
}
|
}
|
||||||
@ -50,8 +51,8 @@ class RegenTest : public ::testing::Test {
|
|||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
|
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
|
||||||
auto adapter = data::ArrayAdapter(StringView{dense});
|
auto adapter = data::ArrayAdapter(StringView{dense});
|
||||||
p_fmat_ = std::shared_ptr<DMatrix>(new DMatrixForTest{
|
p_fmat_ = std::shared_ptr<DMatrix>(
|
||||||
&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
|
new DMatrixForTest{&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
|
||||||
|
|
||||||
p_fmat_->Info().labels.Reshape(256, 1);
|
p_fmat_->Info().labels.Reshape(256, 1);
|
||||||
auto labels = p_fmat_->Info().labels.Data();
|
auto labels = p_fmat_->Info().labels.Data();
|
||||||
@ -74,7 +75,7 @@ class RegenTest : public ::testing::Test {
|
|||||||
auto for_test = dynamic_cast<DMatrixForTest*>(p_fmat_.get());
|
auto for_test = dynamic_cast<DMatrixForTest*>(p_fmat_.get());
|
||||||
CHECK(for_test);
|
CHECK(for_test);
|
||||||
auto backup = for_test->NumRegen();
|
auto backup = for_test->NumRegen();
|
||||||
for_test->GetBatches<Page>(BatchParam{});
|
for_test->GetBatches<Page>(p_fmat_->Ctx(), BatchParam{});
|
||||||
CHECK_EQ(for_test->NumRegen(), backup);
|
CHECK_EQ(for_test->NumRegen(), backup);
|
||||||
|
|
||||||
if (reset) {
|
if (reset) {
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class TestQuantileDMatrix:
|
|||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dmatrix_feature_weights(self) -> None:
|
def test_dmatrix_feature_weights(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
rng = cp.random.RandomState(1994)
|
rng = cp.random.RandomState(1994)
|
||||||
data = rng.randn(5, 5)
|
data = rng.randn(5, 5)
|
||||||
m = xgb.DMatrix(data)
|
m = xgb.DMatrix(data)
|
||||||
@ -26,23 +27,91 @@ class TestQuantileDMatrix:
|
|||||||
m.set_info(feature_weights=feature_weights)
|
m.set_info(feature_weights=feature_weights)
|
||||||
|
|
||||||
cp.testing.assert_array_equal(
|
cp.testing.assert_array_equal(
|
||||||
cp.array(m.get_float_info('feature_weights')),
|
cp.array(m.get_float_info("feature_weights")),
|
||||||
feature_weights.astype(np.float32))
|
feature_weights.astype(np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dmatrix_cupy_init(self) -> None:
|
def test_dmatrix_cupy_init(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
data = cp.random.randn(5, 5)
|
data = cp.random.randn(5, 5)
|
||||||
xgb.QuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
xgb.QuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"on_device,tree_method",
|
||||||
|
[(True, "hist"), (False, "gpu_hist"), (False, "hist"), (True, "gpu_hist")],
|
||||||
|
)
|
||||||
|
def test_initialization(self, on_device: bool, tree_method: str) -> None:
|
||||||
|
n_samples, n_features, max_bin = 64, 3, 16
|
||||||
|
X, y, w = tm.make_batches(
|
||||||
|
n_samples,
|
||||||
|
n_features=n_features,
|
||||||
|
n_batches=1,
|
||||||
|
use_cupy=on_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init SparsePage
|
||||||
|
Xy = xgb.DMatrix(X[0], y[0], weight=w[0])
|
||||||
|
# Init GIDX/Ellpack
|
||||||
|
xgb.train(
|
||||||
|
{"tree_method": tree_method, "max_bin": max_bin},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
)
|
||||||
|
# query cuts from GIDX/Ellpack
|
||||||
|
qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
|
||||||
|
tm.predictor_equal(Xy, qXy)
|
||||||
|
with pytest.raises(ValueError, match="Inconsistent"):
|
||||||
|
# max_bin changed.
|
||||||
|
xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin - 1, ref=Xy)
|
||||||
|
|
||||||
|
# No error, DMatrix can be modified for different training session.
|
||||||
|
xgb.train(
|
||||||
|
{"tree_method": tree_method, "max_bin": max_bin - 1},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init Ellpack/GIDX
|
||||||
|
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
|
||||||
|
# Init GIDX/Ellpack
|
||||||
|
xgb.train(
|
||||||
|
{"tree_method": tree_method, "max_bin": max_bin},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
)
|
||||||
|
# query cuts from GIDX/Ellpack
|
||||||
|
qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
|
||||||
|
tm.predictor_equal(Xy, qXy)
|
||||||
|
with pytest.raises(ValueError, match="Inconsistent"):
|
||||||
|
# max_bin changed.
|
||||||
|
xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin - 1, ref=Xy)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X[0], y[0], weight=w[0])
|
||||||
|
booster0 = xgb.train(
|
||||||
|
{"tree_method": "hist", "max_bin": max_bin, "max_depth": 4},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
)
|
||||||
|
booster1 = xgb.train(
|
||||||
|
{"tree_method": "gpu_hist", "max_bin": max_bin, "max_depth": 4},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
)
|
||||||
|
qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
|
||||||
|
predt0 = booster0.predict(qXy)
|
||||||
|
predt1 = booster1.predict(qXy)
|
||||||
|
np.testing.assert_allclose(predt0, predt1)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tree_method,max_bin", [
|
"tree_method,max_bin",
|
||||||
("hist", 16), ("gpu_hist", 16), ("hist", 64), ("gpu_hist", 64)
|
[("hist", 16), ("gpu_hist", 16), ("hist", 64), ("gpu_hist", 64)],
|
||||||
]
|
|
||||||
)
|
)
|
||||||
def test_interoperability(self, tree_method: str, max_bin: int) -> None:
|
def test_interoperability(self, tree_method: str, max_bin: int) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
n_samples = 64
|
n_samples = 64
|
||||||
n_features = 3
|
n_features = 3
|
||||||
X, y, w = tm.make_batches(
|
X, y, w = tm.make_batches(
|
||||||
@ -75,6 +144,7 @@ class TestQuantileDMatrix:
|
|||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_metainfo(self) -> None:
|
def test_metainfo(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
rng = cp.random.RandomState(1994)
|
rng = cp.random.RandomState(1994)
|
||||||
|
|
||||||
rows = 10
|
rows = 10
|
||||||
@ -98,6 +168,7 @@ class TestQuantileDMatrix:
|
|||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_ref_dmatrix(self) -> None:
|
def test_ref_dmatrix(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
rng = cp.random.RandomState(1994)
|
rng = cp.random.RandomState(1994)
|
||||||
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
||||||
|
|
||||||
@ -158,5 +229,6 @@ class TestQuantileDMatrix:
|
|||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_check_inf(self) -> None:
|
def test_check_inf(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|
||||||
rng = cp.random.default_rng(1994)
|
rng = cp.random.default_rng(1994)
|
||||||
check_inf(rng)
|
check_inf(rng)
|
||||||
|
|||||||
@ -153,12 +153,18 @@ class TestGPUUpdaters:
|
|||||||
tm.dataset_strategy
|
tm.dataset_strategy
|
||||||
)
|
)
|
||||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||||
def test_gpu_hist_device_dmatrix(self, param, num_rounds, dataset):
|
def test_gpu_hist_device_dmatrix(
|
||||||
|
self, param: dict, num_rounds: int, dataset: tm.TestDataset
|
||||||
|
) -> None:
|
||||||
# We cannot handle empty dataset yet
|
# We cannot handle empty dataset yet
|
||||||
assume(len(dataset.y) > 0)
|
assume(len(dataset.y) > 0)
|
||||||
param['tree_method'] = 'gpu_hist'
|
param['tree_method'] = 'gpu_hist'
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
result = train_result(param, dataset.get_device_dmat(), num_rounds)
|
result = train_result(
|
||||||
|
param,
|
||||||
|
dataset.get_device_dmat(max_bin=param.get("max_bin", None)),
|
||||||
|
num_rounds
|
||||||
|
)
|
||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3)
|
assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user