diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 856404107..30c262190 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -76,11 +76,20 @@ void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feat column_sizes_scan->begin(), [=] __device__(size_t column_size) { return thrust::min(num_cuts_per_feature, column_size); }); + +#if defined(XGBOOST_USE_HIP) + thrust::exclusive_scan(thrust::hip::par(alloc), cut_ptr_it, + cut_ptr_it + column_sizes_scan->size(), + cuts_ptr->DevicePointer()); + thrust::exclusive_scan(thrust::hip::par(alloc), column_sizes_scan->begin(), + column_sizes_scan->end(), column_sizes_scan->begin()); +#else thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it, cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer()); thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), column_sizes_scan->end(), column_sizes_scan->begin()); +#endif } inline size_t constexpr BytesPerElement(bool has_weight) { @@ -179,8 +188,14 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, &column_sizes_scan, &sorted_entries); dh::XGBDeviceAllocator alloc; + +#if defined(XGBOOST_USE_HIP) + thrust::sort(thrust::hip::par(alloc), sorted_entries.begin(), + sorted_entries.end(), detail::EntryCompareOp()); +#else thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); +#endif if (sketch_container->HasCategorical()) { auto d_cuts_ptr = cuts_ptr.DeviceSpan(); @@ -205,7 +220,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, size_t columns, size_t begin, size_t end, SketchContainer *sketch_container) { dh::XGBCachingDeviceAllocator alloc; + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device)); +#else dh::safe_cuda(cudaSetDevice(device)); +#endif + info.weights_.SetDevice(device); auto weights = info.weights_.ConstDeviceSpan(); @@ -238,11 +259,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx); return weights[group_idx]; }); + +#if defined(XGBOOST_USE_HIP) + auto retit = thrust::copy_if(thrust::hip::par(alloc), + weight_iter + begin, weight_iter + end, + batch_iter + begin, + d_temp_weights.data(), // output + is_valid); +#else auto retit = thrust::copy_if(thrust::cuda::par(alloc), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output is_valid); +#endif + CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } else { CHECK_EQ(batch.NumRows(), weights.size()); @@ -251,11 +282,21 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, [=]__device__(size_t idx) -> float { return weights[batch.GetElement(idx).row_idx]; }); + +#if defined(XGBOOST_USE_HIP) + auto retit = thrust::copy_if(thrust::hip::par(alloc), + weight_iter + begin, weight_iter + end, + batch_iter + begin, + d_temp_weights.data(), // output + is_valid); +#else auto retit = thrust::copy_if(thrust::cuda::par(alloc), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output is_valid); +#endif + CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); }