[EM] Allow staging ellpack on host for GPU external memory. (#10488)

- New parameter `on_host`.
- Abstract format creation and stream creation into policy classes.
This commit is contained in:
Jiaming Yuan
2024-06-28 04:42:18 +08:00
committed by GitHub
parent 824fba783e
commit e8a962575a
36 changed files with 842 additions and 317 deletions

View File

@@ -298,13 +298,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto missing = GetMissing(jconfig);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)};
API_END();
}

View File

@@ -429,7 +429,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBDefaultDeviceAllocatorImpl()
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {}
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()) {}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
};
@@ -484,8 +484,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBCachingDeviceAllocatorImpl()
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()),
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()),
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBOOST_DEVICE void construct(T *) {} // NOLINT
private:

View File

@@ -6,7 +6,7 @@
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
#define XGBOOST_COMMON_ERROR_MSG_H_
#include <cinttypes> // for uint64_t
#include <cstdint> // for uint64_t
#include <limits> // for numeric_limits
#include <string> // for string
@@ -103,5 +103,11 @@ inline auto NoFederated() { return "XGBoost is not compiled with federated learn
inline auto NoCategorical(std::string name) {
return name + " doesn't support categorical features.";
}
inline void NoOnHost(bool on_host) {
if (on_host) {
LOG(FATAL) << "Caching on host memory is only available for GPU.";
}
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@@ -163,7 +163,7 @@ class HistogramCuts {
return vals[bin_idx - 1];
}
void SetDevice(DeviceOrd d) const {
void SetDevice(DeviceOrd d) {
this->cut_ptrs_.SetDevice(d);
this->cut_ptrs_.ConstDevicePointer();

View File

@@ -901,15 +901,12 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_p
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
}
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t n_threads,
std::string cache) {
return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, n_threads,
cache);
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int32_t n_threads,
std::string cache, bool on_host) {
return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host};
}
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
@@ -919,10 +916,11 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
XGDMatrixCallbackNext* next, float missing,
int nthread, int max_bin);
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing,
int32_t n_threads, std::string, bool);
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,

View File

@@ -36,7 +36,7 @@ void EllpackPage::SetBaseRowId(std::size_t) {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
"EllpackPage is required";
}
size_t EllpackPage::Size() const {
bst_idx_t EllpackPage::Size() const {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
"EllpackPage is required";
return 0;

View File

@@ -29,7 +29,7 @@ EllpackPage::~EllpackPage() = default;
EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }
size_t EllpackPage::Size() const { return impl_->Size(); }
[[nodiscard]] bst_idx_t EllpackPage::Size() const { return impl_->Size(); }
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
@@ -91,13 +91,13 @@ __global__ void CompressBinEllpackKernel(
// Construct an ELLPACK matrix with the given number of empty rows.
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
size_t row_stride, size_t n_rows)
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) {
bst_idx_t row_stride, bst_idx_t n_rows)
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} {
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(device.ordinal));
monitor_.Start("InitCompressedData");
InitCompressedData(device);
this->InitCompressedData(device);
monitor_.Stop("InitCompressedData");
}
@@ -403,7 +403,7 @@ struct CopyPage {
// Copy the data from the given EllpackPage to the current page.
size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) {
monitor_.Start("Copy");
size_t num_elements = page->n_rows * page->row_stride;
bst_idx_t num_elements = page->n_rows * page->row_stride;
CHECK_EQ(row_stride, page->row_stride);
CHECK_EQ(NumSymbols(), page->NumSymbols());
CHECK_GE(n_rows * row_stride, offset + num_elements);
@@ -461,16 +461,17 @@ struct CompactPage {
};
// Compacts the data from the given EllpackPage into the current page.
void EllpackPageImpl::Compact(DeviceOrd device, EllpackPageImpl const* page,
void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page,
common::Span<size_t> row_indexes) {
monitor_.Start("Compact");
monitor_.Start(__func__);
CHECK_EQ(row_stride, page->row_stride);
CHECK_EQ(NumSymbols(), page->NumSymbols());
CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size());
gidx_buffer.SetDevice(device);
page->gidx_buffer.SetDevice(device);
dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes));
monitor_.Stop("Compact");
gidx_buffer.SetDevice(ctx->Device());
page->gidx_buffer.SetDevice(ctx->Device());
auto cuctx = ctx->CUDACtx();
dh::LaunchN(page->n_rows, cuctx->Stream(), CompactPage(this, page, row_indexes));
monitor_.Stop(__func__);
}
// Initialize the buffer to stored compressed features.
@@ -551,7 +552,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
}
// Return the number of rows contained in this page.
size_t EllpackPageImpl::Size() const { return n_rows; }
[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; }
// Return the memory cost for storing the compressed features.
size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,

View File

@@ -143,7 +143,7 @@ class EllpackPageImpl {
* and the given number of rows.
*/
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
bool is_dense, size_t row_stride, size_t n_rows);
bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows);
/*!
* \brief Constructor used for external memory.
*/
@@ -181,14 +181,14 @@ class EllpackPageImpl {
/*! \brief Compact the given ELLPACK page into the current page.
*
* @param device The GPU device to use.
* @param context The GPU context.
* @param page The ELLPACK page to compact from.
* @param row_indexes Row indexes for the compacted page.
*/
void Compact(DeviceOrd device, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
void Compact(Context const* ctx, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
/*! \return Number of instances in the page. */
[[nodiscard]] size_t Size() const;
[[nodiscard]] bst_idx_t Size() const;
/*! \brief Set the base row id for this page. */
void SetBaseRowId(std::size_t row_id) {
@@ -231,7 +231,7 @@ class EllpackPageImpl {
/*! \brief Whether or not if the matrix is dense. */
bool is_dense;
/*! \brief Row length for ELLPACK. */
size_t row_stride;
bst_idx_t row_stride;
bst_idx_t base_rowid{0};
bst_idx_t n_rows{};
/*! \brief global index of histogram, which is stored in ELLPACK format. */

View File

@@ -41,7 +41,7 @@ class EllpackPage {
EllpackPage(EllpackPage&& that);
/*! \return Number of instances in the page. */
[[nodiscard]] size_t Size() const;
[[nodiscard]] bst_idx_t Size() const;
/*! \brief Set the base row id for this page. */
void SetBaseRowId(std::size_t row_id);

View File

@@ -10,6 +10,7 @@
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
#include "ellpack_page.cuh" // for EllpackPage
#include "ellpack_page_raw_format.h"
#include "ellpack_page_source.h"
namespace xgboost::data {
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
@@ -32,7 +33,6 @@ template <typename T>
return false;
}
vec->SetDevice(DeviceOrd::CUDA(0));
vec->Resize(n);
auto d_vec = vec->DeviceSpan();
dh::safe_cuda(
@@ -54,6 +54,7 @@ template <typename T>
if (!fi->Read(&impl->row_stride)) {
return false;
}
impl->gidx_buffer.SetDevice(device_);
if (!ReadDeviceVec(fi, &impl->gidx_buffer)) {
return false;
}
@@ -73,6 +74,65 @@ template <typename T>
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector());
bytes += fo->Write(impl->base_rowid);
dh::DefaultStream().Sync();
return bytes;
}
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const {
auto* impl = page->Impl();
CHECK(this->cuts_->cut_values_.DeviceCanRead());
impl->SetCuts(this->cuts_);
if (!fi->Read(&impl->n_rows)) {
return false;
}
if (!fi->Read(&impl->is_dense)) {
return false;
}
if (!fi->Read(&impl->row_stride)) {
return false;
}
// Read vec
bst_idx_t n{0};
if (!fi->Read(&n)) {
return false;
}
if (n != 0) {
impl->gidx_buffer.SetDevice(device_);
impl->gidx_buffer.Resize(n);
auto span = impl->gidx_buffer.DeviceSpan();
if (!fi->Read(span.data(), span.size_bytes())) {
return false;
}
}
if (!fi->Read(&impl->base_rowid)) {
return false;
}
dh::DefaultStream().Sync();
return true;
}
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
EllpackHostCacheStream* fo) const {
bst_idx_t bytes{0};
auto* impl = page.Impl();
bytes += fo->Write(impl->n_rows);
bytes += fo->Write(impl->is_dense);
bytes += fo->Write(impl->row_stride);
// Write vector
bst_idx_t n = impl->gidx_buffer.Size();
bytes += fo->Write(n);
if (!impl->gidx_buffer.Empty()) {
auto span = impl->gidx_buffer.ConstDeviceSpan();
bytes += fo->Write(span.data(), span.size_bytes());
}
bytes += fo->Write(impl->base_rowid);
dh::DefaultStream().Sync();
return bytes;
}
} // namespace xgboost::data

View File

@@ -20,15 +20,22 @@ class HistogramCuts;
}
namespace xgboost::data {
class EllpackHostCacheStream;
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
std::shared_ptr<common::HistogramCuts const> cuts_;
DeviceOrd device_;
public:
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts)
: cuts_{std::move(cuts)} {}
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device)
: cuts_{std::move(cuts)}, device_{device} {}
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
[[nodiscard]] std::size_t Write(const EllpackPage& page,
common::AlignedFileWriteStream* fo) override;
[[nodiscard]] bool Read(EllpackPage* page, EllpackHostCacheStream* fi) const;
[[nodiscard]] std::size_t Write(const EllpackPage& page, EllpackHostCacheStream* fo) const;
};
#if !defined(XGBOOST_USE_CUDA)

View File

@@ -1,29 +1,161 @@
/**
* Copyright 2019-2024, XGBoost contributors
*/
#include <memory>
#include <thrust/host_vector.h> // for host_vector
#include "ellpack_page.cuh"
#include "ellpack_page.h" // for EllpackPage
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, uint64_t, uint32_t
#include <memory> // for shared_ptr, make_unique, make_shared
#include <utility> // for move
#include "../common/common.h" // for safe_cuda
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
#include "ellpack_page.cuh" // for EllpackPageImpl
#include "ellpack_page.h" // for EllpackPage
#include "ellpack_page_source.h"
#include "xgboost/base.h" // for bst_idx_t
namespace xgboost::data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
struct EllpackHostCache {
thrust::host_vector<std::int8_t, common::cuda::pinned_allocator<std::int8_t>> cache;
void Resize(std::size_t n, dh::CUDAStreamView stream) {
stream.Sync(); // Prevent partial copy inside resize.
cache.resize(n);
}
};
class EllpackHostCacheStreamImpl {
std::shared_ptr<EllpackHostCache> cache_;
bst_idx_t cur_ptr_{0};
bst_idx_t bound_{0};
public:
explicit EllpackHostCacheStreamImpl(std::shared_ptr<EllpackHostCache> cache)
: cache_{std::move(cache)} {}
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes) {
auto n = cur_ptr_ + n_bytes;
if (n > cache_->cache.size()) {
cache_->Resize(n, dh::DefaultStream());
}
dh::safe_cuda(cudaMemcpyAsync(cache_->cache.data() + cur_ptr_, ptr, n_bytes, cudaMemcpyDefault,
dh::DefaultStream()));
cur_ptr_ = n;
return n_bytes;
}
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes) {
CHECK_LE(cur_ptr_ + n_bytes, bound_);
dh::safe_cuda(cudaMemcpyAsync(ptr, cache_->cache.data() + cur_ptr_, n_bytes, cudaMemcpyDefault,
dh::DefaultStream()));
cur_ptr_ += n_bytes;
return true;
}
[[nodiscard]] bst_idx_t Tell() const { return cur_ptr_; }
void Seek(bst_idx_t offset_bytes) { cur_ptr_ = offset_bytes; }
void Bound(bst_idx_t offset_bytes) {
CHECK_LE(offset_bytes, cache_->cache.size());
this->bound_ = offset_bytes;
}
};
/**
* EllpackHostCacheStream
*/
EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache)
: p_impl_{std::make_unique<EllpackHostCacheStreamImpl>(std::move(cache))} {}
EllpackHostCacheStream::~EllpackHostCacheStream() = default;
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Write(void const* ptr, bst_idx_t n_bytes) {
return this->p_impl_->Write(ptr, n_bytes);
}
[[nodiscard]] bool EllpackHostCacheStream::Read(void* ptr, bst_idx_t n_bytes) {
return this->p_impl_->Read(ptr, n_bytes);
}
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Tell() const { return this->p_impl_->Tell(); }
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
void EllpackHostCacheStream::Bound(bst_idx_t offset_bytes) { this->p_impl_->Bound(offset_bytes); }
/**
* EllpackFormatType
*/
template <typename S, template <typename> typename F>
EllpackFormatStreamPolicy<S, F>::EllpackFormatStreamPolicy()
: p_cache_{std::make_shared<EllpackHostCache>()} {}
template <typename S, template <typename> typename F>
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::WriterT>
EllpackFormatStreamPolicy<S, F>::CreateWriter(StringView, std::uint32_t iter) {
auto fo = std::make_unique<EllpackHostCacheStream>(this->p_cache_);
if (iter == 0) {
CHECK(this->p_cache_->cache.empty());
} else {
fo->Seek(this->p_cache_->cache.size());
}
return fo;
}
template <typename S, template <typename> typename F>
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::ReaderT>
EllpackFormatStreamPolicy<S, F>::CreateReader(StringView, bst_idx_t offset,
bst_idx_t length) const {
auto fi = std::make_unique<ReaderT>(this->p_cache_);
fi->Seek(offset);
fi->Bound(offset + length);
CHECK_EQ(fi->Tell(), offset);
return fi;
}
// Instantiation
template EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::EllpackFormatStreamPolicy();
template std::unique_ptr<
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::WriterT>
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringView name,
std::uint32_t iter);
template std::unique_ptr<
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
StringView name, std::uint64_t offset, std::uint64_t length) const;
/**
* EllpackPageSourceImpl
*/
template <typename F>
void EllpackPageSourceImpl<F>::Fetch() {
dh::safe_cuda(cudaSetDevice(this->Device().ordinal));
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
if (this->count_ != 0 && !this->sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
++(*this->source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const &csr = source_->Page();
CHECK_EQ(this->count_, this->source_->Iter());
auto const& csr = this->source_->Page();
this->page_.reset(new EllpackPage{});
auto *impl = this->page_->Impl();
*impl = EllpackPageImpl(device_, cuts_, *csr, is_dense_, row_stride_, feature_types_);
page_->SetBaseRowId(csr->base_rowid);
auto* impl = this->page_->Impl();
*impl = EllpackPageImpl{this->Device(), this->GetCuts(), *csr,
is_dense_, row_stride_, feature_types_};
this->page_->SetBaseRowId(csr->base_rowid);
this->WriteCache();
}
}
// Instantiation
template void
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
template void
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
} // namespace xgboost::data

View File

@@ -19,46 +19,127 @@
#include "xgboost/span.h" // for Span
namespace xgboost::data {
class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
// We need to decouple the storage and the view of the storage so that we can implement
// concurrent read.
// Dummy type to hide CUDA calls from the host compiler.
struct EllpackHostCache;
// Pimpl to hide CUDA calls from the host compiler.
class EllpackHostCacheStreamImpl;
// A view onto the actual cache implemented by `EllpackHostCache`.
class EllpackHostCacheStream {
std::unique_ptr<EllpackHostCacheStreamImpl> p_impl_;
public:
explicit EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache);
~EllpackHostCacheStream();
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes);
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_pod_v<T>, bst_idx_t> Write(T const& v) {
return this->Write(&v, sizeof(T));
}
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes);
template <typename T>
[[nodiscard]] auto Read(T* ptr) -> std::enable_if_t<std::is_pod_v<T>, bool> {
return this->Read(ptr, sizeof(T));
}
[[nodiscard]] bst_idx_t Tell() const;
void Seek(bst_idx_t offset_bytes);
// Limit the size of read. offset_bytes is the maximum offset that this stream can read
// to. An error is raised if the limited is exceeded.
void Bound(bst_idx_t offset_bytes);
};
template <typename S>
class EllpackFormatPolicy {
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
DeviceOrd device_;
public:
using FormatT = EllpackPageRawFormat;
public:
[[nodiscard]] auto CreatePageFormat() const {
CHECK_EQ(cuts_->cut_values_.Device(), device_);
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_}};
return fmt;
}
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device) {
std::swap(cuts_, cuts);
device_ = device;
CHECK(this->device_.IsCUDA());
}
[[nodiscard]] auto GetCuts() {
CHECK(cuts_);
return cuts_;
}
[[nodiscard]] auto Device() const { return device_; }
};
template <typename S, template <typename> typename F>
class EllpackFormatStreamPolicy : public F<S> {
std::shared_ptr<EllpackHostCache> p_cache_;
public:
using WriterT = EllpackHostCacheStream;
using ReaderT = EllpackHostCacheStream;
public:
EllpackFormatStreamPolicy();
[[nodiscard]] std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter);
[[nodiscard]] std::unique_ptr<ReaderT> CreateReader(StringView name, bst_idx_t offset,
bst_idx_t length) const;
};
template <typename F>
class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
using Super = PageSourceIncMixIn<EllpackPage, F>;
bool is_dense_;
bst_idx_t row_stride_;
BatchParam param_;
common::Span<FeatureType const> feature_types_;
std::shared_ptr<common::HistogramCuts const> cuts_;
DeviceOrd device_;
protected:
[[nodiscard]] SparsePageFormat<EllpackPage>* CreatePageFormat() const override {
cuts_->SetDevice(this->device_);
return new EllpackPageRawFormat{cuts_};
}
public:
EllpackPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features,
std::size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
std::shared_ptr<common::HistogramCuts> cuts, bool is_dense,
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
: Super{missing, nthreads, n_features, n_batches, cache, false},
is_dense_{is_dense},
row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types},
cuts_{std::move(cuts)},
device_{device} {
feature_types_{feature_types} {
this->source_ = source;
cuts->SetDevice(device);
this->SetCuts(std::move(cuts), device);
this->Fetch();
}
void Fetch() final;
};
// Cache to host
using EllpackPageHostSource =
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
// Cache to disk
using EllpackPageSource =
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
#if !defined(XGBOOST_USE_CUDA)
inline void EllpackPageSource::Fetch() {
template <typename F>
inline void EllpackPageSourceImpl<F>::Fetch() {
// silent the warning about unused variables.
(void)(row_stride_);
(void)(is_dense_);
(void)(device_);
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)

View File

@@ -17,20 +17,35 @@
#include "xgboost/data.h" // for BatchParam, FeatureType
#include "xgboost/span.h" // for Span
namespace xgboost {
namespace data {
class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
namespace xgboost::data {
/**
* @brief Policy for creating ghist index format. The storage is default (disk).
*/
template <typename S>
class GHistIndexFormatPolicy {
protected:
common::HistogramCuts cuts_;
public:
using FormatT = SparsePageFormat<GHistIndexMatrix>;
public:
[[nodiscard]] auto CreatePageFormat() const {
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
return fmt;
}
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
};
class GradientIndexPageSource
: public PageSourceIncMixIn<
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
bool is_dense_;
std::int32_t max_bin_per_feat_;
common::Span<FeatureType const> feature_types_;
double sparse_thresh_;
protected:
[[nodiscard]] SparsePageFormat<GHistIndexMatrix>* CreatePageFormat() const override {
return new GHistIndexRawFormat{cuts_};
}
public:
GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
@@ -39,17 +54,16 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
std::isnan(param.sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{param.max_bin},
feature_types_{feature_types},
sparse_thresh_{param.sparse_thresh} {
this->source_ = source;
this->SetCuts(std::move(cuts));
this->Fetch();
}
void Fetch() final;
};
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_

View File

@@ -38,13 +38,17 @@ std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT
#endif
} // namespace detail
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t nthreads, std::string cache_prefix)
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
cache_prefix_{std::move(cache_prefix)} {
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
float missing, int32_t nthreads, std::string cache_prefix,
bool on_host)
: proxy_{proxy_handle},
iter_{iter_handle},
reset_{reset},
next_{next},
missing_{missing},
cache_prefix_{std::move(cache_prefix)},
on_host_{on_host} {
Context ctx;
ctx.nthread = nthreads;
@@ -103,8 +107,26 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
fmat_ctx_ = ctx;
}
SparsePageDMatrix::~SparsePageDMatrix() {
// Clear out all resources before deleting the cache file.
sparse_page_source_.reset();
std::visit([](auto &&ptr) { ptr.reset(); }, ellpack_page_source_);
column_source_.reset();
sorted_column_source_.reset();
ghist_index_source_.reset();
for (auto const &kv : cache_info_) {
CHECK(kv.second);
auto n = kv.second->ShardName();
if (kv.second->OnHost()) {
continue;
}
TryDeleteCacheFile(n);
}
}
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
auto id = MakeCache(this, ".row.page", false, cache_prefix_, &cache_info_);
// Don't use proxy DMatrix once this is already initialized, this allows users to
// release the iterator and data.
if (cache_info_.at(id)->written) {
@@ -132,8 +154,9 @@ BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
}
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_);
CHECK_NE(this->Info().num_col_, 0);
error::NoOnHost(on_host_);
this->InitializeSparsePage(ctx);
if (!column_source_) {
column_source_ =
@@ -146,8 +169,9 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
}
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_);
CHECK_NE(this->Info().num_col_, 0);
error::NoOnHost(on_host_);
this->InitializeSparsePage(ctx);
if (!sorted_column_source_) {
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
@@ -165,11 +189,12 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
CHECK_GE(param.max_bin, 2);
}
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
error::NoOnHost(on_host_);
auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
this->InitializeSparsePage(ctx);
cache_info_.erase(id);
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
LOG(INFO) << "Generating new Gradient Index.";
// Use sorted sketch for approx.
auto sorted_sketch = param.regen;
@@ -193,7 +218,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
#if !defined(XGBOOST_USE_CUDA)
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
common::AssertGPUSupport();
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
return BatchSet{BatchIterator<EllpackPage>{nullptr}};
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::data

View File

@@ -1,7 +1,9 @@
/**
* Copyright 2021-2024, XGBoost contributors
*/
#include <memory> // for shared_ptr
#include <memory> // for shared_ptr
#include <utility> // for move
#include <variant> // for visit
#include "../common/hist_util.cuh"
#include "../common/hist_util.h" // for HistogramCuts
@@ -19,13 +21,15 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
CHECK_GE(param.max_bin, 2);
}
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0;
auto id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
bst_idx_t row_stride = 0;
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
this->InitializeSparsePage(ctx);
// reinitialize the cache
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
LOG(INFO) << "Generating new a Ellpack page.";
std::shared_ptr<common::HistogramCuts> cuts;
if (!param.hess.empty()) {
cuts = std::make_shared<common::HistogramCuts>(
@@ -41,17 +45,28 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
CHECK_NE(row_stride, 0);
batch_param_ = param;
auto ft = this->info_.feature_types.ConstDeviceSpan();
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_,
ctx->Device());
auto ft = this->Info().feature_types.ConstDeviceSpan();
if (on_host_ && std::get_if<EllpackHostPtr>(&ellpack_page_source_) == nullptr) {
ellpack_page_source_.emplace<EllpackHostPtr>(nullptr);
}
std::visit(
[&](auto&& ptr) {
ptr.reset(); // make sure resource is released before making new ones.
using SourceT = typename std::remove_reference_t<decltype(ptr)>::element_type;
ptr = std::make_shared<SourceT>(this->missing_, ctx->Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param,
std::move(cuts), this->IsDense(), row_stride, ft,
this->sparse_page_source_, ctx->Device());
},
ellpack_page_source_);
} else {
CHECK(sparse_page_source_);
ellpack_page_source_->Reset();
std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_);
}
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
auto batch_set =
std::visit([this](auto&& ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
this->ellpack_page_source_);
return batch_set;
}
} // namespace xgboost::data

View File

@@ -7,16 +7,20 @@
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <cstdint> // for uint32_t, int32_t
#include <map> // for map
#include <memory> // for shared_ptr
#include <sstream> // for stringstream
#include <string> // for string
#include <variant> // for variant, visit
#include "ellpack_page_source.h"
#include "gradient_index_page_source.h"
#include "sparse_page_source.h"
#include "xgboost/data.h"
#include "ellpack_page_source.h" // for EllpackPageSource, EllpackPageHostSource
#include "gradient_index_page_source.h" // for GradientIndexPageSource
#include "sparse_page_source.h" // for SparsePageSource, Cache
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo
#include "xgboost/logging.h"
#include "xgboost/span.h" // for Span
namespace xgboost::data {
/**
@@ -70,6 +74,7 @@ class SparsePageDMatrix : public DMatrix {
float missing_;
Context fmat_ctx_;
std::string cache_prefix_;
bool on_host_{false};
std::uint32_t n_batches_{0};
// sparse page is the source to other page types, we make a special member function.
void InitializeSparsePage(Context const *ctx);
@@ -79,29 +84,16 @@ class SparsePageDMatrix : public DMatrix {
public:
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
std::string cache_prefix);
std::string cache_prefix, bool on_host = false);
~SparsePageDMatrix() override {
// Clear out all resources before deleting the cache file.
sparse_page_source_.reset();
ellpack_page_source_.reset();
column_source_.reset();
sorted_column_source_.reset();
ghist_index_source_.reset();
for (auto const &kv : cache_info_) {
CHECK(kv.second);
auto n = kv.second->ShardName();
TryDeleteCacheFile(n);
}
}
~SparsePageDMatrix() override;
[[nodiscard]] MetaInfo &Info() override;
[[nodiscard]] const MetaInfo &Info() const override;
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
[[nodiscard]] bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
DMatrix *Slice(common::Span<std::int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
@@ -111,7 +103,7 @@ class SparsePageDMatrix : public DMatrix {
}
[[nodiscard]] bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_source_);
return std::visit([](auto &&ptr) { return static_cast<bool>(ptr); }, ellpack_page_source_);
}
[[nodiscard]] bool GHistIndexExists() const override {
return static_cast<bool>(ghist_index_source_);
@@ -138,7 +130,9 @@ class SparsePageDMatrix : public DMatrix {
private:
// source data pointers.
std::shared_ptr<SparsePageSource> sparse_page_source_;
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
using EllpackDiskPtr = std::shared_ptr<EllpackPageSource>;
using EllpackHostPtr = std::shared_ptr<EllpackPageHostSource>;
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
std::shared_ptr<CSCPageSource> column_source_;
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
@@ -153,15 +147,16 @@ class SparsePageDMatrix : public DMatrix {
/**
* @brief Make cache if it doesn't exist yet.
*/
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, bool on_host,
std::string prefix,
std::map<std::string, std::shared_ptr<Cache>> *out) {
auto &cache_info = *out;
auto name = MakeId(prefix, ptr);
auto id = name + format;
auto it = cache_info.find(id);
if (it == cache_info.cend()) {
cache_info[id].reset(new Cache{false, name, format});
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName() << std::endl;
cache_info[id].reset(new Cache{false, name, format, on_host});
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName();
}
return id;
}

View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#include "sparse_page_source.h"
#include <filesystem> // for exists
#include <string> // for string
#include <cstdio> // for remove
#include <numeric> // for partial_sum
namespace xgboost::data {
void Cache::Commit() {
if (!written) {
std::partial_sum(offset.begin(), offset.end(), offset.begin());
written = true;
}
}
void TryDeleteCacheFile(const std::string& file) {
// Don't throw, this is called in a destructor.
auto exists = std::filesystem::exists(file);
if (!exists) {
LOG(WARNING) << "External memory cache file " << file << " is missing.";
}
if (std::remove(file.c_str()) != 0) {
LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually";
}
}
} // namespace xgboost::data

View File

@@ -8,11 +8,9 @@
#include <algorithm> // for min
#include <atomic> // for atomic
#include <cstdint> // for uint64_t
#include <cstdio> // for remove
#include <future> // for future
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <numeric> // for partial_sum
#include <string> // for string
#include <utility> // for pair, move
#include <vector> // for vector
@@ -27,18 +25,12 @@
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/data.h" // for SparsePage, CSCPage
#include "xgboost/data.h" // for SparsePage, CSCPage, SortedCSCPage
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
#include "xgboost/logging.h" // for CHECK_EQ
namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) {
if (std::remove(file.c_str()) != 0) {
// Don't throw, this is called in a destructor.
LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually";
}
}
void TryDeleteCacheFile(const std::string& file);
/**
* @brief Information about the cache including path and page offsets.
@@ -46,13 +38,14 @@ inline void TryDeleteCacheFile(const std::string& file) {
struct Cache {
// whether the write to the cache is complete
bool written;
bool on_host;
std::string name;
std::string format;
// offset into binary cache file.
std::vector<std::uint64_t> offset;
Cache(bool w, std::string n, std::string fmt)
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
Cache(bool w, std::string n, std::string fmt, bool on_host)
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
offset.push_back(0);
}
@@ -64,6 +57,7 @@ struct Cache {
[[nodiscard]] std::string ShardName() const {
return ShardName(this->name, this->format);
}
[[nodiscard]] bool OnHost() const { return on_host; }
/**
* @brief Record a page with size of n_bytes.
*/
@@ -83,12 +77,7 @@ struct Cache {
/**
* @brief Call this once the write for the cache is complete.
*/
void Commit() {
if (!written) {
std::partial_sum(offset.begin(), offset.end(), offset.begin());
written = true;
}
}
void Commit();
};
// Prevents multi-threaded call to `GetBatches`.
@@ -146,10 +135,59 @@ class ExceHandler {
};
/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
* @brief Default implementation of the stream creater.
*/
template <typename S, template <typename> typename F>
class DefaultFormatStreamPolicy : public F<S> {
public:
using WriterT = common::AlignedFileWriteStream;
using ReaderT = common::AlignedResourceReadStream;
public:
std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter) {
std::unique_ptr<common::AlignedFileWriteStream> fo;
if (iter == 0) {
fo = std::make_unique<common::AlignedFileWriteStream>(name, "wb");
} else {
fo = std::make_unique<common::AlignedFileWriteStream>(name, "ab");
}
return fo;
}
std::unique_ptr<ReaderT> CreateReader(StringView name, std::uint64_t offset,
std::uint64_t length) const {
return std::make_unique<common::PrivateMmapConstStream>(std::string{name}, offset, length);
}
};
/**
* @brief Default implementatioin of the format creator.
*/
template <typename S>
class SparsePageSourceImpl : public BatchIteratorImpl<S> {
class DefaultFormatPolicy {
public:
using FormatT = SparsePageFormat<S>;
public:
auto CreatePageFormat() const {
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
return fmt;
}
};
/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
*
* The interface to external storage is divided into two types. The first one is the
* format, representing how to read and write the binary. The second part is where to
* store the binary cache. These policies are implemented in the `FormatStreamPolicy`
* policy class. The format policy controls how to create the format (the first part), and
* the stream policy decides where the stream should read from and write to (the second
* part). This way we can compose the polices and page types with ease.
*/
template <typename S,
typename FormatStreamPolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPolicy {
protected:
// Prevents calling this iterator from multiple places(or threads).
std::mutex single_threaded_;
@@ -165,7 +203,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
// Index to the current page.
std::uint32_t count_{0};
// Total number of batches.
std::uint32_t n_batches_{0};
bst_idx_t n_batches_{0};
std::shared_ptr<Cache> cache_info_;
@@ -179,10 +217,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
ExceHandler exce_;
common::Monitor monitor_;
[[nodiscard]] virtual SparsePageFormat<S>* CreatePageFormat() const {
return ::xgboost::data::CreatePageFormat<S>("raw");
}
[[nodiscard]] bool ReadCache() {
CHECK(!at_end_);
if (!cache_info_->written) {
@@ -196,8 +230,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::int32_t kPrefetches = 3;
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
n_prefetches = std::max(n_prefetches, 1);
std::int32_t n_prefetch_batches =
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
CHECK_LE(n_prefetch_batches, kPrefetches);
std::size_t fetch_it = count_;
@@ -216,10 +249,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
*GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
auto name = self->cache_info_->ShardName();
auto [offset, length] = self->cache_info_->View(fetch_it);
auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
this->CreateReader(name, offset, length)};
CHECK(fmt->Read(page.get(), fi.get()));
});
return page;
@@ -243,16 +277,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
CHECK(!cache_info_->written);
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
auto fmt{this->CreatePageFormat()};
auto name = cache_info_->ShardName();
std::unique_ptr<common::AlignedFileWriteStream> fo;
if (this->Iter() == 0) {
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "wb");
} else {
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "ab");
}
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
this->CreateWriter(StringView{name}, this->Iter())};
auto bytes = fmt->Write(*page_, fo.get());
timer.Stop();
@@ -265,9 +294,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
virtual void Fetch() = 0;
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{nthreads},
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},
@@ -403,18 +432,19 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
};
// A mixin for advancing the iterator.
template <typename S>
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
template <typename S,
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
protected:
std::shared_ptr<SparsePageSource> source_;
using Super = SparsePageSourceImpl<S>;
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
// so we avoid fetching it.
bool sync_{true};
public:
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
std::uint32_t n_batches, std::shared_ptr<Cache> cache, bool sync)
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
[[nodiscard]] PageSourceIncMixIn& operator++() final {

View File

@@ -234,7 +234,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
// Compact the ELLPACK pages into the single sample page.
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : batch_iterator) {
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
@@ -252,7 +252,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
auto cuctx = ctx->CUDACtx();
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
@@ -279,21 +279,18 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
auto cuctx = ctx->CUDACtx();
bst_idx_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
RandomWeight(common::GlobalRandom()())));
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
size_t sample_rows =
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
IsNonZero());
@@ -301,18 +298,16 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
sample_row_index_.begin());
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
sample_row_index_.begin(), ClearEmptyRows());
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl();
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense,
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), dmat->IsDense(),
first_page->row_stride, sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : batch_iterator) {
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
@@ -363,21 +358,24 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
return sample;
}
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx,
common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
auto cuctx = ctx->CUDACtx();
thrust::fill(cuctx->CTP(), dh::tend(threshold) - 1, dh::tend(threshold),
std::numeric_limits<float>::max());
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
CombineGradientPair{});
thrust::sort(cuctx->TP(), dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(threshold), dh::tend(threshold) - 1,
dh::tbegin(grad_sum));
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::transform(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
SampleRateDelta(threshold, gpair.size(), sample_rows));
thrust::device_ptr<float> min =
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
}
}; // namespace tree

View File

@@ -129,9 +129,8 @@ class GradientBasedSampler {
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
static size_t CalculateThresholdIndex(Context const* ctx, common::Span<GradientPair> gpair,
common::Span<float> threshold, common::Span<float> grad_sum,
size_t sample_rows);
private: