add namespace aliases to reduce code

This commit is contained in:
Hui Liu
2023-10-27 09:11:55 -07:00
parent e00131c465
commit 4a4b528d54
19 changed files with 110 additions and 407 deletions

View File

@@ -17,6 +17,9 @@
#if defined(XGBOOST_USE_HIP)
namespace cub = hipcub;
namespace thrust {
namespace cuda = thrust::hip;
}
#endif
namespace xgboost {
@@ -122,13 +125,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan(thrust::hip::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
#endif
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);

View File

@@ -17,6 +17,12 @@
#include "adapter.h"
#include "array_interface.h"
#if defined(XGBOOST_USE_HIP)
namespace thrust {
namespace cuda = thrust::hip;
}
#endif
namespace xgboost {
namespace data {
@@ -246,17 +252,10 @@ std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offs
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
bst_row_t row_stride =
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
#elif defined(XGBOOST_USE_HIP)
bst_row_t row_stride =
dh::Reduce(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
#endif
return row_stride;
}
@@ -280,13 +279,8 @@ bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
// intervals to early stop. But we expect all data to be valid here, using small
// intervals only decreases performance due to excessive kernel launch and stream
// synchronization.
#if defined(XGBOOST_USE_CUDA)
auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true,
thrust::logical_and<>{});
#elif defined(XGBOOST_USE_HIP)
auto valid = dh::Reduce(thrust::hip::par(alloc), value_iter, value_iter + batch.Size(), true,
thrust::logical_and<>{});
#endif
return valid;
}
}; // namespace data

View File

@@ -16,6 +16,12 @@
#include "simple_batch_iterator.h"
#include "sparse_page_source.h"
#if defined(XGBOOST_USE_HIP)
namespace thrust {
namespace cuda = thrust::hip;
}
#endif
namespace xgboost::data {
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
DataIterHandle iter_handle, float missing,
@@ -86,11 +92,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
return GetRowCounts(value, row_counts_span, get_device(), missing);
}));
#if defined(XGBOOST_USE_CUDA)
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
#elif defined(XGBOOST_USE_HIP)
nnz += thrust::reduce(thrust::hip::par(alloc), row_counts.begin(), row_counts.end());
#endif
batches++;
} while (iter.Next());

View File

@@ -13,6 +13,12 @@
#include "../common/error_msg.h" // for InfInData
#include "device_adapter.cuh" // for HasInfInData
#if defined(XGBOOST_USE_HIP)
namespace thrust {
namespace cuda = thrust::hip;
}
#endif
namespace xgboost::data {
#if defined(XGBOOST_USE_CUDA)
@@ -69,15 +75,9 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
thrust::exclusive_scan(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
#elif defined(XGBOOST_USE_HIP)
thrust::exclusive_scan(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
#endif
}
template <typename AdapterBatchT>