enable rocm, fix hist_util.cuh

This commit is contained in:
amdsc21 2023-03-08 06:36:15 +01:00
parent d3be67ad8e
commit ba9e00d911

View File

@ -76,11 +76,20 @@ void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feat
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
#if defined(XGBOOST_USE_HIP)
thrust::exclusive_scan(thrust::hip::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::hip::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
#else
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
#endif
}
inline size_t constexpr BytesPerElement(bool has_weight) {
@ -179,8 +188,14 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
&column_sizes_scan,
&sorted_entries);
dh::XGBDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
thrust::sort(thrust::hip::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
#else
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
#endif
if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
@ -205,7 +220,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
size_t columns, size_t begin, size_t end,
SketchContainer *sketch_container) {
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#else
dh::safe_cuda(cudaSetDevice(device));
#endif
info.weights_.SetDevice(device);
auto weights = info.weights_.ConstDeviceSpan();
@ -238,11 +259,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
return weights[group_idx];
});
#if defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#else
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#endif
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
} else {
CHECK_EQ(batch.NumRows(), weights.size());
@ -251,11 +282,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
[=]__device__(size_t idx) -> float {
return weights[batch.GetElement(idx).row_idx];
});
#if defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#else
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#endif
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
}