[EM] Refactor ellpack construction. (#10810)

- Remove the calculation of n_symbols in the accessor.
- Pack initialization steps into the parameter list.
- Pass the context into various ctors.
- Specialization for dense data to prepare for further compression.
This commit is contained in:
Jiaming Yuan 2024-09-09 14:10:10 +08:00 committed by GitHub
parent c69c4adb58
commit 5f7f31d464
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 187 additions and 158 deletions

View File

@ -77,13 +77,11 @@ class CompressedBufferWriter {
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
constexpr int kBitsPerByte = 8; constexpr int kBitsPerByte = 8;
size_t compressed_size = static_cast<size_t>(std::ceil( size_t compressed_size = static_cast<size_t>(std::ceil(
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) / static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) / kBitsPerByte));
kBitsPerByte));
// Handle atomicOr where input must be unsigned int, hence 4 bytes aligned. // Handle atomicOr where input must be unsigned int, hence 4 bytes aligned.
size_t ret = size_t ret = std::ceil(static_cast<double>(compressed_size + detail::kPadding) /
std::ceil(static_cast<double>(compressed_size + detail::kPadding) / static_cast<double>(sizeof(std::uint32_t))) *
static_cast<double>(sizeof(unsigned int))) * sizeof(std::uint32_t);
sizeof(unsigned int);
return ret; return ret;
} }

View File

@ -11,9 +11,10 @@
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh" #include "../common/cuda_rt_utils.h" // for SetDevice
#include "../common/hist_util.cuh" // for HistogramCuts
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc #include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "device_adapter.cuh" // for NoInfInData #include "device_adapter.cuh" // for NoInfInData
#include "ellpack_page.cuh" #include "ellpack_page.cuh"
#include "ellpack_page.h" #include "ellpack_page.h"
@ -91,13 +92,23 @@ __global__ void CompressBinEllpackKernel(
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
} }
[[nodiscard]] std::size_t CalcNumSymbols(Context const*, bool /*is_dense*/,
std::shared_ptr<common::HistogramCuts const> cuts) {
// Return the total number of symbols (total number of bins plus 1 for not found)
return cuts->cut_values_.Size() + 1;
}
// Construct an ELLPACK matrix with the given number of empty rows. // Construct an ELLPACK matrix with the given number of empty rows.
EllpackPageImpl::EllpackPageImpl(Context const* ctx, EllpackPageImpl::EllpackPageImpl(Context const* ctx,
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense, std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
bst_idx_t row_stride, bst_idx_t n_rows) bst_idx_t row_stride, bst_idx_t n_rows)
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} { : is_dense(is_dense),
cuts_(std::move(cuts)),
row_stride{row_stride},
n_rows{n_rows},
n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
monitor_.Init("ellpack_page"); monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); common::SetDevice(ctx->Ordinal());
this->InitCompressedData(ctx); this->InitCompressedData(ctx);
} }
@ -106,56 +117,55 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx,
std::shared_ptr<common::HistogramCuts const> cuts, std::shared_ptr<common::HistogramCuts const> cuts,
const SparsePage& page, bool is_dense, size_t row_stride, const SparsePage& page, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types) common::Span<FeatureType const> feature_types)
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()), row_stride(row_stride) { : cuts_(std::move(cuts)),
is_dense(is_dense),
n_rows(page.Size()),
row_stride(row_stride),
n_symbols_(CalcNumSymbols(ctx, this->is_dense, this->cuts_)) {
this->InitCompressedData(ctx); this->InitCompressedData(ctx);
this->CreateHistIndices(ctx->Device(), page, feature_types); this->CreateHistIndices(ctx, page, feature_types);
} }
// Construct an ELLPACK matrix in memory. // Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param) EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* p_fmat, const BatchParam& param)
: is_dense(dmat->IsDense()) { : is_dense{p_fmat->IsDense()},
n_rows{p_fmat->Info().num_row_},
row_stride{GetRowStride(p_fmat)},
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
cuts_{param.hess.empty()
? std::make_shared<common::HistogramCuts>(
common::DeviceSketch(ctx, p_fmat, param.max_bin))
: std::make_shared<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, p_fmat, param.max_bin, param.hess))},
n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
monitor_.Init("ellpack_page"); monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); common::SetDevice(ctx->Ordinal());
n_rows = dmat->Info().num_row_;
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
if (!param.hess.empty()) {
cuts_ = std::make_shared<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess));
} else {
cuts_ = std::make_shared<common::HistogramCuts>(common::DeviceSketch(ctx, dmat, param.max_bin));
}
monitor_.Stop("Quantiles");
this->InitCompressedData(ctx); this->InitCompressedData(ctx);
dmat->Info().feature_types.SetDevice(ctx->Device()); p_fmat->Info().feature_types.SetDevice(ctx->Device());
auto ft = dmat->Info().feature_types.ConstDeviceSpan(); auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression"); monitor_.Start("BinningCompression");
CHECK(dmat->SingleColBlock()); CHECK(p_fmat->SingleColBlock());
for (const auto& batch : dmat->GetBatches<SparsePage>()) { for (const auto& batch : p_fmat->GetBatches<SparsePage>()) {
CreateHistIndices(ctx->Device(), batch, ft); CreateHistIndices(ctx, batch, ft);
} }
monitor_.Stop("BinningCompression"); monitor_.Stop("BinningCompression");
} }
template <typename AdapterBatchT> template <typename AdapterBatchT, bool kIsDense>
struct WriteCompressedEllpackFunctor { struct WriteCompressedEllpackFunctor {
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer, WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
const common::CompressedBufferWriter& writer, const common::CompressedBufferWriter& writer, AdapterBatchT batch,
AdapterBatchT batch,
EllpackDeviceAccessor accessor, EllpackDeviceAccessor accessor,
common::Span<FeatureType const> feature_types, common::Span<FeatureType const> feature_types,
const data::IsValidFunctor& is_valid) const data::IsValidFunctor& is_valid)
: d_buffer(buffer), : d_buffer(buffer),
writer(writer), writer(writer),
batch(std::move(batch)), batch(std::move(batch)),
accessor(std::move(accessor)), accessor(std::move(accessor)),
feature_types(std::move(feature_types)), feature_types(std::move(feature_types)),
is_valid(is_valid) {} is_valid(is_valid) {}
common::CompressedByteT* d_buffer; common::CompressedByteT* d_buffer;
common::CompressedBufferWriter writer; common::CompressedBufferWriter writer;
@ -197,9 +207,10 @@ struct TupleScanOp {
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
template <typename AdapterBatchT> template <bool kIsDense, typename AdapterBatchT>
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types, void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch,
EllpackPageImpl* dst, DeviceOrd device, float missing) { common::Span<FeatureType const> feature_types, EllpackPageImpl* dst,
float missing) {
// Some witchcraft happens here // Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ELLPACK matrix // The goal is to copy valid elements out of the input to an ELLPACK matrix
// with a given row stride, using no extra working memory Standard stream // with a given row stride, using no extra working memory Standard stream
@ -223,36 +234,35 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
return is_valid(batch.GetElement(idx)); return is_valid(batch.GetElement(idx));
}); });
auto key_value_index_iter = thrust::make_zip_iterator( auto key_value_index_iter =
thrust::make_tuple(key_iter, value_iter, counting)); thrust::make_zip_iterator(thrust::make_tuple(key_iter, value_iter, counting));
// Tuple[0] = The row index of the input, used as a key to define segments // Tuple[0] = The row index of the input, used as a key to define segments
// Tuple[1] = Scanned flags of valid elements for each row // Tuple[1] = Scanned flags of valid elements for each row
// Tuple[2] = The index in the input data // Tuple[2] = The index in the input data
using Tuple = thrust::tuple<size_t, size_t, size_t>; using Tuple = thrust::tuple<bst_idx_t, bst_idx_t, bst_idx_t>;
auto device_accessor = dst->GetDeviceAccessor(device); auto device_accessor = dst->GetDeviceAccessor(ctx);
common::CompressedBufferWriter writer(device_accessor.NumSymbols()); auto n_symbols = dst->NumSymbols();
common::CompressedBufferWriter writer{n_symbols};
auto d_compressed_buffer = dst->gidx_buffer.data(); auto d_compressed_buffer = dst->gidx_buffer.data();
// We redirect the scan output into this functor to do the actual writing // We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, feature_types,
is_valid);
dh::TypedDiscard<Tuple> discard; dh::TypedDiscard<Tuple> discard;
thrust::transform_output_iterator< WriteCompressedEllpackFunctor<AdapterBatchT, kIsDense> functor{
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)> d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid};
out(discard, functor); thrust::transform_output_iterator<decltype(functor), decltype(discard)> out(discard, functor);
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit // Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
// So we don't crash on n > 2^31 // So we don't crash on n > 2^31
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
using DispatchScan = using DispatchScan = cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
cub::DispatchScan<decltype(key_value_index_iter), decltype(out), TupleScanOp<Tuple>, cub::NullType, std::int64_t>;
TupleScanOp<Tuple>, cub::NullType, int64_t>;
#if THRUST_MAJOR_VERSION >= 2 #if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(), TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr)); ctx->CUDACtx()->Stream()));
#else #else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(), TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
@ -262,7 +272,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
#if THRUST_MAJOR_VERSION >= 2 #if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(), key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr)); cub::NullType(), batch.Size(), ctx->CUDACtx()->Stream()));
#else #else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(), key_value_index_iter, out, TupleScanOp<Tuple>(),
@ -270,20 +280,19 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
#endif #endif
} }
void WriteNullValues(EllpackPageImpl* dst, DeviceOrd device, common::Span<size_t> row_counts) { void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) {
// Write the null values // Write the null values
auto device_accessor = dst->GetDeviceAccessor(device); auto device_accessor = dst->GetDeviceAccessor(ctx);
common::CompressedBufferWriter writer(device_accessor.NumSymbols()); common::CompressedBufferWriter writer(dst->NumSymbols());
auto d_compressed_buffer = dst->gidx_buffer.data(); auto d_compressed_buffer = dst->gidx_buffer.data();
auto row_stride = dst->row_stride; auto row_stride = dst->row_stride;
dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) { dh::LaunchN(row_stride * dst->n_rows, ctx->CUDACtx()->Stream(), [=] __device__(bst_idx_t idx) {
// For some reason this variable got captured as const // For some reason this variable got captured as const
auto writer_non_const = writer; auto writer_non_const = writer;
size_t row_idx = idx / row_stride; size_t row_idx = idx / row_stride;
size_t row_offset = idx % row_stride; size_t row_offset = idx % row_stride;
if (row_offset >= row_counts[row_idx]) { if (row_offset >= row_counts[row_idx]) {
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, writer_non_const.AtomicWriteSymbol(d_compressed_buffer, device_accessor.NullValue(), idx);
device_accessor.NullValue(), idx);
} }
}); });
} }
@ -292,12 +301,18 @@ template <typename AdapterBatch>
EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing,
bool is_dense, common::Span<size_t> row_counts_span, bool is_dense, common::Span<size_t> row_counts_span,
common::Span<FeatureType const> feature_types, size_t row_stride, common::Span<FeatureType const> feature_types, size_t row_stride,
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts) { bst_idx_t n_rows,
dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); std::shared_ptr<common::HistogramCuts const> cuts)
: EllpackPageImpl{ctx, cuts, is_dense, row_stride, n_rows} {
common::SetDevice(ctx->Ordinal());
*this = EllpackPageImpl(ctx, cuts, is_dense, row_stride, n_rows); if (this->IsDense()) {
CopyDataToEllpack(batch, feature_types, this, ctx->Device(), missing); CopyDataToEllpack<true>(ctx, batch, feature_types, this, missing);
WriteNullValues(this, ctx->Device(), row_counts_span); } else {
CopyDataToEllpack<false>(ctx, batch, feature_types, this, missing);
}
WriteNullValues(ctx, this, row_counts_span);
} }
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ #define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
@ -358,7 +373,8 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
: is_dense{page.IsDense()}, : is_dense{page.IsDense()},
base_rowid{page.base_rowid}, base_rowid{page.base_rowid},
n_rows{page.Size()}, n_rows{page.Size()},
cuts_{std::make_shared<common::HistogramCuts>(page.cut)} { cuts_{std::make_shared<common::HistogramCuts>(page.cut)},
n_symbols_{CalcNumSymbols(ctx, page.IsDense(), cuts_)} {
auto it = common::MakeIndexTransformIter( auto it = common::MakeIndexTransformIter(
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
row_stride = *std::max_element(it, it + page.Size()); row_stride = *std::max_element(it, it + page.Size());
@ -373,7 +389,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(),
cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream())); cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream()));
auto accessor = this->GetDeviceAccessor(ctx->Device(), ft); auto accessor = this->GetDeviceAccessor(ctx, ft);
auto null = accessor.NullValue(); auto null = accessor.NullValue();
this->monitor_.Start("CopyGHistToEllpack"); this->monitor_.Start("CopyGHistToEllpack");
CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, d_compressed_buffer, null); CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, d_compressed_buffer, null);
@ -469,11 +485,14 @@ void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page,
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
void EllpackPageImpl::SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) {
cuts_ = std::move(cuts);
}
// Initialize the buffer to stored compressed features. // Initialize the buffer to stored compressed features.
void EllpackPageImpl::InitCompressedData(Context const* ctx) { void EllpackPageImpl::InitCompressedData(Context const* ctx) {
monitor_.Start(__func__); monitor_.Start(__func__);
auto num_symbols = NumSymbols(); auto num_symbols = this->NumSymbols();
// Required buffer size for storing data matrix in ELLPack format. // Required buffer size for storing data matrix in ELLPack format.
std::size_t compressed_size_bytes = std::size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols); common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols);
@ -483,7 +502,7 @@ void EllpackPageImpl::InitCompressedData(Context const* ctx) {
} }
// Compress a CSR page into ELLPACK. // Compress a CSR page into ELLPACK.
void EllpackPageImpl::CreateHistIndices(DeviceOrd device, void EllpackPageImpl::CreateHistIndices(Context const* ctx,
const SparsePage& row_batch, const SparsePage& row_batch,
common::Span<FeatureType const> feature_types) { common::Span<FeatureType const> feature_types) {
if (row_batch.Size() == 0) return; if (row_batch.Size() == 0) return;
@ -493,7 +512,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
// bin and compress entries in batches of rows // bin and compress entries in batches of rows
size_t gpu_batch_nrows = size_t gpu_batch_nrows =
std::min(dh::TotalMemory(device.ordinal) / (16 * row_stride * sizeof(Entry)), std::min(dh::TotalMemory(ctx->Ordinal()) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(row_batch.Size())); static_cast<size_t>(row_batch.Size()));
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows); size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
@ -531,7 +550,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
const dim3 block3(32, 8, 1); // 256 threads const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1); common::DivRoundUp(row_stride, block3.y), 1);
auto device_accessor = GetDeviceAccessor(device); auto device_accessor = this->GetDeviceAccessor(ctx);
dh::LaunchKernel{grid3, block3}( // NOLINT dh::LaunchKernel{grid3, block3}( // NOLINT
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.data(), CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.data(),
row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(),
@ -545,18 +564,18 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
std::size_t EllpackPageImpl::MemCostBytes() const { std::size_t EllpackPageImpl::MemCostBytes() const {
return this->gidx_buffer.size_bytes() + sizeof(this->n_rows) + sizeof(this->is_dense) + return this->gidx_buffer.size_bytes() + sizeof(this->n_rows) + sizeof(this->is_dense) +
sizeof(this->row_stride) + sizeof(this->base_rowid); sizeof(this->row_stride) + sizeof(this->base_rowid) + sizeof(this->n_symbols_);
} }
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
DeviceOrd device, common::Span<FeatureType const> feature_types) const { Context const* ctx, common::Span<FeatureType const> feature_types) const {
return {device, return {ctx,
cuts_, cuts_,
is_dense, is_dense,
row_stride, row_stride,
base_rowid, base_rowid,
n_rows, n_rows,
common::CompressedIterator<uint32_t>(gidx_buffer.data(), NumSymbols()), common::CompressedIterator<uint32_t>(gidx_buffer.data(), this->NumSymbols()),
feature_types}; feature_types};
} }
@ -568,19 +587,20 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
CHECK_NE(gidx_buffer.size(), 0); CHECK_NE(gidx_buffer.size(), 0);
dh::safe_cuda(cudaMemcpyAsync(h_gidx_buffer->data(), gidx_buffer.data(), gidx_buffer.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(h_gidx_buffer->data(), gidx_buffer.data(), gidx_buffer.size_bytes(),
cudaMemcpyDefault, ctx->CUDACtx()->Stream())); cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
return {DeviceOrd::CPU(), Context cpu_ctx;
return {ctx->IsCPU() ? ctx : &cpu_ctx,
cuts_, cuts_,
is_dense, is_dense,
row_stride, row_stride,
base_rowid, base_rowid,
n_rows, n_rows,
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()), common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), this->NumSymbols()),
feature_types}; feature_types};
} }
[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing( [[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing(
Context const* ctx, common::Span<FeatureType const> feature_types) const { Context const* ctx, common::Span<FeatureType const> feature_types) const {
auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types); auto d_acc = this->GetDeviceAccessor(ctx, feature_types);
using T = typename decltype(d_acc.gidx_iter)::value_type; using T = typename decltype(d_acc.gidx_iter)::value_type;
auto it = thrust::make_transform_iterator( auto it = thrust::make_transform_iterator(
thrust::make_counting_iterator(0ull), thrust::make_counting_iterator(0ull),

View File

@ -43,20 +43,20 @@ struct EllpackDeviceAccessor {
common::Span<const FeatureType> feature_types; common::Span<const FeatureType> feature_types;
EllpackDeviceAccessor() = delete; EllpackDeviceAccessor() = delete;
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts, EllpackDeviceAccessor(Context const* ctx, std::shared_ptr<const common::HistogramCuts> cuts,
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows, bool is_dense, bst_idx_t row_stride, bst_idx_t base_rowid, bst_idx_t n_rows,
common::CompressedIterator<uint32_t> gidx_iter, common::CompressedIterator<uint32_t> gidx_iter,
common::Span<FeatureType const> feature_types) common::Span<FeatureType const> feature_types)
: is_dense(is_dense), : is_dense{is_dense},
row_stride(row_stride), row_stride{row_stride},
base_rowid(base_rowid), base_rowid{base_rowid},
n_rows(n_rows), n_rows{n_rows},
gidx_iter(gidx_iter), gidx_iter{gidx_iter},
feature_types{feature_types} { feature_types{feature_types} {
if (device.IsCUDA()) { if (ctx->IsCUDA()) {
cuts->cut_values_.SetDevice(device); cuts->cut_values_.SetDevice(ctx->Device());
cuts->cut_ptrs_.SetDevice(device); cuts->cut_ptrs_.SetDevice(ctx->Device());
cuts->min_vals_.SetDevice(device); cuts->min_vals_.SetDevice(ctx->Device());
gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan(); gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan();
feature_segments = cuts->cut_ptrs_.ConstDeviceSpan(); feature_segments = cuts->cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts->min_vals_.ConstDeviceSpan(); min_fvalue = cuts->min_vals_.ConstDeviceSpan();
@ -127,9 +127,6 @@ struct EllpackDeviceAccessor {
[[nodiscard]] __device__ bool IsInRange(size_t row_id) const { [[nodiscard]] __device__ bool IsInRange(size_t row_id) const {
return row_id >= base_rowid && row_id < base_rowid + n_rows; return row_id >= base_rowid && row_id < base_rowid + n_rows;
} }
/*! \brief Return the total number of symbols (total number of bins plus 1 for
* not found). */
[[nodiscard]] XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
[[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return this->NumBins(); } [[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return this->NumBins(); }
@ -160,7 +157,7 @@ class EllpackPageImpl {
EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts, EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows); bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows);
/** /**
* @brief Constructor used for external memory. * @brief Constructor used for external memory with DMatrix.
*/ */
EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts, EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
const SparsePage& page, bool is_dense, size_t row_stride, const SparsePage& page, bool is_dense, size_t row_stride,
@ -173,12 +170,14 @@ class EllpackPageImpl {
* in CSR format. * in CSR format.
*/ */
explicit EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& parm); explicit EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& parm);
/**
* @brief Constructor for Quantile DMatrix using an adapter.
*/
template <typename AdapterBatch> template <typename AdapterBatch>
explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense, explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense,
common::Span<size_t> row_counts_span, common::Span<size_t> row_counts_span,
common::Span<FeatureType const> feature_types, size_t row_stride, common::Span<FeatureType const> feature_types, size_t row_stride,
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts); bst_idx_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
/** /**
* @brief Constructor from an existing CPU gradient index. * @brief Constructor from an existing CPU gradient index.
*/ */
@ -214,7 +213,7 @@ class EllpackPageImpl {
[[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; } [[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; }
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; } [[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; } void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts);
[[nodiscard]] bool IsDense() const { return is_dense; } [[nodiscard]] bool IsDense() const { return is_dense; }
/** @return Estimation of memory cost of this page. */ /** @return Estimation of memory cost of this page. */
@ -224,12 +223,14 @@ class EllpackPageImpl {
* @brief Return the total number of symbols (total number of bins plus 1 for not * @brief Return the total number of symbols (total number of bins plus 1 for not
* found). * found).
*/ */
[[nodiscard]] std::size_t NumSymbols() const { return cuts_->TotalBins() + 1; } [[nodiscard]] std::size_t NumSymbols() const { return this->n_symbols_; }
void SetNumSymbols(bst_idx_t n_symbols) { this->n_symbols_ = n_symbols; }
/** /**
* @brief Get an accessor that can be passed into CUDA kernels. * @brief Get an accessor that can be passed into CUDA kernels.
*/ */
[[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor( [[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor(
DeviceOrd device, common::Span<FeatureType const> feature_types = {}) const; Context const* ctx, common::Span<FeatureType const> feature_types = {}) const;
/** /**
* @brief Get an accessor for host code. * @brief Get an accessor for host code.
*/ */
@ -246,10 +247,9 @@ class EllpackPageImpl {
/** /**
* @brief Compress a single page of CSR data into ELLPACK. * @brief Compress a single page of CSR data into ELLPACK.
* *
* @param device The GPU device to use.
* @param row_batch The CSR page. * @param row_batch The CSR page.
*/ */
void CreateHistIndices(DeviceOrd device, const SparsePage& row_batch, void CreateHistIndices(Context const* ctx, const SparsePage& row_batch,
common::Span<FeatureType const> feature_types); common::Span<FeatureType const> feature_types);
/** /**
* @brief Initialize the buffer to store compressed features. * @brief Initialize the buffer to store compressed features.
@ -272,6 +272,7 @@ class EllpackPageImpl {
private: private:
std::shared_ptr<common::HistogramCuts const> cuts_; std::shared_ptr<common::HistogramCuts const> cuts_;
bst_idx_t n_symbols_{0};
common::Monitor monitor_; common::Monitor monitor_;
}; };

View File

@ -55,7 +55,6 @@ template <typename T>
xgboost_NVTX_FN_RANGE(); xgboost_NVTX_FN_RANGE();
auto* impl = page->Impl(); auto* impl = page->Impl();
impl->SetCuts(this->cuts_);
RET_IF_NOT(fi->Read(&impl->n_rows)); RET_IF_NOT(fi->Read(&impl->n_rows));
RET_IF_NOT(fi->Read(&impl->is_dense)); RET_IF_NOT(fi->Read(&impl->is_dense));
RET_IF_NOT(fi->Read(&impl->row_stride)); RET_IF_NOT(fi->Read(&impl->row_stride));
@ -66,6 +65,12 @@ template <typename T>
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer)); RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
} }
RET_IF_NOT(fi->Read(&impl->base_rowid)); RET_IF_NOT(fi->Read(&impl->base_rowid));
bst_idx_t n_symbols{0};
RET_IF_NOT(fi->Read(&n_symbols));
impl->SetNumSymbols(n_symbols);
impl->SetCuts(this->cuts_);
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
return true; return true;
} }
@ -84,6 +89,8 @@ template <typename T>
[[maybe_unused]] auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx_buffer); [[maybe_unused]] auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx_buffer);
bytes += common::WriteVec(fo, h_gidx_buffer); bytes += common::WriteVec(fo, h_gidx_buffer);
bytes += fo->Write(impl->base_rowid); bytes += fo->Write(impl->base_rowid);
bytes += fo->Write(impl->NumSymbols());
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
return bytes; return bytes;
} }
@ -93,9 +100,10 @@ template <typename T>
auto* impl = page->Impl(); auto* impl = page->Impl();
CHECK(this->cuts_->cut_values_.DeviceCanRead()); CHECK(this->cuts_->cut_values_.DeviceCanRead());
impl->SetCuts(this->cuts_);
fi->Read(page, this->param_.prefetch_copy || !this->has_hmm_ats_); fi->Read(page, this->param_.prefetch_copy || !this->has_hmm_ats_);
impl->SetCuts(this->cuts_);
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
return true; return true;
@ -108,8 +116,7 @@ template <typename T>
fo->Write(page); fo->Write(page);
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
auto* impl = page.Impl(); return page.Impl()->MemCostBytes();
return impl->MemCostBytes();
} }
#undef RET_IF_NOT #undef RET_IF_NOT

View File

@ -81,6 +81,7 @@ class EllpackHostCacheStreamImpl {
new_impl->is_dense = impl->IsDense(); new_impl->is_dense = impl->IsDense();
new_impl->row_stride = impl->row_stride; new_impl->row_stride = impl->row_stride;
new_impl->base_rowid = impl->base_rowid; new_impl->base_rowid = impl->base_rowid;
new_impl->SetNumSymbols(impl->NumSymbols());
dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), impl->gidx_buffer.data(), dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), impl->gidx_buffer.data(),
impl->gidx_buffer.size_bytes(), cudaMemcpyDefault)); impl->gidx_buffer.size_bytes(), cudaMemcpyDefault));
@ -108,6 +109,7 @@ class EllpackHostCacheStreamImpl {
impl->is_dense = page->IsDense(); impl->is_dense = page->IsDense();
impl->row_stride = page->row_stride; impl->row_stride = page->row_stride;
impl->base_rowid = page->base_rowid; impl->base_rowid = page->base_rowid;
impl->SetNumSymbols(page->NumSymbols());
} }
}; };

View File

@ -58,9 +58,9 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
/** /**
* Generate gradient index. * Generate gradient index.
*/ */
size_t offset = 0; bst_idx_t offset = 0;
iter.Reset(); iter.Reset();
size_t n_batches_for_verification = 0; bst_idx_t n_batches_for_verification = 0;
while (iter.Next()) { while (iter.Next()) {
init_page(); init_page();
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
@ -75,10 +75,11 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx)); proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx));
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan(); auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) { auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types, return EllpackPageImpl{
ext_info.row_stride, rows, cuts); &fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types,
ext_info.row_stride, rows, cuts};
}); });
std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset); bst_idx_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset);
offset += num_elements; offset += num_elements;
proxy->Info().num_row_ = BatchSamples(proxy); proxy->Info().num_row_ = BatchSamples(proxy);

View File

@ -927,8 +927,8 @@ class GPUPredictor : public xgboost::Predictor {
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) { for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
dmat->Info().feature_types.SetDevice(ctx_->Device()); dmat->Info().feature_types.SetDevice(ctx_->Device());
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types), this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_, feature_types), d_model,
d_model, out_preds, batch_offset); out_preds, batch_offset);
batch_offset += page.Size() * model.learner_model_param->OutputLength(); batch_offset += page.Size() * model.learner_model_param->OutputLength();
} }
} }
@ -1068,7 +1068,7 @@ class GPUPredictor : public xgboost::Predictor {
} }
} else { } else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) { for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_->Device())}; EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_)};
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(), auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
std::numeric_limits<float>::quiet_NaN()}; std::numeric_limits<float>::quiet_NaN()};
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size; auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
@ -1139,8 +1139,7 @@ class GPUPredictor : public xgboost::Predictor {
} else { } else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) { for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
auto impl = batch.Impl(); auto impl = batch.Impl();
auto acc = auto acc = impl->GetDeviceAccessor(ctx_, p_fmat->Info().feature_types.ConstDeviceSpan());
impl->GetDeviceAccessor(ctx_->Device(), p_fmat->Info().feature_types.ConstDeviceSpan());
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size; auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(), auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
std::numeric_limits<float>::quiet_NaN()}; std::numeric_limits<float>::quiet_NaN()};
@ -1225,7 +1224,7 @@ class GPUPredictor : public xgboost::Predictor {
} else { } else {
bst_idx_t batch_offset = 0; bst_idx_t batch_offset = 0;
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) { for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())}; EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_)};
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads)); auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset); launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
batch_offset += batch.Size(); batch_offset += batch.Size();

View File

@ -9,6 +9,7 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "../collective/aggregator.cuh" // for GlobalSum #include "../collective/aggregator.cuh" // for GlobalSum
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h" #include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE #include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
@ -39,9 +40,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
auto d_sum = sum.View(ctx->Device()); auto d_sum = sum.View(ctx->Device());
CHECK(d_sum.CContiguous()); CHECK(d_sum.CContiguous());
dh::XGBCachingDeviceAllocator<char> alloc; thrust::reduce_by_key(ctx->CUDACtx()->CTP(), key_it, key_it + gpair.Size(), grad_it,
auto policy = thrust::cuda::par(alloc);
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
auto rc = collective::GlobalSum(ctx, info, auto rc = collective::GlobalSum(ctx, info,
@ -49,7 +48,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
d_sum.Size() * 2, ctx->Device())); d_sum.Size() * 2, ctx->Device()));
SafeColl(rc); SafeColl(rc);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable { [=] XGBOOST_DEVICE(std::size_t i) mutable {
out(i) = static_cast<float>( out(i) = static_cast<float>(
CalcUnregularizedWeight(d_sum(i).GetGrad(), d_sum(i).GetHess())); CalcUnregularizedWeight(d_sum(i).GetGrad(), d_sum(i).GetHess()));

View File

@ -186,7 +186,7 @@ class HistogramAgent {
// Increases the throughput of this kernel significantly // Increases the throughput of this kernel significantly
__device__ void ProcessFullTileShared(std::size_t offset) { __device__ void ProcessFullTileShared(std::size_t offset) {
std::size_t idx[kItemsPerThread]; std::size_t idx[kItemsPerThread];
int ridx[kItemsPerThread]; Idx ridx[kItemsPerThread];
int gidx[kItemsPerThread]; int gidx[kItemsPerThread];
GradientPair gpair[kItemsPerThread]; GradientPair gpair[kItemsPerThread];
#pragma unroll #pragma unroll

View File

@ -338,7 +338,7 @@ struct GPUHistMakerDevice {
monitor.Start(__func__); monitor.Start(__func__);
auto d_node_hist = histogram_.GetNodeHistogram(nidx); auto d_node_hist = histogram_.GetNodeHistogram(nidx);
auto batch = page.Impl(); auto batch = page.Impl();
auto acc = batch->GetDeviceAccessor(ctx_->Device()); auto acc = batch->GetDeviceAccessor(ctx_);
auto d_ridx = partitioners_.at(k)->GetRows(nidx); auto d_ridx = partitioners_.at(k)->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc, this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc,
@ -497,7 +497,7 @@ struct GPUHistMakerDevice {
std::int32_t k{0}; std::int32_t k{0};
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) { for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_);
auto go_left = GoLeftOp{d_matrix}; auto go_left = GoLeftOp{d_matrix};
// Partition histogram. // Partition histogram.
@ -567,9 +567,10 @@ struct GPUHistMakerDevice {
dh::CopyTo(p_tree->GetSplitCategories(), &categories); dh::CopyTo(p_tree->GetSplitCategories(), &categories);
auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
auto d_categories = dh::ToSpan(categories); auto d_categories = dh::ToSpan(categories);
auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) { for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_, ft);
std::vector<NodeSplitData> split_data(p_tree->NumNodes()); std::vector<NodeSplitData> split_data(p_tree->NumNodes());
auto const& tree = *p_tree; auto const& tree = *p_tree;

View File

@ -203,7 +203,8 @@ class BaseMGPUTest : public ::testing::Test {
* available. * available.
*/ */
template <typename Fn> template <typename Fn>
auto DoTest(Fn&& fn, bool is_federated, [[maybe_unused]] bool emulate_if_single = false) const { auto DoTest([[maybe_unused]] Fn&& fn, bool is_federated,
[[maybe_unused]] bool emulate_if_single = false) const {
auto n_gpus = common::AllVisibleGPUs(); auto n_gpus = common::AllVisibleGPUs();
if (is_federated) { if (is_federated) {
#if defined(XGBOOST_USE_FEDERATED) #if defined(XGBOOST_USE_FEDERATED)

View File

@ -19,7 +19,7 @@ TEST(EllpackPage, EmptyDMatrix) {
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256; constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
constexpr float kSparsity = 0; constexpr float kSparsity = 0;
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix(); auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
Context ctx{MakeCUDACtx(0)}; auto ctx = MakeCUDACtx(0);
auto& page = *dmat->GetBatches<EllpackPage>( auto& page = *dmat->GetBatches<EllpackPage>(
&ctx, BatchParam{kMaxBin, tree::TrainParam::DftSparseThreshold()}) &ctx, BatchParam{kMaxBin, tree::TrainParam::DftSparseThreshold()})
.begin(); .begin();
@ -94,7 +94,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
Context ctx{MakeCUDACtx(0)}; Context ctx{MakeCUDACtx(0)};
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()}; auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
auto ellpack = EllpackPage(&ctx, m.get(), p); auto ellpack = EllpackPage(&ctx, m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(ctx.Device()); auto accessor = ellpack.Impl()->GetDeviceAccessor(&ctx);
ASSERT_EQ(kCats, accessor.NumBins()); ASSERT_EQ(kCats, accessor.NumBins());
auto x_copy = x; auto x_copy = x;
@ -167,11 +167,11 @@ TEST(EllpackPage, Copy) {
EXPECT_EQ(impl->base_rowid, current_row); EXPECT_EQ(impl->base_rowid, current_row);
for (size_t i = 0; i < impl->Size(); i++) { for (size_t i = 0; i < impl->Size(); i++) {
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row, dh::LaunchN(kCols,
row_d.data().get())); ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin()); thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), current_row, dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), current_row,
row_result_d.data().get())); row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
@ -223,12 +223,12 @@ TEST(EllpackPage, Compact) {
continue; continue;
} }
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row, dh::LaunchN(kCols,
row_d.data().get())); ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaDeviceSynchronize());
thrust::copy(row_d.begin(), row_d.end(), row.begin()); thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), compacted_row, dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), compacted_row,
row_result_d.data().get())); row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());

View File

@ -27,7 +27,7 @@ void TestEquivalent(float sparsity) {
size_t num_elements = page_concatenated->Copy(&ctx, page, offset); size_t num_elements = page_concatenated->Copy(&ctx, page, offset);
offset += num_elements; offset += num_elements;
} }
auto from_iter = page_concatenated->GetDeviceAccessor(ctx.Device()); auto from_iter = page_concatenated->GetDeviceAccessor(&ctx);
ASSERT_EQ(m.Info().num_col_, CudaArrayIterForTest::Cols()); ASSERT_EQ(m.Info().num_col_, CudaArrayIterForTest::Cols());
ASSERT_EQ(m.Info().num_row_, CudaArrayIterForTest::Rows()); ASSERT_EQ(m.Info().num_row_, CudaArrayIterForTest::Rows());
@ -37,7 +37,7 @@ void TestEquivalent(float sparsity) {
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)}; DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)};
auto bp = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; auto bp = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
for (auto& ellpack : dm->GetBatches<EllpackPage>(&ctx, bp)) { for (auto& ellpack : dm->GetBatches<EllpackPage>(&ctx, bp)) {
auto from_data = ellpack.Impl()->GetDeviceAccessor(ctx.Device()); auto from_data = ellpack.Impl()->GetDeviceAccessor(&ctx);
std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size()); std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size());
std::vector<float> min_fvalues_iter(from_iter.min_fvalue.size()); std::vector<float> min_fvalues_iter(from_iter.min_fvalue.size());
@ -71,7 +71,7 @@ void TestEquivalent(float sparsity) {
auto data_buf = ellpack.Impl()->GetHostAccessor(&ctx, &buffer_from_data); auto data_buf = ellpack.Impl()->GetHostAccessor(&ctx, &buffer_from_data);
ASSERT_NE(buffer_from_data.size(), 0); ASSERT_NE(buffer_from_data.size(), 0);
ASSERT_NE(buffer_from_iter.size(), 0); ASSERT_NE(buffer_from_iter.size(), 0);
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols()); CHECK_EQ(ellpack.Impl()->NumSymbols(), page_concatenated->NumSymbols());
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride); CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) { for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
CHECK_EQ(data_buf.gidx_iter[i], data_iter.gidx_iter[i]); CHECK_EQ(data_buf.gidx_iter[i], data_iter.gidx_iter[i]);
@ -146,10 +146,10 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) {
auto impl = ellpack.Impl(); auto impl = ellpack.Impl();
std::vector<common::CompressedByteT> h_gidx; std::vector<common::CompressedByteT> h_gidx;
auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx); auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx);
EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(ctx.Device()).NullValue()); EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(&ctx).NullValue());
EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(ctx.Device()).NullValue()); EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(&ctx).NullValue());
// null values get placed after valid values in a row // null values get placed after valid values in a row
EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(ctx.Device()).NullValue()); EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(&ctx).NullValue());
EXPECT_EQ(m.Info().num_col_, cols); EXPECT_EQ(m.Info().num_col_, cols);
EXPECT_EQ(m.Info().num_row_, rows); EXPECT_EQ(m.Info().num_row_, rows);
EXPECT_EQ(m.Info().num_nonzero_, rows* cols - 3); EXPECT_EQ(m.Info().num_nonzero_, rows* cols - 3);

View File

@ -14,7 +14,7 @@
namespace xgboost { namespace xgboost {
TEST(SparsePageDMatrix, EllpackPage) { TEST(SparsePageDMatrix, EllpackPage) {
Context ctx{MakeCUDACtx(0)}; auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -301,11 +301,11 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
EXPECT_EQ(impl_ext->base_rowid, current_row); EXPECT_EQ(impl_ext->base_rowid, current_row);
for (size_t i = 0; i < impl_ext->Size(); i++) { for (size_t i = 0; i < impl_ext->Size(); i++) {
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row, dh::LaunchN(kCols,
row_d.data().get())); ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin()); thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(ctx.Device()), current_row, dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(&ctx), current_row,
row_ext_d.data().get())); row_ext_d.data().get()));
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin()); thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());

View File

@ -136,7 +136,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(), feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
!use_shared_memory_histograms); !use_shared_memory_histograms);
builder.AllocateHistograms(&ctx, {0}); builder.AllocateHistograms(&ctx, {0});
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser); row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
@ -189,7 +189,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser); d_histogram, quantiser);
@ -205,7 +205,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser); d_new_histogram, quantiser);
@ -230,7 +230,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global); single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser); dh::ToSpan(baseline), quantiser);
@ -298,7 +298,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_categories, false); single_group.DeviceAccessor(ctx.Device()), num_categories, false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser); dh::ToSpan(cat_hist), quantiser);
} }
@ -315,7 +315,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false); single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser); dh::ToSpan(encode_hist), quantiser);
} }
@ -449,7 +449,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto impl = page.Impl(); auto impl = page.Impl();
if (k == 0) { if (k == 0) {
// Initialization // Initialization
auto d_matrix = impl->GetDeviceAccessor(ctx.Device()); auto d_matrix = impl->GetDeviceAccessor(&ctx);
fg = std::make_unique<FeatureGroups>(impl->Cuts()); fg = std::make_unique<FeatureGroups>(impl->Cuts());
auto init = GradientPairInt64{0, 0}; auto init = GradientPairInt64{0, 0};
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init); multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
@ -465,7 +465,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global); fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser); d_histogram, quantiser);
++k; ++k;
@ -491,7 +491,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder; DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()), builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
d_histogram.size(), force_global); d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()), builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser); d_histogram, quantiser);
} }