Check cub errors. (#10721)
- Make sure cuda error returned by cub scan is caught. - Avoid temporary buffer allocation in thrust device vector.
This commit is contained in:
@@ -49,9 +49,9 @@ void TestUpdatePositionBatch() {
|
||||
TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); }
|
||||
|
||||
void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Segment>& segments) {
|
||||
thrust::device_vector<uint32_t> ridx = ridx_in;
|
||||
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
|
||||
thrust::device_vector<bst_uint> counts(segments.size());
|
||||
thrust::device_vector<cuda_impl::RowIndexT> ridx = ridx_in;
|
||||
thrust::device_vector<cuda_impl::RowIndexT> ridx_tmp(ridx_in.size());
|
||||
thrust::device_vector<cuda_impl::RowIndexT> counts(segments.size());
|
||||
|
||||
auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; };
|
||||
std::vector<int> op_data(segments.size());
|
||||
@@ -66,7 +66,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
||||
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
|
||||
nullptr));
|
||||
dh::device_vector<int8_t> tmp;
|
||||
dh::DeviceUVector<int8_t> tmp;
|
||||
SortPositionBatch<decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
|
||||
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op,
|
||||
&tmp);
|
||||
@@ -91,5 +91,4 @@ TEST(GpuHist, SortPositionBatch) {
|
||||
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 6}});
|
||||
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{3, 6}, {0, 2}});
|
||||
}
|
||||
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user