Check cub errors. (#10721)

- Make sure cuda error returned by cub scan is caught.
- Avoid temporary buffer allocation in thrust device vector.
This commit is contained in:
Jiaming Yuan
2024-08-21 02:50:26 +08:00
committed by GitHub
parent b949a4bf7b
commit 508ac13243
5 changed files with 27 additions and 21 deletions

View File

@@ -49,9 +49,9 @@ void TestUpdatePositionBatch() {
TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); }
void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Segment>& segments) {
thrust::device_vector<uint32_t> ridx = ridx_in;
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
thrust::device_vector<bst_uint> counts(segments.size());
thrust::device_vector<cuda_impl::RowIndexT> ridx = ridx_in;
thrust::device_vector<cuda_impl::RowIndexT> ridx_tmp(ridx_in.size());
thrust::device_vector<cuda_impl::RowIndexT> counts(segments.size());
auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; };
std::vector<int> op_data(segments.size());
@@ -66,7 +66,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
nullptr));
dh::device_vector<int8_t> tmp;
dh::DeviceUVector<int8_t> tmp;
SortPositionBatch<decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op,
&tmp);
@@ -91,5 +91,4 @@ TEST(GpuHist, SortPositionBatch) {
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 6}});
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{3, 6}, {0, 2}});
}
} // namespace xgboost::tree