Remove column major specialization. (#5755)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -35,51 +35,12 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::device_vector<size_t> column_sizes(adapter->NumColumns());
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
auto& batch = adapter->Value();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::device_vector<size_t> temp_row_ptr(
|
||||
thrust::device_pointer_cast(row_ptr.data()),
|
||||
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
size_t begin = 0;
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
auto& batch = adapter->Value();
|
||||
auto transform_f = [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
@@ -116,13 +77,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
|
||||
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
|
||||
sparse_page_.data.Resize(info_.num_nonzero_);
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
}
|
||||
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
|
||||
Reference in New Issue
Block a user