[EM] Refactor ellpack construction. (#10810)
- Remove the calculation of n_symbols in the accessor. - Pack initialization steps into the parameter list. - Pass the context into various ctors. - Specialization for dense data to prepare for further compression.
This commit is contained in:
parent
c69c4adb58
commit
5f7f31d464
@ -77,13 +77,11 @@ class CompressedBufferWriter {
|
||||
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
|
||||
constexpr int kBitsPerByte = 8;
|
||||
size_t compressed_size = static_cast<size_t>(std::ceil(
|
||||
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
|
||||
kBitsPerByte));
|
||||
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) / kBitsPerByte));
|
||||
// Handle atomicOr where input must be unsigned int, hence 4 bytes aligned.
|
||||
size_t ret =
|
||||
std::ceil(static_cast<double>(compressed_size + detail::kPadding) /
|
||||
static_cast<double>(sizeof(unsigned int))) *
|
||||
sizeof(unsigned int);
|
||||
size_t ret = std::ceil(static_cast<double>(compressed_size + detail::kPadding) /
|
||||
static_cast<double>(sizeof(std::uint32_t))) *
|
||||
sizeof(std::uint32_t);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@ -11,9 +11,10 @@
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "../common/cuda_rt_utils.h" // for SetDevice
|
||||
#include "../common/hist_util.cuh" // for HistogramCuts
|
||||
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||
#include "device_adapter.cuh" // for NoInfInData
|
||||
#include "ellpack_page.cuh"
|
||||
#include "ellpack_page.h"
|
||||
@ -91,13 +92,23 @@ __global__ void CompressBinEllpackKernel(
|
||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t CalcNumSymbols(Context const*, bool /*is_dense*/,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts) {
|
||||
// Return the total number of symbols (total number of bins plus 1 for not found)
|
||||
return cuts->cut_values_.Size() + 1;
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix with the given number of empty rows.
|
||||
EllpackPageImpl::EllpackPageImpl(Context const* ctx,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
|
||||
bst_idx_t row_stride, bst_idx_t n_rows)
|
||||
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} {
|
||||
: is_dense(is_dense),
|
||||
cuts_(std::move(cuts)),
|
||||
row_stride{row_stride},
|
||||
n_rows{n_rows},
|
||||
n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
|
||||
monitor_.Init("ellpack_page");
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
common::SetDevice(ctx->Ordinal());
|
||||
|
||||
this->InitCompressedData(ctx);
|
||||
}
|
||||
@ -106,56 +117,55 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts,
|
||||
const SparsePage& page, bool is_dense, size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()), row_stride(row_stride) {
|
||||
: cuts_(std::move(cuts)),
|
||||
is_dense(is_dense),
|
||||
n_rows(page.Size()),
|
||||
row_stride(row_stride),
|
||||
n_symbols_(CalcNumSymbols(ctx, this->is_dense, this->cuts_)) {
|
||||
this->InitCompressedData(ctx);
|
||||
this->CreateHistIndices(ctx->Device(), page, feature_types);
|
||||
this->CreateHistIndices(ctx, page, feature_types);
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix in memory.
|
||||
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param)
|
||||
: is_dense(dmat->IsDense()) {
|
||||
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* p_fmat, const BatchParam& param)
|
||||
: is_dense{p_fmat->IsDense()},
|
||||
n_rows{p_fmat->Info().num_row_},
|
||||
row_stride{GetRowStride(p_fmat)},
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
cuts_{param.hess.empty()
|
||||
? std::make_shared<common::HistogramCuts>(
|
||||
common::DeviceSketch(ctx, p_fmat, param.max_bin))
|
||||
: std::make_shared<common::HistogramCuts>(
|
||||
common::DeviceSketchWithHessian(ctx, p_fmat, param.max_bin, param.hess))},
|
||||
n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
|
||||
monitor_.Init("ellpack_page");
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
|
||||
n_rows = dmat->Info().num_row_;
|
||||
|
||||
monitor_.Start("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
row_stride = GetRowStride(dmat);
|
||||
if (!param.hess.empty()) {
|
||||
cuts_ = std::make_shared<common::HistogramCuts>(
|
||||
common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess));
|
||||
} else {
|
||||
cuts_ = std::make_shared<common::HistogramCuts>(common::DeviceSketch(ctx, dmat, param.max_bin));
|
||||
}
|
||||
monitor_.Stop("Quantiles");
|
||||
common::SetDevice(ctx->Ordinal());
|
||||
|
||||
this->InitCompressedData(ctx);
|
||||
|
||||
dmat->Info().feature_types.SetDevice(ctx->Device());
|
||||
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
p_fmat->Info().feature_types.SetDevice(ctx->Device());
|
||||
auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
|
||||
monitor_.Start("BinningCompression");
|
||||
CHECK(dmat->SingleColBlock());
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
CreateHistIndices(ctx->Device(), batch, ft);
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
for (const auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CreateHistIndices(ctx, batch, ft);
|
||||
}
|
||||
monitor_.Stop("BinningCompression");
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
template <typename AdapterBatchT, bool kIsDense>
|
||||
struct WriteCompressedEllpackFunctor {
|
||||
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
|
||||
const common::CompressedBufferWriter& writer,
|
||||
AdapterBatchT batch,
|
||||
const common::CompressedBufferWriter& writer, AdapterBatchT batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
const data::IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(std::move(batch)),
|
||||
accessor(std::move(accessor)),
|
||||
feature_types(std::move(feature_types)),
|
||||
is_valid(is_valid) {}
|
||||
writer(writer),
|
||||
batch(std::move(batch)),
|
||||
accessor(std::move(accessor)),
|
||||
feature_types(std::move(feature_types)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
@ -197,9 +207,10 @@ struct TupleScanOp {
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl* dst, DeviceOrd device, float missing) {
|
||||
template <bool kIsDense, typename AdapterBatchT>
|
||||
void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch,
|
||||
common::Span<FeatureType const> feature_types, EllpackPageImpl* dst,
|
||||
float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
@ -223,36 +234,35 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
return is_valid(batch.GetElement(idx));
|
||||
});
|
||||
|
||||
auto key_value_index_iter = thrust::make_zip_iterator(
|
||||
thrust::make_tuple(key_iter, value_iter, counting));
|
||||
auto key_value_index_iter =
|
||||
thrust::make_zip_iterator(thrust::make_tuple(key_iter, value_iter, counting));
|
||||
|
||||
// Tuple[0] = The row index of the input, used as a key to define segments
|
||||
// Tuple[1] = Scanned flags of valid elements for each row
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
using Tuple = thrust::tuple<bst_idx_t, bst_idx_t, bst_idx_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto device_accessor = dst->GetDeviceAccessor(ctx);
|
||||
auto n_symbols = dst->NumSymbols();
|
||||
|
||||
common::CompressedBufferWriter writer{n_symbols};
|
||||
auto d_compressed_buffer = dst->gidx_buffer.data();
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, feature_types,
|
||||
is_valid);
|
||||
dh::TypedDiscard<Tuple> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT, kIsDense> functor{
|
||||
d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid};
|
||||
thrust::transform_output_iterator<decltype(functor), decltype(discard)> out(discard, functor);
|
||||
|
||||
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
|
||||
// So we don't crash on n > 2^31
|
||||
size_t temp_storage_bytes = 0;
|
||||
using DispatchScan =
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
using DispatchScan = cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, std::int64_t>;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr));
|
||||
ctx->CUDACtx()->Stream()));
|
||||
#else
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
@ -262,7 +272,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr));
|
||||
cub::NullType(), batch.Size(), ctx->CUDACtx()->Stream()));
|
||||
#else
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
@ -270,20 +280,19 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
#endif
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, DeviceOrd device, common::Span<size_t> row_counts) {
|
||||
void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto device_accessor = dst->GetDeviceAccessor(ctx);
|
||||
common::CompressedBufferWriter writer(dst->NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.data();
|
||||
auto row_stride = dst->row_stride;
|
||||
dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(row_stride * dst->n_rows, ctx->CUDACtx()->Stream(), [=] __device__(bst_idx_t idx) {
|
||||
// For some reason this variable got captured as const
|
||||
auto writer_non_const = writer;
|
||||
size_t row_idx = idx / row_stride;
|
||||
size_t row_offset = idx % row_stride;
|
||||
if (row_offset >= row_counts[row_idx]) {
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
|
||||
device_accessor.NullValue(), idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, device_accessor.NullValue(), idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -292,12 +301,18 @@ template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing,
|
||||
bool is_dense, common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types, size_t row_stride,
|
||||
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts) {
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
bst_idx_t n_rows,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts)
|
||||
: EllpackPageImpl{ctx, cuts, is_dense, row_stride, n_rows} {
|
||||
common::SetDevice(ctx->Ordinal());
|
||||
|
||||
*this = EllpackPageImpl(ctx, cuts, is_dense, row_stride, n_rows);
|
||||
CopyDataToEllpack(batch, feature_types, this, ctx->Device(), missing);
|
||||
WriteNullValues(this, ctx->Device(), row_counts_span);
|
||||
if (this->IsDense()) {
|
||||
CopyDataToEllpack<true>(ctx, batch, feature_types, this, missing);
|
||||
} else {
|
||||
CopyDataToEllpack<false>(ctx, batch, feature_types, this, missing);
|
||||
}
|
||||
|
||||
WriteNullValues(ctx, this, row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
||||
@ -358,7 +373,8 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
|
||||
: is_dense{page.IsDense()},
|
||||
base_rowid{page.base_rowid},
|
||||
n_rows{page.Size()},
|
||||
cuts_{std::make_shared<common::HistogramCuts>(page.cut)} {
|
||||
cuts_{std::make_shared<common::HistogramCuts>(page.cut)},
|
||||
n_symbols_{CalcNumSymbols(ctx, page.IsDense(), cuts_)} {
|
||||
auto it = common::MakeIndexTransformIter(
|
||||
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
|
||||
row_stride = *std::max_element(it, it + page.Size());
|
||||
@ -373,7 +389,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(),
|
||||
cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream()));
|
||||
|
||||
auto accessor = this->GetDeviceAccessor(ctx->Device(), ft);
|
||||
auto accessor = this->GetDeviceAccessor(ctx, ft);
|
||||
auto null = accessor.NullValue();
|
||||
this->monitor_.Start("CopyGHistToEllpack");
|
||||
CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, d_compressed_buffer, null);
|
||||
@ -469,11 +485,14 @@ void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page,
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
void EllpackPageImpl::SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) {
|
||||
cuts_ = std::move(cuts);
|
||||
}
|
||||
|
||||
// Initialize the buffer to stored compressed features.
|
||||
void EllpackPageImpl::InitCompressedData(Context const* ctx) {
|
||||
monitor_.Start(__func__);
|
||||
auto num_symbols = NumSymbols();
|
||||
|
||||
auto num_symbols = this->NumSymbols();
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
std::size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols);
|
||||
@ -483,7 +502,7 @@ void EllpackPageImpl::InitCompressedData(Context const* ctx) {
|
||||
}
|
||||
|
||||
// Compress a CSR page into ELLPACK.
|
||||
void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
void EllpackPageImpl::CreateHistIndices(Context const* ctx,
|
||||
const SparsePage& row_batch,
|
||||
common::Span<FeatureType const> feature_types) {
|
||||
if (row_batch.Size() == 0) return;
|
||||
@ -493,7 +512,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows =
|
||||
std::min(dh::TotalMemory(device.ordinal) / (16 * row_stride * sizeof(Entry)),
|
||||
std::min(dh::TotalMemory(ctx->Ordinal()) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(row_batch.Size()));
|
||||
|
||||
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
|
||||
@ -531,7 +550,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
const dim3 block3(32, 8, 1); // 256 threads
|
||||
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(row_stride, block3.y), 1);
|
||||
auto device_accessor = GetDeviceAccessor(device);
|
||||
auto device_accessor = this->GetDeviceAccessor(ctx);
|
||||
dh::LaunchKernel{grid3, block3}( // NOLINT
|
||||
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.data(),
|
||||
row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(),
|
||||
@ -545,18 +564,18 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
|
||||
std::size_t EllpackPageImpl::MemCostBytes() const {
|
||||
return this->gidx_buffer.size_bytes() + sizeof(this->n_rows) + sizeof(this->is_dense) +
|
||||
sizeof(this->row_stride) + sizeof(this->base_rowid);
|
||||
sizeof(this->row_stride) + sizeof(this->base_rowid) + sizeof(this->n_symbols_);
|
||||
}
|
||||
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||
DeviceOrd device, common::Span<FeatureType const> feature_types) const {
|
||||
return {device,
|
||||
Context const* ctx, common::Span<FeatureType const> feature_types) const {
|
||||
return {ctx,
|
||||
cuts_,
|
||||
is_dense,
|
||||
row_stride,
|
||||
base_rowid,
|
||||
n_rows,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), NumSymbols()),
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), this->NumSymbols()),
|
||||
feature_types};
|
||||
}
|
||||
|
||||
@ -568,19 +587,20 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
|
||||
CHECK_NE(gidx_buffer.size(), 0);
|
||||
dh::safe_cuda(cudaMemcpyAsync(h_gidx_buffer->data(), gidx_buffer.data(), gidx_buffer.size_bytes(),
|
||||
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
|
||||
return {DeviceOrd::CPU(),
|
||||
Context cpu_ctx;
|
||||
return {ctx->IsCPU() ? ctx : &cpu_ctx,
|
||||
cuts_,
|
||||
is_dense,
|
||||
row_stride,
|
||||
base_rowid,
|
||||
n_rows,
|
||||
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()),
|
||||
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), this->NumSymbols()),
|
||||
feature_types};
|
||||
}
|
||||
|
||||
[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing(
|
||||
Context const* ctx, common::Span<FeatureType const> feature_types) const {
|
||||
auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types);
|
||||
auto d_acc = this->GetDeviceAccessor(ctx, feature_types);
|
||||
using T = typename decltype(d_acc.gidx_iter)::value_type;
|
||||
auto it = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0ull),
|
||||
|
||||
@ -43,20 +43,20 @@ struct EllpackDeviceAccessor {
|
||||
common::Span<const FeatureType> feature_types;
|
||||
|
||||
EllpackDeviceAccessor() = delete;
|
||||
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts,
|
||||
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
|
||||
EllpackDeviceAccessor(Context const* ctx, std::shared_ptr<const common::HistogramCuts> cuts,
|
||||
bool is_dense, bst_idx_t row_stride, bst_idx_t base_rowid, bst_idx_t n_rows,
|
||||
common::CompressedIterator<uint32_t> gidx_iter,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
: is_dense(is_dense),
|
||||
row_stride(row_stride),
|
||||
base_rowid(base_rowid),
|
||||
n_rows(n_rows),
|
||||
gidx_iter(gidx_iter),
|
||||
: is_dense{is_dense},
|
||||
row_stride{row_stride},
|
||||
base_rowid{base_rowid},
|
||||
n_rows{n_rows},
|
||||
gidx_iter{gidx_iter},
|
||||
feature_types{feature_types} {
|
||||
if (device.IsCUDA()) {
|
||||
cuts->cut_values_.SetDevice(device);
|
||||
cuts->cut_ptrs_.SetDevice(device);
|
||||
cuts->min_vals_.SetDevice(device);
|
||||
if (ctx->IsCUDA()) {
|
||||
cuts->cut_values_.SetDevice(ctx->Device());
|
||||
cuts->cut_ptrs_.SetDevice(ctx->Device());
|
||||
cuts->min_vals_.SetDevice(ctx->Device());
|
||||
gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan();
|
||||
feature_segments = cuts->cut_ptrs_.ConstDeviceSpan();
|
||||
min_fvalue = cuts->min_vals_.ConstDeviceSpan();
|
||||
@ -127,9 +127,6 @@ struct EllpackDeviceAccessor {
|
||||
[[nodiscard]] __device__ bool IsInRange(size_t row_id) const {
|
||||
return row_id >= base_rowid && row_id < base_rowid + n_rows;
|
||||
}
|
||||
/*! \brief Return the total number of symbols (total number of bins plus 1 for
|
||||
* not found). */
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return this->NumBins(); }
|
||||
|
||||
@ -160,7 +157,7 @@ class EllpackPageImpl {
|
||||
EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
|
||||
bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows);
|
||||
/**
|
||||
* @brief Constructor used for external memory.
|
||||
* @brief Constructor used for external memory with DMatrix.
|
||||
*/
|
||||
EllpackPageImpl(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
|
||||
const SparsePage& page, bool is_dense, size_t row_stride,
|
||||
@ -173,12 +170,14 @@ class EllpackPageImpl {
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
/**
|
||||
* @brief Constructor for Quantile DMatrix using an adapter.
|
||||
*/
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense,
|
||||
common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types, size_t row_stride,
|
||||
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
|
||||
bst_idx_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
|
||||
/**
|
||||
* @brief Constructor from an existing CPU gradient index.
|
||||
*/
|
||||
@ -214,7 +213,7 @@ class EllpackPageImpl {
|
||||
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; }
|
||||
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
|
||||
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; }
|
||||
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts);
|
||||
|
||||
[[nodiscard]] bool IsDense() const { return is_dense; }
|
||||
/** @return Estimation of memory cost of this page. */
|
||||
@ -224,12 +223,14 @@ class EllpackPageImpl {
|
||||
* @brief Return the total number of symbols (total number of bins plus 1 for not
|
||||
* found).
|
||||
*/
|
||||
[[nodiscard]] std::size_t NumSymbols() const { return cuts_->TotalBins() + 1; }
|
||||
[[nodiscard]] std::size_t NumSymbols() const { return this->n_symbols_; }
|
||||
void SetNumSymbols(bst_idx_t n_symbols) { this->n_symbols_ = n_symbols; }
|
||||
|
||||
/**
|
||||
* @brief Get an accessor that can be passed into CUDA kernels.
|
||||
*/
|
||||
[[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor(
|
||||
DeviceOrd device, common::Span<FeatureType const> feature_types = {}) const;
|
||||
Context const* ctx, common::Span<FeatureType const> feature_types = {}) const;
|
||||
/**
|
||||
* @brief Get an accessor for host code.
|
||||
*/
|
||||
@ -246,10 +247,9 @@ class EllpackPageImpl {
|
||||
/**
|
||||
* @brief Compress a single page of CSR data into ELLPACK.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
* @param row_batch The CSR page.
|
||||
*/
|
||||
void CreateHistIndices(DeviceOrd device, const SparsePage& row_batch,
|
||||
void CreateHistIndices(Context const* ctx, const SparsePage& row_batch,
|
||||
common::Span<FeatureType const> feature_types);
|
||||
/**
|
||||
* @brief Initialize the buffer to store compressed features.
|
||||
@ -272,6 +272,7 @@ class EllpackPageImpl {
|
||||
|
||||
private:
|
||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
||||
bst_idx_t n_symbols_{0};
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
|
||||
@ -55,7 +55,6 @@ template <typename T>
|
||||
xgboost_NVTX_FN_RANGE();
|
||||
auto* impl = page->Impl();
|
||||
|
||||
impl->SetCuts(this->cuts_);
|
||||
RET_IF_NOT(fi->Read(&impl->n_rows));
|
||||
RET_IF_NOT(fi->Read(&impl->is_dense));
|
||||
RET_IF_NOT(fi->Read(&impl->row_stride));
|
||||
@ -66,6 +65,12 @@ template <typename T>
|
||||
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
|
||||
}
|
||||
RET_IF_NOT(fi->Read(&impl->base_rowid));
|
||||
bst_idx_t n_symbols{0};
|
||||
RET_IF_NOT(fi->Read(&n_symbols));
|
||||
impl->SetNumSymbols(n_symbols);
|
||||
|
||||
impl->SetCuts(this->cuts_);
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
return true;
|
||||
}
|
||||
@ -84,6 +89,8 @@ template <typename T>
|
||||
[[maybe_unused]] auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx_buffer);
|
||||
bytes += common::WriteVec(fo, h_gidx_buffer);
|
||||
bytes += fo->Write(impl->base_rowid);
|
||||
bytes += fo->Write(impl->NumSymbols());
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
return bytes;
|
||||
}
|
||||
@ -93,9 +100,10 @@ template <typename T>
|
||||
|
||||
auto* impl = page->Impl();
|
||||
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
||||
impl->SetCuts(this->cuts_);
|
||||
|
||||
fi->Read(page, this->param_.prefetch_copy || !this->has_hmm_ats_);
|
||||
impl->SetCuts(this->cuts_);
|
||||
|
||||
dh::DefaultStream().Sync();
|
||||
|
||||
return true;
|
||||
@ -108,8 +116,7 @@ template <typename T>
|
||||
fo->Write(page);
|
||||
dh::DefaultStream().Sync();
|
||||
|
||||
auto* impl = page.Impl();
|
||||
return impl->MemCostBytes();
|
||||
return page.Impl()->MemCostBytes();
|
||||
}
|
||||
|
||||
#undef RET_IF_NOT
|
||||
|
||||
@ -81,6 +81,7 @@ class EllpackHostCacheStreamImpl {
|
||||
new_impl->is_dense = impl->IsDense();
|
||||
new_impl->row_stride = impl->row_stride;
|
||||
new_impl->base_rowid = impl->base_rowid;
|
||||
new_impl->SetNumSymbols(impl->NumSymbols());
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), impl->gidx_buffer.data(),
|
||||
impl->gidx_buffer.size_bytes(), cudaMemcpyDefault));
|
||||
@ -108,6 +109,7 @@ class EllpackHostCacheStreamImpl {
|
||||
impl->is_dense = page->IsDense();
|
||||
impl->row_stride = page->row_stride;
|
||||
impl->base_rowid = page->base_rowid;
|
||||
impl->SetNumSymbols(page->NumSymbols());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -58,9 +58,9 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
/**
|
||||
* Generate gradient index.
|
||||
*/
|
||||
size_t offset = 0;
|
||||
bst_idx_t offset = 0;
|
||||
iter.Reset();
|
||||
size_t n_batches_for_verification = 0;
|
||||
bst_idx_t n_batches_for_verification = 0;
|
||||
while (iter.Next()) {
|
||||
init_page();
|
||||
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||
@ -75,10 +75,11 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx));
|
||||
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
||||
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||
return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types,
|
||||
ext_info.row_stride, rows, cuts);
|
||||
return EllpackPageImpl{
|
||||
&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types,
|
||||
ext_info.row_stride, rows, cuts};
|
||||
});
|
||||
std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset);
|
||||
bst_idx_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset);
|
||||
offset += num_elements;
|
||||
|
||||
proxy->Info().num_row_ = BatchSamples(proxy);
|
||||
|
||||
@ -927,8 +927,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
dmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types),
|
||||
d_model, out_preds, batch_offset);
|
||||
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_, feature_types), d_model,
|
||||
out_preds, batch_offset);
|
||||
batch_offset += page.Size() * model.learner_model_param->OutputLength();
|
||||
}
|
||||
}
|
||||
@ -1068,7 +1068,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
} else {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
|
||||
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
|
||||
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_)};
|
||||
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
|
||||
std::numeric_limits<float>::quiet_NaN()};
|
||||
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
|
||||
@ -1139,8 +1139,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
} else {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
|
||||
auto impl = batch.Impl();
|
||||
auto acc =
|
||||
impl->GetDeviceAccessor(ctx_->Device(), p_fmat->Info().feature_types.ConstDeviceSpan());
|
||||
auto acc = impl->GetDeviceAccessor(ctx_, p_fmat->Info().feature_types.ConstDeviceSpan());
|
||||
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
|
||||
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
|
||||
std::numeric_limits<float>::quiet_NaN()};
|
||||
@ -1225,7 +1224,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
} else {
|
||||
bst_idx_t batch_offset = 0;
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_)};
|
||||
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
|
||||
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
|
||||
batch_offset += batch.Size();
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/aggregator.cuh" // for GlobalSum
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
|
||||
@ -39,9 +40,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
auto d_sum = sum.View(ctx->Device());
|
||||
CHECK(d_sum.CContiguous());
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto policy = thrust::cuda::par(alloc);
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::reduce_by_key(ctx->CUDACtx()->CTP(), key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
auto rc = collective::GlobalSum(ctx, info,
|
||||
@ -49,7 +48,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
d_sum.Size() * 2, ctx->Device()));
|
||||
SafeColl(rc);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
out(i) = static_cast<float>(
|
||||
CalcUnregularizedWeight(d_sum(i).GetGrad(), d_sum(i).GetHess()));
|
||||
|
||||
@ -186,7 +186,7 @@ class HistogramAgent {
|
||||
// Increases the throughput of this kernel significantly
|
||||
__device__ void ProcessFullTileShared(std::size_t offset) {
|
||||
std::size_t idx[kItemsPerThread];
|
||||
int ridx[kItemsPerThread];
|
||||
Idx ridx[kItemsPerThread];
|
||||
int gidx[kItemsPerThread];
|
||||
GradientPair gpair[kItemsPerThread];
|
||||
#pragma unroll
|
||||
|
||||
@ -338,7 +338,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Start(__func__);
|
||||
auto d_node_hist = histogram_.GetNodeHistogram(nidx);
|
||||
auto batch = page.Impl();
|
||||
auto acc = batch->GetDeviceAccessor(ctx_->Device());
|
||||
auto acc = batch->GetDeviceAccessor(ctx_);
|
||||
|
||||
auto d_ridx = partitioners_.at(k)->GetRows(nidx);
|
||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc,
|
||||
@ -497,7 +497,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
std::int32_t k{0};
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
|
||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_);
|
||||
auto go_left = GoLeftOp{d_matrix};
|
||||
|
||||
// Partition histogram.
|
||||
@ -567,9 +567,10 @@ struct GPUHistMakerDevice {
|
||||
dh::CopyTo(p_tree->GetSplitCategories(), &categories);
|
||||
auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
|
||||
auto d_categories = dh::ToSpan(categories);
|
||||
auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
|
||||
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_, ft);
|
||||
|
||||
std::vector<NodeSplitData> split_data(p_tree->NumNodes());
|
||||
auto const& tree = *p_tree;
|
||||
|
||||
@ -203,7 +203,8 @@ class BaseMGPUTest : public ::testing::Test {
|
||||
* available.
|
||||
*/
|
||||
template <typename Fn>
|
||||
auto DoTest(Fn&& fn, bool is_federated, [[maybe_unused]] bool emulate_if_single = false) const {
|
||||
auto DoTest([[maybe_unused]] Fn&& fn, bool is_federated,
|
||||
[[maybe_unused]] bool emulate_if_single = false) const {
|
||||
auto n_gpus = common::AllVisibleGPUs();
|
||||
if (is_federated) {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
|
||||
@ -19,7 +19,7 @@ TEST(EllpackPage, EmptyDMatrix) {
|
||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
||||
constexpr float kSparsity = 0;
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto& page = *dmat->GetBatches<EllpackPage>(
|
||||
&ctx, BatchParam{kMaxBin, tree::TrainParam::DftSparseThreshold()})
|
||||
.begin();
|
||||
@ -94,7 +94,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
|
||||
auto ellpack = EllpackPage(&ctx, m.get(), p);
|
||||
auto accessor = ellpack.Impl()->GetDeviceAccessor(ctx.Device());
|
||||
auto accessor = ellpack.Impl()->GetDeviceAccessor(&ctx);
|
||||
ASSERT_EQ(kCats, accessor.NumBins());
|
||||
|
||||
auto x_copy = x;
|
||||
@ -167,11 +167,11 @@ TEST(EllpackPage, Copy) {
|
||||
EXPECT_EQ(impl->base_rowid, current_row);
|
||||
|
||||
for (size_t i = 0; i < impl->Size(); i++) {
|
||||
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
|
||||
row_d.data().get()));
|
||||
dh::LaunchN(kCols,
|
||||
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
|
||||
thrust::copy(row_d.begin(), row_d.end(), row.begin());
|
||||
|
||||
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), current_row,
|
||||
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), current_row,
|
||||
row_result_d.data().get()));
|
||||
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
|
||||
|
||||
@ -223,12 +223,12 @@ TEST(EllpackPage, Compact) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
|
||||
row_d.data().get()));
|
||||
dh::LaunchN(kCols,
|
||||
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
thrust::copy(row_d.begin(), row_d.end(), row.begin());
|
||||
|
||||
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), compacted_row,
|
||||
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), compacted_row,
|
||||
row_result_d.data().get()));
|
||||
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ void TestEquivalent(float sparsity) {
|
||||
size_t num_elements = page_concatenated->Copy(&ctx, page, offset);
|
||||
offset += num_elements;
|
||||
}
|
||||
auto from_iter = page_concatenated->GetDeviceAccessor(ctx.Device());
|
||||
auto from_iter = page_concatenated->GetDeviceAccessor(&ctx);
|
||||
ASSERT_EQ(m.Info().num_col_, CudaArrayIterForTest::Cols());
|
||||
ASSERT_EQ(m.Info().num_row_, CudaArrayIterForTest::Rows());
|
||||
|
||||
@ -37,7 +37,7 @@ void TestEquivalent(float sparsity) {
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)};
|
||||
auto bp = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
for (auto& ellpack : dm->GetBatches<EllpackPage>(&ctx, bp)) {
|
||||
auto from_data = ellpack.Impl()->GetDeviceAccessor(ctx.Device());
|
||||
auto from_data = ellpack.Impl()->GetDeviceAccessor(&ctx);
|
||||
|
||||
std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size());
|
||||
std::vector<float> min_fvalues_iter(from_iter.min_fvalue.size());
|
||||
@ -71,7 +71,7 @@ void TestEquivalent(float sparsity) {
|
||||
auto data_buf = ellpack.Impl()->GetHostAccessor(&ctx, &buffer_from_data);
|
||||
ASSERT_NE(buffer_from_data.size(), 0);
|
||||
ASSERT_NE(buffer_from_iter.size(), 0);
|
||||
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols());
|
||||
CHECK_EQ(ellpack.Impl()->NumSymbols(), page_concatenated->NumSymbols());
|
||||
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
|
||||
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
|
||||
CHECK_EQ(data_buf.gidx_iter[i], data_iter.gidx_iter[i]);
|
||||
@ -146,10 +146,10 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) {
|
||||
auto impl = ellpack.Impl();
|
||||
std::vector<common::CompressedByteT> h_gidx;
|
||||
auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx);
|
||||
EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(ctx.Device()).NullValue());
|
||||
EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(ctx.Device()).NullValue());
|
||||
EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(&ctx).NullValue());
|
||||
EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(&ctx).NullValue());
|
||||
// null values get placed after valid values in a row
|
||||
EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(ctx.Device()).NullValue());
|
||||
EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(&ctx).NullValue());
|
||||
EXPECT_EQ(m.Info().num_col_, cols);
|
||||
EXPECT_EQ(m.Info().num_row_, rows);
|
||||
EXPECT_EQ(m.Info().num_nonzero_, rows* cols - 3);
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
TEST(SparsePageDMatrix, EllpackPage) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
@ -301,11 +301,11 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
||||
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
||||
|
||||
for (size_t i = 0; i < impl_ext->Size(); i++) {
|
||||
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
|
||||
row_d.data().get()));
|
||||
dh::LaunchN(kCols,
|
||||
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
|
||||
thrust::copy(row_d.begin(), row_d.end(), row.begin());
|
||||
|
||||
dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(ctx.Device()), current_row,
|
||||
dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(&ctx), current_row,
|
||||
row_ext_d.data().get()));
|
||||
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());
|
||||
|
||||
|
||||
@ -136,7 +136,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
|
||||
!use_shared_memory_histograms);
|
||||
builder.AllocateHistograms(&ctx, {0});
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
|
||||
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
|
||||
|
||||
@ -189,7 +189,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
|
||||
@ -205,7 +205,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
d_new_histogram, quantiser);
|
||||
|
||||
@ -230,7 +230,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(baseline), quantiser);
|
||||
|
||||
@ -298,7 +298,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(cat_hist), quantiser);
|
||||
}
|
||||
@ -315,7 +315,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(encode_hist), quantiser);
|
||||
}
|
||||
@ -449,7 +449,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
auto impl = page.Impl();
|
||||
if (k == 0) {
|
||||
// Initialization
|
||||
auto d_matrix = impl->GetDeviceAccessor(ctx.Device());
|
||||
auto d_matrix = impl->GetDeviceAccessor(&ctx);
|
||||
fg = std::make_unique<FeatureGroups>(impl->Cuts());
|
||||
auto init = GradientPairInt64{0, 0};
|
||||
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
|
||||
@ -465,7 +465,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(&ctx),
|
||||
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
++k;
|
||||
@ -491,7 +491,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
|
||||
d_histogram.size(), force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(&ctx),
|
||||
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user