This commit is contained in:
Hui Liu 2023-10-23 22:29:48 -07:00
parent 558352afc9
commit 79319dfd4d
8 changed files with 115 additions and 100 deletions

View File

@ -36,21 +36,21 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
private:
static constexpr std::size_t kUuidLength =
#if defined(XGBOOST_USE_HIP)
sizeof(hipUUID) / sizeof(uint64_t);
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
#elif defined(XGBOOST_USE_HIP)
sizeof(hipUUID) / sizeof(uint64_t);
#endif
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
#if defined(XGBOOST_USE_HIP)
hipUUID id;
hipDeviceGetUuid(&id, device_ordinal_);
std::memcpy(uuid.data(), static_cast<void *>(&id), sizeof(id));
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
#elif defined(XGBOOST_USE_HIP)
hipUUID id;
hipDeviceGetUuid(&id, device_ordinal_);
std::memcpy(uuid.data(), static_cast<void *>(&id), sizeof(id));
#endif
}

View File

@ -11,10 +11,10 @@
#include <cstddef> // size_t
#include <cstdint> // int32_t
#if defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp>
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
#include <cub/cub.cuh> // DispatchSegmentedRadixSort,NullType,DoubleBuffer
#elif defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp>
#endif
#include <iterator> // distance

View File

@ -175,17 +175,17 @@ void GetColumnSizesScan(DeviceOrd device, size_t num_columns, std::size_t num_cu
return thrust::min(num_cuts_per_feature, column_size);
});
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
#elif defined(XGBOOST_USE_HIP)
thrust::exclusive_scan(thrust::hip::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::hip::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
#elif defined(XGBOOST_USE_CUDA)
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
#endif
}
@ -309,12 +309,12 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
&sorted_entries);
dh::XGBDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
thrust::sort(thrust::hip::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
#elif defined(XGBOOST_USE_HIP)
thrust::sort(thrust::hip::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
#endif
if (sketch_container->HasCategorical()) {
@ -374,14 +374,14 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
return weights[group_idx];
});
#if defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
#if defined(XGBOOST_USE_CUDA)
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#elif defined(XGBOOST_USE_CUDA)
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
#elif defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
@ -397,14 +397,14 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
return weights[batch.GetElement(idx).row_idx];
});
#if defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
#if defined(XGBOOST_USE_CUDA)
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
is_valid);
#elif defined(XGBOOST_USE_CUDA)
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
#elif defined(XGBOOST_USE_HIP)
auto retit = thrust::copy_if(thrust::hip::par(alloc),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output

View File

@ -184,15 +184,15 @@ class SketchContainer {
d_column_scan = this->columns_ptr_.DeviceSpan();
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
size_t n_uniques = dh::SegmentedUnique(
thrust::hip::par(alloc), d_column_scan.data(),
thrust::cuda::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);
#elif defined(XGBOOST_USE_CUDA)
#elif defined(XGBOOST_USE_HIP)
size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc), d_column_scan.data(),
thrust::hip::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);

View File

@ -217,12 +217,12 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
#if defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());
#endif
auto n_segments = std::distance(seg_beg, seg_end) - 1;

View File

@ -6,10 +6,10 @@
#include <algorithm>
#include <cassert>
#if defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp> // NOLINT
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
#include <cub/cub.cuh> // NOLINT
#elif defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp> // NOLINT
#endif
#include <limits>
@ -127,16 +127,16 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
thrust::make_counting_iterator(0),
[=] XGBOOST_DEVICE(size_t i) { return predts[d_sorted_idx[i]]; });
#if defined(XGBOOST_USE_HIP)
auto end_unique = thrust::unique_by_key_copy(
thrust::hip::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
auto end_unique = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
#elif defined(XGBOOST_USE_HIP)
auto end_unique = thrust::unique_by_key_copy(
thrust::hip::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
#endif
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
@ -179,10 +179,10 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
Pair last = cache->fptp.back();
#if defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), in, in + d_unique_idx.size());
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
double auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size());
#elif defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), in, in + d_unique_idx.size());
#endif
return std::make_tuple(last.first, last.second, auc);
@ -239,14 +239,14 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
double tp_sum;
double auc_sum;
#if defined(XGBOOST_USE_HIP)
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::hip::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_HIP)
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::hip::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
#endif
if (tp_sum != 0 && !std::isnan(auc_sum)) {
@ -329,12 +329,12 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
return auc;
});
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_in,
#if defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), dh::tbegin(d_auc));
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
#elif defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), dh::tbegin(d_auc));
#endif
@ -410,9 +410,9 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
dh::TemporaryArray<uint32_t> unique_class_ptr(d_class_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::hip::par(alloc),
thrust::cuda::par(alloc),
dh::tbegin(d_class_ptr),
dh::tend(d_class_ptr),
uni_key,
@ -421,9 +421,9 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#elif defined(XGBOOST_USE_CUDA)
#elif defined(XGBOOST_USE_HIP)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
thrust::hip::par(alloc),
dh::tbegin(d_class_ptr),
dh::tend(d_class_ptr),
uni_key,
@ -553,14 +553,14 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
thrust::make_counting_iterator(0),
[=] XGBOOST_DEVICE(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
#if defined(XGBOOST_USE_HIP)
size_t n_valid = thrust::count_if(
thrust::hip::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
size_t n_valid = thrust::count_if(
thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
#elif defined(XGBOOST_USE_HIP)
size_t n_valid = thrust::count_if(
thrust::hip::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
#endif
if (n_valid < info.group_ptr_.size() - 1) {
@ -659,12 +659,12 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
/**
* Scale the AUC with number of items in each group.
*/
#if defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0);
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
double auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0);
#elif defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0);
#endif
return std::make_pair(auc, n_valid);
@ -694,14 +694,14 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
dh::XGBCachingDeviceAllocator<char> alloc;
double total_pos, total_neg;
#if defined(XGBOOST_USE_HIP)
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::hip::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_HIP)
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::hip::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
#endif
if (total_pos <= 0.0 || total_neg <= 0.0) {
@ -755,13 +755,13 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
#if defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
#elif defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
@ -834,9 +834,9 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
dh::TemporaryArray<uint32_t> unique_class_ptr(d_group_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::hip::par(alloc),
thrust::cuda::par(alloc),
dh::tbegin(d_group_ptr),
dh::tend(d_group_ptr),
uni_key,
@ -845,9 +845,9 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#elif defined(XGBOOST_USE_CUDA)
#elif defined(XGBOOST_USE_HIP)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
thrust::hip::par(alloc),
dh::tbegin(d_group_ptr),
dh::tend(d_group_ptr),
uni_key,
@ -909,14 +909,14 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
return thrust::make_pair(0.0, static_cast<uint32_t>(1));
});
#if defined(XGBOOST_USE_HIP)
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::hip::par(alloc), it, it + n_groups,
thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::cuda::par(alloc), it, it + n_groups,
thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
#elif defined(XGBOOST_USE_HIP)
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::hip::par(alloc), it, it + n_groups,
thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
#endif
}
return std::make_pair(auc, n_groups - invalid_groups);
@ -949,13 +949,13 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels.View(ctx->Device());
#if defined(XGBOOST_USE_HIP)
if (thrust::any_of(thrust::hip::par(alloc), dh::tbegin(labels.Values()),
#if defined(XGBOOST_USE_CUDA)
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
}
#elif defined(XGBOOST_USE_CUDA)
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
#elif defined(XGBOOST_USE_HIP)
if (thrust::any_of(thrust::hip::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
}
@ -981,13 +981,13 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
return thrust::make_pair(y * w, (1.0 - y) * w);
});
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
#if defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
#elif defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});

View File

@ -62,6 +62,21 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
return PackedReduceResult{v, wt};
},
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#elif defined(XGBOOST_USE_HIP)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce(
thrust::hip::par(alloc), begin, end,
[=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx);
auto target_id = std::get<1>(idx);
auto res = loss(i, sample_id, target_id);
float v{std::get<0>(res)}, wt{std::get<1>(res)};
return PackedReduceResult{v, wt};
},
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)

View File

@ -11,7 +11,9 @@
#include "evaluate_splits.cuh"
#include "expand_entry.cuh"
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32
#elif defined(XGBOOST_USE_HIP)
#include <hip/hip_cooperative_groups.h>
#ifdef __AMDGCN_WAVEFRONT_SIZE
@ -20,8 +22,6 @@
#endif
#define WARP_SIZE WAVEFRONT_SIZE
#elif defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32
#endif
#if defined(XGBOOST_USE_HIP)