Fix compilation with the latest ctk. (#10123)

This commit is contained in:
Jiaming Yuan 2024-03-15 08:04:41 +08:00 committed by GitHub
parent 617970a0c2
commit 56b1868278
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 8 deletions

View File

@ -173,13 +173,13 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
auto scan_key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
[=] XGBOOST_DEVICE(size_t idx) { return dh::SegmentId(out_ptr, idx); });
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
auto ind = get_ind(t); // == 0 if element is from x
// x_counter, y_counter
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
});
// Compute the index for both x and y (which of the element in a and b are used in each

View File

@ -171,11 +171,10 @@ struct WriteCompressedEllpackFunctor {
using Tuple = thrust::tuple<size_t, size_t, size_t>;
__device__ size_t operator()(Tuple out) {
auto e = batch.GetElement(out.get<2>());
auto e = batch.GetElement(thrust::get<2>(out));
if (is_valid(e)) {
// -1 because the scan is inclusive
size_t output_position =
accessor.row_stride * e.row_idx + out.get<1>() - 1;
size_t output_position = accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1;
uint32_t bin_idx = 0;
if (common::IsCat(feature_types, e.column_idx)) {
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
@ -192,8 +191,8 @@ template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
if (thrust::get<0>(a) == thrust::get<0>(b)) {
thrust::get<1>(b) += thrust::get<1>(a);
return b;
}
// Not equal