[EM] Avoid synchronous calls and unnecessary ATS access. (#10811)
- Pass context into various functions. - Factor out some CUDA algorithms. - Use ATS only for update position.
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/copy.h> // thrust::copy
|
||||
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/threading_utils.cuh"
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
TEST(SegmentedTrapezoidThreads, Basic) {
|
||||
size_t constexpr kElements = 24, kGroups = 3;
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
dh::device_vector<size_t> offset_ptr(kGroups + 1, 0);
|
||||
offset_ptr[0] = 0;
|
||||
offset_ptr[1] = 8;
|
||||
@@ -19,11 +20,11 @@ TEST(SegmentedTrapezoidThreads, Basic) {
|
||||
|
||||
size_t h = 1;
|
||||
dh::device_vector<size_t> thread_ptr(kGroups + 1, 0);
|
||||
size_t total = SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
size_t total = SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
ASSERT_EQ(total, kElements - kGroups);
|
||||
|
||||
h = 2;
|
||||
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
std::vector<size_t> h_thread_ptr(thread_ptr.size());
|
||||
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
|
||||
for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
|
||||
@@ -31,7 +32,7 @@ TEST(SegmentedTrapezoidThreads, Basic) {
|
||||
}
|
||||
|
||||
h = 7;
|
||||
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
SegmentedTrapezoidThreads(&ctx, dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
|
||||
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
|
||||
for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
|
||||
ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 28);
|
||||
@@ -66,5 +67,4 @@ TEST(SegmentedTrapezoidThreads, Unravel) {
|
||||
ASSERT_EQ(i, 6);
|
||||
ASSERT_EQ(j, 7);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
Reference in New Issue
Block a user