Cleanup to prepare for using mmap pointer in external memory. (#9317)

- Update SparseDMatrix comment.
- Use a pointer in the bitfield. We will replace the `std::vector<bool>` in `ColumnMatrix` with bitfield.
- Clean up the page source. The timer is removed as it's inaccurate once we swap the mmap pointer into the page.
This commit is contained in:
Jiaming Yuan 2023-06-22 06:43:11 +08:00 committed by GitHub
parent 4066d68261
commit 54da4b3185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 220 additions and 171 deletions

View File

@ -70,7 +70,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
@ -108,7 +108,7 @@ bool IsBitwiseOp(Operation const &op) {
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 by Contributors
/**
* Copyright 2019-2023, XGBoost Contributors
* \file bitfield.h
*/
#ifndef XGBOOST_COMMON_BITFIELD_H_
@ -50,14 +50,17 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
}
#endif // defined(__CUDACC__)
/*!
* \brief A non-owning type with auxiliary methods defined for manipulating bits.
/**
* @brief A non-owning type with auxiliary methods defined for manipulating bits.
*
* \tparam Direction Whether the bits start from left or from right.
* @tparam VT Underlying value type, must be an unsigned integer.
* @tparam Direction Whether the bits start from left or from right.
* @tparam IsConst Whether the view is const.
*/
template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer {
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using size_type = size_t; // NOLINT
using index_type = size_t; // NOLINT
using pointer = value_type*; // NOLINT
@ -70,8 +73,9 @@ struct BitFieldContainer {
};
private:
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
value_type* bits_{nullptr};
size_type n_values_{0};
static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage.");
public:
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
@ -86,13 +90,15 @@ struct BitFieldContainer {
public:
BitFieldContainer() = default;
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits)
: bits_{bits.data()}, n_values_{bits.size()} {}
BitFieldContainer(BitFieldContainer const& other) = default;
BitFieldContainer(BitFieldContainer&& other) = default;
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
BitFieldContainer &operator=(BitFieldContainer &&that) = default;
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }
XGBOOST_DEVICE auto Bits() { return common::Span<value_type>{bits_, NumValues()}; }
XGBOOST_DEVICE auto Bits() const { return common::Span<value_type const>{bits_, NumValues()}; }
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
@ -103,17 +109,17 @@ struct BitFieldContainer {
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(bits_.size(), rhs.bits_.size());
size_t min_size = min(NumValues(), rhs.NumValues());
if (tid < min_size) {
bits_[tid] |= rhs.bits_[tid];
Data()[tid] |= rhs.Data()[tid];
}
return *this;
}
#else
BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
size_t min_size = std::min(NumValues(), rhs.NumValues());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] |= rhs.bits_[i];
Data()[i] |= rhs.Data()[i];
}
return *this;
}
@ -121,75 +127,85 @@ struct BitFieldContainer {
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = min(bits_.size(), rhs.bits_.size());
size_t min_size = min(NumValues(), rhs.NumValues());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < min_size) {
bits_[tid] &= rhs.bits_[tid];
Data()[tid] &= rhs.Data()[tid];
}
return *this;
}
#else
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
size_t min_size = std::min(NumValues(), rhs.NumValues());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] &= rhs.bits_[i];
Data()[i] &= rhs.Data()[i];
}
return *this;
}
#endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ auto Set(index_type pos) {
__device__ auto Set(index_type pos) noexcept(true) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type& value = Data()[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
}
__device__ void Clear(index_type pos) {
__device__ void Clear(index_type pos) noexcept(true) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type& value = Data()[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
}
#else
void Set(index_type pos) {
void Set(index_type pos) noexcept(true) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type& value = Data()[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit;
}
void Clear(index_type pos) {
void Clear(index_type pos) noexcept(true) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type& value = Data()[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
value &= clear_bit;
}
#endif // defined(__CUDA_ARCH__)
XGBOOST_DEVICE bool Check(Pos pos_v) const {
XGBOOST_DEVICE bool Check(Pos pos_v) const noexcept(true) {
pos_v = Direction::Shift(pos_v);
SPAN_LT(pos_v.int_pos, bits_.size());
value_type const value = bits_[pos_v.int_pos];
assert(pos_v.int_pos < NumValues());
value_type const value = Data()[pos_v.int_pos];
value_type const test_bit = kOne << pos_v.bit_pos;
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(index_type pos) const {
[[nodiscard]] XGBOOST_DEVICE bool Check(index_type pos) const noexcept(true) {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
/**
* @brief Returns the total number of bits that can be viewed. This is equal to or
* larger than the acutal number of valid bits.
*/
[[nodiscard]] XGBOOST_DEVICE size_type Capacity() const noexcept(true) {
return kValueSize * NumValues();
}
/**
* @brief Number of storage unit used in this bit field.
*/
[[nodiscard]] XGBOOST_DEVICE size_type NumValues() const noexcept(true) { return n_values_; }
XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); }
XGBOOST_DEVICE pointer Data() const noexcept(true) { return bits_; }
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
inline friend std::ostream &
operator<<(std::ostream &os, BitFieldContainer<VT, Direction, IsConst> field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.bits_[i]);
inline friend std::ostream& operator<<(std::ostream& os,
BitFieldContainer<VT, Direction, IsConst> field) {
os << "Bits "
<< "storage size: " << field.NumValues() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.NumValues(); ++i) {
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.Data()[i]);
os << bset << "\n";
}
return os;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2020-2022 by XGBoost Contributors
/**
* Copyright 2020-2023, XGBoost Contributors
* \file categorical.h
*/
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
@ -10,7 +10,6 @@
#include "bitfield.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
namespace xgboost {

View File

@ -84,7 +84,7 @@ class HistogramCuts {
return *this;
}
uint32_t FeatureBins(bst_feature_t feature) const {
[[nodiscard]] bst_bin_t FeatureBins(bst_feature_t feature) const {
return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature];
}
@ -92,8 +92,8 @@ class HistogramCuts {
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); }
bool HasCategorical() const { return has_categorical_; }
float MaxCategory() const { return max_cat_; }
[[nodiscard]] bool HasCategorical() const { return has_categorical_; }
[[nodiscard]] float MaxCategory() const { return max_cat_; }
/**
* \brief Set meta info about categorical features.
*
@ -105,12 +105,13 @@ class HistogramCuts {
max_cat_ = max_cat;
}
size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }
[[nodiscard]] bst_bin_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }
// Return the index of a cut point that is strictly greater than the input
// value, or the last available index if none exists
bst_bin_t SearchBin(float value, bst_feature_t column_id, std::vector<uint32_t> const& ptrs,
std::vector<float> const& values) const {
[[nodiscard]] bst_bin_t SearchBin(float value, bst_feature_t column_id,
std::vector<uint32_t> const& ptrs,
std::vector<float> const& values) const {
auto end = ptrs[column_id + 1];
auto beg = ptrs[column_id];
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
@ -119,20 +120,20 @@ class HistogramCuts {
return idx;
}
bst_bin_t SearchBin(float value, bst_feature_t column_id) const {
[[nodiscard]] bst_bin_t SearchBin(float value, bst_feature_t column_id) const {
return this->SearchBin(value, column_id, Ptrs(), Values());
}
/**
* \brief Search the bin index for numerical feature.
*/
bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); }
[[nodiscard]] bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); }
/**
* \brief Search the bin index for categorical feature.
*/
bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector<uint32_t> const& ptrs,
std::vector<float> const& vals) const {
[[nodiscard]] bst_bin_t SearchCatBin(float value, bst_feature_t fidx,
std::vector<uint32_t> const& ptrs,
std::vector<float> const& vals) const {
auto end = ptrs.at(fidx + 1) + vals.cbegin();
auto beg = ptrs[fidx] + vals.cbegin();
// Truncates the value in case it's not perfectly rounded.
@ -143,12 +144,14 @@ class HistogramCuts {
}
return bin_idx;
}
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
[[nodiscard]] bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
auto const& ptrs = this->Ptrs();
auto const& vals = this->Values();
return this->SearchCatBin(value, fidx, ptrs, vals);
}
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
[[nodiscard]] bst_bin_t SearchCatBin(Entry const& e) const {
return SearchCatBin(e.fvalue, e.index);
}
/**
* \brief Return numerical bin value given bin index.

View File

@ -590,7 +590,7 @@ class ArrayInterface {
template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
// Only used for cuDF at the moment.
CHECK_EQ(array.valid.Size(), 0);
CHECK_EQ(array.valid.Capacity(), 0);
auto dispatch = [&](auto t) {
using T = std::remove_const_t<decltype(t)> const;
// Set the data size to max as we don't know the original size of a sliced array:

View File

@ -416,7 +416,8 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
p_out->Reshape(array.shape);
return;
}
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
CHECK_EQ(array.valid.Capacity(), 0)
<< "Meta info like label or weight can not have missing value.";
if (array.is_contiguous && array.type == ToDType<T>::kType) {
// Handle contigious
p_out->ModifyInplace([&](HostDeviceVector<T>* data, common::Span<size_t, D> shape) {

View File

@ -33,7 +33,8 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
p_out->Reshape(array.shape);
return;
}
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
CHECK_EQ(array.valid.Capacity(), 0)
<< "Meta info like label or weight can not have missing value.";
auto ptr_device = SetDeviceToPtr(array.data);
p_out->SetDevice(ptr_device);

View File

@ -5,6 +5,7 @@
#include <thrust/iterator/transform_output_iterator.h>
#include "../common/categorical.h"
#include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
@ -313,7 +314,8 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const>
auto d_csc_indptr = dh::ToSpan(csc_indptr);
auto bin_type = page.index.GetBinTypeSize();
common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value
common::CompressedBufferWriter writer{page.cut.TotalBins() +
static_cast<std::size_t>(1)}; // +1 for null value
dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable {
auto ridx = idx / row_stride;
@ -357,8 +359,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
// copy gidx
common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer();
dh::device_vector<size_t> row_ptr(page.row_ptr);
dh::device_vector<size_t> row_ptr(page.row_ptr.size());
auto d_row_ptr = dh::ToSpan(row_ptr);
dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(),
cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream()));
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
auto null = accessor.NullValue();

View File

@ -7,9 +7,6 @@
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <map>
#include <memory>
@ -20,35 +17,33 @@
#include "ellpack_page_source.h"
#include "gradient_index_page_source.h"
#include "sparse_page_source.h"
#include "xgboost/data.h"
#include "xgboost/logging.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
/**
* \brief DMatrix used for external memory.
*
* The external memory is created for controlling memory usage by splitting up data into
* multiple batches. However that doesn't mean we will actually process exact 1 batch at
* a time, which would be terribly slow considering that we have to loop through the
* whole dataset for every tree split. So we use async pre-fetch and let caller to decide
* how many batches it wants to process by returning data as shared pointer. The caller
* can use async function to process the data or just stage those batches, making the
* decision is out of the scope for sparse page dmatrix. These 2 optimizations might
* defeat the purpose of splitting up dataset since if you load all the batches then the
* memory usage is even worse than using a single batch. Essentially we need to control
* how many batches can be in memory at the same time.
* multiple batches. However that doesn't mean we will actually process exactly 1 batch
* at a time, which would be terribly slow considering that we have to loop through the
* whole dataset for every tree split. So we use async to pre-fetch pages and let the
* caller to decide how many batches it wants to process by returning data as a shared
* pointer. The caller can use async function to process the data or just stage those
* batches based on its use cases. These two optimizations might defeat the purpose of
* splitting up dataset since if you stage all the batches then the memory usage might be
* even worse than using a single batch. As a result, we must control how many batches can
* be in memory at any given time.
*
* Right now the write to the cache is sequential operation and is blocking, reading from
* cache is async but with a hard coded limit of 4 pages as an heuristic. So by sparse
* dmatrix itself there can be only 9 pages in main memory (might be of different types)
* at the same time: 1 page pending for write, 4 pre-fetched sparse pages, 4 pre-fetched
* dependent pages. If the caller stops iteration at the middle and start again, then the
* number of pages in memory can hit 16 due to pre-fetching, but this should be a bug in
* caller's code (XGBoost doesn't discard a large portion of data at the end, there's not
* sampling algo that samples only the first portion of data).
* Right now the write to the cache is a sequential operation and is blocking. Reading
* from cache on ther other hand, is async but with a hard coded limit of 3 pages as an
* heuristic. So by sparse dmatrix itself there can be only 7 pages in main memory (might
* be of different types) at the same time: 1 page pending for write, 3 pre-fetched sparse
* pages, 3 pre-fetched dependent pages.
*
* Of course if the caller decides to retain some batches to perform parallel processing,
* then we might load all pages in memory, which is also considered as a bug in caller's
* code. So if the algo supports external memory, it must be careful that queue for async
* code. So if the algo supports external memory, it must be careful that queue for async
* call must have an upper limit.
*
* Another assumption we make is that the data must be immutable so caller should never
@ -101,7 +96,7 @@ class SparsePageDMatrix : public DMatrix {
MetaInfo &Info() override;
const MetaInfo &Info() const override;
Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
@ -153,6 +148,5 @@ inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::st
}
return id;
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_

View File

@ -6,39 +6,43 @@
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#include <algorithm> // for min
#include <future> // async
#include <future> // for async
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <utility> // for pair, move
#include <vector>
#include "../common/common.h"
#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/timer.h" // for Monitor, Timer
#include "adapter.h"
#include "dmlc/common.h" // OMPException
#include "proxy_dmatrix.h"
#include "sparse_page_writer.h"
#include "dmlc/common.h" // for OMPException
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h"
#include "xgboost/data.h"
namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) {
if (std::remove(file.c_str()) != 0) {
// Don't throw, this is called in a destructor.
LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually";
}
}
/**
* @brief Information about the cache including path and page offsets.
*/
struct Cache {
// whether the write to the cache is complete
bool written;
std::string name;
std::string format;
// offset into binary cache file.
std::vector<size_t> offset;
std::vector<std::uint64_t> offset;
Cache(bool w, std::string n, std::string fmt)
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
@ -50,14 +54,24 @@ struct Cache {
return name + format;
}
std::string ShardName() {
[[nodiscard]] std::string ShardName() const {
return ShardName(this->name, this->format);
}
void Push(std::size_t n_bytes) {
offset.push_back(n_bytes);
/**
* @brief Record a page with size of n_bytes.
*/
void Push(std::size_t n_bytes) { offset.push_back(n_bytes); }
/**
* @brief Returns the view start and length for the i^th page.
*/
[[nodiscard]] auto View(std::size_t i) const {
std::uint64_t off = offset.at(i);
std::uint64_t len = offset.at(i + 1) - offset[i];
return std::pair{off, len};
}
// The write is completed.
/**
* @brief Call this once the write for the cache is complete.
*/
void Commit() {
if (!written) {
std::partial_sum(offset.begin(), offset.end(), offset.begin());
@ -66,7 +80,7 @@ struct Cache {
}
};
// Prevents multi-threaded call.
// Prevents multi-threaded call to `GetBatches`.
class TryLockGuard {
std::mutex& lock_;
@ -79,22 +93,25 @@ class TryLockGuard {
}
};
/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
*/
template <typename S>
class SparsePageSourceImpl : public BatchIteratorImpl<S> {
protected:
// Prevents calling this iterator from multiple places(or threads).
std::mutex single_threaded_;
// The current page.
std::shared_ptr<S> page_;
bool at_end_ {false};
float missing_;
int nthreads_;
std::int32_t nthreads_;
bst_feature_t n_features_;
uint32_t count_{0};
uint32_t n_batches_ {0};
// Index to the current page.
std::uint32_t count_{0};
// Total number of batches.
std::uint32_t n_batches_{0};
std::shared_ptr<Cache> cache_info_;
@ -102,6 +119,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
// can pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring};
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
// OOM error should be rare.
dmlc::OMPException exec_;
common::Monitor monitor_;
@ -123,7 +143,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
exec_.Rethrow();
monitor_.Start("launch");
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
@ -134,33 +153,25 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
auto page = std::make_shared<S>();
this->exec_.Run([&] {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset;
auto fi = std::make_unique<common::PrivateMmapConstStream>(n, offset, length);
auto name = self->cache_info_->ShardName();
auto [offset, length] = self->cache_info_->View(fetch_it);
auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
CHECK(fmt->Read(page.get(), fi.get()));
timer.Stop();
LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds()
<< " seconds.";
});
return page;
});
}
monitor_.Stop("launch");
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
monitor_.Start("Wait");
page_ = (*ring_)[count_].get();
monitor_.Stop("Wait");
CHECK(!(*ring_)[count_].valid());
monitor_.Stop("Wait");
exec_.Rethrow();
return true;
@ -183,6 +194,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
auto bytes = fmt->Write(*page_, fo.get());
timer.Stop();
// Not entirely accurate, the kernels doesn't have to flush the data.
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds.";
cache_info_->Push(bytes);
@ -204,6 +216,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
~SparsePageSourceImpl() override {
// Don't orphan the threads.
for (auto& fu : *ring_) {
if (fu.valid()) {
fu.get();
@ -211,18 +224,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}
}
uint32_t Iter() const { return count_; }
[[nodiscard]] uint32_t Iter() const { return count_; }
const S &operator*() const override {
CHECK(page_);
return *page_;
}
std::shared_ptr<S const> Page() const override {
[[nodiscard]] std::shared_ptr<S const> Page() const override {
return page_;
}
bool AtEnd() const override {
[[nodiscard]] bool AtEnd() const override {
return at_end_;
}
@ -230,20 +243,23 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
TryLockGuard guard{single_threaded_};
at_end_ = false;
count_ = 0;
// Pre-fetch for the next round of iterations.
this->Fetch();
}
};
#if defined(XGBOOST_USE_CUDA)
// Push data from CUDA.
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page);
#else
inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); }
#endif
class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
// This is the source from the user.
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter_;
DMatrixProxy* proxy_;
size_t base_row_id_ {0};
std::size_t base_row_id_{0};
void Fetch() final {
page_ = std::make_shared<SparsePage>();

View File

@ -439,7 +439,7 @@ struct ShapSplitCondition {
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
if (categories.Capacity() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
@ -454,7 +454,7 @@ struct ShapSplitCondition {
if (l.Data() == r.Data()) {
return l;
}
if (l.Size() > r.Size()) {
if (l.Capacity() > r.Capacity()) {
thrust::swap(l, r);
}
for (size_t i = 0; i < r.Bits().size(); ++i) {
@ -466,7 +466,7 @@ struct ShapSplitCondition {
// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
if (categories.Capacity() != 0 || other.categories.Capacity() != 0) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
/**
* Copyright 2019-2023, XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/device_vector.h>
@ -140,20 +140,20 @@ void FeatureInteractionConstraintDevice::Reset() {
__global__ void ClearBuffersKernel(
LBitField64 result_buffer_output, LBitField64 result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_output.Size()) {
if (tid < result_buffer_output.Capacity()) {
result_buffer_output.Clear(tid);
}
if (tid < result_buffer_input.Size()) {
if (tid < result_buffer_input.Capacity()) {
result_buffer_input.Clear(tid);
}
}
void FeatureInteractionConstraintDevice::ClearBuffers() {
CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size());
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
CHECK_EQ(output_buffer_bits_.Capacity(), input_buffer_bits_.Capacity());
CHECK_LE(feature_buffer_.Capacity(), output_buffer_bits_.Capacity());
uint32_t constexpr kBlockThreads = 256;
auto const n_grids = static_cast<uint32_t>(
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
common::DivRoundUp(input_buffer_bits_.Capacity(), kBlockThreads));
dh::LaunchKernel {n_grids, kBlockThreads} (
ClearBuffersKernel,
output_buffer_bits_, input_buffer_bits_);
@ -207,11 +207,11 @@ common::Span<bst_feature_t> FeatureInteractionConstraintDevice::Query(
ClearBuffers();
LBitField64 node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
CHECK_EQ(input_buffer_bits_.Capacity(), output_buffer_bits_.Capacity());
uint32_t constexpr kBlockThreads = 256;
auto n_grids = static_cast<uint32_t>(
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
common::DivRoundUp(output_buffer_bits_.Capacity(), kBlockThreads));
dh::LaunchKernel {n_grids, kBlockThreads} (
SetInputBufferKernel,
feature_list, input_buffer_bits_);
@ -274,13 +274,13 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
LBitField64 left,
LBitField64 right) {
auto tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid > node.Size()) {
if (tid > node.Capacity()) {
return;
}
// enable constraints from feature
node |= feature;
// clear the buffer after use
if (tid < feature.Size()) {
if (tid < feature.Capacity()) {
feature.Clear(tid);
}
@ -323,7 +323,7 @@ void FeatureInteractionConstraintDevice::Split(
s_sets_, s_sets_ptr_);
uint32_t constexpr kBlockThreads = 256;
auto n_grids = static_cast<uint32_t>(common::DivRoundUp(node.Size(), kBlockThreads));
auto n_grids = static_cast<uint32_t>(common::DivRoundUp(node.Capacity(), kBlockThreads));
dh::LaunchKernel {n_grids, kBlockThreads} (
InteractionConstraintSplitKernel,

View File

@ -213,7 +213,7 @@ std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};
std::vector<bst_cat_t> cats;
for (size_t i = 0; i < split.Size(); ++i) {
for (size_t i = 0; i < split.Capacity(); ++i) {
if (split.Check(i)) {
cats.push_back(static_cast<bst_cat_t>(i));
}
@ -1004,7 +1004,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto segment = split_categories_segments_[i];
auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
for (size_t i = 0; i < cat_bits.Capacity(); ++i) {
if (cat_bits.Check(i)) {
categories.GetArray().emplace_back(i);
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
/**
* Copyright 2019-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/bitfield.h"
@ -14,7 +14,7 @@ TEST(BitField, Check) {
static_cast<typename common::Span<LBitField64::value_type>::index_type>(
storage.size())});
size_t true_bit = 190;
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
for (size_t i = true_bit + 1; i < bits.Capacity(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
@ -34,7 +34,7 @@ TEST(BitField, Check) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
for (size_t i = true_bit + 1; i < bits.Capacity(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
/**
* Copyright 2019-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/copy.h>
@ -12,7 +12,7 @@ namespace xgboost {
__global__ void TestSetKernel(LBitField64 bits) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < bits.Size()) {
if (tid < bits.Capacity()) {
bits.Set(tid);
}
}
@ -36,20 +36,16 @@ TEST(BitField, GPUSet) {
std::vector<LBitField64::value_type> h_storage(storage.size());
thrust::copy(storage.begin(), storage.end(), h_storage.begin());
LBitField64 outputs {
common::Span<LBitField64::value_type>{h_storage.data(),
h_storage.data() + h_storage.size()}};
LBitField64 outputs{
common::Span<LBitField64::value_type>{h_storage.data(), h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i));
}
}
__global__ void TestOrKernel(LBitField64 lhs, LBitField64 rhs) {
lhs |= rhs;
}
TEST(BitField, GPUAnd) {
namespace {
template <bool is_and, typename Op>
void TestGPULogic(Op op) {
uint32_t constexpr kBits = 128;
dh::device_vector<LBitField64::value_type> lhs_storage(kBits);
dh::device_vector<LBitField64::value_type> rhs_storage(kBits);
@ -57,13 +53,32 @@ TEST(BitField, GPUAnd) {
auto rhs = LBitField64(dh::ToSpan(rhs_storage));
thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL);
thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast<LBitField64::value_type>(0UL));
TestOrKernel<<<1, kBits>>>(lhs, rhs);
dh::LaunchN(kBits, [=] __device__(auto) mutable { op(lhs, rhs); });
std::vector<LBitField64::value_type> h_storage(lhs_storage.size());
thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin());
LBitField64 outputs {{h_storage.data(), h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i));
LBitField64 outputs{{h_storage.data(), h_storage.data() + h_storage.size()}};
if (is_and) {
for (size_t i = 0; i < kBits; ++i) {
ASSERT_FALSE(outputs.Check(i));
}
} else {
for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i));
}
}
}
} // namespace xgboost
void TestGPUAnd() {
TestGPULogic<true>([] XGBOOST_DEVICE(LBitField64 & lhs, LBitField64 const& rhs) { lhs &= rhs; });
}
void TestGPUOr() {
TestGPULogic<false>([] XGBOOST_DEVICE(LBitField64 & lhs, LBitField64 const& rhs) { lhs |= rhs; });
}
} // namespace
TEST(BitField, GPUAnd) { TestGPUAnd(); }
TEST(BitField, GPUOr) { TestGPUOr(); }
} // namespace xgboost

View File

@ -83,7 +83,9 @@ template <typename BinIdxType>
void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
const GHistIndexMatrix& gmat) {
for (auto i = 0ull; i < col.Size(); i++) {
if (col.IsMissing(i)) continue;
if (col.IsMissing(i)) {
continue;
}
EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i));
}
}

View File

@ -285,8 +285,6 @@ TEST(GpuHist, PartitionTwoNodes) {
dh::ToSpan(feature_histogram_b)};
thrust::device_vector<GPUExpandEntry> results(2);
evaluator.EvaluateSplits({0, 1}, 1, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(results));
GPUExpandEntry result_a = results[0];
GPUExpandEntry result_b = results[1];
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]),
std::bitset<32>("10000000000000000000000000000000"));
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]),

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
/**
* Copyright 2019-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/copy.h>
@ -53,7 +53,7 @@ void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
LBitField64 h_field{ {h_field_storage.data(),
h_field_storage.data() + h_field_storage.size()} };
for (size_t i = 0; i < h_field.Size(); ++i) {
for (size_t i = 0; i < h_field.Capacity(); ++i) {
if (positions.find(i) != positions.cend()) {
ASSERT_TRUE(h_field.Check(i));
} else {
@ -82,7 +82,7 @@ TEST(GPUFeatureInteractionConstraint, Init) {
{h_node_storage.data(), h_node_storage.data() + h_node_storage.size()}
};
// no feature is attached to node.
for (size_t i = 0; i < h_node.Size(); ++i) {
for (size_t i = 0; i < h_node.Capacity(); ++i) {
ASSERT_FALSE(h_node.Check(i));
}
}