Fix compilation with the latest ctk. (#10123)
This commit is contained in:
parent
617970a0c2
commit
56b1868278
@ -173,13 +173,13 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
|
|
||||||
auto scan_key_it = dh::MakeTransformIterator<size_t>(
|
auto scan_key_it = dh::MakeTransformIterator<size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
|
[=] XGBOOST_DEVICE(size_t idx) { return dh::SegmentId(out_ptr, idx); });
|
||||||
|
|
||||||
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
|
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
|
||||||
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
|
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
|
||||||
auto ind = get_ind(t); // == 0 if element is from x
|
auto ind = get_ind(t); // == 0 if element is from x
|
||||||
// x_counter, y_counter
|
// x_counter, y_counter
|
||||||
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
|
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
|
||||||
});
|
});
|
||||||
|
|
||||||
// Compute the index for both x and y (which of the element in a and b are used in each
|
// Compute the index for both x and y (which of the element in a and b are used in each
|
||||||
|
|||||||
@ -171,11 +171,10 @@ struct WriteCompressedEllpackFunctor {
|
|||||||
|
|
||||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||||
__device__ size_t operator()(Tuple out) {
|
__device__ size_t operator()(Tuple out) {
|
||||||
auto e = batch.GetElement(out.get<2>());
|
auto e = batch.GetElement(thrust::get<2>(out));
|
||||||
if (is_valid(e)) {
|
if (is_valid(e)) {
|
||||||
// -1 because the scan is inclusive
|
// -1 because the scan is inclusive
|
||||||
size_t output_position =
|
size_t output_position = accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1;
|
||||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
|
||||||
uint32_t bin_idx = 0;
|
uint32_t bin_idx = 0;
|
||||||
if (common::IsCat(feature_types, e.column_idx)) {
|
if (common::IsCat(feature_types, e.column_idx)) {
|
||||||
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
|
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
|
||||||
@ -192,8 +191,8 @@ template <typename Tuple>
|
|||||||
struct TupleScanOp {
|
struct TupleScanOp {
|
||||||
__device__ Tuple operator()(Tuple a, Tuple b) {
|
__device__ Tuple operator()(Tuple a, Tuple b) {
|
||||||
// Key equal
|
// Key equal
|
||||||
if (a.template get<0>() == b.template get<0>()) {
|
if (thrust::get<0>(a) == thrust::get<0>(b)) {
|
||||||
b.template get<1>() += a.template get<1>();
|
thrust::get<1>(b) += thrust::get<1>(a);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
// Not equal
|
// Not equal
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user