xgboost/src/data/ellpack_page_source.h
Jiaming Yuan e8a962575a
[EM] Allow staging ellpack on host for GPU external memory. (#10488)
- New parameter `on_host`.
- Abstract format creation and stream creation into policy classes.
2024-06-28 04:42:18 +08:00

149 lines
4.8 KiB
C++

/**
* Copyright 2019-2024, XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <utility> // for move
#include "../common/hist_util.h" // for HistogramCuts
#include "ellpack_page.h" // for EllpackPage
#include "ellpack_page_raw_format.h" // for EllpackPageRawFormat
#include "sparse_page_source.h" // for PageSourceIncMixIn
#include "xgboost/base.h" // for bst_idx_t
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/data.h" // for BatchParam
#include "xgboost/span.h" // for Span
namespace xgboost::data {
// We need to decouple the storage and the view of the storage so that we can implement
// concurrent read.
// Dummy type to hide CUDA calls from the host compiler.
struct EllpackHostCache;
// Pimpl to hide CUDA calls from the host compiler.
class EllpackHostCacheStreamImpl;
// A view onto the actual cache implemented by `EllpackHostCache`.
class EllpackHostCacheStream {
std::unique_ptr<EllpackHostCacheStreamImpl> p_impl_;
public:
explicit EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache);
~EllpackHostCacheStream();
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes);
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_pod_v<T>, bst_idx_t> Write(T const& v) {
return this->Write(&v, sizeof(T));
}
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes);
template <typename T>
[[nodiscard]] auto Read(T* ptr) -> std::enable_if_t<std::is_pod_v<T>, bool> {
return this->Read(ptr, sizeof(T));
}
[[nodiscard]] bst_idx_t Tell() const;
void Seek(bst_idx_t offset_bytes);
// Limit the size of read. offset_bytes is the maximum offset that this stream can read
// to. An error is raised if the limited is exceeded.
void Bound(bst_idx_t offset_bytes);
};
template <typename S>
class EllpackFormatPolicy {
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
DeviceOrd device_;
public:
using FormatT = EllpackPageRawFormat;
public:
[[nodiscard]] auto CreatePageFormat() const {
CHECK_EQ(cuts_->cut_values_.Device(), device_);
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_}};
return fmt;
}
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device) {
std::swap(cuts_, cuts);
device_ = device;
CHECK(this->device_.IsCUDA());
}
[[nodiscard]] auto GetCuts() {
CHECK(cuts_);
return cuts_;
}
[[nodiscard]] auto Device() const { return device_; }
};
template <typename S, template <typename> typename F>
class EllpackFormatStreamPolicy : public F<S> {
std::shared_ptr<EllpackHostCache> p_cache_;
public:
using WriterT = EllpackHostCacheStream;
using ReaderT = EllpackHostCacheStream;
public:
EllpackFormatStreamPolicy();
[[nodiscard]] std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter);
[[nodiscard]] std::unique_ptr<ReaderT> CreateReader(StringView name, bst_idx_t offset,
bst_idx_t length) const;
};
template <typename F>
class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
using Super = PageSourceIncMixIn<EllpackPage, F>;
bool is_dense_;
bst_idx_t row_stride_;
BatchParam param_;
common::Span<FeatureType const> feature_types_;
public:
EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features,
std::size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
std::shared_ptr<common::HistogramCuts> cuts, bool is_dense,
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
: Super{missing, nthreads, n_features, n_batches, cache, false},
is_dense_{is_dense},
row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types} {
this->source_ = source;
cuts->SetDevice(device);
this->SetCuts(std::move(cuts), device);
this->Fetch();
}
void Fetch() final;
};
// Cache to host
using EllpackPageHostSource =
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
// Cache to disk
using EllpackPageSource =
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
#if !defined(XGBOOST_USE_CUDA)
template <typename F>
inline void EllpackPageSourceImpl<F>::Fetch() {
// silent the warning about unused variables.
(void)(row_stride_);
(void)(is_dense_);
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::data
#endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_