Loop over copy_if (#6201)
* Loop over copy_if * Catch OOM. Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -35,24 +35,38 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct COOToEntryOp {
|
||||
AdapterBatchT batch;
|
||||
__device__ Entry operator()(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
auto& batch = adapter->Value();
|
||||
auto transform_f = [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}; // NOLINT
|
||||
float missing) {
|
||||
auto batch = adapter->Value();
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
thrust::transform_iterator<decltype(transform_f), decltype(counting), Entry>
|
||||
transform_iter(counting, transform_f);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::copy_if(
|
||||
thrust::cuda::par(alloc), transform_iter, transform_iter + batch.Size(),
|
||||
thrust::device_pointer_cast(data.data()), IsValidFunctor(missing));
|
||||
COOToEntryOp<decltype(batch)> transform_op{batch};
|
||||
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
|
||||
transform_iter(counting, transform_op);
|
||||
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
|
||||
// See thrust issue #1302
|
||||
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
|
||||
auto begin_output = thrust::device_pointer_cast(data.data());
|
||||
for (size_t offset = 0; offset < batch.Size(); offset += max_copy_size) {
|
||||
auto begin_input = transform_iter + offset;
|
||||
auto end_input =
|
||||
transform_iter + std::min(offset + max_copy_size, batch.Size());
|
||||
begin_output =
|
||||
thrust::copy_if(thrust::cuda::par(alloc), begin_input, end_input,
|
||||
begin_output, IsValidFunctor(missing));
|
||||
}
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
@@ -77,8 +91,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
|
||||
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
|
||||
sparse_page_.data.Resize(info_.num_nonzero_);
|
||||
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), missing);
|
||||
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
|
||||
Reference in New Issue
Block a user