Avoid thrust vector initialization. (#10544)

* Avoid thrust vector initialization.

- Add a wrapper for rmm device uvector.
- Split up the `Resize` method for HDV.
This commit is contained in:
Jiaming Yuan
2024-07-11 17:29:27 +08:00
committed by GitHub
parent 89da9f9741
commit 1ca4bfd20e
13 changed files with 510 additions and 291 deletions

View File

@@ -0,0 +1,21 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/device_vector.cuh"
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
namespace dh {
TEST(DeviceUVector, Basic) {
GlobalMemoryLogger().Clear();
std::int32_t verbosity{3};
std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity);
DeviceUVector<float> uvec;
uvec.Resize(12);
auto peak = GlobalMemoryLogger().PeakMemory();
auto n_bytes = sizeof(decltype(uvec)::value_type) * uvec.size();
ASSERT_EQ(peak, n_bytes);
std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity);
}
} // namespace dh

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2018-2023 XGBoost contributors
* Copyright 2018-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/equal.h>
@@ -181,4 +181,41 @@ TEST(HostDeviceVector, Empty) {
ASSERT_FALSE(another.Empty());
ASSERT_TRUE(vec.Empty());
}
TEST(HostDeviceVector, Resize) {
auto check = [&](HostDeviceVector<float> const& vec) {
auto const& h_vec = vec.ConstHostSpan();
for (std::size_t i = 0; i < 4; ++i) {
ASSERT_EQ(h_vec[i], i + 1);
}
for (std::size_t i = 4; i < vec.Size(); ++i) {
ASSERT_EQ(h_vec[i], 3.0);
}
};
{
HostDeviceVector<float> vec{1.0f, 2.0f, 3.0f, 4.0f};
vec.SetDevice(DeviceOrd::CUDA(0));
vec.ConstDeviceSpan();
ASSERT_TRUE(vec.DeviceCanRead());
ASSERT_FALSE(vec.DeviceCanWrite());
vec.DeviceSpan();
vec.Resize(7, 3.0f);
ASSERT_TRUE(vec.DeviceCanWrite());
check(vec);
}
{
HostDeviceVector<float> vec{{1.0f, 2.0f, 3.0f, 4.0f}, DeviceOrd::CUDA(0)};
ASSERT_TRUE(vec.DeviceCanWrite());
vec.Resize(7, 3.0f);
ASSERT_TRUE(vec.DeviceCanWrite());
check(vec);
}
{
HostDeviceVector<float> vec{1.0f, 2.0f, 3.0f, 4.0f};
ASSERT_TRUE(vec.HostCanWrite());
vec.Resize(7, 3.0f);
ASSERT_TRUE(vec.HostCanWrite());
check(vec);
}
}
} // namespace xgboost::common