[EM] Enable access to the number of batches. (#10691)

- Expose `NumBatches` in `DMatrix`.
- Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~.
- Purge old external memory data generation code.
This commit is contained in:
Jiaming Yuan
2024-08-17 02:59:45 +08:00
committed by GitHub
parent 033a666900
commit 8d7fe262d9
26 changed files with 169 additions and 352 deletions

View File

@@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator<T> {
} // namespace detail
template <typename T>
using TypedDiscard =
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
using TypedDiscard = std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
VectorT &vec,
IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(VectorT &vec, IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
CHECK_LE(offset + size, vec.size());
return {vec.data().get() + offset, size};
return {thrust::raw_pointer_cast(vec.data()) + offset, size};
}
template <typename T>
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
size_t offset, size_t size) {
xgboost::common::Span<T> ToSpan(thrust::device_vector<T> &vec, size_t offset, size_t size) {
return ToSpan(vec, offset, size);
}
@@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
// Changing this has effect on prediction return, where we need to pass the pointer to
// third-party libraries like cuPy
inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread};
#else
return CUDAStreamView{cudaStreamLegacy};
#endif
}
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }
class CUDAStream {
cudaStream_t stream_;

View File

@@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
ext_info.SetInfo(ctx, &this->info_);
this->n_batches_ = ext_info.n_batches;
/**
* Generate quantiles
*/

View File

@@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache, bst_bin_t max_bin, bool on_host);
~ExtMemQuantileDMatrix() override;
[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
private:
void InitFromCPU(
@@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache_prefix_;
bool on_host_;
BatchParam batch_;
bst_idx_t n_batches_{0};
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;

View File

@@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix {
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const &param) override;
bool SingleColBlock() const override { return true; }
};
} // namespace data
} // namespace xgboost

View File

@@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix {
MetaInfo const& Info() const override { return info_; }
Context const* Ctx() const override { return &ctx_; }
bool SingleColBlock() const override { return false; }
bool EllpackExists() const override { return false; }
bool GHistIndexExists() const override { return false; }
bool SparsePageExists() const override { return false; }

View File

@@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;
Context const* Ctx() const override { return &fmat_ctx_; }
bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(int num_slices, int slice_id) override;

View File

@@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix {
[[nodiscard]] MetaInfo &Info() override;
[[nodiscard]] const MetaInfo &Info() const override;
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
DMatrix *Slice(common::Span<std::int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;

View File

@@ -3,10 +3,10 @@
*/
#include "sparse_page_source.h"
#include <filesystem> // for exists
#include <string> // for string
#include <cstdio> // for remove
#include <filesystem> // for exists
#include <numeric> // for partial_sum
#include <string> // for string
namespace xgboost::data {
void Cache::Commit() {
@@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) {
<< "; you may want to remove it manually";
}
}
#if !defined(XGBOOST_USE_CUDA)
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
#endif
} // namespace xgboost::data

View File

@@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
cuda_impl::Dispatch(proxy,
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
}
void InitNewThread::operator()() const {
*GlobalConfigThreadLocalStore::Get() = config;
// For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the
// stream when creating a new thread in the thread pool. While for CUDA 11.8, this
// action might cause an insufficient driver version error for some reason. Lastly, it
// should work with CUDA 12.5 without any action being taken.
// dh::DefaultStream().Sync();
}
} // namespace xgboost::data

View File

@@ -210,6 +210,12 @@ class DefaultFormatPolicy {
}
};
struct InitNewThread {
GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get();
void operator()() const;
};
/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
*
@@ -330,10 +336,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{std::max(2, std::min(nthreads, 16)),
[config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}},
: workers_{std::max(2, std::min(nthreads, 16)), InitNewThread{}},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},