Define CUDA Context. (#8604)

We will transition to non-default and non-blocking CUDA stream.
This commit is contained in:
Jiaming Yuan 2022-12-20 15:15:07 +08:00 committed by GitHub
parent e01639548a
commit c6a8754c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 120 additions and 62 deletions

View File

@ -8,15 +8,14 @@
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/parameter.h> #include <xgboost/parameter.h>
#include <memory> // std::shared_ptr
#include <string> #include <string>
namespace xgboost { namespace xgboost {
struct Context : public XGBoostParameter<Context> { struct CUDAContext;
private:
// cached value for CFS CPU limit. (used in containerized env)
std::int32_t cfs_cpu_count_; // NOLINT
struct Context : public XGBoostParameter<Context> {
public: public:
// Constant representing the device ID of CPU. // Constant representing the device ID of CPU.
static std::int32_t constexpr kCpuId = -1; static std::int32_t constexpr kCpuId = -1;
@ -51,6 +50,7 @@ struct Context : public XGBoostParameter<Context> {
bool IsCPU() const { return gpu_id == kCpuId; } bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); } bool IsCUDA() const { return !IsCPU(); }
CUDAContext const* CUDACtx() const;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(Context) { DMLC_DECLARE_PARAMETER(Context) {
@ -73,6 +73,14 @@ struct Context : public XGBoostParameter<Context> {
.set_default(false) .set_default(false)
.describe("Enable checking whether parameters are used or not."); .describe("Enable checking whether parameters are used or not.");
} }
private:
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
// while trying to hide CUDA code from host compiler.
mutable std::shared_ptr<CUDAContext> cuctx_;
// cached value for CFS CPU limit. (used in containerized env)
std::int32_t cfs_cpu_count_; // NOLINT
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,28 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_CUDA_CONTEXT_CUH_
#define XGBOOST_COMMON_CUDA_CONTEXT_CUH_
#include <thrust/execution_policy.h>
#include "device_helpers.cuh"
namespace xgboost {
struct CUDAContext {
private:
dh::XGBCachingDeviceAllocator<char> caching_alloc_;
dh::XGBDeviceAllocator<char> alloc_;
public:
/**
* \brief Caching thrust policy.
*/
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
/**
* \brief Thrust policy without caching allocator.
*/
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
auto Stream() const { return dh::DefaultStream(); }
};
} // namespace xgboost
#endif // XGBOOST_COMMON_CUDA_CONTEXT_CUH_

View File

@ -5,7 +5,7 @@
*/ */
#include <xgboost/context.h> #include <xgboost/context.h>
#include "common/common.h" #include "common/common.h" // AssertGPUSupport
#include "common/threading_utils.h" #include "common/threading_utils.h"
namespace xgboost { namespace xgboost {
@ -59,4 +59,11 @@ std::int32_t Context::Threads() const {
} }
return n_threads; return n_threads;
} }
#if !defined(XGBOOST_USE_CUDA)
CUDAContext const* Context::CUDACtx() const {
common::AssertGPUSupport();
return nullptr;
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

14
src/context.cu Normal file
View File

@ -0,0 +1,14 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#include "common/cuda_context.cuh" // CUDAContext
#include "xgboost/context.h"
namespace xgboost {
CUDAContext const* Context::CUDACtx() const {
if (!cuctx_) {
cuctx_.reset(new CUDAContext{});
}
return cuctx_.get();
}
} // namespace xgboost

View File

@ -1,18 +1,19 @@
/*! /**
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
* *
* \file data.cu * \file data.cu
* \brief Handles setting metainfo from array interface. * \brief Handles setting metainfo from array interface.
*/ */
#include "xgboost/data.h" #include "../common/cuda_context.cuh"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "array_interface.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/linalg_op.cuh" #include "../common/linalg_op.cuh"
#include "array_interface.h"
#include "device_adapter.cuh" #include "device_adapter.cuh"
#include "simple_dmatrix.h" #include "simple_dmatrix.h"
#include "validation.h" #include "validation.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
namespace { namespace {
@ -25,7 +26,7 @@ auto SetDeviceToPtr(void const* ptr) {
} }
template <typename T, int32_t D> template <typename T, int32_t D>
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) { void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array(arr_interface); ArrayInterface<D> array(arr_interface);
if (array.n == 0) { if (array.n == 0) {
p_out->SetDevice(0); p_out->SetDevice(0);
@ -43,15 +44,19 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
// set data // set data
data->Resize(array.n); data->Resize(array.n);
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
cudaMemcpyDefault)); cudaMemcpyDefault, ctx->Stream()));
}); });
return; return;
} }
p_out->Reshape(array.shape); p_out->Reshape(array.shape);
auto t = p_out->View(ptr_device); auto t = p_out->View(ptr_device);
linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) { linalg::ElementWiseTransformDevice(
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, array.shape)); t,
}); [=] __device__(size_t i, T) {
return linalg::detail::Apply(TypedIndex<T, D>{array},
linalg::UnravelIndex<D>(i, array.shape));
},
ctx->Stream());
} }
void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out) { void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out) {
@ -115,14 +120,13 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
} }
} // namespace } // namespace
// Context is not used until we have CUDA stream. void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
// multi-dim float info // multi-dim float info
if (key == "base_margin") { if (key == "base_margin") {
CopyTensorInfoImpl(array, &base_margin_); CopyTensorInfoImpl(ctx.CUDACtx(), array, &base_margin_);
return; return;
} else if (key == "label") { } else if (key == "label") {
CopyTensorInfoImpl(array, &labels); CopyTensorInfoImpl(ctx.CUDACtx(), array, &labels);
auto ptr = labels.Data()->ConstDevicePointer(); auto ptr = labels.Data()->ConstDevicePointer();
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{}); auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large."; CHECK(valid) << "Label contains NaN, infinity or a value too large.";
@ -142,7 +146,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
} }
// float info // float info
linalg::Tensor<float, 1> t; linalg::Tensor<float, 1> t;
CopyTensorInfoImpl(array, &t); CopyTensorInfoImpl(ctx.CUDACtx(), array, &t);
if (key == "weight") { if (key == "weight") {
this->weights_ = std::move(*t.Data()); this->weights_ = std::move(*t.Data());
auto ptr = weights_.ConstDevicePointer(); auto ptr = weights_.ConstDevicePointer();
@ -156,7 +160,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
this->feature_weights = std::move(*t.Data()); this->feature_weights = std::move(*t.Data());
auto d_feature_weights = feature_weights.ConstDeviceSpan(); auto d_feature_weights = feature_weights.ConstDeviceSpan();
auto valid = auto valid =
thrust::none_of(thrust::device, d_feature_weights.data(), thrust::none_of(ctx.CUDACtx()->CTP(), d_feature_weights.data(),
d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{}); d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{});
CHECK(valid) << "Feature weight must be greater than 0."; CHECK(valid) << "Feature weight must be greater than 0.";
} else { } else {

View File

@ -35,7 +35,7 @@
#include "common/version.h" #include "common/version.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost/context.h" #include "xgboost/context.h" // Context
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/feature_map.h" #include "xgboost/feature_map.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"

View File

@ -267,12 +267,12 @@ __global__ void __launch_bounds__(kBlockThreads)
} }
} }
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups, FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx, common::Span<const uint32_t> d_ridx,
common::Span<GradientPairInt64> histogram, common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
GradientQuantiser rounding, bool force_global_memory) { bool force_global_memory) {
// decide whether to use shared memory // decide whether to use shared memory
int device = 0; int device = 0;
dh::safe_cuda(cudaGetDevice(&device)); dh::safe_cuda(cudaGetDevice(&device));
@ -318,9 +318,9 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
min(grid_size, min(grid_size,
unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock))); unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
dh::LaunchKernel {dim3(grid_size, num_groups), dh::LaunchKernel{dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size,
static_cast<uint32_t>(kBlockThreads), smem_size}( ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(),
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); gpair.data(), rounding);
}; };
if (shared) { if (shared) {

View File

@ -5,9 +5,9 @@
#define HISTOGRAM_CUH_ #define HISTOGRAM_CUH_
#include <thrust/transform.h> #include <thrust/transform.h>
#include "feature_groups.cuh" #include "../../common/cuda_context.cuh"
#include "../../data/ellpack_page.cuh" #include "../../data/ellpack_page.cuh"
#include "feature_groups.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -56,12 +56,11 @@ public:
} }
}; };
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups, FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx, common::Span<const uint32_t> ridx,
common::Span<GradientPairInt64> histogram, common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
GradientQuantiser rounding,
bool force_global_memory = false); bool force_global_memory = false);
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -20,6 +20,7 @@
#include "../common/io.h" #include "../common/io.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../data/ellpack_page.cuh" #include "../data/ellpack_page.cuh"
#include "../common/cuda_context.cuh" // CUDAContext
#include "constraints.cuh" #include "constraints.cuh"
#include "driver.h" #include "driver.h"
#include "gpu_hist/evaluate_splits.cuh" #include "gpu_hist/evaluate_splits.cuh"
@ -344,9 +345,9 @@ struct GPUHistMakerDevice {
void BuildHist(int nidx) { void BuildHist(int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx); auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->gpu_id),
feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, d_ridx, d_node_hist,
d_ridx, d_node_hist, *quantiser); *quantiser);
} }
// Attempt to do subtraction trick // Attempt to do subtraction trick
@ -646,7 +647,7 @@ struct GPUHistMakerDevice {
return quantiser.ToFixedPoint(gpair); return quantiser.ToFixedPoint(gpair);
}); });
GradientPairInt64 root_sum_quantised = GradientPairInt64 root_sum_quantised =
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(), dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
GradientPairInt64{}, thrust::plus<GradientPairInt64>{}); GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
using ReduceT = typename decltype(root_sum_quantised)::ValueT; using ReduceT = typename decltype(root_sum_quantised)::ValueT;
collective::Allreduce<collective::Operation::kSum>( collective::Allreduce<collective::Operation::kSum>(

View File

@ -11,6 +11,7 @@ namespace xgboost {
namespace tree { namespace tree {
void TestDeterministicHistogram(bool is_dense, int shm_size) { void TestDeterministicHistogram(bool is_dense, int shm_size) {
Context ctx = CreateEmptyGenericParam(0);
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
float constexpr kLower = -1e-2, kUpper = 1e2; float constexpr kLower = -1e-2, kUpper = 1e2;
@ -34,9 +35,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
sizeof(GradientPairInt64)); sizeof(GradientPairInt64));
auto quantiser = GradientQuantiser(gpair.DeviceSpan()); auto quantiser = GradientQuantiser(gpair.DeviceSpan());
BuildGradientHistogram(page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_histogram,
ridx, d_histogram, quantiser); quantiser);
std::vector<GradientPairInt64> histogram_h(num_bins); std::vector<GradientPairInt64> histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(), dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
@ -48,10 +49,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
auto d_new_histogram = dh::ToSpan(new_histogram); auto d_new_histogram = dh::ToSpan(new_histogram);
auto quantiser = GradientQuantiser(gpair.DeviceSpan()); auto quantiser = GradientQuantiser(gpair.DeviceSpan());
BuildGradientHistogram(page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
feature_groups.DeviceAccessor(0), feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
gpair.DeviceSpan(), ridx, d_new_histogram, d_new_histogram, quantiser);
quantiser);
std::vector<GradientPairInt64> new_histogram_h(num_bins); std::vector<GradientPairInt64> new_histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(), dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
@ -71,10 +71,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
dh::device_vector<GradientPairInt64> baseline(num_bins); dh::device_vector<GradientPairInt64> baseline(num_bins);
BuildGradientHistogram(page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
single_group.DeviceAccessor(0), single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), dh::ToSpan(baseline), quantiser);
quantiser);
std::vector<GradientPairInt64> baseline_h(num_bins); std::vector<GradientPairInt64> baseline_h(num_bins);
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(), dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
@ -115,6 +114,7 @@ void ValidateCategoricalHistogram(size_t n_categories, common::Span<GradientPair
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data. // Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
void TestGPUHistogramCategorical(size_t num_categories) { void TestGPUHistogramCategorical(size_t num_categories) {
auto ctx = CreateEmptyGenericParam(0);
size_t constexpr kRows = 340; size_t constexpr kRows = 340;
size_t constexpr kBins = 256; size_t constexpr kBins = 256;
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories); auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
@ -133,10 +133,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) { for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
single_group.DeviceAccessor(0), single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), dh::ToSpan(cat_hist), quantiser);
quantiser);
} }
/** /**
@ -148,10 +147,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) { for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
single_group.DeviceAccessor(0), single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), dh::ToSpan(encode_hist), quantiser);
quantiser);
} }
std::vector<GradientPairInt64> h_cat_hist(cat_hist.size()); std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());

View File

@ -109,11 +109,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
maker.gpair = gpair.DeviceSpan(); maker.gpair = gpair.DeviceSpan();
maker.quantiser.reset(new GradientQuantiser(maker.gpair)); maker.quantiser.reset(new GradientQuantiser(maker.gpair));
BuildGradientHistogram( BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(),
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0), maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
maker.hist.GetNodeHistogram(0), *maker.quantiser, *maker.quantiser, !use_shared_memory_histograms);
!use_shared_memory_histograms);
DeviceHistogramStorage<>& d_hist = maker.hist; DeviceHistogramStorage<>& d_hist = maker.hist;