enable rocm, fix hist_util.cuh
This commit is contained in:
parent
d3be67ad8e
commit
ba9e00d911
@ -76,11 +76,20 @@ void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feat
|
||||
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
|
||||
return thrust::min(num_cuts_per_feature, column_size);
|
||||
});
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
thrust::exclusive_scan(thrust::hip::par(alloc), cut_ptr_it,
|
||||
cut_ptr_it + column_sizes_scan->size(),
|
||||
cuts_ptr->DevicePointer());
|
||||
thrust::exclusive_scan(thrust::hip::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
#else
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
|
||||
cut_ptr_it + column_sizes_scan->size(),
|
||||
cuts_ptr->DevicePointer());
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
#endif
|
||||
}
|
||||
|
||||
inline size_t constexpr BytesPerElement(bool has_weight) {
|
||||
@ -179,8 +188,14 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
|
||||
&column_sizes_scan,
|
||||
&sorted_entries);
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
thrust::sort(thrust::hip::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
#else
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
#endif
|
||||
|
||||
if (sketch_container->HasCategorical()) {
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
@ -205,7 +220,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
size_t columns, size_t begin, size_t end,
|
||||
SketchContainer *sketch_container) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
#endif
|
||||
|
||||
info.weights_.SetDevice(device);
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
@ -238,11 +259,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
||||
return weights[group_idx];
|
||||
});
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
auto retit = thrust::copy_if(thrust::hip::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
#else
|
||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
#endif
|
||||
|
||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||
} else {
|
||||
CHECK_EQ(batch.NumRows(), weights.size());
|
||||
@ -251,11 +282,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
[=]__device__(size_t idx) -> float {
|
||||
return weights[batch.GetElement(idx).row_idx];
|
||||
});
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
auto retit = thrust::copy_if(thrust::hip::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
#else
|
||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
#endif
|
||||
|
||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user