Move device dmatrix construction code into ellpack. (#5623)
This commit is contained in:
parent
33e052b1e5
commit
67d267f9da
@ -351,6 +351,8 @@ class EllpackPage {
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
|
||||
EllpackPage(EllpackPage&& that);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
size_t Size() const;
|
||||
|
||||
|
||||
@ -212,6 +212,27 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
int device_idx_;
|
||||
};
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
size_t row_stride = thrust::reduce(
|
||||
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_
|
||||
|
||||
@ -19,181 +19,6 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
size_t row_stride = thrust::reduce(
|
||||
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct WriteCompressedEllpackFunctor {
|
||||
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
|
||||
const common::CompressedBufferWriter& writer,
|
||||
AdapterBatchT batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
const IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(std::move(batch)),
|
||||
accessor(std::move(accessor)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
AdapterBatchT batch;
|
||||
EllpackDeviceAccessor accessor;
|
||||
IsValidFunctor is_valid;
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
__device__ size_t operator()(Tuple out) {
|
||||
auto e = batch.GetElement(out.get<2>());
|
||||
if (is_valid(e)) {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
// compaction needs to be modified to do this, so we manually define a
|
||||
// segmented stream compaction via operators on an inclusive scan. The output
|
||||
// of this inclusive scan is fed to a custom function which works out the
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
IsValidFunctor is_valid(missing);
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx).row_idx; });
|
||||
auto value_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting, [=] __device__(size_t idx) -> size_t {
|
||||
return is_valid(batch.GetElement(idx));
|
||||
});
|
||||
|
||||
auto key_value_index_iter = thrust::make_zip_iterator(
|
||||
thrust::make_tuple(key_iter, value_iter, counting));
|
||||
|
||||
// Tuple[0] = The row index of the input, used as a key to define segments
|
||||
// Tuple[1] = Scanned flags of valid elements for each row
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
thrust::discard_iterator<size_t> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
|
||||
key_value_index_iter + batch.Size(), out,
|
||||
[=] __device__(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.get<0>() == b.get<0>()) {
|
||||
b.get<1>() += a.get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT, typename AdapterBatchT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
||||
EllpackPageImpl* dst, float missing) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
auto row_stride = dst->row_stride;
|
||||
size_t begin = 0;
|
||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
size_t output_position =
|
||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
||||
output_position);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
auto row_stride = dst->row_stride;
|
||||
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
size_t row_idx = idx / row_stride;
|
||||
size_t row_offset = idx % row_stride;
|
||||
if (row_offset >= row_counts[row_idx]) {
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
|
||||
device_accessor.NullValue(), idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
@ -210,30 +35,25 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int
|
||||
size_t row_stride =
|
||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
|
||||
row_counts_span, row_stride);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
adapter->NumRows());
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(batch, ellpack_page_->Impl(), adapter->DeviceIdx(),
|
||||
missing);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, batch, ellpack_page_->Impl(), missing);
|
||||
}
|
||||
|
||||
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(), row_counts_span);
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
}
|
||||
template DeviceDMatrix::DeviceDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread, int max_bin);
|
||||
template DeviceDMatrix::DeviceDMatrix(CupyAdapter* adapter, float missing,
|
||||
|
||||
#define DEVICE_DMARIX_SPECIALIZATION(__ADAPTER_T) \
|
||||
template DeviceDMatrix::DeviceDMatrix(__ADAPTER_T* adapter, float missing, \
|
||||
int nthread, int max_bin);
|
||||
|
||||
DEVICE_DMARIX_SPECIALIZATION(CudfAdapter);
|
||||
DEVICE_DMARIX_SPECIALIZATION(CupyAdapter);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -3,10 +3,12 @@
|
||||
*/
|
||||
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/random.h"
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -17,6 +19,8 @@ EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
|
||||
|
||||
EllpackPage::~EllpackPage() = default;
|
||||
|
||||
EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }
|
||||
|
||||
size_t EllpackPage::Size() const { return impl_->Size(); }
|
||||
|
||||
void EllpackPage::SetBaseRowId(size_t row_id) { impl_->SetBaseRowId(row_id); }
|
||||
@ -74,22 +78,19 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
InitCompressedData(device);
|
||||
this->InitCompressedData(device);
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
}
|
||||
|
||||
size_t GetRowStride(DMatrix* dmat) {
|
||||
if (dmat->IsDense()) return dmat->Info().num_col_;
|
||||
|
||||
size_t row_stride = 0;
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(
|
||||
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
return row_stride;
|
||||
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
const SparsePage& page, bool is_dense,
|
||||
size_t row_stride)
|
||||
: cuts_(std::move(cuts)),
|
||||
is_dense(is_dense),
|
||||
n_rows(page.Size()),
|
||||
row_stride(row_stride) {
|
||||
this->InitCompressedData(device);
|
||||
this->CreateHistIndices(device, page);
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix in memory.
|
||||
@ -117,6 +118,190 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct WriteCompressedEllpackFunctor {
|
||||
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
|
||||
const common::CompressedBufferWriter& writer,
|
||||
AdapterBatchT batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
const data::IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(std::move(batch)),
|
||||
accessor(std::move(accessor)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
AdapterBatchT batch;
|
||||
EllpackDeviceAccessor accessor;
|
||||
data::IsValidFunctor is_valid;
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
__device__ size_t operator()(Tuple out) {
|
||||
auto e = batch.GetElement(out.get<2>());
|
||||
if (is_valid(e)) {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
// compaction needs to be modified to do this, so we manually define a
|
||||
// segmented stream compaction via operators on an inclusive scan. The output
|
||||
// of this inclusive scan is fed to a custom function which works out the
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) {
|
||||
return batch.GetElement(idx).row_idx;
|
||||
});
|
||||
auto value_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) -> size_t {
|
||||
return is_valid(batch.GetElement(idx));
|
||||
});
|
||||
|
||||
auto key_value_index_iter = thrust::make_zip_iterator(
|
||||
thrust::make_tuple(key_iter, value_iter, counting));
|
||||
|
||||
// Tuple[0] = The row index of the input, used as a key to define segments
|
||||
// Tuple[1] = Scanned flags of valid elements for each row
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
thrust::discard_iterator<size_t> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
|
||||
key_value_index_iter + batch.Size(), out,
|
||||
[=] __device__(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.get<0>() == b.get<0>()) {
|
||||
b.get<1>() += a.get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT, typename AdapterBatchT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
||||
EllpackPageImpl* dst, float missing) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
auto row_stride = dst->row_stride;
|
||||
size_t begin = 0;
|
||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
size_t output_position =
|
||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
||||
output_position);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
auto row_stride = dst->row_stride;
|
||||
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
size_t row_idx = idx / row_stride;
|
||||
size_t row_offset = idx % row_stride;
|
||||
if (row_offset >= row_counts[row_idx]) {
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
|
||||
device_accessor.NullValue(), idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin, common::Span<size_t> row_counts_span,
|
||||
size_t row_stride) {
|
||||
common::HistogramCuts cuts =
|
||||
common::AdapterDeviceSketch(adapter, max_bin, missing);
|
||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||
auto& batch = adapter->Value();
|
||||
|
||||
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
||||
adapter->NumRows());
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(batch, this, adapter->DeviceIdx(), missing);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, batch, this, missing);
|
||||
}
|
||||
|
||||
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \
|
||||
common::Span<size_t> row_counts_span, \
|
||||
size_t row_stride);
|
||||
|
||||
ELLPACK_SPECIALIZATION(data::CudfAdapter)
|
||||
ELLPACK_SPECIALIZATION(data::CupyAdapter)
|
||||
|
||||
// A functor that copies the data from one EllpackPage to another.
|
||||
struct CopyPage {
|
||||
common::CompressedBufferWriter cbw;
|
||||
@ -295,15 +480,4 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(int device) const {
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.ConstDevicePointer(),
|
||||
NumSymbols()));
|
||||
}
|
||||
|
||||
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
const SparsePage& page, bool is_dense,
|
||||
size_t row_stride)
|
||||
: cuts_(std::move(cuts)),
|
||||
is_dense(is_dense),
|
||||
n_rows(page.Size()),
|
||||
row_stride(row_stride) {
|
||||
this->InitCompressedData(device);
|
||||
this->CreateHistIndices(device, page);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -159,6 +159,10 @@ class EllpackPageImpl {
|
||||
*/
|
||||
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
template <typename AdapterT>
|
||||
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin, common::Span<size_t> row_counts_span,
|
||||
size_t row_stride);
|
||||
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
@ -229,6 +233,19 @@ public:
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
inline size_t GetRowStride(DMatrix* dmat) {
|
||||
if (dmat->IsDense()) return dmat->Info().num_col_;
|
||||
|
||||
size_t row_stride = 0;
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(
|
||||
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
return row_stride;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
|
||||
@ -13,20 +13,6 @@
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
size_t GetRowStride(DMatrix* dmat) {
|
||||
if (dmat->IsDense()) return dmat->Info().num_col_;
|
||||
|
||||
size_t row_stride = 0;
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(
|
||||
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
// Build the quantile sketch across the whole input data, then use the histogram cuts to compress
|
||||
// each CSR page, and write the accumulated ELLPACK pages to disk.
|
||||
EllpackPageSource::EllpackPageSource(DMatrix* dmat,
|
||||
|
||||
@ -128,7 +128,8 @@ inline void CheckCacheFileExists(const std::string& file) {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Given a set of cache files and page type, this object iterates over batches using prefetching for improved performance. Not thread safe.
|
||||
* \brief Given a set of cache files and page type, this object iterates over batches
|
||||
* using prefetching for improved performance. Not thread safe.
|
||||
*
|
||||
* \tparam PageT Type of the page t.
|
||||
*/
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user