Remove column major specialization. (#5755)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -154,8 +154,8 @@ struct WriteCompressedEllpackFunctor {
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
@@ -209,51 +209,6 @@ void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT, typename AdapterBatchT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
||||
EllpackPageImpl* dst, float missing) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
auto row_stride = dst->row_stride;
|
||||
size_t begin = 0;
|
||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
size_t output_position =
|
||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
||||
output_position);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
@@ -284,12 +239,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense
|
||||
|
||||
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
||||
adapter->NumRows());
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(batch, this, adapter->DeviceIdx(), missing);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, batch, this, missing);
|
||||
}
|
||||
|
||||
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
|
||||
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user