complete test porting

This commit is contained in:
amdsc21 2023-03-11 02:17:05 +01:00
parent 9bf16a2ca6
commit 3a07b1edf8
24 changed files with 183 additions and 16 deletions

View File

@ -2,6 +2,9 @@
* Copyright 2017-2023 XGBoost contributors
*/
#pragma once
#if defined(XGBOOST_USE_CUDA)
#include <thrust/binary_search.h> // thrust::upper_bound
#include <thrust/device_malloc_allocator.h>
#include <thrust/device_ptr.h>
@ -1381,3 +1384,7 @@ class LDGIterator {
}
};
} // namespace dh
#elif defined(XGBOOST_USE_HIP)
#include" device_helpers.hip.h"
#endif

View File

@ -364,6 +364,8 @@ TEST(CAPI, BuildInfo) {
ASSERT_TRUE(get<Object const>(loaded).find("USE_OPENMP") != get<Object const>(loaded).cend());
ASSERT_TRUE(get<Object const>(loaded).find("USE_CUDA") != get<Object const>(loaded).cend());
ASSERT_TRUE(get<Object const>(loaded).find("USE_NCCL") != get<Object const>(loaded).cend());
ASSERT_TRUE(get<Object const>(loaded).find("USE_HIP") != get<Object const>(loaded).cend());
ASSERT_TRUE(get<Object const>(loaded).find("USE_RCCL") != get<Object const>(loaded).cend());
}
TEST(CAPI, NullPtr) {

View File

@ -0,0 +1,2 @@
#include "test_algorithm.cu"

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "test_bitfield.cu"
#endif

View File

@ -126,7 +126,13 @@ TEST(DeviceHelpers, Reduce) {
size_t kSize = std::numeric_limits<uint32_t>::max();
auto it = thrust::make_counting_iterator(0ul);
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
auto batched = dh::Reduce(thrust::cuda::par(alloc), it, it + kSize, 0ul, thrust::maximum<size_t>{});
#elif defined(XGBOOST_USE_HIP)
auto batched = dh::Reduce(thrust::hip::par(alloc), it, it + kSize, 0ul, thrust::maximum<size_t>{});
#endif
CHECK_EQ(batched, kSize - 1);
}
@ -170,6 +176,10 @@ TEST(Allocator, OOM) {
ASSERT_THROW({dh::caching_device_vector<char> vec(size);}, dmlc::Error);
ASSERT_THROW({dh::device_vector<char> vec(size);}, dmlc::Error);
// Clear last error so we don't fail subsequent tests
#if defined(XGBOOST_USE_CUDA)
cudaGetLastError();
#elif defined(XGBOOST_USE_HIP)
hipGetLastError();
#endif
}
} // namespace xgboost

View File

@ -0,0 +1,2 @@
#include "test_device_helpers.cu"

View File

@ -32,7 +32,11 @@ struct ReadSymbolFunction {
};
TEST(CompressedIterator, TestGPU) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
std::vector<int> test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX};
int num_elements = 1000;
int repetitions = 1000;

View File

@ -0,0 +1,2 @@
#include "test_gpu_compressed_iterator.cu"

View File

@ -53,7 +53,13 @@ TEST(HistUtil, SketchBatchNumElements) {
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
size_t constexpr kCols = 10000;
int device;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaGetDevice(&device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipGetDevice(&device));
#endif
auto avail = static_cast<size_t>(dh::AvailableMemory(device) * 0.8);
auto per_elem = detail::BytesPerElement(false);
auto avail_elem = avail / per_elem;

View File

@ -18,6 +18,9 @@
#ifdef __CUDACC__
#include <xgboost/json.h>
#include "../../../src/data/device_adapter.cuh"
#elif defined(__HIP_PLATFORM_AMD__)
#include <xgboost/json.h>
#include "../../../src/data/device_adapter.hip.h"
#endif // __CUDACC__
// Some helper functions used to test both GPU and CPU algorithms
@ -47,7 +50,7 @@ inline std::vector<float> GenerateRandomWeights(int num_rows) {
return w;
}
#ifdef __CUDACC__
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
int num_rows, int num_columns) {
Json array_interface{Object()};

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "test_hist_util.cu"
#endif

View File

@ -6,7 +6,12 @@
#include <thrust/equal.h>
#include <thrust/iterator/counting_iterator.h>
#if defined(XGBOOST_USE_CUDA)
#include "../../../src/common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../../src/common/device_helpers.hip.h"
#endif
#include <xgboost/host_device_vector.h>
namespace xgboost {
@ -14,9 +19,16 @@ namespace common {
namespace {
void SetDeviceForTest(int device) {
int n_devices;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaGetDeviceCount(&n_devices));
device %= n_devices;
dh::safe_cuda(cudaSetDevice(device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipGetDeviceCount(&n_devices));
device %= n_devices;
dh::safe_cuda(hipSetDevice(device));
#endif
}
} // namespace

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "test_host_device_vector.cu"
#endif

View File

@ -0,0 +1,2 @@
#include "test_linalg.cu"

View File

@ -80,7 +80,11 @@ TEST(GPUQuantile, Unique) {
// if with_error is true, the test tolerates floating point error
void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
Span<bst_row_t const> d_columns_ptr, bool with_error = false) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#endif
std::vector<SketchEntry> h_in(in.size());
dh::CopyDeviceSpanToVector(&h_in, in);
std::vector<bst_row_t> h_columns_ptr(d_columns_ptr.size());

View File

@ -0,0 +1,2 @@
#include "test_quantile.cu"

View File

@ -7,7 +7,12 @@
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#if defined(XGBOOST_USE_CUDA)
#include "../../../src/common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../../src/common/device_helpers.hip.h"
#endif
#include <xgboost/span.h>
#include "test_span.h"
@ -20,19 +25,37 @@ struct TestStatus {
public:
TestStatus () {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMalloc(&status_, sizeof(int)));
int h_status = 1;
dh::safe_cuda(cudaMemcpy(status_, &h_status,
sizeof(int), cudaMemcpyHostToDevice));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMalloc(&status_, sizeof(int)));
int h_status = 1;
dh::safe_cuda(hipMemcpy(status_, &h_status,
sizeof(int), hipMemcpyHostToDevice));
#endif
}
~TestStatus() {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaFree(status_));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipFree(status_));
#endif
}
int Get() {
int h_status;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(&h_status, status_,
sizeof(int), cudaMemcpyDeviceToHost));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(&h_status, status_,
sizeof(int), hipMemcpyDeviceToHost));
#endif
return h_status;
}
@ -89,14 +112,22 @@ TEST(GPUSpan, FromOther) {
}
TEST(GPUSpan, Assignment) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestAssignment{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpan, TestStatus) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestTestStatus{status.Data()});
ASSERT_EQ(status.Get(), -1);
@ -119,7 +150,11 @@ struct TestEqual {
};
TEST(GPUSpan, WithTrust) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
// Not adviced to initialize span with host_vector, since h_vec.data() is
// a host function.
thrust::host_vector<float> h_vec (16);
@ -156,14 +191,22 @@ TEST(GPUSpan, WithTrust) {
}
TEST(GPUSpan, BeginEnd) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestBeginEnd{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpan, RBeginREnd) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestRBeginREnd{status.Data()});
ASSERT_EQ(status.Get(), 1);
@ -195,14 +238,22 @@ TEST(GPUSpan, Modify) {
}
TEST(GPUSpan, Observers) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestObservers{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpan, Compare) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestIterCompare{status.Data()});
ASSERT_EQ(status.Get(), 1);
@ -222,7 +273,11 @@ struct TestElementAccess {
};
TEST(GPUSpanDeathTest, ElementAccess) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
auto test_element_access = []() {
thrust::host_vector<float> h_vec (16);
InitializeRange(h_vec.begin(), h_vec.end());
@ -320,8 +375,13 @@ void TestFrontBack() {
// make sure the termination happens inside this test.
try {
dh::LaunchN(1, [=] __device__(size_t) { s.front(); });
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaGetLastError());
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipDeviceSynchronize());
dh::safe_cuda(hipGetLastError());
#endif
} catch (dmlc::Error const& e) {
std::terminate();
}
@ -331,8 +391,13 @@ void TestFrontBack() {
{
try {
dh::LaunchN(1, [=] __device__(size_t) { s.back(); });
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaGetLastError());
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipDeviceSynchronize());
dh::safe_cuda(hipGetLastError());
#endif
} catch (dmlc::Error const& e) {
std::terminate();
}
@ -382,42 +447,66 @@ TEST(GPUSpanDeathTest, Subspan) {
}
TEST(GPUSpanIter, Construct) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestIterConstruct{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpanIter, Ref) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestIterRef{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpanIter, Calculate) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestIterCalculate{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpanIter, Compare) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestIterCompare{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpan, AsBytes) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestAsBytes{status.Data()});
ASSERT_EQ(status.Get(), 1);
}
TEST(GPUSpan, AsWritableBytes) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(0));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(0));
#endif
TestStatus status;
dh::LaunchN(16, TestAsWritableBytes{status.Data()});
ASSERT_EQ(status.Get(), 1);

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "test_span.cu"
#endif

View File

@ -70,13 +70,13 @@ TEST(Stats, Median) {
auto m = out(0);
ASSERT_EQ(m, .5f);
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
ctx.gpu_id = 0;
ASSERT_FALSE(ctx.IsCPU());
Median(&ctx, values, weights, &out);
m = out(0);
ASSERT_EQ(m, .5f);
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
}
{
@ -89,12 +89,12 @@ TEST(Stats, Median) {
ASSERT_EQ(out(0), .5f);
ASSERT_EQ(out(1), .5f);
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
ctx.gpu_id = 0;
Median(&ctx, values, weights, &out);
ASSERT_EQ(out(0), .5f);
ASSERT_EQ(out(1), .5f);
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
}
}
@ -121,12 +121,12 @@ TEST(Stats, Mean) {
TestMean(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
TEST(Stats, GPUMean) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestMean(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,2 @@
#include "test_stats.cu"

View File

@ -0,0 +1,2 @@
#include "test_threading_utils.cu"

View File

@ -11,7 +11,7 @@
#include "../../../src/common/transform.h"
#include "../helpers.h"
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
#define TRANSFORM_GPU 0
@ -53,7 +53,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
#if !defined(__CUDACC__)
#if !defined(__CUDACC__) && !defined(__HIP_PLATFORM_AMD__)
TEST(TransformDeathTest, Exception) {
size_t const kSize {16};
std::vector<bst_float> h_in(kSize);

View File

@ -40,13 +40,13 @@ TEST(GBTree, SelectTreeMethod) {
gbtree.Configure({{"booster", "dart"}, {"tree_method", "hist"}});
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
#ifdef XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
gbtree.Configure({{"tree_method", "gpu_hist"}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
gbtree.Configure({{"booster", "dart"}, {"tree_method", "gpu_hist"}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
#endif // XGBOOST_USE_CUDA
#endif // XGBOOST_USE_CUDA, XGBOOST_USE_HIP
}
TEST(GBTree, PredictionCache) {
@ -110,7 +110,7 @@ TEST(GBTree, WrongUpdater) {
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
}
#ifdef XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
TEST(GBTree, ChoosePredictor) {
// The test ensures data don't get pulled into device.
size_t constexpr kRows = 17;
@ -162,7 +162,7 @@ TEST(GBTree, ChoosePredictor) {
// data is not pulled back into host
ASSERT_FALSE(data.HostCanWrite());
}
#endif // XGBOOST_USE_CUDA
#endif // XGBOOST_USE_CUDA || XGBOOST_USE_HIP
// Some other parts of test are in `Tree.JsonIO'.
TEST(GBTree, JsonIO) {
@ -294,12 +294,12 @@ class Dart : public testing::TestWithParam<char const*> {
TEST_P(Dart, Prediction) { this->Run(GetParam()); }
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart,
testing::Values("auto", "cpu_predictor", "gpu_predictor"));
#else
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("auto", "cpu_predictor"));
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
std::pair<Json, Json> TestModelSlice(std::string booster) {