[EM] Allow staging ellpack on host for GPU external memory. (#10488)
- New parameter `on_host`. - Abstract format creation and stream creation into policy classes.
This commit is contained in:
@@ -298,13 +298,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
auto missing = GetMissing(jconfig);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
|
||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@@ -429,7 +429,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
}
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
XGBDefaultDeviceAllocatorImpl()
|
||||
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {}
|
||||
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()) {}
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
};
|
||||
|
||||
@@ -484,8 +484,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
}
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
XGBCachingDeviceAllocatorImpl()
|
||||
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()),
|
||||
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
|
||||
: SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()),
|
||||
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
||||
private:
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
|
||||
#define XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
#include <cinttypes> // for uint64_t
|
||||
#include <cstdint> // for uint64_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
|
||||
@@ -103,5 +103,11 @@ inline auto NoFederated() { return "XGBoost is not compiled with federated learn
|
||||
inline auto NoCategorical(std::string name) {
|
||||
return name + " doesn't support categorical features.";
|
||||
}
|
||||
|
||||
inline void NoOnHost(bool on_host) {
|
||||
if (on_host) {
|
||||
LOG(FATAL) << "Caching on host memory is only available for GPU.";
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@@ -163,7 +163,7 @@ class HistogramCuts {
|
||||
return vals[bin_idx - 1];
|
||||
}
|
||||
|
||||
void SetDevice(DeviceOrd d) const {
|
||||
void SetDevice(DeviceOrd d) {
|
||||
this->cut_ptrs_.SetDevice(d);
|
||||
this->cut_ptrs_.ConstDevicePointer();
|
||||
|
||||
|
||||
@@ -901,15 +901,12 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_p
|
||||
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
|
||||
}
|
||||
|
||||
template <typename DataIterHandle, typename DMatrixHandle,
|
||||
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
|
||||
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int32_t n_threads,
|
||||
std::string cache) {
|
||||
return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, n_threads,
|
||||
cache);
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
typename XGDMatrixCallbackNext>
|
||||
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int32_t n_threads,
|
||||
std::string cache, bool on_host) {
|
||||
return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host};
|
||||
}
|
||||
|
||||
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
||||
@@ -919,10 +916,11 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
|
||||
XGDMatrixCallbackNext* next, float missing,
|
||||
int nthread, int max_bin);
|
||||
|
||||
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
|
||||
DataIterResetCallback, XGDMatrixCallbackNext>(
|
||||
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
|
||||
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
||||
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing,
|
||||
int32_t n_threads, std::string, bool);
|
||||
|
||||
template <typename AdapterT>
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
||||
|
||||
@@ -36,7 +36,7 @@ void EllpackPage::SetBaseRowId(std::size_t) {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
}
|
||||
size_t EllpackPage::Size() const {
|
||||
bst_idx_t EllpackPage::Size() const {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
return 0;
|
||||
|
||||
@@ -29,7 +29,7 @@ EllpackPage::~EllpackPage() = default;
|
||||
|
||||
EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }
|
||||
|
||||
size_t EllpackPage::Size() const { return impl_->Size(); }
|
||||
[[nodiscard]] bst_idx_t EllpackPage::Size() const { return impl_->Size(); }
|
||||
|
||||
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
|
||||
|
||||
@@ -91,13 +91,13 @@ __global__ void CompressBinEllpackKernel(
|
||||
// Construct an ELLPACK matrix with the given number of empty rows.
|
||||
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
||||
size_t row_stride, size_t n_rows)
|
||||
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) {
|
||||
bst_idx_t row_stride, bst_idx_t n_rows)
|
||||
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} {
|
||||
monitor_.Init("ellpack_page");
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
InitCompressedData(device);
|
||||
this->InitCompressedData(device);
|
||||
monitor_.Stop("InitCompressedData");
|
||||
}
|
||||
|
||||
@@ -403,7 +403,7 @@ struct CopyPage {
|
||||
// Copy the data from the given EllpackPage to the current page.
|
||||
size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) {
|
||||
monitor_.Start("Copy");
|
||||
size_t num_elements = page->n_rows * page->row_stride;
|
||||
bst_idx_t num_elements = page->n_rows * page->row_stride;
|
||||
CHECK_EQ(row_stride, page->row_stride);
|
||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
||||
@@ -461,16 +461,17 @@ struct CompactPage {
|
||||
};
|
||||
|
||||
// Compacts the data from the given EllpackPage into the current page.
|
||||
void EllpackPageImpl::Compact(DeviceOrd device, EllpackPageImpl const* page,
|
||||
void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page,
|
||||
common::Span<size_t> row_indexes) {
|
||||
monitor_.Start("Compact");
|
||||
monitor_.Start(__func__);
|
||||
CHECK_EQ(row_stride, page->row_stride);
|
||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||
CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size());
|
||||
gidx_buffer.SetDevice(device);
|
||||
page->gidx_buffer.SetDevice(device);
|
||||
dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes));
|
||||
monitor_.Stop("Compact");
|
||||
gidx_buffer.SetDevice(ctx->Device());
|
||||
page->gidx_buffer.SetDevice(ctx->Device());
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
dh::LaunchN(page->n_rows, cuctx->Stream(), CompactPage(this, page, row_indexes));
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
// Initialize the buffer to stored compressed features.
|
||||
@@ -551,7 +552,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
}
|
||||
|
||||
// Return the number of rows contained in this page.
|
||||
size_t EllpackPageImpl::Size() const { return n_rows; }
|
||||
[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; }
|
||||
|
||||
// Return the memory cost for storing the compressed features.
|
||||
size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
||||
|
||||
@@ -143,7 +143,7 @@ class EllpackPageImpl {
|
||||
* and the given number of rows.
|
||||
*/
|
||||
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
|
||||
bool is_dense, size_t row_stride, size_t n_rows);
|
||||
bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows);
|
||||
/*!
|
||||
* \brief Constructor used for external memory.
|
||||
*/
|
||||
@@ -181,14 +181,14 @@ class EllpackPageImpl {
|
||||
|
||||
/*! \brief Compact the given ELLPACK page into the current page.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param context The GPU context.
|
||||
* @param page The ELLPACK page to compact from.
|
||||
* @param row_indexes Row indexes for the compacted page.
|
||||
*/
|
||||
void Compact(DeviceOrd device, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
||||
void Compact(Context const* ctx, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
[[nodiscard]] size_t Size() const;
|
||||
[[nodiscard]] bst_idx_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(std::size_t row_id) {
|
||||
@@ -231,7 +231,7 @@ class EllpackPageImpl {
|
||||
/*! \brief Whether or not if the matrix is dense. */
|
||||
bool is_dense;
|
||||
/*! \brief Row length for ELLPACK. */
|
||||
size_t row_stride;
|
||||
bst_idx_t row_stride;
|
||||
bst_idx_t base_rowid{0};
|
||||
bst_idx_t n_rows{};
|
||||
/*! \brief global index of histogram, which is stored in ELLPACK format. */
|
||||
|
||||
@@ -41,7 +41,7 @@ class EllpackPage {
|
||||
EllpackPage(EllpackPage&& that);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
[[nodiscard]] size_t Size() const;
|
||||
[[nodiscard]] bst_idx_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(std::size_t row_id);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
||||
#include "ellpack_page.cuh" // for EllpackPage
|
||||
#include "ellpack_page_raw_format.h"
|
||||
#include "ellpack_page_source.h"
|
||||
|
||||
namespace xgboost::data {
|
||||
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
||||
@@ -32,7 +33,6 @@ template <typename T>
|
||||
return false;
|
||||
}
|
||||
|
||||
vec->SetDevice(DeviceOrd::CUDA(0));
|
||||
vec->Resize(n);
|
||||
auto d_vec = vec->DeviceSpan();
|
||||
dh::safe_cuda(
|
||||
@@ -54,6 +54,7 @@ template <typename T>
|
||||
if (!fi->Read(&impl->row_stride)) {
|
||||
return false;
|
||||
}
|
||||
impl->gidx_buffer.SetDevice(device_);
|
||||
if (!ReadDeviceVec(fi, &impl->gidx_buffer)) {
|
||||
return false;
|
||||
}
|
||||
@@ -73,6 +74,65 @@ template <typename T>
|
||||
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
|
||||
bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector());
|
||||
bytes += fo->Write(impl->base_rowid);
|
||||
dh::DefaultStream().Sync();
|
||||
return bytes;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const {
|
||||
auto* impl = page->Impl();
|
||||
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
||||
impl->SetCuts(this->cuts_);
|
||||
if (!fi->Read(&impl->n_rows)) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&impl->is_dense)) {
|
||||
return false;
|
||||
}
|
||||
if (!fi->Read(&impl->row_stride)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read vec
|
||||
bst_idx_t n{0};
|
||||
if (!fi->Read(&n)) {
|
||||
return false;
|
||||
}
|
||||
if (n != 0) {
|
||||
impl->gidx_buffer.SetDevice(device_);
|
||||
impl->gidx_buffer.Resize(n);
|
||||
auto span = impl->gidx_buffer.DeviceSpan();
|
||||
if (!fi->Read(span.data(), span.size_bytes())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!fi->Read(&impl->base_rowid)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
|
||||
EllpackHostCacheStream* fo) const {
|
||||
bst_idx_t bytes{0};
|
||||
auto* impl = page.Impl();
|
||||
bytes += fo->Write(impl->n_rows);
|
||||
bytes += fo->Write(impl->is_dense);
|
||||
bytes += fo->Write(impl->row_stride);
|
||||
|
||||
// Write vector
|
||||
bst_idx_t n = impl->gidx_buffer.Size();
|
||||
bytes += fo->Write(n);
|
||||
|
||||
if (!impl->gidx_buffer.Empty()) {
|
||||
auto span = impl->gidx_buffer.ConstDeviceSpan();
|
||||
bytes += fo->Write(span.data(), span.size_bytes());
|
||||
}
|
||||
bytes += fo->Write(impl->base_rowid);
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
return bytes;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -20,15 +20,22 @@ class HistogramCuts;
|
||||
}
|
||||
|
||||
namespace xgboost::data {
|
||||
|
||||
class EllpackHostCacheStream;
|
||||
|
||||
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
||||
DeviceOrd device_;
|
||||
|
||||
public:
|
||||
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts)
|
||||
: cuts_{std::move(cuts)} {}
|
||||
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device)
|
||||
: cuts_{std::move(cuts)}, device_{device} {}
|
||||
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
|
||||
[[nodiscard]] std::size_t Write(const EllpackPage& page,
|
||||
common::AlignedFileWriteStream* fo) override;
|
||||
|
||||
[[nodiscard]] bool Read(EllpackPage* page, EllpackHostCacheStream* fi) const;
|
||||
[[nodiscard]] std::size_t Write(const EllpackPage& page, EllpackHostCacheStream* fo) const;
|
||||
};
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@@ -1,29 +1,161 @@
|
||||
/**
|
||||
* Copyright 2019-2024, XGBoost contributors
|
||||
*/
|
||||
#include <memory>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "ellpack_page.cuh"
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, uint64_t, uint32_t
|
||||
#include <memory> // for shared_ptr, make_unique, make_shared
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/common.h" // for safe_cuda
|
||||
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
|
||||
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
|
||||
#include "ellpack_page.cuh" // for EllpackPageImpl
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "ellpack_page_source.h"
|
||||
#include "xgboost/base.h" // for bst_idx_t
|
||||
|
||||
namespace xgboost::data {
|
||||
void EllpackPageSource::Fetch() {
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
struct EllpackHostCache {
|
||||
thrust::host_vector<std::int8_t, common::cuda::pinned_allocator<std::int8_t>> cache;
|
||||
|
||||
void Resize(std::size_t n, dh::CUDAStreamView stream) {
|
||||
stream.Sync(); // Prevent partial copy inside resize.
|
||||
cache.resize(n);
|
||||
}
|
||||
};
|
||||
|
||||
class EllpackHostCacheStreamImpl {
|
||||
std::shared_ptr<EllpackHostCache> cache_;
|
||||
bst_idx_t cur_ptr_{0};
|
||||
bst_idx_t bound_{0};
|
||||
|
||||
public:
|
||||
explicit EllpackHostCacheStreamImpl(std::shared_ptr<EllpackHostCache> cache)
|
||||
: cache_{std::move(cache)} {}
|
||||
|
||||
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes) {
|
||||
auto n = cur_ptr_ + n_bytes;
|
||||
if (n > cache_->cache.size()) {
|
||||
cache_->Resize(n, dh::DefaultStream());
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpyAsync(cache_->cache.data() + cur_ptr_, ptr, n_bytes, cudaMemcpyDefault,
|
||||
dh::DefaultStream()));
|
||||
cur_ptr_ = n;
|
||||
return n_bytes;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes) {
|
||||
CHECK_LE(cur_ptr_ + n_bytes, bound_);
|
||||
dh::safe_cuda(cudaMemcpyAsync(ptr, cache_->cache.data() + cur_ptr_, n_bytes, cudaMemcpyDefault,
|
||||
dh::DefaultStream()));
|
||||
cur_ptr_ += n_bytes;
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] bst_idx_t Tell() const { return cur_ptr_; }
|
||||
void Seek(bst_idx_t offset_bytes) { cur_ptr_ = offset_bytes; }
|
||||
void Bound(bst_idx_t offset_bytes) {
|
||||
CHECK_LE(offset_bytes, cache_->cache.size());
|
||||
this->bound_ = offset_bytes;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* EllpackHostCacheStream
|
||||
*/
|
||||
|
||||
EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache)
|
||||
: p_impl_{std::make_unique<EllpackHostCacheStreamImpl>(std::move(cache))} {}
|
||||
|
||||
EllpackHostCacheStream::~EllpackHostCacheStream() = default;
|
||||
|
||||
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Write(void const* ptr, bst_idx_t n_bytes) {
|
||||
return this->p_impl_->Write(ptr, n_bytes);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool EllpackHostCacheStream::Read(void* ptr, bst_idx_t n_bytes) {
|
||||
return this->p_impl_->Read(ptr, n_bytes);
|
||||
}
|
||||
|
||||
[[nodiscard]] bst_idx_t EllpackHostCacheStream::Tell() const { return this->p_impl_->Tell(); }
|
||||
|
||||
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
|
||||
|
||||
void EllpackHostCacheStream::Bound(bst_idx_t offset_bytes) { this->p_impl_->Bound(offset_bytes); }
|
||||
|
||||
/**
|
||||
* EllpackFormatType
|
||||
*/
|
||||
|
||||
template <typename S, template <typename> typename F>
|
||||
EllpackFormatStreamPolicy<S, F>::EllpackFormatStreamPolicy()
|
||||
: p_cache_{std::make_shared<EllpackHostCache>()} {}
|
||||
|
||||
template <typename S, template <typename> typename F>
|
||||
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::WriterT>
|
||||
EllpackFormatStreamPolicy<S, F>::CreateWriter(StringView, std::uint32_t iter) {
|
||||
auto fo = std::make_unique<EllpackHostCacheStream>(this->p_cache_);
|
||||
if (iter == 0) {
|
||||
CHECK(this->p_cache_->cache.empty());
|
||||
} else {
|
||||
fo->Seek(this->p_cache_->cache.size());
|
||||
}
|
||||
return fo;
|
||||
}
|
||||
|
||||
template <typename S, template <typename> typename F>
|
||||
[[nodiscard]] std::unique_ptr<typename EllpackFormatStreamPolicy<S, F>::ReaderT>
|
||||
EllpackFormatStreamPolicy<S, F>::CreateReader(StringView, bst_idx_t offset,
|
||||
bst_idx_t length) const {
|
||||
auto fi = std::make_unique<ReaderT>(this->p_cache_);
|
||||
fi->Seek(offset);
|
||||
fi->Bound(offset + length);
|
||||
CHECK_EQ(fi->Tell(), offset);
|
||||
return fi;
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::EllpackFormatStreamPolicy();
|
||||
|
||||
template std::unique_ptr<
|
||||
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::WriterT>
|
||||
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringView name,
|
||||
std::uint32_t iter);
|
||||
|
||||
template std::unique_ptr<
|
||||
typename EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
|
||||
EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
|
||||
StringView name, std::uint64_t offset, std::uint64_t length) const;
|
||||
|
||||
/**
|
||||
* EllpackPageSourceImpl
|
||||
*/
|
||||
template <typename F>
|
||||
void EllpackPageSourceImpl<F>::Fetch() {
|
||||
dh::safe_cuda(cudaSetDevice(this->Device().ordinal));
|
||||
if (!this->ReadCache()) {
|
||||
if (count_ != 0 && !sync_) {
|
||||
if (this->count_ != 0 && !this->sync_) {
|
||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||
// there's no need to increment the source.
|
||||
++(*source_);
|
||||
++(*this->source_);
|
||||
}
|
||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||
CHECK_EQ(count_, source_->Iter());
|
||||
auto const &csr = source_->Page();
|
||||
CHECK_EQ(this->count_, this->source_->Iter());
|
||||
auto const& csr = this->source_->Page();
|
||||
this->page_.reset(new EllpackPage{});
|
||||
auto *impl = this->page_->Impl();
|
||||
*impl = EllpackPageImpl(device_, cuts_, *csr, is_dense_, row_stride_, feature_types_);
|
||||
page_->SetBaseRowId(csr->base_rowid);
|
||||
auto* impl = this->page_->Impl();
|
||||
*impl = EllpackPageImpl{this->Device(), this->GetCuts(), *csr,
|
||||
is_dense_, row_stride_, feature_types_};
|
||||
this->page_->SetBaseRowId(csr->base_rowid);
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template void
|
||||
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||
template void
|
||||
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -19,46 +19,127 @@
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::data {
|
||||
class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
||||
// We need to decouple the storage and the view of the storage so that we can implement
|
||||
// concurrent read.
|
||||
|
||||
// Dummy type to hide CUDA calls from the host compiler.
|
||||
struct EllpackHostCache;
|
||||
// Pimpl to hide CUDA calls from the host compiler.
|
||||
class EllpackHostCacheStreamImpl;
|
||||
|
||||
// A view onto the actual cache implemented by `EllpackHostCache`.
|
||||
class EllpackHostCacheStream {
|
||||
std::unique_ptr<EllpackHostCacheStreamImpl> p_impl_;
|
||||
|
||||
public:
|
||||
explicit EllpackHostCacheStream(std::shared_ptr<EllpackHostCache> cache);
|
||||
~EllpackHostCacheStream();
|
||||
|
||||
[[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes);
|
||||
template <typename T>
|
||||
[[nodiscard]] std::enable_if_t<std::is_pod_v<T>, bst_idx_t> Write(T const& v) {
|
||||
return this->Write(&v, sizeof(T));
|
||||
}
|
||||
|
||||
[[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes);
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] auto Read(T* ptr) -> std::enable_if_t<std::is_pod_v<T>, bool> {
|
||||
return this->Read(ptr, sizeof(T));
|
||||
}
|
||||
|
||||
[[nodiscard]] bst_idx_t Tell() const;
|
||||
void Seek(bst_idx_t offset_bytes);
|
||||
// Limit the size of read. offset_bytes is the maximum offset that this stream can read
|
||||
// to. An error is raised if the limited is exceeded.
|
||||
void Bound(bst_idx_t offset_bytes);
|
||||
};
|
||||
|
||||
template <typename S>
|
||||
class EllpackFormatPolicy {
|
||||
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
||||
DeviceOrd device_;
|
||||
|
||||
public:
|
||||
using FormatT = EllpackPageRawFormat;
|
||||
|
||||
public:
|
||||
[[nodiscard]] auto CreatePageFormat() const {
|
||||
CHECK_EQ(cuts_->cut_values_.Device(), device_);
|
||||
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_}};
|
||||
return fmt;
|
||||
}
|
||||
|
||||
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device) {
|
||||
std::swap(cuts_, cuts);
|
||||
device_ = device;
|
||||
CHECK(this->device_.IsCUDA());
|
||||
}
|
||||
[[nodiscard]] auto GetCuts() {
|
||||
CHECK(cuts_);
|
||||
return cuts_;
|
||||
}
|
||||
[[nodiscard]] auto Device() const { return device_; }
|
||||
};
|
||||
|
||||
template <typename S, template <typename> typename F>
|
||||
class EllpackFormatStreamPolicy : public F<S> {
|
||||
std::shared_ptr<EllpackHostCache> p_cache_;
|
||||
|
||||
public:
|
||||
using WriterT = EllpackHostCacheStream;
|
||||
using ReaderT = EllpackHostCacheStream;
|
||||
|
||||
public:
|
||||
EllpackFormatStreamPolicy();
|
||||
[[nodiscard]] std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter);
|
||||
|
||||
[[nodiscard]] std::unique_ptr<ReaderT> CreateReader(StringView name, bst_idx_t offset,
|
||||
bst_idx_t length) const;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
|
||||
using Super = PageSourceIncMixIn<EllpackPage, F>;
|
||||
bool is_dense_;
|
||||
bst_idx_t row_stride_;
|
||||
BatchParam param_;
|
||||
common::Span<FeatureType const> feature_types_;
|
||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
||||
DeviceOrd device_;
|
||||
|
||||
protected:
|
||||
[[nodiscard]] SparsePageFormat<EllpackPage>* CreatePageFormat() const override {
|
||||
cuts_->SetDevice(this->device_);
|
||||
return new EllpackPageRawFormat{cuts_};
|
||||
}
|
||||
|
||||
public:
|
||||
EllpackPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
||||
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
||||
EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||
std::size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
||||
std::shared_ptr<common::HistogramCuts> cuts, bool is_dense,
|
||||
bst_idx_t row_stride, common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
|
||||
: Super{missing, nthreads, n_features, n_batches, cache, false},
|
||||
is_dense_{is_dense},
|
||||
row_stride_{row_stride},
|
||||
param_{std::move(param)},
|
||||
feature_types_{feature_types},
|
||||
cuts_{std::move(cuts)},
|
||||
device_{device} {
|
||||
feature_types_{feature_types} {
|
||||
this->source_ = source;
|
||||
cuts->SetDevice(device);
|
||||
this->SetCuts(std::move(cuts), device);
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
void Fetch() final;
|
||||
};
|
||||
|
||||
// Cache to host
|
||||
using EllpackPageHostSource =
|
||||
EllpackPageSourceImpl<EllpackFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||
|
||||
// Cache to disk
|
||||
using EllpackPageSource =
|
||||
EllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void EllpackPageSource::Fetch() {
|
||||
template <typename F>
|
||||
inline void EllpackPageSourceImpl<F>::Fetch() {
|
||||
// silent the warning about unused variables.
|
||||
(void)(row_stride_);
|
||||
(void)(is_dense_);
|
||||
(void)(device_);
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@@ -17,20 +17,35 @@
|
||||
#include "xgboost/data.h" // for BatchParam, FeatureType
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
||||
namespace xgboost::data {
|
||||
/**
|
||||
* @brief Policy for creating ghist index format. The storage is default (disk).
|
||||
*/
|
||||
template <typename S>
|
||||
class GHistIndexFormatPolicy {
|
||||
protected:
|
||||
common::HistogramCuts cuts_;
|
||||
|
||||
public:
|
||||
using FormatT = SparsePageFormat<GHistIndexMatrix>;
|
||||
|
||||
public:
|
||||
[[nodiscard]] auto CreatePageFormat() const {
|
||||
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
|
||||
return fmt;
|
||||
}
|
||||
|
||||
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
|
||||
};
|
||||
|
||||
class GradientIndexPageSource
|
||||
: public PageSourceIncMixIn<
|
||||
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
|
||||
bool is_dense_;
|
||||
std::int32_t max_bin_per_feat_;
|
||||
common::Span<FeatureType const> feature_types_;
|
||||
double sparse_thresh_;
|
||||
|
||||
protected:
|
||||
[[nodiscard]] SparsePageFormat<GHistIndexMatrix>* CreatePageFormat() const override {
|
||||
return new GHistIndexRawFormat{cuts_};
|
||||
}
|
||||
|
||||
public:
|
||||
GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
|
||||
@@ -39,17 +54,16 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
|
||||
std::isnan(param.sparse_thresh)),
|
||||
cuts_{std::move(cuts)},
|
||||
is_dense_{is_dense},
|
||||
max_bin_per_feat_{param.max_bin},
|
||||
feature_types_{feature_types},
|
||||
sparse_thresh_{param.sparse_thresh} {
|
||||
this->source_ = source;
|
||||
this->SetCuts(std::move(cuts));
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
void Fetch() final;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
|
||||
|
||||
@@ -38,13 +38,17 @@ std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
|
||||
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int32_t nthreads, std::string cache_prefix)
|
||||
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
|
||||
cache_prefix_{std::move(cache_prefix)} {
|
||||
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
||||
float missing, int32_t nthreads, std::string cache_prefix,
|
||||
bool on_host)
|
||||
: proxy_{proxy_handle},
|
||||
iter_{iter_handle},
|
||||
reset_{reset},
|
||||
next_{next},
|
||||
missing_{missing},
|
||||
cache_prefix_{std::move(cache_prefix)},
|
||||
on_host_{on_host} {
|
||||
Context ctx;
|
||||
ctx.nthread = nthreads;
|
||||
|
||||
@@ -103,8 +107,26 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
||||
fmat_ctx_ = ctx;
|
||||
}
|
||||
|
||||
SparsePageDMatrix::~SparsePageDMatrix() {
|
||||
// Clear out all resources before deleting the cache file.
|
||||
sparse_page_source_.reset();
|
||||
std::visit([](auto &&ptr) { ptr.reset(); }, ellpack_page_source_);
|
||||
column_source_.reset();
|
||||
sorted_column_source_.reset();
|
||||
ghist_index_source_.reset();
|
||||
|
||||
for (auto const &kv : cache_info_) {
|
||||
CHECK(kv.second);
|
||||
auto n = kv.second->ShardName();
|
||||
if (kv.second->OnHost()) {
|
||||
continue;
|
||||
}
|
||||
TryDeleteCacheFile(n);
|
||||
}
|
||||
}
|
||||
|
||||
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
||||
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
|
||||
auto id = MakeCache(this, ".row.page", false, cache_prefix_, &cache_info_);
|
||||
// Don't use proxy DMatrix once this is already initialized, this allows users to
|
||||
// release the iterator and data.
|
||||
if (cache_info_.at(id)->written) {
|
||||
@@ -132,8 +154,9 @@ BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||
}
|
||||
|
||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
|
||||
auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_);
|
||||
CHECK_NE(this->Info().num_col_, 0);
|
||||
error::NoOnHost(on_host_);
|
||||
this->InitializeSparsePage(ctx);
|
||||
if (!column_source_) {
|
||||
column_source_ =
|
||||
@@ -146,8 +169,9 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||
}
|
||||
|
||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
||||
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
|
||||
auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_);
|
||||
CHECK_NE(this->Info().num_col_, 0);
|
||||
error::NoOnHost(on_host_);
|
||||
this->InitializeSparsePage(ctx);
|
||||
if (!sorted_column_source_) {
|
||||
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
||||
@@ -165,11 +189,12 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
}
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
error::NoOnHost(on_host_);
|
||||
auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
|
||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||
this->InitializeSparsePage(ctx);
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
|
||||
LOG(INFO) << "Generating new Gradient Index.";
|
||||
// Use sorted sketch for approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
@@ -193,7 +218,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
||||
common::AssertGPUSupport();
|
||||
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
||||
return BatchSet{BatchIterator<EllpackPage>{nullptr}};
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost contributors
|
||||
*/
|
||||
#include <memory> // for shared_ptr
|
||||
#include <memory> // for shared_ptr
|
||||
#include <utility> // for move
|
||||
#include <variant> // for visit
|
||||
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
@@ -19,13 +21,15 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
}
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
size_t row_stride = 0;
|
||||
auto id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
|
||||
|
||||
bst_idx_t row_stride = 0;
|
||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||
this->InitializeSparsePage(ctx);
|
||||
// reinitialize the cache
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
|
||||
LOG(INFO) << "Generating new a Ellpack page.";
|
||||
std::shared_ptr<common::HistogramCuts> cuts;
|
||||
if (!param.hess.empty()) {
|
||||
cuts = std::make_shared<common::HistogramCuts>(
|
||||
@@ -41,17 +45,28 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
CHECK_NE(row_stride, 0);
|
||||
batch_param_ = param;
|
||||
|
||||
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
||||
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
|
||||
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
|
||||
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_,
|
||||
ctx->Device());
|
||||
auto ft = this->Info().feature_types.ConstDeviceSpan();
|
||||
if (on_host_ && std::get_if<EllpackHostPtr>(&ellpack_page_source_) == nullptr) {
|
||||
ellpack_page_source_.emplace<EllpackHostPtr>(nullptr);
|
||||
}
|
||||
std::visit(
|
||||
[&](auto&& ptr) {
|
||||
ptr.reset(); // make sure resource is released before making new ones.
|
||||
using SourceT = typename std::remove_reference_t<decltype(ptr)>::element_type;
|
||||
ptr = std::make_shared<SourceT>(this->missing_, ctx->Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param,
|
||||
std::move(cuts), this->IsDense(), row_stride, ft,
|
||||
this->sparse_page_source_, ctx->Device());
|
||||
},
|
||||
ellpack_page_source_);
|
||||
} else {
|
||||
CHECK(sparse_page_source_);
|
||||
ellpack_page_source_->Reset();
|
||||
std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_);
|
||||
}
|
||||
|
||||
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
||||
auto batch_set =
|
||||
std::visit([this](auto&& ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
|
||||
this->ellpack_page_source_);
|
||||
return batch_set;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -7,16 +7,20 @@
|
||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <cstdint> // for uint32_t, int32_t
|
||||
#include <map> // for map
|
||||
#include <memory> // for shared_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <string> // for string
|
||||
#include <variant> // for variant, visit
|
||||
|
||||
#include "ellpack_page_source.h"
|
||||
#include "gradient_index_page_source.h"
|
||||
#include "sparse_page_source.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "ellpack_page_source.h" // for EllpackPageSource, EllpackPageHostSource
|
||||
#include "gradient_index_page_source.h" // for GradientIndexPageSource
|
||||
#include "sparse_page_source.h" // for SparsePageSource, Cache
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::data {
|
||||
/**
|
||||
@@ -70,6 +74,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
float missing_;
|
||||
Context fmat_ctx_;
|
||||
std::string cache_prefix_;
|
||||
bool on_host_{false};
|
||||
std::uint32_t n_batches_{0};
|
||||
// sparse page is the source to other page types, we make a special member function.
|
||||
void InitializeSparsePage(Context const *ctx);
|
||||
@@ -79,29 +84,16 @@ class SparsePageDMatrix : public DMatrix {
|
||||
public:
|
||||
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
|
||||
std::string cache_prefix);
|
||||
std::string cache_prefix, bool on_host = false);
|
||||
|
||||
~SparsePageDMatrix() override {
|
||||
// Clear out all resources before deleting the cache file.
|
||||
sparse_page_source_.reset();
|
||||
ellpack_page_source_.reset();
|
||||
column_source_.reset();
|
||||
sorted_column_source_.reset();
|
||||
ghist_index_source_.reset();
|
||||
|
||||
for (auto const &kv : cache_info_) {
|
||||
CHECK(kv.second);
|
||||
auto n = kv.second->ShardName();
|
||||
TryDeleteCacheFile(n);
|
||||
}
|
||||
}
|
||||
~SparsePageDMatrix() override;
|
||||
|
||||
[[nodiscard]] MetaInfo &Info() override;
|
||||
[[nodiscard]] const MetaInfo &Info() const override;
|
||||
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
||||
// The only DMatrix implementation that returns false.
|
||||
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
||||
DMatrix *Slice(common::Span<std::int32_t const>) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||
return nullptr;
|
||||
}
|
||||
@@ -111,7 +103,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
}
|
||||
|
||||
[[nodiscard]] bool EllpackExists() const override {
|
||||
return static_cast<bool>(ellpack_page_source_);
|
||||
return std::visit([](auto &&ptr) { return static_cast<bool>(ptr); }, ellpack_page_source_);
|
||||
}
|
||||
[[nodiscard]] bool GHistIndexExists() const override {
|
||||
return static_cast<bool>(ghist_index_source_);
|
||||
@@ -138,7 +130,9 @@ class SparsePageDMatrix : public DMatrix {
|
||||
private:
|
||||
// source data pointers.
|
||||
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
||||
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
|
||||
using EllpackDiskPtr = std::shared_ptr<EllpackPageSource>;
|
||||
using EllpackHostPtr = std::shared_ptr<EllpackPageHostSource>;
|
||||
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
|
||||
std::shared_ptr<CSCPageSource> column_source_;
|
||||
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
||||
@@ -153,15 +147,16 @@ class SparsePageDMatrix : public DMatrix {
|
||||
/**
|
||||
* @brief Make cache if it doesn't exist yet.
|
||||
*/
|
||||
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
||||
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, bool on_host,
|
||||
std::string prefix,
|
||||
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
||||
auto &cache_info = *out;
|
||||
auto name = MakeId(prefix, ptr);
|
||||
auto id = name + format;
|
||||
auto it = cache_info.find(id);
|
||||
if (it == cache_info.cend()) {
|
||||
cache_info[id].reset(new Cache{false, name, format});
|
||||
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName() << std::endl;
|
||||
cache_info[id].reset(new Cache{false, name, format, on_host});
|
||||
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName();
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
30
src/data/sparse_page_source.cc
Normal file
30
src/data/sparse_page_source.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
#include <filesystem> // for exists
|
||||
#include <string> // for string
|
||||
#include <cstdio> // for remove
|
||||
#include <numeric> // for partial_sum
|
||||
|
||||
namespace xgboost::data {
|
||||
void Cache::Commit() {
|
||||
if (!written) {
|
||||
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
||||
written = true;
|
||||
}
|
||||
}
|
||||
|
||||
void TryDeleteCacheFile(const std::string& file) {
|
||||
// Don't throw, this is called in a destructor.
|
||||
auto exists = std::filesystem::exists(file);
|
||||
if (!exists) {
|
||||
LOG(WARNING) << "External memory cache file " << file << " is missing.";
|
||||
}
|
||||
if (std::remove(file.c_str()) != 0) {
|
||||
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||
<< "; you may want to remove it manually";
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
@@ -8,11 +8,9 @@
|
||||
#include <algorithm> // for min
|
||||
#include <atomic> // for atomic
|
||||
#include <cstdint> // for uint64_t
|
||||
#include <cstdio> // for remove
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <numeric> // for partial_sum
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair, move
|
||||
#include <vector> // for vector
|
||||
@@ -27,18 +25,12 @@
|
||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/data.h" // for SparsePage, CSCPage
|
||||
#include "xgboost/data.h" // for SparsePage, CSCPage, SortedCSCPage
|
||||
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
|
||||
#include "xgboost/logging.h" // for CHECK_EQ
|
||||
|
||||
namespace xgboost::data {
|
||||
inline void TryDeleteCacheFile(const std::string& file) {
|
||||
if (std::remove(file.c_str()) != 0) {
|
||||
// Don't throw, this is called in a destructor.
|
||||
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||
<< "; you may want to remove it manually";
|
||||
}
|
||||
}
|
||||
void TryDeleteCacheFile(const std::string& file);
|
||||
|
||||
/**
|
||||
* @brief Information about the cache including path and page offsets.
|
||||
@@ -46,13 +38,14 @@ inline void TryDeleteCacheFile(const std::string& file) {
|
||||
struct Cache {
|
||||
// whether the write to the cache is complete
|
||||
bool written;
|
||||
bool on_host;
|
||||
std::string name;
|
||||
std::string format;
|
||||
// offset into binary cache file.
|
||||
std::vector<std::uint64_t> offset;
|
||||
|
||||
Cache(bool w, std::string n, std::string fmt)
|
||||
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
|
||||
Cache(bool w, std::string n, std::string fmt, bool on_host)
|
||||
: written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} {
|
||||
offset.push_back(0);
|
||||
}
|
||||
|
||||
@@ -64,6 +57,7 @@ struct Cache {
|
||||
[[nodiscard]] std::string ShardName() const {
|
||||
return ShardName(this->name, this->format);
|
||||
}
|
||||
[[nodiscard]] bool OnHost() const { return on_host; }
|
||||
/**
|
||||
* @brief Record a page with size of n_bytes.
|
||||
*/
|
||||
@@ -83,12 +77,7 @@ struct Cache {
|
||||
/**
|
||||
* @brief Call this once the write for the cache is complete.
|
||||
*/
|
||||
void Commit() {
|
||||
if (!written) {
|
||||
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
||||
written = true;
|
||||
}
|
||||
}
|
||||
void Commit();
|
||||
};
|
||||
|
||||
// Prevents multi-threaded call to `GetBatches`.
|
||||
@@ -146,10 +135,59 @@ class ExceHandler {
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
||||
* @brief Default implementation of the stream creater.
|
||||
*/
|
||||
template <typename S, template <typename> typename F>
|
||||
class DefaultFormatStreamPolicy : public F<S> {
|
||||
public:
|
||||
using WriterT = common::AlignedFileWriteStream;
|
||||
using ReaderT = common::AlignedResourceReadStream;
|
||||
|
||||
public:
|
||||
std::unique_ptr<WriterT> CreateWriter(StringView name, std::uint32_t iter) {
|
||||
std::unique_ptr<common::AlignedFileWriteStream> fo;
|
||||
if (iter == 0) {
|
||||
fo = std::make_unique<common::AlignedFileWriteStream>(name, "wb");
|
||||
} else {
|
||||
fo = std::make_unique<common::AlignedFileWriteStream>(name, "ab");
|
||||
}
|
||||
return fo;
|
||||
}
|
||||
|
||||
std::unique_ptr<ReaderT> CreateReader(StringView name, std::uint64_t offset,
|
||||
std::uint64_t length) const {
|
||||
return std::make_unique<common::PrivateMmapConstStream>(std::string{name}, offset, length);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Default implementatioin of the format creator.
|
||||
*/
|
||||
template <typename S>
|
||||
class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
class DefaultFormatPolicy {
|
||||
public:
|
||||
using FormatT = SparsePageFormat<S>;
|
||||
|
||||
public:
|
||||
auto CreatePageFormat() const {
|
||||
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
|
||||
return fmt;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
||||
*
|
||||
* The interface to external storage is divided into two types. The first one is the
|
||||
* format, representing how to read and write the binary. The second part is where to
|
||||
* store the binary cache. These policies are implemented in the `FormatStreamPolicy`
|
||||
* policy class. The format policy controls how to create the format (the first part), and
|
||||
* the stream policy decides where the stream should read from and write to (the second
|
||||
* part). This way we can compose the polices and page types with ease.
|
||||
*/
|
||||
template <typename S,
|
||||
typename FormatStreamPolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
||||
class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPolicy {
|
||||
protected:
|
||||
// Prevents calling this iterator from multiple places(or threads).
|
||||
std::mutex single_threaded_;
|
||||
@@ -165,7 +203,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
// Index to the current page.
|
||||
std::uint32_t count_{0};
|
||||
// Total number of batches.
|
||||
std::uint32_t n_batches_{0};
|
||||
bst_idx_t n_batches_{0};
|
||||
|
||||
std::shared_ptr<Cache> cache_info_;
|
||||
|
||||
@@ -179,10 +217,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
ExceHandler exce_;
|
||||
common::Monitor monitor_;
|
||||
|
||||
[[nodiscard]] virtual SparsePageFormat<S>* CreatePageFormat() const {
|
||||
return ::xgboost::data::CreatePageFormat<S>("raw");
|
||||
}
|
||||
|
||||
[[nodiscard]] bool ReadCache() {
|
||||
CHECK(!at_end_);
|
||||
if (!cache_info_->written) {
|
||||
@@ -196,8 +230,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
std::int32_t kPrefetches = 3;
|
||||
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
||||
n_prefetches = std::max(n_prefetches, 1);
|
||||
std::int32_t n_prefetch_batches =
|
||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
||||
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
|
||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||
CHECK_LE(n_prefetch_batches, kPrefetches);
|
||||
std::size_t fetch_it = count_;
|
||||
@@ -216,10 +249,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
auto page = std::make_shared<S>();
|
||||
this->exce_.Run([&] {
|
||||
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
|
||||
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
||||
auto name = self->cache_info_->ShardName();
|
||||
auto [offset, length] = self->cache_info_->View(fetch_it);
|
||||
auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
|
||||
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
|
||||
this->CreateReader(name, offset, length)};
|
||||
CHECK(fmt->Read(page.get(), fi.get()));
|
||||
});
|
||||
return page;
|
||||
@@ -243,16 +277,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
CHECK(!cache_info_->written);
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
std::unique_ptr<SparsePageFormat<S>> fmt{this->CreatePageFormat()};
|
||||
auto fmt{this->CreatePageFormat()};
|
||||
|
||||
auto name = cache_info_->ShardName();
|
||||
std::unique_ptr<common::AlignedFileWriteStream> fo;
|
||||
if (this->Iter() == 0) {
|
||||
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "wb");
|
||||
} else {
|
||||
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "ab");
|
||||
}
|
||||
|
||||
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
|
||||
this->CreateWriter(StringView{name}, this->Iter())};
|
||||
auto bytes = fmt->Write(*page_, fo.get());
|
||||
|
||||
timer.Stop();
|
||||
@@ -265,9 +294,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
virtual void Fetch() = 0;
|
||||
|
||||
public:
|
||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||
std::shared_ptr<Cache> cache)
|
||||
: workers_{nthreads},
|
||||
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
|
||||
missing_{missing},
|
||||
nthreads_{nthreads},
|
||||
n_features_{n_features},
|
||||
@@ -403,18 +432,19 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
||||
};
|
||||
|
||||
// A mixin for advancing the iterator.
|
||||
template <typename S>
|
||||
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
||||
template <typename S,
|
||||
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
||||
class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
||||
protected:
|
||||
std::shared_ptr<SparsePageSource> source_;
|
||||
using Super = SparsePageSourceImpl<S>;
|
||||
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
|
||||
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
|
||||
// so we avoid fetching it.
|
||||
bool sync_{true};
|
||||
|
||||
public:
|
||||
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
||||
std::uint32_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
||||
bst_idx_t n_batches, std::shared_ptr<Cache> cache, bool sync)
|
||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
||||
|
||||
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
||||
|
||||
@@ -234,7 +234,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
||||
// Compact the ELLPACK pages into the single sample page.
|
||||
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||
for (auto& batch : batch_iterator) {
|
||||
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
}
|
||||
|
||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||
@@ -252,7 +252,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
size_t n_rows = dmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
||||
|
||||
@@ -279,21 +279,18 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
bst_idx_t n_rows = dmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
// Perform Poisson sampling in place.
|
||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||
RandomWeight(common::GlobalRandom()())));
|
||||
|
||||
// Count the sampled rows.
|
||||
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
|
||||
|
||||
size_t sample_rows =
|
||||
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
|
||||
// Compact gradient pairs.
|
||||
gpair_.resize(sample_rows);
|
||||
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||
|
||||
// Index the sample rows.
|
||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||
IsNonZero());
|
||||
@@ -301,18 +298,16 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
||||
sample_row_index_.begin());
|
||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||
sample_row_index_.begin(), ClearEmptyRows());
|
||||
|
||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||
auto first_page = (*batch_iterator.begin()).Impl();
|
||||
// Create a new ELLPACK page with empty rows.
|
||||
page_.reset(); // Release the device memory first before reallocating
|
||||
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense,
|
||||
page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), dmat->IsDense(),
|
||||
first_page->row_stride, sample_rows));
|
||||
|
||||
// Compact the ELLPACK pages into the single sample page.
|
||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||
for (auto& batch : batch_iterator) {
|
||||
page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
}
|
||||
|
||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||
@@ -363,21 +358,24 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
|
||||
return sample;
|
||||
}
|
||||
|
||||
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
||||
size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx,
|
||||
common::Span<GradientPair> gpair,
|
||||
common::Span<float> threshold,
|
||||
common::Span<float> grad_sum,
|
||||
size_t sample_rows) {
|
||||
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
||||
CombineGradientPair());
|
||||
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
|
||||
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
thrust::fill(cuctx->CTP(), dh::tend(threshold) - 1, dh::tend(threshold),
|
||||
std::numeric_limits<float>::max());
|
||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
||||
CombineGradientPair{});
|
||||
thrust::sort(cuctx->TP(), dh::tbegin(threshold), dh::tend(threshold) - 1);
|
||||
thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(threshold), dh::tend(threshold) - 1,
|
||||
dh::tbegin(grad_sum));
|
||||
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
|
||||
thrust::transform(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum),
|
||||
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
|
||||
SampleRateDelta(threshold, gpair.size(), sample_rows));
|
||||
thrust::device_ptr<float> min =
|
||||
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
||||
}
|
||||
}; // namespace tree
|
||||
|
||||
@@ -129,9 +129,8 @@ class GradientBasedSampler {
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
||||
|
||||
/*! \brief Calculate the threshold used to normalize sampling probabilities. */
|
||||
static size_t CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
||||
common::Span<float> threshold,
|
||||
common::Span<float> grad_sum,
|
||||
static size_t CalculateThresholdIndex(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
common::Span<float> threshold, common::Span<float> grad_sum,
|
||||
size_t sample_rows);
|
||||
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user