Add managed memory allocator. (#10711)

This commit is contained in:
Jiaming Yuan 2024-08-17 03:02:34 +08:00 committed by GitHub
parent 8d7fe262d9
commit ec3f327c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 128 additions and 71 deletions

View File

@ -132,7 +132,7 @@ class DataIteratorProxy {
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
template <typename T>
using Alloc = xgboost::common::cuda::pinned_allocator<T>;
using Alloc = xgboost::common::cuda_impl::pinned_allocator<T>;
template <typename U>
using HostVector = std::vector<U, Alloc<U>>;

View File

@ -1,20 +1,19 @@
/*!
* Copyright 2022 by XGBoost Contributors
* \file common.h
* \brief cuda pinned allocator for usage with thrust containers
/**
* Copyright 2022-2024, XGBoost Contributors
*
* @brief cuda pinned allocator for usage with thrust containers
*/
#pragma once
#include <cstddef>
#include <limits>
#include <cuda_runtime.h>
#include <cstddef> // for size_t
#include <limits> // for numeric_limits
#include "common.h"
namespace xgboost {
namespace common {
namespace cuda {
namespace xgboost::common::cuda_impl {
// \p pinned_allocator is a CUDA-specific host memory allocator
// that employs \c cudaMallocHost for allocation.
//
@ -22,72 +21,94 @@ namespace cuda {
// that Thrust used to provide.
//
// \see https://en.cppreference.com/w/cpp/memory/allocator
template <typename T>
class pinned_allocator;
template <>
class pinned_allocator<void> {
public:
using value_type = void; // NOLINT: The type of the elements in the allocator
using pointer = void*; // NOLINT: The type returned by address() / allocate()
using const_pointer = const void*; // NOLINT: The type returned by address()
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
template <typename U>
struct rebind { // NOLINT
using other = pinned_allocator<U>; // NOLINT: The rebound type
};
};
template <typename T>
class pinned_allocator {
public:
using value_type = T; // NOLINT: The type of the elements in the allocator
struct PinnedAllocPolicy {
using pointer = T*; // NOLINT: The type returned by address() / allocate()
using const_pointer = const T*; // NOLINT: The type returned by address()
using reference = T&; // NOLINT: The parameter type for address()
using const_reference = const T&; // NOLINT: The parameter type for address()
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
using value_type = T; // NOLINT: The type of the elements in the allocator
template <typename U>
struct rebind { // NOLINT
using other = pinned_allocator<U>; // NOLINT: The rebound type
};
size_type max_size() const { // NOLINT
return std::numeric_limits<size_type>::max() / sizeof(value_type);
}
XGBOOST_DEVICE inline pinned_allocator() {}; // NOLINT: host/device markup ignored on defaulted functions
XGBOOST_DEVICE inline ~pinned_allocator() {} // NOLINT: host/device markup ignored on defaulted functions
XGBOOST_DEVICE inline pinned_allocator(pinned_allocator const&) {} // NOLINT: host/device markup ignored on defaulted functions
pinned_allocator& operator=(pinned_allocator const& that) = default;
pinned_allocator& operator=(pinned_allocator&& that) = default;
template <typename U>
XGBOOST_DEVICE inline pinned_allocator(pinned_allocator<U> const&) {} // NOLINT
XGBOOST_DEVICE inline pointer address(reference r) { return &r; } // NOLINT
XGBOOST_DEVICE inline const_pointer address(const_reference r) { return &r; } // NOLINT
inline pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
if (cnt > this->max_size()) { throw std::bad_alloc(); } // end if
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
if (cnt > this->max_size()) {
throw std::bad_alloc{};
} // end if
pointer result(nullptr);
dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
return result;
}
inline void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFreeHost(p)); } // NOLINT
inline size_type max_size() const { return (std::numeric_limits<size_type>::max)() / sizeof(T); } // NOLINT
XGBOOST_DEVICE inline bool operator==(pinned_allocator const& x) const { return true; }
XGBOOST_DEVICE inline bool operator!=(pinned_allocator const& x) const {
return !operator==(x);
}
void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFreeHost(p)); } // NOLINT
};
} // namespace cuda
} // namespace common
} // namespace xgboost
template <typename T>
struct ManagedAllocPolicy {
using pointer = T*; // NOLINT: The type returned by address() / allocate()
using const_pointer = const T*; // NOLINT: The type returned by address()
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using value_type = T; // NOLINT: The type of the elements in the allocator
size_type max_size() const { // NOLINT
return std::numeric_limits<size_type>::max() / sizeof(value_type);
}
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
if (cnt > this->max_size()) {
throw std::bad_alloc{};
} // end if
pointer result(nullptr);
dh::safe_cuda(cudaMallocManaged(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
return result;
}
void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFree(p)); } // NOLINT
};
template <typename T, template <typename> typename Policy>
class CudaHostAllocatorImpl : public Policy<T> { // NOLINT
public:
using value_type = typename Policy<T>::value_type; // NOLINT
using pointer = typename Policy<T>::pointer; // NOLINT
using const_pointer = typename Policy<T>::const_pointer; // NOLINT
using size_type = typename Policy<T>::size_type; // NOLINT
using reference = T&; // NOLINT: The parameter type for address()
using const_reference = const T&; // NOLINT: The parameter type for address()
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
template <typename U>
struct rebind { // NOLINT
using other = CudaHostAllocatorImpl<U, Policy>; // NOLINT: The rebound type
};
CudaHostAllocatorImpl() = default;
~CudaHostAllocatorImpl() = default;
CudaHostAllocatorImpl(CudaHostAllocatorImpl const&) = default;
CudaHostAllocatorImpl& operator=(CudaHostAllocatorImpl const& that) = default;
CudaHostAllocatorImpl& operator=(CudaHostAllocatorImpl&& that) = default;
template <typename U>
CudaHostAllocatorImpl(CudaHostAllocatorImpl<U, Policy> const&) {} // NOLINT
pointer address(reference r) { return &r; } // NOLINT
const_pointer address(const_reference r) { return &r; } // NOLINT
bool operator==(CudaHostAllocatorImpl const& x) const { return true; }
bool operator!=(CudaHostAllocatorImpl const& x) const { return !operator==(x); }
};
template <typename T>
using pinned_allocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>; // NOLINT
template <typename T>
using managed_allocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>; // NOLINT
} // namespace xgboost::common::cuda_impl

View File

@ -20,7 +20,7 @@
namespace xgboost::data {
struct EllpackHostCache {
thrust::host_vector<std::int8_t, common::cuda::pinned_allocator<std::int8_t>> cache;
thrust::host_vector<std::int8_t, common::cuda_impl::pinned_allocator<std::int8_t>> cache;
void Resize(std::size_t n, dh::CUDAStreamView stream) {
stream.Sync(); // Prevent partial copy inside resize.

View File

@ -57,7 +57,7 @@ struct CatAccessor {
class GPUHistEvaluator {
using CatST = common::CatBitField::value_type; // categorical storage type
// use pinned memory to stage the categories, used for sort based splits.
using Alloc = xgboost::common::cuda::pinned_allocator<CatST>;
using Alloc = xgboost::common::cuda_impl::pinned_allocator<CatST>;
private:
TreeEvaluator tree_evaluator_;

View File

@ -0,0 +1,36 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <vector>
#include "../../../src/common/cuda_pinned_allocator.h"
#include "../../../src/common/device_helpers.cuh" // for DefaultStream
#include "../../../src/common/numeric.h" // for Iota
namespace xgboost {
TEST(CudaHostMalloc, Pinned) {
std::vector<float, common::cuda_impl::pinned_allocator<float>> vec;
vec.resize(10);
ASSERT_EQ(vec.size(), 10);
Context ctx;
common::Iota(&ctx, vec.begin(), vec.end(), 0);
float k = 0;
for (auto v : vec) {
ASSERT_EQ(v, k);
++k;
}
}
TEST(CudaHostMalloc, Managed) {
std::vector<float, common::cuda_impl::managed_allocator<float>> vec;
vec.resize(10);
#if defined(__linux__)
dh::safe_cuda(
cudaMemPrefetchAsync(vec.data(), vec.size() * sizeof(float), 0, dh::DefaultStream()));
#endif
dh::DefaultStream().Sync();
}
} // namespace xgboost