[EM] Allow staging ellpack on host for GPU external memory. (#10488)
- New parameter `on_host`. - Abstract format creation and stream creation into policy classes.
This commit is contained in:
parent
824fba783e
commit
e8a962575a
@ -72,6 +72,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
||||||
$(PKGROOT)/src/data/gradient_index_format.o \
|
$(PKGROOT)/src/data/gradient_index_format.o \
|
||||||
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
|
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
|
||||||
|
$(PKGROOT)/src/data/sparse_page_source.o \
|
||||||
$(PKGROOT)/src/data/proxy_dmatrix.o \
|
$(PKGROOT)/src/data/proxy_dmatrix.o \
|
||||||
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
||||||
$(PKGROOT)/src/predictor/predictor.o \
|
$(PKGROOT)/src/predictor/predictor.o \
|
||||||
|
|||||||
@ -72,6 +72,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
$(PKGROOT)/src/data/gradient_index_page_source.o \
|
||||||
$(PKGROOT)/src/data/gradient_index_format.o \
|
$(PKGROOT)/src/data/gradient_index_format.o \
|
||||||
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
|
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
|
||||||
|
$(PKGROOT)/src/data/sparse_page_source.o \
|
||||||
$(PKGROOT)/src/data/proxy_dmatrix.o \
|
$(PKGROOT)/src/data/proxy_dmatrix.o \
|
||||||
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
||||||
$(PKGROOT)/src/predictor/predictor.o \
|
$(PKGROOT)/src/predictor/predictor.o \
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class MetaInfo {
|
|||||||
static constexpr uint64_t kNumField = 12;
|
static constexpr uint64_t kNumField = 12;
|
||||||
|
|
||||||
/*! \brief number of rows in the data */
|
/*! \brief number of rows in the data */
|
||||||
uint64_t num_row_{0}; // NOLINT
|
bst_idx_t num_row_{0}; // NOLINT
|
||||||
/*! \brief number of columns in the data */
|
/*! \brief number of columns in the data */
|
||||||
uint64_t num_col_{0}; // NOLINT
|
uint64_t num_col_{0}; // NOLINT
|
||||||
/*! \brief number of nonzero entries in the data */
|
/*! \brief number of nonzero entries in the data */
|
||||||
@ -535,10 +535,11 @@ class DMatrix {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
[[nodiscard]] bool PageExists() const;
|
[[nodiscard]] bool PageExists() const;
|
||||||
|
|
||||||
// the following are column meta data, should be able to answer them fast.
|
/**
|
||||||
/*! \return Whether the data columns single column block. */
|
* @return Whether the data columns single column block.
|
||||||
|
*/
|
||||||
[[nodiscard]] virtual bool SingleColBlock() const = 0;
|
[[nodiscard]] virtual bool SingleColBlock() const = 0;
|
||||||
/*! \brief virtual destructor */
|
|
||||||
virtual ~DMatrix();
|
virtual ~DMatrix();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -600,34 +601,34 @@ class DMatrix {
|
|||||||
int nthread, bst_bin_t max_bin);
|
int nthread, bst_bin_t max_bin);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Create an external memory DMatrix with callbacks.
|
* @brief Create an external memory DMatrix with callbacks.
|
||||||
*
|
*
|
||||||
* \tparam DataIterHandle External iterator type, defined in C API.
|
* @tparam DataIterHandle External iterator type, defined in C API.
|
||||||
* \tparam DMatrixHandle DMatrix handle, defined in C API.
|
* @tparam DMatrixHandle DMatrix handle, defined in C API.
|
||||||
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
|
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
|
||||||
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
|
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
|
||||||
*
|
*
|
||||||
* \param iter External data iterator
|
* @param iter External data iterator
|
||||||
* \param proxy A hanlde to ProxyDMatrix
|
* @param proxy A hanlde to ProxyDMatrix
|
||||||
* \param reset Callback for reset
|
* @param reset Callback for reset
|
||||||
* \param next Callback for next
|
* @param next Callback for next
|
||||||
* \param missing Value that should be treated as missing.
|
* @param missing Value that should be treated as missing.
|
||||||
* \param nthread number of threads used for initialization.
|
* @param nthread number of threads used for initialization.
|
||||||
* \param cache Prefix of cache file path.
|
* @param cache Prefix of cache file path.
|
||||||
|
* @param on_host Used for GPU, whether the data should be cached on host memory.
|
||||||
*
|
*
|
||||||
* \return A created external memory DMatrix.
|
* @return A created external memory DMatrix.
|
||||||
*/
|
*/
|
||||||
template <typename DataIterHandle, typename DMatrixHandle,
|
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||||
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
|
typename XGDMatrixCallbackNext>
|
||||||
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
|
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
|
||||||
DataIterResetCallback *reset,
|
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
std::string cache, bool on_host);
|
||||||
int32_t nthread, std::string cache);
|
|
||||||
|
|
||||||
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Slice a DMatrix by columns.
|
* @brief Slice a DMatrix by columns.
|
||||||
*
|
*
|
||||||
* @param num_slices Total number of slices
|
* @param num_slices Total number of slices
|
||||||
* @param slice_id Index of the current slice
|
* @param slice_id Index of the current slice
|
||||||
|
|||||||
@ -503,18 +503,29 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
----------
|
----------
|
||||||
cache_prefix :
|
cache_prefix :
|
||||||
Prefix to the cache files, only used in external memory.
|
Prefix to the cache files, only used in external memory.
|
||||||
|
|
||||||
release_data :
|
release_data :
|
||||||
Whether the iterator should release the data during iteration. Set it to True if
|
Whether the iterator should release the data during iteration. Set it to True if
|
||||||
the data transformation (converting data to np.float32 type) is memory
|
the data transformation (converting data to np.float32 type) is memory
|
||||||
intensive. Otherwise, if the transformation is computation intensive then we can
|
intensive. Otherwise, if the transformation is computation intensive then we can
|
||||||
keep the cache.
|
keep the cache.
|
||||||
|
|
||||||
|
on_host :
|
||||||
|
Whether the data should be cached on host memory instead of harddrive when using
|
||||||
|
GPU with external memory. If set to true, then the "external memory" would
|
||||||
|
simply be CPU (host) memory. This is still working in progress, not ready for
|
||||||
|
test yet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cache_prefix: Optional[str] = None, release_data: bool = True
|
self,
|
||||||
|
cache_prefix: Optional[str] = None,
|
||||||
|
release_data: bool = True,
|
||||||
|
on_host: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cache_prefix = cache_prefix
|
self.cache_prefix = cache_prefix
|
||||||
|
self.on_host = on_host
|
||||||
|
|
||||||
self._handle = _ProxyDMatrix()
|
self._handle = _ProxyDMatrix()
|
||||||
self._exception: Optional[Exception] = None
|
self._exception: Optional[Exception] = None
|
||||||
@ -905,12 +916,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
|||||||
|
|
||||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
||||||
it = iterator
|
it = iterator
|
||||||
args = {
|
args = make_jcargs(
|
||||||
"missing": self.missing,
|
missing=self.missing,
|
||||||
"nthread": self.nthread,
|
nthread=self.nthread,
|
||||||
"cache_prefix": it.cache_prefix if it.cache_prefix else "",
|
cache_prefix=it.cache_prefix if it.cache_prefix else "",
|
||||||
}
|
on_host=it.on_host,
|
||||||
args_cstr = from_pystr_to_cstr(json.dumps(args))
|
)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
reset_callback, next_callback = it.get_callbacks(enable_categorical)
|
reset_callback, next_callback = it.get_callbacks(enable_categorical)
|
||||||
ret = _LIB.XGDMatrixCreateFromCallback(
|
ret = _LIB.XGDMatrixCreateFromCallback(
|
||||||
@ -918,7 +929,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
|||||||
it.proxy.handle,
|
it.proxy.handle,
|
||||||
reset_callback,
|
reset_callback,
|
||||||
next_callback,
|
next_callback,
|
||||||
args_cstr,
|
args,
|
||||||
ctypes.byref(handle),
|
ctypes.byref(handle),
|
||||||
)
|
)
|
||||||
it.reraise()
|
it.reraise()
|
||||||
|
|||||||
@ -198,19 +198,20 @@ def skip_win() -> PytestSkip:
|
|||||||
class IteratorForTest(xgb.core.DataIter):
|
class IteratorForTest(xgb.core.DataIter):
|
||||||
"""Iterator for testing streaming DMatrix. (external memory, quantile)"""
|
"""Iterator for testing streaming DMatrix. (external memory, quantile)"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=too-many-arguments
|
||||||
self,
|
self,
|
||||||
X: Sequence,
|
X: Sequence,
|
||||||
y: Sequence,
|
y: Sequence,
|
||||||
w: Optional[Sequence],
|
w: Optional[Sequence],
|
||||||
cache: Optional[str],
|
cache: Optional[str],
|
||||||
|
on_host: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(X) == len(y)
|
assert len(X) == len(y)
|
||||||
self.X = X
|
self.X = X
|
||||||
self.y = y
|
self.y = y
|
||||||
self.w = w
|
self.w = w
|
||||||
self.it = 0
|
self.it = 0
|
||||||
super().__init__(cache_prefix=cache)
|
super().__init__(cache_prefix=cache, on_host=on_host)
|
||||||
|
|
||||||
def next(self, input_data: Callable) -> int:
|
def next(self, input_data: Callable) -> int:
|
||||||
if self.it == len(self.X):
|
if self.it == len(self.X):
|
||||||
@ -367,7 +368,11 @@ class TestDataset:
|
|||||||
weight.append(w)
|
weight.append(w)
|
||||||
|
|
||||||
it = IteratorForTest(
|
it = IteratorForTest(
|
||||||
predictor, response, weight if weight else None, cache="cache"
|
predictor,
|
||||||
|
response,
|
||||||
|
weight if weight else None,
|
||||||
|
cache="cache",
|
||||||
|
on_host=False,
|
||||||
)
|
)
|
||||||
return xgb.DMatrix(it)
|
return xgb.DMatrix(it)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def run_mixed_sparsity(device: str) -> None:
|
|||||||
|
|
||||||
X = [cp.array(batch) for batch in X]
|
X = [cp.array(batch) for batch in X]
|
||||||
|
|
||||||
it = tm.IteratorForTest(X, y, None, None)
|
it = tm.IteratorForTest(X, y, None, None, on_host=False)
|
||||||
Xy_0 = xgboost.QuantileDMatrix(it)
|
Xy_0 = xgboost.QuantileDMatrix(it)
|
||||||
|
|
||||||
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
|
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
|
||||||
|
|||||||
@ -207,6 +207,7 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
|
|||||||
it = tm.IteratorForTest(
|
it = tm.IteratorForTest(
|
||||||
*tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
*tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
||||||
cache="cache",
|
cache="cache",
|
||||||
|
on_host=False,
|
||||||
)
|
)
|
||||||
Xy: xgb.DMatrix = xgb.DMatrix(it)
|
Xy: xgb.DMatrix = xgb.DMatrix(it)
|
||||||
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
|
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
|
||||||
|
|||||||
@ -298,13 +298,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
|||||||
auto missing = GetMissing(jconfig);
|
auto missing = GetMissing(jconfig);
|
||||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||||
|
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
|
||||||
|
|
||||||
xgboost_CHECK_C_ARG_PTR(next);
|
xgboost_CHECK_C_ARG_PTR(next);
|
||||||
xgboost_CHECK_C_ARG_PTR(reset);
|
xgboost_CHECK_C_ARG_PTR(reset);
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
|
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)};
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -429,7 +429,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
}
|
}
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
XGBDefaultDeviceAllocatorImpl()
|
XGBDefaultDeviceAllocatorImpl()
|
||||||
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {}
|
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()) {}
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -484,7 +484,7 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
}
|
}
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
XGBCachingDeviceAllocatorImpl()
|
XGBCachingDeviceAllocatorImpl()
|
||||||
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()),
|
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()),
|
||||||
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
|
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
|
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
#define XGBOOST_COMMON_ERROR_MSG_H_
|
#define XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|
||||||
#include <cinttypes> // for uint64_t
|
#include <cstdint> // for uint64_t
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
|
||||||
@ -103,5 +103,11 @@ inline auto NoFederated() { return "XGBoost is not compiled with federated learn
|
|||||||
inline auto NoCategorical(std::string name) {
|
inline auto NoCategorical(std::string name) {
|
||||||
return name + " doesn't support categorical features.";
|
return name + " doesn't support categorical features.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void NoOnHost(bool on_host) {
|
||||||
|
if (on_host) {
|
||||||
|
LOG(FATAL) << "Caching on host memory is only available for GPU.";
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
@ -163,7 +163,7 @@ class HistogramCuts {
|
|||||||
return vals[bin_idx - 1];
|
return vals[bin_idx - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDevice(DeviceOrd d) const {
|
void SetDevice(DeviceOrd d) {
|
||||||
this->cut_ptrs_.SetDevice(d);
|
this->cut_ptrs_.SetDevice(d);
|
||||||
this->cut_ptrs_.ConstDevicePointer();
|
this->cut_ptrs_.ConstDevicePointer();
|
||||||
|
|
||||||
|
|||||||
@ -901,15 +901,12 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_p
|
|||||||
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
|
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataIterHandle, typename DMatrixHandle,
|
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||||
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
|
typename XGDMatrixCallbackNext>
|
||||||
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
|
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
|
||||||
DataIterResetCallback *reset,
|
XGDMatrixCallbackNext* next, float missing, int32_t n_threads,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
std::string cache, bool on_host) {
|
||||||
int32_t n_threads,
|
return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host};
|
||||||
std::string cache) {
|
|
||||||
return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, n_threads,
|
|
||||||
cache);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
||||||
@ -919,10 +916,11 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
|
|||||||
XGDMatrixCallbackNext* next, float missing,
|
XGDMatrixCallbackNext* next, float missing,
|
||||||
int nthread, int max_bin);
|
int nthread, int max_bin);
|
||||||
|
|
||||||
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
|
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
||||||
DataIterResetCallback, XGDMatrixCallbackNext>(
|
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
DataIterResetCallback* reset,
|
||||||
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
|
XGDMatrixCallbackNext* next, float missing,
|
||||||
|
int32_t n_threads, std::string, bool);
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
||||||
|
|||||||
@ -36,7 +36,7 @@ void EllpackPage::SetBaseRowId(std::size_t) {
|
|||||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||||
"EllpackPage is required";
|
"EllpackPage is required";
|
||||||
}
|
}
|
||||||
size_t EllpackPage::Size() const {
|
bst_idx_t EllpackPage::Size() const {
|
||||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||||
"EllpackPage is required";
|
"EllpackPage is required";
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@ -29,7 +29,7 @@ EllpackPage::~EllpackPage() = default;
|
|||||||
|
|
||||||
EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }
|
EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }
|
||||||
|
|
||||||
size_t EllpackPage::Size() const { return impl_->Size(); }
|
[[nodiscard]] bst_idx_t EllpackPage::Size() const { return impl_->Size(); }
|
||||||
|
|
||||||
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
|
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
|
||||||
|
|
||||||
@ -91,13 +91,13 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
// Construct an ELLPACK matrix with the given number of empty rows.
|
// Construct an ELLPACK matrix with the given number of empty rows.
|
||||||
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
|
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
||||||
size_t row_stride, size_t n_rows)
|
bst_idx_t row_stride, bst_idx_t n_rows)
|
||||||
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) {
|
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} {
|
||||||
monitor_.Init("ellpack_page");
|
monitor_.Init("ellpack_page");
|
||||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
|
|
||||||
monitor_.Start("InitCompressedData");
|
monitor_.Start("InitCompressedData");
|
||||||
InitCompressedData(device);
|
this->InitCompressedData(device);
|
||||||
monitor_.Stop("InitCompressedData");
|
monitor_.Stop("InitCompressedData");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,7 +403,7 @@ struct CopyPage {
|
|||||||
// Copy the data from the given EllpackPage to the current page.
|
// Copy the data from the given EllpackPage to the current page.
|
||||||
size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) {
|
size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) {
|
||||||
monitor_.Start("Copy");
|
monitor_.Start("Copy");
|
||||||
size_t num_elements = page->n_rows * page->row_stride;
|
bst_idx_t num_elements = page->n_rows * page->row_stride;
|
||||||
CHECK_EQ(row_stride, page->row_stride);
|
CHECK_EQ(row_stride, page->row_stride);
|
||||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||||
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
||||||
@ -461,16 +461,17 @@ struct CompactPage {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Compacts the data from the given EllpackPage into the current page.
|
// Compacts the data from the given EllpackPage into the current page.
|
||||||
void EllpackPageImpl::Compact(DeviceOrd device, EllpackPageImpl const* page,
|
void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page,
|
||||||
common::Span<size_t> row_indexes) {
|
common::Span<size_t> row_indexes) {
|
||||||
monitor_.Start("Compact");
|
monitor_.Start(__func__);
|
||||||
CHECK_EQ(row_stride, page->row_stride);
|
CHECK_EQ(row_stride, page->row_stride);
|
||||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||||
CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size());
|
CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size());
|
||||||
gidx_buffer.SetDevice(device);
|
gidx_buffer.SetDevice(ctx->Device());
|
||||||
page->gidx_buffer.SetDevice(device);
|
page->gidx_buffer.SetDevice(ctx->Device());
|
||||||
dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes));
|
auto cuctx = ctx->CUDACtx();
|
||||||
monitor_.Stop("Compact");
|
dh::LaunchN(page->n_rows, cuctx->Stream(), CompactPage(this, page, row_indexes));
|
||||||
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the buffer to stored compressed features.
|
// Initialize the buffer to stored compressed features.
|
||||||
@ -551,7 +552,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the number of rows contained in this page.
|
// Return the number of rows contained in this page.
|
||||||
size_t EllpackPageImpl::Size() const { return n_rows; }
|
[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; }
|
||||||
|
|
||||||
// Return the memory cost for storing the compressed features.
|
// Return the memory cost for storing the compressed features.
|
||||||
size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class EllpackPageImpl {
|
|||||||
* and the given number of rows.
|
* and the given number of rows.
|
||||||
*/
|
*/
|
||||||
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
|
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
|
||||||
bool is_dense, size_t row_stride, size_t n_rows);
|
bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows);
|
||||||
/*!
|
/*!
|
||||||
* \brief Constructor used for external memory.
|
* \brief Constructor used for external memory.
|
||||||
*/
|
*/
|
||||||
@ -181,14 +181,14 @@ class EllpackPageImpl {
|
|||||||
|
|
||||||
/*! \brief Compact the given ELLPACK page into the current page.
|
/*! \brief Compact the given ELLPACK page into the current page.
|
||||||
*
|
*
|
||||||
* @param device The GPU device to use.
|
* @param context The GPU context.
|
||||||
* @param page The ELLPACK page to compact from.
|
* @param page The ELLPACK page to compact from.
|
||||||
* @param row_indexes Row indexes for the compacted page.
|
* @param row_indexes Row indexes for the compacted page.
|
||||||
*/
|
*/
|
||||||
void Compact(DeviceOrd device, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
void Compact(Context const* ctx, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
||||||
|
|
||||||
/*! \return Number of instances in the page. */
|
/*! \return Number of instances in the page. */
|
||||||
[[nodiscard]] size_t Size() const;
|
[[nodiscard]] bst_idx_t Size() const;
|
||||||
|
|
||||||
/*! \brief Set the base row id for this page. */
|
/*! \brief Set the base row id for this page. */
|
||||||
void SetBaseRowId(std::size_t row_id) {
|
void SetBaseRowId(std::size_t row_id) {
|
||||||
@ -231,7 +231,7 @@ class EllpackPageImpl {
|
|||||||
/*! \brief Whether or not if the matrix is dense. */
|
/*! \brief Whether or not if the matrix is dense. */
|
||||||
bool is_dense;
|
bool is_dense;
|
||||||
/*! \brief Row length for ELLPACK. */
|
/*! \brief Row length for ELLPACK. */
|
||||||
size_t row_stride;
|
bst_idx_t row_stride;
|
||||||
bst_idx_t base_rowid{0};
|
bst_idx_t base_rowid{0};
|
||||||
bst_idx_t n_rows{};
|
bst_idx_t n_rows{};
|
||||||
/*! \brief global index of histogram, which is stored in ELLPACK format. */
|
/*! \brief global index of histogram, which is stored in ELLPACK format. */
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class EllpackPage {
|
|||||||
EllpackPage(EllpackPage&& that);
|
EllpackPage(EllpackPage&& that);
|
||||||
|
|
||||||
/*! \return Number of instances in the page. */
|
/*! \return Number of instances in the page. */
|
||||||
[[nodiscard]] size_t Size() const;
|
[[nodiscard]] bst_idx_t Size() const;
|
||||||
|
|
||||||
/*! \brief Set the base row id for this page. */
|
/*! \brief Set the base row id for this page. */
|
||||||
void SetBaseRowId(std::size_t row_id);
|
void SetBaseRowId(std::size_t row_id);
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
||||||
#include "ellpack_page.cuh" // for EllpackPage
|
#include "ellpack_page.cuh" // for EllpackPage
|
||||||
#include "ellpack_page_raw_format.h"
|
#include "ellpack_page_raw_format.h"
|
||||||
|
#include "ellpack_page_source.h"
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
||||||
@ -32,7 +33,6 @@ template <typename T>
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
vec->SetDevice(DeviceOrd::CUDA(0));
|
|
||||||
vec->Resize(n);
|
vec->Resize(n);
|
||||||
auto d_vec = vec->DeviceSpan();
|
auto d_vec = vec->DeviceSpan();
|
||||||
dh::safe_cuda(
|
dh::safe_cuda(
|
||||||
@ -54,6 +54,7 @@ template <typename T>
|
|||||||
if (!fi->Read(&impl->row_stride)) {
|
if (!fi->Read(&impl->row_stride)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
impl->gidx_buffer.SetDevice(device_);
|
||||||
if (!ReadDeviceVec(fi, &impl->gidx_buffer)) {
|
if (!ReadDeviceVec(fi, &impl->gidx_buffer)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -73,6 +74,65 @@ template <typename T>
|
|||||||
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
|
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
|
||||||
bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector());
|
bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector());
|
||||||
bytes += fo->Write(impl->base_rowid);
|
bytes += fo->Write(impl->base_rowid);
|
||||||
|
dh::DefaultStream().Sync();
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const {
|
||||||
|
auto* impl = page->Impl();
|
||||||
|
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
||||||
|
impl->SetCuts(this->cuts_);
|
||||||
|
if (!fi->Read(&impl->n_rows)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!fi->Read(&impl->is_dense)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!fi->Read(&impl->row_stride)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read vec
|
||||||
|
bst_idx_t n{0};
|
||||||
|
if (!fi->Read(&n)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (n != 0) {
|
||||||
|
impl->gidx_buffer.SetDevice(device_);
|
||||||
|
impl->gidx_buffer.Resize(n);
|
||||||
|
auto span = impl->gidx_buffer.DeviceSpan();
|
||||||
|
if (!fi->Read(span.data(), span.size_bytes())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fi->Read(&impl->base_rowid)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::DefaultStream().Sync();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
|
||||||
|
EllpackHostCacheStream* fo) const {
|
||||||
|
bst_idx_t bytes{0};
|
||||||
|
auto* impl = page.Impl();
|
||||||
|
bytes += fo->Write(impl->n_rows);
|
||||||
|
bytes += fo->Write(impl->is_dense);
|
||||||
|
bytes += fo->Write(impl->row_stride);
|
||||||
|
|
||||||
|
// Write vector
|
||||||
|
bst_idx_t n = impl->gidx_buffer.Size();
|
||||||
|
bytes += fo->Write(n);
|
||||||
|
|
||||||
|
if (!impl->gidx_buffer.Empty()) {
|
||||||
|
auto span = impl->gidx_buffer.ConstDeviceSpan();
|
||||||
|
bytes += fo->Write(span.data(), span.size_bytes());
|
||||||
|
}
|
||||||
|
bytes += fo->Write(impl->base_rowid);
|
||||||
|
|
||||||
|
dh::DefaultStream().Sync();
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -20,15 +20,22 @@ class HistogramCuts;
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
|
|
||||||
|
class EllpackHostCacheStream;
|
||||||
|
|
||||||
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
std::shared_ptr<common::HistogramCuts const> cuts_;
|
||||||
|
DeviceOrd device_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts)
|
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device)
|
||||||
: cuts_{std::move(cuts)} {}
|
: cuts_{std::move(cuts)}, device_{device} {}
|
||||||
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
|
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
|
||||||
[[nodiscard]] std::size_t Write(const EllpackPage& page,
|
[[nodiscard]] std::size_t Write(const EllpackPage& page,
|
||||||
common::AlignedFileWriteStream* fo) override;
|
common::AlignedFileWriteStream* fo) override;
|
||||||
|
|
||||||
|
[[nodiscard]] bool Read(EllpackPage* page, EllpackHostCacheStream* fi) const;
|
||||||
|
[[nodiscard]] std::size_t Write(const EllpackPage& page, EllpackHostCacheStream* fo) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -1,29 +1,161 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2024, XGBoost contributors
|
* Copyright 2019-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <memory>
|
#include <thrust/host_vector.h> // for host_vector
|
||||||
|
|
||||||
#include "ellpack_page.cuh"
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int8_t, uint64_t, uint32_t
|
||||||
|
#include <memory> // for shared_ptr, make_unique, make_shared
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
|
#include "../common/common.h" // for safe_cuda
|
||||||
|
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
|
||||||
|
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
|
||||||
|
#include "ellpack_page.cuh" // for EllpackPageImpl
|
||||||
#include "ellpack_page.h" // for EllpackPage
|
#include "ellpack_page.h" // for EllpackPage
|
||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h"
|
||||||
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
void EllpackPageSource::Fetch() {
|
struct EllpackHostCache {
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
thrust::host_vector<std::int8_t, common::cuda::pinned_allocator<std::int8_t>> cache;
|
||||||
|
|
||||||
|
void Resize(std::size_t n, dh::CUDAStreamView stream) {
|
||||||
|
stream.Sync(); // Prevent partial copy inside resize.
|
||||||
|
cache.resize(n);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class EllpackHostCacheStreamImpl {
|
||||||
|
std::shared_ptr<EllpackHostCache> cache_;
|
||||||
|
bst_idx_t cur_ptr_{0};
|
||||||
|
bst_idx_t bound_{0};
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit EllpackHostCacheStreamImpl(std::shared_ptr<EllpackHostCache> cache)
|
||||||
|
: cache_{std::move(cache)} {}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes) {
|
||||||
|
auto n = cur_ptr_ + n_bytes;
|
||||||
|
if (n > cache_->cache.size()) {
|
||||||
|
cache_->Resize(n, dh::DefaultStream());
|
||||||
|
}
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(cache_->cache.data() + cur_ptr_, ptr, n_bytes, cudaMemcpyDefault,
|
||||||
|
dh::DefaultStream()));
|
||||||
|
cur_ptr_ = n;
|
||||||
|
return n_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes) {
|
||||||
|
CHECK_LE(cur_ptr_ + n_bytes, bound_);
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(ptr, cache_->cache.data() + cur_ptr_, n_bytes, cudaMemcpyDefault,
|
||||||
|
dh::DefaultStream()));
|
||||||
|
cur_ptr_ += n_bytes;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t Tell() const { return cur_ptr_; }
|
||||||
|
void Seek(bst_idx_t offset_bytes) { cur_ptr_ = offset_bytes; }
|
||||||
|
void Bound(bst_idx_t offset_bytes) {
|
||||||
|
CHECK_LE(offset_bytes, cache_->cache.size());
|
||||||
|
this->bound_ = offset_bytes;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EllpackHostCacheStream
|
||||||
|
*/
|
||||||
|
|
||||||
|
EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache)
|
||||||
|
: p_impl_{std::make_unique<EllpackHostCacheStreamImpl>(std::move(cache))} {}
|
||||||
|
|
||||||
|
EllpackHostCacheStream::~EllpackHostCacheStream() = default;
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Write(void const* ptr, bst_idx_t n_bytes) {
|
||||||
|
return this->p_impl_->Write(ptr, n_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool EllpackHostCacheStream::Read(void* ptr, bst_idx_t n_bytes) {
|
||||||
|
return this->p_impl_->Read(ptr, n_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Tell() const { return this->p_impl_->Tell(); }
|
||||||
|
|
||||||
|
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
|
||||||
|
|
||||||
|
void EllpackHostCacheStream::Bound(bst_idx_t offset_bytes) { this->p_impl_->Bound(offset_bytes); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EllpackFormatType
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename S, template <typename> typename F>
|
||||||
|
EllpackFormatStreamPolicy<S, F>::EllpackFormatStreamPolicy()
|
||||||
|
: p_cache_{std::make_shared<EllpackHostCache>()} {}
|
||||||
|
|
||||||
|
template <typename S, template <typename> typename F>
|
||||||
|
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::WriterT>
|
||||||
|
EllpackFormatStreamPolicy<S, F>::CreateWriter(StringView, std::uint32_t iter) {
|
||||||
|
auto fo = std::make_unique<EllpackHostCacheStream>(this->p_cache_);
|
||||||
|
if (iter == 0) {
|
||||||
|
CHECK(this->p_cache_->cache.empty());
|
||||||
|
} else {
|
||||||
|
fo->Seek(this->p_cache_->cache.size());
|
||||||
|
}
|
||||||
|
return fo;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename S, template <typename> typename F>
|
||||||
|
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::ReaderT>
|
||||||
|
EllpackFormatStreamPolicy<S, F>::CreateReader(StringView, bst_idx_t offset,
|
||||||
|
bst_idx_t length) const {
|
||||||
|
auto fi = std::make_unique<ReaderT>(this->p_cache_);
|
||||||
|
fi->Seek(offset);
|
||||||
|
fi->Bound(offset + length);
|
||||||
|
CHECK_EQ(fi->Tell(), offset);
|
||||||
|
return fi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiation
|
||||||
|
template EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::EllpackFormatStreamPolicy();
|
||||||
|
|
||||||
|
template std::unique_ptr<
|
||||||
|
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::WriterT>
|
||||||
|
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringView name,
|
||||||
|
std::uint32_t iter);
|
||||||
|
|
||||||
|
template std::unique_ptr<
|
||||||
|
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
|
||||||
|
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
|
||||||
|
StringView name, std::uint64_t offset, std::uint64_t length) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EllpackPageSourceImpl
|
||||||
|
*/
|
||||||
|
template <typename F>
|
||||||
|
void EllpackPageSourceImpl<F>::Fetch() {
|
||||||
|
dh::safe_cuda(cudaSetDevice(this->Device().ordinal));
|
||||||
if (!this->ReadCache()) {
|
if (!this->ReadCache()) {
|
||||||
if (count_ != 0 && !sync_) {
|
if (this->count_ != 0 && !this->sync_) {
|
||||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||||
// there's no need to increment the source.
|
// there's no need to increment the source.
|
||||||
++(*source_);
|
++(*this->source_);
|
||||||
}
|
}
|
||||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||||
CHECK_EQ(count_, source_->Iter());
|
CHECK_EQ(this->count_, this->source_->Iter());
|
||||||
auto const &csr = source_->Page();
|
auto const& csr = this->source_->Page();
|
||||||
this->page_.reset(new EllpackPage{});
|
this->page_.reset(new EllpackPage{});
|
||||||
auto* impl = this->page_->Impl();
|
auto* impl = this->page_->Impl();
|
||||||
*impl = EllpackPageImpl(device_, cuts_, *csr, is_dense_, row_stride_, feature_types_);
|
*impl = EllpackPageImpl{this->Device(), this->GetCuts(), *csr,
|
||||||
page_->SetBaseRowId(csr->base_rowid);
|
is_dense_, row_stride_, feature_types_};
|
||||||
|
this->page_->SetBaseRowId(csr->base_rowid);
|
||||||
this->WriteCache();
|
this->WriteCache();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Instantiation
|
||||||
|
template void
|
||||||
|
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
|
template void
|
||||||
|
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -19,46 +19,127 @@
|
|||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
// We need to decouple the storage and the view of the storage so that we can implement
|
||||||
|
// concurrent read.
|
||||||
|
|
||||||
|
// Dummy type to hide CUDA calls from the host compiler.
|
||||||
|
struct EllpackHostCache;
|
||||||
|
// Pimpl to hide CUDA calls from the host compiler.
|
||||||
|
class EllpackHostCacheStreamImpl;
|
||||||
|
|
||||||
|
// A view onto the actual cache implemented by `EllpackHostCache`.
|
||||||
|
class EllpackHostCacheStream {
|
||||||
|
std::unique_ptr<EllpackHostCacheStreamImpl> p_impl_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache);
|
||||||
|
~EllpackHostCacheStream();
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes);
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] std::enable_if_t<std::is_pod_v<T>, bst_idx_t> Write(T const& v) {
|
||||||
|
return this->Write(&v, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] auto Read(T* ptr) -> std::enable_if_t<std::is_pod_v<T>, bool> {
|
||||||
|
return this->Read(ptr, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t Tell() const;
|
||||||
|
void Seek(bst_idx_t offset_bytes);
|
||||||
|
// Limit the size of read. offset_bytes is the maximum offset that this stream can read
|
||||||
|
// to. An error is raised if the limited is exceeded.
|
||||||
|
void Bound(bst_idx_t offset_bytes);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
class EllpackFormatPolicy {
|
||||||
|
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
||||||
|
DeviceOrd device_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using FormatT = EllpackPageRawFormat;
|
||||||
|
|
||||||
|
public:
|
||||||
|
[[nodiscard]] auto CreatePageFormat() const {
|
||||||
|
CHECK_EQ(cuts_->cut_values_.Device(), device_);
|
||||||
|
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_}};
|
||||||
|
return fmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device) {
|
||||||
|
std::swap(cuts_, cuts);
|
||||||
|
device_ = device;
|
||||||
|
CHECK(this->device_.IsCUDA());
|
||||||
|
}
|
||||||
|
[[nodiscard]] auto GetCuts() {
|
||||||
|
CHECK(cuts_);
|
||||||
|
return cuts_;
|
||||||
|
}
|
||||||
|
[[nodiscard]] auto Device() const { return device_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename S, template <typename> typename F>
|
||||||
|
class EllpackFormatStreamPolicy : public F<S> {
|
||||||
|
std::shared_ptr<EllpackHostCache> p_cache_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using WriterT = EllpackHostCacheStream;
|
||||||
|
using ReaderT = EllpackHostCacheStream;
|
||||||
|
|
||||||
|
public:
|
||||||
|
EllpackFormatStreamPolicy();
|
||||||
|
[[nodiscard]] std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter);
|
||||||
|
|
||||||
|
[[nodiscard]] std::unique_ptr<ReaderT> CreateReader(StringView name, bst_idx_t offset,
|
||||||
|
bst_idx_t length) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
|
||||||
|
using Super = PageSourceIncMixIn<EllpackPage, F>;
|
||||||
bool is_dense_;
|
bool is_dense_;
|
||||||
bst_idx_t row_stride_;
|
bst_idx_t row_stride_;
|
||||||
BatchParam param_;
|
BatchParam param_;
|
||||||
common::Span<FeatureType const> feature_types_;
|
common::Span<FeatureType const> feature_types_;
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
|
||||||
DeviceOrd device_;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
[[nodiscard]] SparsePageFormat<EllpackPage>* CreatePageFormat() const override {
|
|
||||||
cuts_->SetDevice(this->device_);
|
|
||||||
return new EllpackPageRawFormat{cuts_};
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EllpackPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||||
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
std::size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
std::shared_ptr<common::HistogramCuts> cuts, bool is_dense,
|
||||||
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
|
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
|
||||||
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
|
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
|
||||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
: Super{missing, nthreads, n_features, n_batches, cache, false},
|
||||||
is_dense_{is_dense},
|
is_dense_{is_dense},
|
||||||
row_stride_{row_stride},
|
row_stride_{row_stride},
|
||||||
param_{std::move(param)},
|
param_{std::move(param)},
|
||||||
feature_types_{feature_types},
|
feature_types_{feature_types} {
|
||||||
cuts_{std::move(cuts)},
|
|
||||||
device_{device} {
|
|
||||||
this->source_ = source;
|
this->source_ = source;
|
||||||
|
cuts->SetDevice(device);
|
||||||
|
this->SetCuts(std::move(cuts), device);
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fetch() final;
|
void Fetch() final;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Cache to host
|
||||||
|
using EllpackPageHostSource =
|
||||||
|
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||||
|
|
||||||
|
// Cache to disk
|
||||||
|
using EllpackPageSource =
|
||||||
|
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
inline void EllpackPageSource::Fetch() {
|
template <typename F>
|
||||||
|
inline void EllpackPageSourceImpl<F>::Fetch() {
|
||||||
// silent the warning about unused variables.
|
// silent the warning about unused variables.
|
||||||
(void)(row_stride_);
|
(void)(row_stride_);
|
||||||
(void)(is_dense_);
|
(void)(is_dense_);
|
||||||
(void)(device_);
|
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -17,20 +17,35 @@
|
|||||||
#include "xgboost/data.h" // for BatchParam, FeatureType
|
#include "xgboost/data.h" // for BatchParam, FeatureType
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
/**
|
||||||
class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
* @brief Policy for creating ghist index format. The storage is default (disk).
|
||||||
|
*/
|
||||||
|
template <typename S>
|
||||||
|
class GHistIndexFormatPolicy {
|
||||||
|
protected:
|
||||||
common::HistogramCuts cuts_;
|
common::HistogramCuts cuts_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using FormatT = SparsePageFormat<GHistIndexMatrix>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
[[nodiscard]] auto CreatePageFormat() const {
|
||||||
|
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
|
||||||
|
return fmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class GradientIndexPageSource
|
||||||
|
: public PageSourceIncMixIn<
|
||||||
|
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
|
||||||
bool is_dense_;
|
bool is_dense_;
|
||||||
std::int32_t max_bin_per_feat_;
|
std::int32_t max_bin_per_feat_;
|
||||||
common::Span<FeatureType const> feature_types_;
|
common::Span<FeatureType const> feature_types_;
|
||||||
double sparse_thresh_;
|
double sparse_thresh_;
|
||||||
|
|
||||||
protected:
|
|
||||||
[[nodiscard]] SparsePageFormat<GHistIndexMatrix>* CreatePageFormat() const override {
|
|
||||||
return new GHistIndexRawFormat{cuts_};
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||||
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
||||||
@ -39,17 +54,16 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
|||||||
std::shared_ptr<SparsePageSource> source)
|
std::shared_ptr<SparsePageSource> source)
|
||||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
|
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
|
||||||
std::isnan(param.sparse_thresh)),
|
std::isnan(param.sparse_thresh)),
|
||||||
cuts_{std::move(cuts)},
|
|
||||||
is_dense_{is_dense},
|
is_dense_{is_dense},
|
||||||
max_bin_per_feat_{param.max_bin},
|
max_bin_per_feat_{param.max_bin},
|
||||||
feature_types_{feature_types},
|
feature_types_{feature_types},
|
||||||
sparse_thresh_{param.sparse_thresh} {
|
sparse_thresh_{param.sparse_thresh} {
|
||||||
this->source_ = source;
|
this->source_ = source;
|
||||||
|
this->SetCuts(std::move(cuts));
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fetch() final;
|
void Fetch() final;
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
||||||
|
|||||||
@ -38,13 +38,17 @@ std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT
|
|||||||
#endif
|
#endif
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
|
||||||
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
|
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
|
||||||
DataIterResetCallback *reset,
|
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
float missing, int32_t nthreads, std::string cache_prefix,
|
||||||
int32_t nthreads, std::string cache_prefix)
|
bool on_host)
|
||||||
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
|
: proxy_{proxy_handle},
|
||||||
cache_prefix_{std::move(cache_prefix)} {
|
iter_{iter_handle},
|
||||||
|
reset_{reset},
|
||||||
|
next_{next},
|
||||||
|
missing_{missing},
|
||||||
|
cache_prefix_{std::move(cache_prefix)},
|
||||||
|
on_host_{on_host} {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.nthread = nthreads;
|
ctx.nthread = nthreads;
|
||||||
|
|
||||||
@ -103,8 +107,26 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
fmat_ctx_ = ctx;
|
fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SparsePageDMatrix::~SparsePageDMatrix() {
|
||||||
|
// Clear out all resources before deleting the cache file.
|
||||||
|
sparse_page_source_.reset();
|
||||||
|
std::visit([](auto &&ptr) { ptr.reset(); }, ellpack_page_source_);
|
||||||
|
column_source_.reset();
|
||||||
|
sorted_column_source_.reset();
|
||||||
|
ghist_index_source_.reset();
|
||||||
|
|
||||||
|
for (auto const &kv : cache_info_) {
|
||||||
|
CHECK(kv.second);
|
||||||
|
auto n = kv.second->ShardName();
|
||||||
|
if (kv.second->OnHost()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TryDeleteCacheFile(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".row.page", false, cache_prefix_, &cache_info_);
|
||||||
// Don't use proxy DMatrix once this is already initialized, this allows users to
|
// Don't use proxy DMatrix once this is already initialized, this allows users to
|
||||||
// release the iterator and data.
|
// release the iterator and data.
|
||||||
if (cache_info_.at(id)->written) {
|
if (cache_info_.at(id)->written) {
|
||||||
@ -132,8 +154,9 @@ BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
CHECK_NE(this->Info().num_col_, 0);
|
CHECK_NE(this->Info().num_col_, 0);
|
||||||
|
error::NoOnHost(on_host_);
|
||||||
this->InitializeSparsePage(ctx);
|
this->InitializeSparsePage(ctx);
|
||||||
if (!column_source_) {
|
if (!column_source_) {
|
||||||
column_source_ =
|
column_source_ =
|
||||||
@ -146,8 +169,9 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
||||||
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
CHECK_NE(this->Info().num_col_, 0);
|
CHECK_NE(this->Info().num_col_, 0);
|
||||||
|
error::NoOnHost(on_host_);
|
||||||
this->InitializeSparsePage(ctx);
|
this->InitializeSparsePage(ctx);
|
||||||
if (!sorted_column_source_) {
|
if (!sorted_column_source_) {
|
||||||
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
||||||
@ -165,11 +189,12 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
|||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
}
|
}
|
||||||
detail::CheckEmpty(batch_param_, param);
|
detail::CheckEmpty(batch_param_, param);
|
||||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
error::NoOnHost(on_host_);
|
||||||
|
auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
this->InitializeSparsePage(ctx);
|
this->InitializeSparsePage(ctx);
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
LOG(INFO) << "Generating new Gradient Index.";
|
LOG(INFO) << "Generating new Gradient Index.";
|
||||||
// Use sorted sketch for approx.
|
// Use sorted sketch for approx.
|
||||||
auto sorted_sketch = param.regen;
|
auto sorted_sketch = param.regen;
|
||||||
@ -193,7 +218,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
|||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
return BatchSet{BatchIterator<EllpackPage>{nullptr}};
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
* Copyright 2021-2024, XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <variant> // for visit
|
||||||
|
|
||||||
#include "../common/hist_util.cuh"
|
#include "../common/hist_util.cuh"
|
||||||
#include "../common/hist_util.h" // for HistogramCuts
|
#include "../common/hist_util.h" // for HistogramCuts
|
||||||
@ -19,13 +21,15 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|||||||
CHECK_GE(param.max_bin, 2);
|
CHECK_GE(param.max_bin, 2);
|
||||||
}
|
}
|
||||||
detail::CheckEmpty(batch_param_, param);
|
detail::CheckEmpty(batch_param_, param);
|
||||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
size_t row_stride = 0;
|
|
||||||
|
bst_idx_t row_stride = 0;
|
||||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
this->InitializeSparsePage(ctx);
|
this->InitializeSparsePage(ctx);
|
||||||
// reinitialize the cache
|
// reinitialize the cache
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
|
||||||
|
LOG(INFO) << "Generating new a Ellpack page.";
|
||||||
std::shared_ptr<common::HistogramCuts> cuts;
|
std::shared_ptr<common::HistogramCuts> cuts;
|
||||||
if (!param.hess.empty()) {
|
if (!param.hess.empty()) {
|
||||||
cuts = std::make_shared<common::HistogramCuts>(
|
cuts = std::make_shared<common::HistogramCuts>(
|
||||||
@ -41,17 +45,28 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|||||||
CHECK_NE(row_stride, 0);
|
CHECK_NE(row_stride, 0);
|
||||||
batch_param_ = param;
|
batch_param_ = param;
|
||||||
|
|
||||||
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
auto ft = this->Info().feature_types.ConstDeviceSpan();
|
||||||
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
|
if (on_host_ && std::get_if<EllpackHostPtr>(&ellpack_page_source_) == nullptr) {
|
||||||
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
|
ellpack_page_source_.emplace<EllpackHostPtr>(nullptr);
|
||||||
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
}
|
||||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_,
|
std::visit(
|
||||||
ctx->Device());
|
[&](auto&& ptr) {
|
||||||
|
ptr.reset(); // make sure resource is released before making new ones.
|
||||||
|
using SourceT = typename std::remove_reference_t<decltype(ptr)>::element_type;
|
||||||
|
ptr = std::make_shared<SourceT>(this->missing_, ctx->Threads(), this->Info().num_col_,
|
||||||
|
this->n_batches_, cache_info_.at(id), param,
|
||||||
|
std::move(cuts), this->IsDense(), row_stride, ft,
|
||||||
|
this->sparse_page_source_, ctx->Device());
|
||||||
|
},
|
||||||
|
ellpack_page_source_);
|
||||||
} else {
|
} else {
|
||||||
CHECK(sparse_page_source_);
|
CHECK(sparse_page_source_);
|
||||||
ellpack_page_source_->Reset();
|
std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
auto batch_set =
|
||||||
|
std::visit([this](auto&& ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
|
||||||
|
this->ellpack_page_source_);
|
||||||
|
return batch_set;
|
||||||
}
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -7,16 +7,20 @@
|
|||||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||||
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||||
|
|
||||||
#include <map>
|
#include <cstdint> // for uint32_t, int32_t
|
||||||
#include <memory>
|
#include <map> // for map
|
||||||
#include <string>
|
#include <memory> // for shared_ptr
|
||||||
#include <utility>
|
#include <sstream> // for stringstream
|
||||||
|
#include <string> // for string
|
||||||
|
#include <variant> // for variant, visit
|
||||||
|
|
||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h" // for EllpackPageSource, EllpackPageHostSource
|
||||||
#include "gradient_index_page_source.h"
|
#include "gradient_index_page_source.h" // for GradientIndexPageSource
|
||||||
#include "sparse_page_source.h"
|
#include "sparse_page_source.h" // for SparsePageSource, Cache
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/data.h" // for DMatrix, MetaInfo
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
/**
|
/**
|
||||||
@ -70,6 +74,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
float missing_;
|
float missing_;
|
||||||
Context fmat_ctx_;
|
Context fmat_ctx_;
|
||||||
std::string cache_prefix_;
|
std::string cache_prefix_;
|
||||||
|
bool on_host_{false};
|
||||||
std::uint32_t n_batches_{0};
|
std::uint32_t n_batches_{0};
|
||||||
// sparse page is the source to other page types, we make a special member function.
|
// sparse page is the source to other page types, we make a special member function.
|
||||||
void InitializeSparsePage(Context const *ctx);
|
void InitializeSparsePage(Context const *ctx);
|
||||||
@ -79,29 +84,16 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
public:
|
public:
|
||||||
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
|
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
|
||||||
std::string cache_prefix);
|
std::string cache_prefix, bool on_host = false);
|
||||||
|
|
||||||
~SparsePageDMatrix() override {
|
~SparsePageDMatrix() override;
|
||||||
// Clear out all resources before deleting the cache file.
|
|
||||||
sparse_page_source_.reset();
|
|
||||||
ellpack_page_source_.reset();
|
|
||||||
column_source_.reset();
|
|
||||||
sorted_column_source_.reset();
|
|
||||||
ghist_index_source_.reset();
|
|
||||||
|
|
||||||
for (auto const &kv : cache_info_) {
|
|
||||||
CHECK(kv.second);
|
|
||||||
auto n = kv.second->ShardName();
|
|
||||||
TryDeleteCacheFile(n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] MetaInfo &Info() override;
|
[[nodiscard]] MetaInfo &Info() override;
|
||||||
[[nodiscard]] const MetaInfo &Info() const override;
|
[[nodiscard]] const MetaInfo &Info() const override;
|
||||||
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
||||||
// The only DMatrix implementation that returns false.
|
// The only DMatrix implementation that returns false.
|
||||||
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
||||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
DMatrix *Slice(common::Span<std::int32_t const>) override {
|
||||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -111,7 +103,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool EllpackExists() const override {
|
[[nodiscard]] bool EllpackExists() const override {
|
||||||
return static_cast<bool>(ellpack_page_source_);
|
return std::visit([](auto &&ptr) { return static_cast<bool>(ptr); }, ellpack_page_source_);
|
||||||
}
|
}
|
||||||
[[nodiscard]] bool GHistIndexExists() const override {
|
[[nodiscard]] bool GHistIndexExists() const override {
|
||||||
return static_cast<bool>(ghist_index_source_);
|
return static_cast<bool>(ghist_index_source_);
|
||||||
@ -138,7 +130,9 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
private:
|
private:
|
||||||
// source data pointers.
|
// source data pointers.
|
||||||
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
||||||
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
|
using EllpackDiskPtr = std::shared_ptr<EllpackPageSource>;
|
||||||
|
using EllpackHostPtr = std::shared_ptr<EllpackPageHostSource>;
|
||||||
|
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
|
||||||
std::shared_ptr<CSCPageSource> column_source_;
|
std::shared_ptr<CSCPageSource> column_source_;
|
||||||
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||||
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
||||||
@ -153,15 +147,16 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
/**
|
/**
|
||||||
* @brief Make cache if it doesn't exist yet.
|
* @brief Make cache if it doesn't exist yet.
|
||||||
*/
|
*/
|
||||||
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, bool on_host,
|
||||||
|
std::string prefix,
|
||||||
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
||||||
auto &cache_info = *out;
|
auto &cache_info = *out;
|
||||||
auto name = MakeId(prefix, ptr);
|
auto name = MakeId(prefix, ptr);
|
||||||
auto id = name + format;
|
auto id = name + format;
|
||||||
auto it = cache_info.find(id);
|
auto it = cache_info.find(id);
|
||||||
if (it == cache_info.cend()) {
|
if (it == cache_info.cend()) {
|
||||||
cache_info[id].reset(new Cache{false, name, format});
|
cache_info[id].reset(new Cache{false, name, format, on_host});
|
||||||
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName() << std::endl;
|
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName();
|
||||||
}
|
}
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|||||||
30
src/data/sparse_page_source.cc
Normal file
30
src/data/sparse_page_source.cc
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "sparse_page_source.h"
|
||||||
|
|
||||||
|
#include <filesystem> // for exists
|
||||||
|
#include <string> // for string
|
||||||
|
#include <cstdio> // for remove
|
||||||
|
#include <numeric> // for partial_sum
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
|
void Cache::Commit() {
|
||||||
|
if (!written) {
|
||||||
|
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
||||||
|
written = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TryDeleteCacheFile(const std::string& file) {
|
||||||
|
// Don't throw, this is called in a destructor.
|
||||||
|
auto exists = std::filesystem::exists(file);
|
||||||
|
if (!exists) {
|
||||||
|
LOG(WARNING) << "External memory cache file " << file << " is missing.";
|
||||||
|
}
|
||||||
|
if (std::remove(file.c_str()) != 0) {
|
||||||
|
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||||
|
<< "; you may want to remove it manually";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::data
|
||||||
@ -8,11 +8,9 @@
|
|||||||
#include <algorithm> // for min
|
#include <algorithm> // for min
|
||||||
#include <atomic> // for atomic
|
#include <atomic> // for atomic
|
||||||
#include <cstdint> // for uint64_t
|
#include <cstdint> // for uint64_t
|
||||||
#include <cstdio> // for remove
|
|
||||||
#include <future> // for future
|
#include <future> // for future
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <mutex> // for mutex
|
#include <mutex> // for mutex
|
||||||
#include <numeric> // for partial_sum
|
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <utility> // for pair, move
|
#include <utility> // for pair, move
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
@ -27,18 +25,12 @@
|
|||||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
#include "xgboost/base.h" // for bst_feature_t
|
#include "xgboost/base.h" // for bst_feature_t
|
||||||
#include "xgboost/data.h" // for SparsePage, CSCPage
|
#include "xgboost/data.h" // for SparsePage, CSCPage, SortedCSCPage
|
||||||
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
|
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
|
||||||
#include "xgboost/logging.h" // for CHECK_EQ
|
#include "xgboost/logging.h" // for CHECK_EQ
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
inline void TryDeleteCacheFile(const std::string& file) {
|
void TryDeleteCacheFile(const std::string& file);
|
||||||
if (std::remove(file.c_str()) != 0) {
|
|
||||||
// Don't throw, this is called in a destructor.
|
|
||||||
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
|
||||||
<< "; you may want to remove it manually";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Information about the cache including path and page offsets.
|
* @brief Information about the cache including path and page offsets.
|
||||||
@ -46,13 +38,14 @@ inline void TryDeleteCacheFile(const std::string& file) {
|
|||||||
struct Cache {
|
struct Cache {
|
||||||
// whether the write to the cache is complete
|
// whether the write to the cache is complete
|
||||||
bool written;
|
bool written;
|
||||||
|
bool on_host;
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string format;
|
std::string format;
|
||||||
// offset into binary cache file.
|
// offset into binary cache file.
|
||||||
std::vector<std::uint64_t> offset;
|
std::vector<std::uint64_t> offset;
|
||||||
|
|
||||||
Cache(bool w, std::string n, std::string fmt)
|
Cache(bool w, std::string n, std::string fmt, bool on_host)
|
||||||
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
|
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
|
||||||
offset.push_back(0);
|
offset.push_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +57,7 @@ struct Cache {
|
|||||||
[[nodiscard]] std::string ShardName() const {
|
[[nodiscard]] std::string ShardName() const {
|
||||||
return ShardName(this->name, this->format);
|
return ShardName(this->name, this->format);
|
||||||
}
|
}
|
||||||
|
[[nodiscard]] bool OnHost() const { return on_host; }
|
||||||
/**
|
/**
|
||||||
* @brief Record a page with size of n_bytes.
|
* @brief Record a page with size of n_bytes.
|
||||||
*/
|
*/
|
||||||
@ -83,12 +77,7 @@ struct Cache {
|
|||||||
/**
|
/**
|
||||||
* @brief Call this once the write for the cache is complete.
|
* @brief Call this once the write for the cache is complete.
|
||||||
*/
|
*/
|
||||||
void Commit() {
|
void Commit();
|
||||||
if (!written) {
|
|
||||||
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
|
||||||
written = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Prevents multi-threaded call to `GetBatches`.
|
// Prevents multi-threaded call to `GetBatches`.
|
||||||
@ -146,10 +135,59 @@ class ExceHandler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
* @brief Default implementation of the stream creater.
|
||||||
|
*/
|
||||||
|
template <typename S, template <typename> typename F>
|
||||||
|
class DefaultFormatStreamPolicy : public F<S> {
|
||||||
|
public:
|
||||||
|
using WriterT = common::AlignedFileWriteStream;
|
||||||
|
using ReaderT = common::AlignedResourceReadStream;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter) {
|
||||||
|
std::unique_ptr<common::AlignedFileWriteStream> fo;
|
||||||
|
if (iter == 0) {
|
||||||
|
fo = std::make_unique<common::AlignedFileWriteStream>(name, "wb");
|
||||||
|
} else {
|
||||||
|
fo = std::make_unique<common::AlignedFileWriteStream>(name, "ab");
|
||||||
|
}
|
||||||
|
return fo;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ReaderT> CreateReader(StringView name, std::uint64_t offset,
|
||||||
|
std::uint64_t length) const {
|
||||||
|
return std::make_unique<common::PrivateMmapConstStream>(std::string{name}, offset, length);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Default implementatioin of the format creator.
|
||||||
*/
|
*/
|
||||||
template <typename S>
|
template <typename S>
|
||||||
class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
class DefaultFormatPolicy {
|
||||||
|
public:
|
||||||
|
using FormatT = SparsePageFormat<S>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
auto CreatePageFormat() const {
|
||||||
|
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
|
||||||
|
return fmt;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
||||||
|
*
|
||||||
|
* The interface to external storage is divided into two types. The first one is the
|
||||||
|
* format, representing how to read and write the binary. The second part is where to
|
||||||
|
* store the binary cache. These policies are implemented in the `FormatStreamPolicy`
|
||||||
|
* policy class. The format policy controls how to create the format (the first part), and
|
||||||
|
* the stream policy decides where the stream should read from and write to (the second
|
||||||
|
* part). This way we can compose the polices and page types with ease.
|
||||||
|
*/
|
||||||
|
template <typename S,
|
||||||
|
typename FormatStreamPolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
||||||
|
class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPolicy {
|
||||||
protected:
|
protected:
|
||||||
// Prevents calling this iterator from multiple places(or threads).
|
// Prevents calling this iterator from multiple places(or threads).
|
||||||
std::mutex single_threaded_;
|
std::mutex single_threaded_;
|
||||||
@ -165,7 +203,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
// Index to the current page.
|
// Index to the current page.
|
||||||
std::uint32_t count_{0};
|
std::uint32_t count_{0};
|
||||||
// Total number of batches.
|
// Total number of batches.
|
||||||
std::uint32_t n_batches_{0};
|
bst_idx_t n_batches_{0};
|
||||||
|
|
||||||
std::shared_ptr<Cache> cache_info_;
|
std::shared_ptr<Cache> cache_info_;
|
||||||
|
|
||||||
@ -179,10 +217,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
ExceHandler exce_;
|
ExceHandler exce_;
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
[[nodiscard]] virtual SparsePageFormat<S>* CreatePageFormat() const {
|
|
||||||
return ::xgboost::data::CreatePageFormat<S>("raw");
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] bool ReadCache() {
|
[[nodiscard]] bool ReadCache() {
|
||||||
CHECK(!at_end_);
|
CHECK(!at_end_);
|
||||||
if (!cache_info_->written) {
|
if (!cache_info_->written) {
|
||||||
@ -196,8 +230,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
std::int32_t kPrefetches = 3;
|
std::int32_t kPrefetches = 3;
|
||||||
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
||||||
n_prefetches = std::max(n_prefetches, 1);
|
n_prefetches = std::max(n_prefetches, 1);
|
||||||
std::int32_t n_prefetch_batches =
|
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
|
||||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
|
||||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||||
CHECK_LE(n_prefetch_batches, kPrefetches);
|
CHECK_LE(n_prefetch_batches, kPrefetches);
|
||||||
std::size_t fetch_it = count_;
|
std::size_t fetch_it = count_;
|
||||||
@ -216,10 +249,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
*GlobalConfigThreadLocalStore::Get() = config;
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exce_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
|
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
||||||
auto name = self->cache_info_->ShardName();
|
auto name = self->cache_info_->ShardName();
|
||||||
auto [offset, length] = self->cache_info_->View(fetch_it);
|
auto [offset, length] = self->cache_info_->View(fetch_it);
|
||||||
auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
|
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
|
||||||
|
this->CreateReader(name, offset, length)};
|
||||||
CHECK(fmt->Read(page.get(), fi.get()));
|
CHECK(fmt->Read(page.get(), fi.get()));
|
||||||
});
|
});
|
||||||
return page;
|
return page;
|
||||||
@ -243,16 +277,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
CHECK(!cache_info_->written);
|
CHECK(!cache_info_->written);
|
||||||
common::Timer timer;
|
common::Timer timer;
|
||||||
timer.Start();
|
timer.Start();
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
|
auto fmt{this->CreatePageFormat()};
|
||||||
|
|
||||||
auto name = cache_info_->ShardName();
|
auto name = cache_info_->ShardName();
|
||||||
std::unique_ptr<common::AlignedFileWriteStream> fo;
|
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
|
||||||
if (this->Iter() == 0) {
|
this->CreateWriter(StringView{name}, this->Iter())};
|
||||||
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "wb");
|
|
||||||
} else {
|
|
||||||
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "ab");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto bytes = fmt->Write(*page_, fo.get());
|
auto bytes = fmt->Write(*page_, fo.get());
|
||||||
|
|
||||||
timer.Stop();
|
timer.Stop();
|
||||||
@ -265,9 +294,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
virtual void Fetch() = 0;
|
virtual void Fetch() = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||||
std::shared_ptr<Cache> cache)
|
std::shared_ptr<Cache> cache)
|
||||||
: workers_{nthreads},
|
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
|
||||||
missing_{missing},
|
missing_{missing},
|
||||||
nthreads_{nthreads},
|
nthreads_{nthreads},
|
||||||
n_features_{n_features},
|
n_features_{n_features},
|
||||||
@ -403,18 +432,19 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A mixin for advancing the iterator.
|
// A mixin for advancing the iterator.
|
||||||
template <typename S>
|
template <typename S,
|
||||||
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
||||||
|
class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<SparsePageSource> source_;
|
std::shared_ptr<SparsePageSource> source_;
|
||||||
using Super = SparsePageSourceImpl<S>;
|
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
|
||||||
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
|
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
|
||||||
// so we avoid fetching it.
|
// so we avoid fetching it.
|
||||||
bool sync_{true};
|
bool sync_{true};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||||
std::uint32_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
||||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
||||||
|
|
||||||
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
||||||
|
|||||||
@ -234,7 +234,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
|||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : batch_iterator) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
@ -252,7 +252,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
|||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
size_t n_rows = dmat->Info().num_row_;
|
size_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
||||||
|
|
||||||
@ -279,21 +279,18 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
bst_idx_t n_rows = dmat->Info().num_row_;
|
bst_idx_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
// Perform Poisson sampling in place.
|
// Perform Poisson sampling in place.
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
|
|
||||||
// Count the sampled rows.
|
// Count the sampled rows.
|
||||||
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
|
size_t sample_rows =
|
||||||
|
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
|
||||||
// Compact gradient pairs.
|
// Compact gradient pairs.
|
||||||
gpair_.resize(sample_rows);
|
gpair_.resize(sample_rows);
|
||||||
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||||
|
|
||||||
// Index the sample rows.
|
// Index the sample rows.
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
IsNonZero());
|
IsNonZero());
|
||||||
@ -301,18 +298,16 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
sample_row_index_.begin());
|
sample_row_index_.begin());
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
sample_row_index_.begin(), ClearEmptyRows());
|
sample_row_index_.begin(), ClearEmptyRows());
|
||||||
|
|
||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
page_.reset(); // Release the device memory first before reallocating
|
||||||
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense,
|
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), dmat->IsDense(),
|
||||||
first_page->row_stride, sample_rows));
|
first_page->row_stride, sample_rows));
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : batch_iterator) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
@ -363,21 +358,24 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
|
|||||||
return sample;
|
return sample;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx,
|
||||||
|
common::Span<GradientPair> gpair,
|
||||||
common::Span<float> threshold,
|
common::Span<float> threshold,
|
||||||
common::Span<float> grad_sum,
|
common::Span<float> grad_sum,
|
||||||
size_t sample_rows) {
|
size_t sample_rows) {
|
||||||
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
|
auto cuctx = ctx->CUDACtx();
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
thrust::fill(cuctx->CTP(), dh::tend(threshold) - 1, dh::tend(threshold),
|
||||||
CombineGradientPair());
|
std::numeric_limits<float>::max());
|
||||||
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
||||||
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
|
CombineGradientPair{});
|
||||||
|
thrust::sort(cuctx->TP(), dh::tbegin(threshold), dh::tend(threshold) - 1);
|
||||||
|
thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(threshold), dh::tend(threshold) - 1,
|
||||||
dh::tbegin(grad_sum));
|
dh::tbegin(grad_sum));
|
||||||
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
|
thrust::transform(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum),
|
||||||
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
|
||||||
SampleRateDelta(threshold, gpair.size(), sample_rows));
|
SampleRateDelta(threshold, gpair.size(), sample_rows));
|
||||||
thrust::device_ptr<float> min =
|
thrust::device_ptr<float> min =
|
||||||
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||||
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
||||||
}
|
}
|
||||||
}; // namespace tree
|
}; // namespace tree
|
||||||
|
|||||||
@ -129,9 +129,8 @@ class GradientBasedSampler {
|
|||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
||||||
|
|
||||||
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
|
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
|
||||||
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
static size_t CalculateThresholdIndex(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
common::Span<float> threshold,
|
common::Span<float> threshold, common::Span<float> grad_sum,
|
||||||
common::Span<float> grad_sum,
|
|
||||||
size_t sample_rows);
|
size_t sample_rows);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -200,7 +200,7 @@ TEST(EllpackPage, Compact) {
|
|||||||
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
|
|
||||||
// Create an empty result page.
|
// Create an empty result page.
|
||||||
EllpackPageImpl result(FstCU(), page->CutsShared(), page->is_dense, page->row_stride,
|
EllpackPageImpl result(ctx.Device(), page->CutsShared(), page->is_dense, page->row_stride,
|
||||||
kCompactedRows);
|
kCompactedRows);
|
||||||
|
|
||||||
// Compact batch pages into the result page.
|
// Compact batch pages into the result page.
|
||||||
@ -210,7 +210,7 @@ TEST(EllpackPage, Compact) {
|
|||||||
thrust::device_vector<size_t> row_indexes_d = row_indexes_h;
|
thrust::device_vector<size_t> row_indexes_d = row_indexes_h;
|
||||||
common::Span<size_t> row_indexes_span(row_indexes_d.data().get(), kRows);
|
common::Span<size_t> row_indexes_span(row_indexes_d.data().get(), kRows);
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
result.Compact(FstCU(), batch.Impl(), row_indexes_span);
|
result.Compact(&ctx, batch.Impl(), row_indexes_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t current_row = 0;
|
size_t current_row = 0;
|
||||||
|
|||||||
@ -4,15 +4,19 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream...
|
#include "../../../src/data/ellpack_page.cuh" // for EllpackPage
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
|
||||||
#include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat
|
#include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat
|
||||||
#include "../../../src/tree/param.h" // TrainParam
|
#include "../../../src/data/ellpack_page_source.h" // for EllpackFormatStreamPolicy
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
TEST(EllpackPageRawFormat, IO) {
|
namespace {
|
||||||
|
template <typename FormatStreamPolicy>
|
||||||
|
void TestEllpackPageRawFormat() {
|
||||||
|
FormatStreamPolicy policy;
|
||||||
|
|
||||||
Context ctx{MakeCUDACtx(0)};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
@ -25,20 +29,22 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
cuts = page.Impl()->CutsShared();
|
cuts = page.Impl()->CutsShared();
|
||||||
}
|
}
|
||||||
|
|
||||||
cuts->SetDevice(ctx.Device());
|
ASSERT_EQ(cuts->cut_values_.Device(), ctx.Device());
|
||||||
auto format = std::make_unique<EllpackPageRawFormat>(cuts);
|
ASSERT_TRUE(cuts->cut_values_.DeviceCanRead());
|
||||||
|
policy.SetCuts(cuts, ctx.Device());
|
||||||
|
|
||||||
|
std::unique_ptr<EllpackPageRawFormat> format{policy.CreatePageFormat()};
|
||||||
|
|
||||||
std::size_t n_bytes{0};
|
std::size_t n_bytes{0};
|
||||||
{
|
{
|
||||||
auto fo = std::make_unique<common::AlignedFileWriteStream>(StringView{path}, "wb");
|
auto fo = policy.CreateWriter(StringView{path}, 0);
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
n_bytes += format->Write(ellpack, fo.get());
|
n_bytes += format->Write(ellpack, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EllpackPage page;
|
EllpackPage page;
|
||||||
std::unique_ptr<common::AlignedResourceReadStream> fi{
|
auto fi = policy.CreateReader(StringView{path}, static_cast<bst_idx_t>(0), n_bytes);
|
||||||
std::make_unique<common::PrivateMmapConstStream>(path.c_str(), 0, n_bytes)};
|
|
||||||
ASSERT_TRUE(format->Read(&page, fi.get()));
|
ASSERT_TRUE(format->Read(&page, fi.get()));
|
||||||
|
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
@ -52,4 +58,13 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector());
|
ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(EllpackPageRawFormat, DiskIO) {
|
||||||
|
TestEllpackPageRawFormat<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EllpackPageRawFormat, HostIO) {
|
||||||
|
TestEllpackPageRawFormat<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>();
|
||||||
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/data.h> // for DMatrix
|
#include <xgboost/data.h> // for DMatrix
|
||||||
|
|
||||||
@ -29,13 +29,9 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
|||||||
EXPECT_EQ(n, dmat->Info().num_row_);
|
EXPECT_EQ(n, dmat->Info().num_row_);
|
||||||
|
|
||||||
auto path =
|
auto path =
|
||||||
data::MakeId(tmp_file + ".cache",
|
data::MakeId(tmp_file + ".cache", dynamic_cast<data::SparsePageDMatrix*>(dmat)) + ".row.page";
|
||||||
dynamic_cast<data::SparsePageDMatrix *>(dmat)) +
|
|
||||||
".row.page";
|
|
||||||
EXPECT_TRUE(FileExists(path));
|
EXPECT_TRUE(FileExists(path));
|
||||||
path =
|
path = data::MakeId(tmp_file + ".cache", dynamic_cast<data::SparsePageDMatrix*>(dmat)) +
|
||||||
data::MakeId(tmp_file + ".cache",
|
|
||||||
dynamic_cast<data::SparsePageDMatrix *>(dmat)) +
|
|
||||||
".ellpack.page";
|
".ellpack.page";
|
||||||
EXPECT_TRUE(FileExists(path));
|
EXPECT_TRUE(FileExists(path));
|
||||||
|
|
||||||
@ -82,8 +78,8 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
|||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, filename);
|
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, filename);
|
||||||
|
|
||||||
// Loop over the batches and count the records
|
// Loop over the batches and count the records
|
||||||
int64_t batch_count = 0;
|
std::int64_t batch_count = 0;
|
||||||
int64_t row_count = 0;
|
bst_idx_t row_count = 0;
|
||||||
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||||
batch_count++;
|
batch_count++;
|
||||||
@ -138,31 +134,40 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, EllpackPageContent) {
|
namespace {
|
||||||
|
// Test comparing external DMatrix with in-core DMatrix
|
||||||
|
class TestEllpackPageExt : public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||||
|
protected:
|
||||||
|
void Run(bool on_host, bool is_dense) {
|
||||||
|
float sparsity = is_dense ? 0.0 : 0.2;
|
||||||
|
|
||||||
auto ctx = MakeCUDACtx(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
constexpr size_t kRows = 6;
|
constexpr bst_idx_t kRows = 64;
|
||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr size_t kPageSize = 1;
|
|
||||||
|
|
||||||
// Create an in-memory DMatrix.
|
// Create an in-memory DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
|
auto p_fmat = RandomDataGenerator{kRows, kCols, sparsity}.GenerateDMatrix(true);
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::unique_ptr<DMatrix>
|
auto prefix = tmpdir.path + "/cache";
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
|
||||||
|
auto p_ext_fmat = RandomDataGenerator{kRows, kCols, sparsity}
|
||||||
|
.Batches(4)
|
||||||
|
.OnHost(on_host)
|
||||||
|
.GenerateSparsePageDMatrix(prefix, true);
|
||||||
|
|
||||||
auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto impl = (*p_fmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_EQ(impl->base_rowid, 0);
|
ASSERT_EQ(impl->base_rowid, 0);
|
||||||
EXPECT_EQ(impl->n_rows, kRows);
|
ASSERT_EQ(impl->n_rows, kRows);
|
||||||
EXPECT_FALSE(impl->is_dense);
|
ASSERT_EQ(impl->is_dense, is_dense);
|
||||||
EXPECT_EQ(impl->row_stride, 2);
|
ASSERT_EQ(impl->row_stride, 2);
|
||||||
EXPECT_EQ(impl->Cuts().TotalBins(), 4);
|
ASSERT_EQ(impl->Cuts().TotalBins(), 4);
|
||||||
|
|
||||||
std::unique_ptr<EllpackPageImpl> impl_ext;
|
std::unique_ptr<EllpackPageImpl> impl_ext;
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto& batch : p_ext_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
if (!impl_ext) {
|
if (!impl_ext) {
|
||||||
impl_ext = std::make_unique<EllpackPageImpl>(
|
impl_ext = std::make_unique<EllpackPageImpl>(
|
||||||
batch.Impl()->gidx_buffer.Device(), batch.Impl()->CutsShared(), batch.Impl()->is_dense,
|
batch.Impl()->gidx_buffer.Device(), batch.Impl()->CutsShared(), batch.Impl()->is_dense,
|
||||||
@ -171,16 +176,42 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
|||||||
auto n_elems = impl_ext->Copy(ctx.Device(), batch.Impl(), offset);
|
auto n_elems = impl_ext->Copy(ctx.Device(), batch.Impl(), offset);
|
||||||
offset += n_elems;
|
offset += n_elems;
|
||||||
}
|
}
|
||||||
EXPECT_EQ(impl_ext->base_rowid, 0);
|
ASSERT_EQ(impl_ext->base_rowid, 0);
|
||||||
EXPECT_EQ(impl_ext->n_rows, kRows);
|
ASSERT_EQ(impl_ext->n_rows, kRows);
|
||||||
EXPECT_FALSE(impl_ext->is_dense);
|
ASSERT_EQ(impl_ext->is_dense, is_dense);
|
||||||
EXPECT_EQ(impl_ext->row_stride, 2);
|
ASSERT_EQ(impl_ext->row_stride, 2);
|
||||||
EXPECT_EQ(impl_ext->Cuts().TotalBins(), 4);
|
ASSERT_EQ(impl_ext->Cuts().TotalBins(), 4);
|
||||||
|
|
||||||
std::vector<common::CompressedByteT> buffer(impl->gidx_buffer.HostVector());
|
std::vector<common::CompressedByteT> buffer(impl->gidx_buffer.HostVector());
|
||||||
std::vector<common::CompressedByteT> buffer_ext(impl_ext->gidx_buffer.HostVector());
|
std::vector<common::CompressedByteT> buffer_ext(impl_ext->gidx_buffer.HostVector());
|
||||||
EXPECT_EQ(buffer, buffer_ext);
|
ASSERT_EQ(buffer, buffer_ext);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST_P(TestEllpackPageExt, Data) {
|
||||||
|
auto [on_host, is_dense] = this->GetParam();
|
||||||
|
this->Run(on_host, is_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(EllpackPageExt, TestEllpackPageExt, ::testing::ValuesIn([]() {
|
||||||
|
std::vector<std::tuple<bool, bool>> values;
|
||||||
|
for (auto on_host : {true, false}) {
|
||||||
|
for (auto is_dense : {true, false}) {
|
||||||
|
values.emplace_back(on_host, is_dense);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}()),
|
||||||
|
[](::testing::TestParamInfo<TestEllpackPageExt::ParamType> const& info) {
|
||||||
|
auto on_host = std::get<0>(info.param);
|
||||||
|
auto is_dense = std::get<1>(info.param);
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << (on_host ? "host" : "ext");
|
||||||
|
ss << "_";
|
||||||
|
ss << (is_dense ? "dense" : "sparse");
|
||||||
|
return ss.str();
|
||||||
|
});
|
||||||
|
|
||||||
struct ReadRowFunction {
|
struct ReadRowFunction {
|
||||||
EllpackDeviceAccessor matrix;
|
EllpackDeviceAccessor matrix;
|
||||||
|
|||||||
@ -437,9 +437,9 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat{
|
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
|
||||||
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
|
static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix)};
|
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix, on_host_)};
|
||||||
|
|
||||||
auto row_page_path =
|
auto row_page_path =
|
||||||
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
||||||
@ -520,9 +520,9 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_idx_t n_samples, bst_featur
|
|||||||
CHECK_GE(n_samples, n_batches);
|
CHECK_GE(n_samples, n_batches);
|
||||||
NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches);
|
NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches);
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat{
|
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
|
||||||
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix)};
|
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix, false)};
|
||||||
|
|
||||||
auto row_page_path =
|
auto row_page_path =
|
||||||
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
||||||
@ -549,7 +549,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
|
|||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat{
|
std::unique_ptr<DMatrix> dmat{
|
||||||
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, prefix)};
|
std::numeric_limits<float>::quiet_NaN(), 0, prefix, false)};
|
||||||
auto row_page_path =
|
auto row_page_path =
|
||||||
data::MakeId(prefix,
|
data::MakeId(prefix,
|
||||||
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
||||||
@ -568,8 +568,8 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
|
|||||||
return dmat;
|
return dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
||||||
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
|
size_t page_size, bool deterministic,
|
||||||
const dmlc::TemporaryDirectory& tempdir) {
|
const dmlc::TemporaryDirectory& tempdir) {
|
||||||
if (!n_rows || !n_cols) {
|
if (!n_rows || !n_cols) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@ -241,6 +241,7 @@ class RandomDataGenerator {
|
|||||||
bst_bin_t bins_{0};
|
bst_bin_t bins_{0};
|
||||||
std::vector<FeatureType> ft_;
|
std::vector<FeatureType> ft_;
|
||||||
bst_cat_t max_cat_{32};
|
bst_cat_t max_cat_{32};
|
||||||
|
bool on_host_{false};
|
||||||
|
|
||||||
Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;
|
Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;
|
||||||
|
|
||||||
@ -266,6 +267,10 @@ class RandomDataGenerator {
|
|||||||
n_batches_ = n_batches;
|
n_batches_ = n_batches;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
RandomDataGenerator& OnHost(bool on_host) {
|
||||||
|
on_host_ = on_host;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
RandomDataGenerator& Seed(uint64_t s) {
|
RandomDataGenerator& Seed(uint64_t s) {
|
||||||
seed_ = s;
|
seed_ = s;
|
||||||
lcg_.Seed(seed_);
|
lcg_.Seed(seed_);
|
||||||
|
|||||||
@ -67,4 +67,30 @@ TEST(RandomDataGenerator, GenerateArrayInterfaceBatch) {
|
|||||||
CHECK_EQ(get<Integer>(j_array["shape"][0]), kRows);
|
CHECK_EQ(get<Integer>(j_array["shape"][0]), kRows);
|
||||||
CHECK_EQ(get<Integer>(j_array["shape"][1]), kCols);
|
CHECK_EQ(get<Integer>(j_array["shape"][1]), kCols);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RandomDataGenerator, SparseDMatrix) {
|
||||||
|
bst_idx_t constexpr kCols{100}, kBatches{13};
|
||||||
|
bst_idx_t n_samples{kBatches * 128};
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto prefix = tmpdir.path + "/cache";
|
||||||
|
auto p_ext_fmat =
|
||||||
|
RandomDataGenerator{n_samples, kCols, 0.0}.Batches(kBatches).GenerateSparsePageDMatrix(prefix,
|
||||||
|
true);
|
||||||
|
|
||||||
|
auto p_fmat = RandomDataGenerator{n_samples, kCols, 0.0}.GenerateDMatrix(true);
|
||||||
|
|
||||||
|
SparsePage concat;
|
||||||
|
std::int32_t n_batches{0};
|
||||||
|
for (auto const& page : p_ext_fmat->GetBatches<SparsePage>()) {
|
||||||
|
concat.Push(page);
|
||||||
|
++n_batches;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_batches, kBatches);
|
||||||
|
ASSERT_EQ(concat.Size(), n_samples);
|
||||||
|
|
||||||
|
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
ASSERT_EQ(page.data.ConstHostVector(), concat.data.ConstHostVector());
|
||||||
|
ASSERT_EQ(page.offset.ConstHostVector(), concat.offset.ConstHostVector());
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -21,20 +21,38 @@ def test_gpu_single_batch() -> None:
|
|||||||
strategies.integers(0, 8),
|
strategies.integers(0, 8),
|
||||||
strategies.booleans(),
|
strategies.booleans(),
|
||||||
strategies.booleans(),
|
strategies.booleans(),
|
||||||
|
strategies.booleans(),
|
||||||
)
|
)
|
||||||
@settings(deadline=None, max_examples=10, print_blob=True)
|
@settings(deadline=None, max_examples=16, print_blob=True)
|
||||||
def test_gpu_data_iterator(
|
def test_gpu_data_iterator(
|
||||||
n_samples_per_batch: int,
|
n_samples_per_batch: int,
|
||||||
n_features: int,
|
n_features: int,
|
||||||
n_batches: int,
|
n_batches: int,
|
||||||
subsample: bool,
|
subsample: bool,
|
||||||
use_cupy: bool,
|
use_cupy: bool,
|
||||||
|
on_host: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_data_iterator(
|
run_data_iterator(
|
||||||
n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, use_cupy
|
n_samples_per_batch,
|
||||||
|
n_features,
|
||||||
|
n_batches,
|
||||||
|
"hist",
|
||||||
|
subsample=subsample,
|
||||||
|
device="cuda",
|
||||||
|
use_cupy=use_cupy,
|
||||||
|
on_host=on_host,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_cpu_data_iterator() -> None:
|
def test_cpu_data_iterator() -> None:
|
||||||
"""Make sure CPU algorithm can handle GPU inputs"""
|
"""Make sure CPU algorithm can handle GPU inputs"""
|
||||||
run_data_iterator(1024, 2, 3, "approx", False, True)
|
run_data_iterator(
|
||||||
|
1024,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
"approx",
|
||||||
|
device="cuda",
|
||||||
|
subsample=False,
|
||||||
|
use_cupy=True,
|
||||||
|
on_host=False,
|
||||||
|
)
|
||||||
|
|||||||
@ -73,7 +73,9 @@ def run_data_iterator(
|
|||||||
n_batches: int,
|
n_batches: int,
|
||||||
tree_method: str,
|
tree_method: str,
|
||||||
subsample: bool,
|
subsample: bool,
|
||||||
|
device: str,
|
||||||
use_cupy: bool,
|
use_cupy: bool,
|
||||||
|
on_host: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
n_rounds = 2
|
n_rounds = 2
|
||||||
# The test is more difficult to pass if the subsample rate is smaller as the root_sum
|
# The test is more difficult to pass if the subsample rate is smaller as the root_sum
|
||||||
@ -83,7 +85,8 @@ def run_data_iterator(
|
|||||||
|
|
||||||
it = IteratorForTest(
|
it = IteratorForTest(
|
||||||
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
||||||
cache="cache"
|
cache="cache",
|
||||||
|
on_host=on_host,
|
||||||
)
|
)
|
||||||
if n_batches == 0:
|
if n_batches == 0:
|
||||||
with pytest.raises(ValueError, match="1 batch"):
|
with pytest.raises(ValueError, match="1 batch"):
|
||||||
@ -98,10 +101,11 @@ def run_data_iterator(
|
|||||||
"tree_method": tree_method,
|
"tree_method": tree_method,
|
||||||
"max_depth": 2,
|
"max_depth": 2,
|
||||||
"subsample": subsample_rate,
|
"subsample": subsample_rate,
|
||||||
|
"device": device,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree_method == "gpu_hist":
|
if device.find("cuda") != -1:
|
||||||
parameters["sampling_method"] = "gradient_based"
|
parameters["sampling_method"] = "gradient_based"
|
||||||
|
|
||||||
results_from_it: Dict[str, Dict[str, List[float]]] = {}
|
results_from_it: Dict[str, Dict[str, List[float]]] = {}
|
||||||
@ -167,10 +171,24 @@ def test_data_iterator(
|
|||||||
subsample: bool,
|
subsample: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_data_iterator(
|
run_data_iterator(
|
||||||
n_samples_per_batch, n_features, n_batches, "approx", subsample, False
|
n_samples_per_batch,
|
||||||
|
n_features,
|
||||||
|
n_batches,
|
||||||
|
"approx",
|
||||||
|
subsample,
|
||||||
|
"cpu",
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
run_data_iterator(
|
run_data_iterator(
|
||||||
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
|
n_samples_per_batch,
|
||||||
|
n_features,
|
||||||
|
n_batches,
|
||||||
|
"hist",
|
||||||
|
subsample,
|
||||||
|
"cpu",
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -241,7 +259,7 @@ def test_cat_check() -> None:
|
|||||||
batches.append((X, y))
|
batches.append((X, y))
|
||||||
|
|
||||||
X, y = list(zip(*batches))
|
X, y = list(zip(*batches))
|
||||||
it = tm.IteratorForTest(X, y, None, cache=None)
|
it = tm.IteratorForTest(X, y, None, cache=None, on_host=False)
|
||||||
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
|
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="categorical features"):
|
with pytest.raises(ValueError, match="categorical features"):
|
||||||
@ -254,7 +272,7 @@ def test_cat_check() -> None:
|
|||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
cache_path = os.path.join(tmpdir, "cache")
|
cache_path = os.path.join(tmpdir, "cache")
|
||||||
|
|
||||||
it = tm.IteratorForTest(X, y, None, cache=cache_path)
|
it = tm.IteratorForTest(X, y, None, cache=cache_path, on_host=False)
|
||||||
Xy = xgb.DMatrix(it, enable_categorical=True)
|
Xy = xgb.DMatrix(it, enable_categorical=True)
|
||||||
with pytest.raises(ValueError, match="categorical features"):
|
with pytest.raises(ValueError, match="categorical features"):
|
||||||
xgb.train({"booster": "gblinear"}, Xy)
|
xgb.train({"booster": "gblinear"}, Xy)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user