Define CUDA Context. (#8604)
We will transition to non-default and non-blocking CUDA stream.
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*
|
||||
* \file data.cu
|
||||
* \brief Handles setting metainfo from array interface.
|
||||
*/
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "array_interface.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/linalg_op.cuh"
|
||||
#include "array_interface.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "validation.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
@@ -25,7 +26,7 @@ auto SetDeviceToPtr(void const* ptr) {
|
||||
}
|
||||
|
||||
template <typename T, int32_t D>
|
||||
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
ArrayInterface<D> array(arr_interface);
|
||||
if (array.n == 0) {
|
||||
p_out->SetDevice(0);
|
||||
@@ -43,15 +44,19 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
// set data
|
||||
data->Resize(array.n);
|
||||
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
|
||||
cudaMemcpyDefault));
|
||||
cudaMemcpyDefault, ctx->Stream()));
|
||||
});
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(ptr_device);
|
||||
linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) {
|
||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, array.shape));
|
||||
});
|
||||
linalg::ElementWiseTransformDevice(
|
||||
t,
|
||||
[=] __device__(size_t i, T) {
|
||||
return linalg::detail::Apply(TypedIndex<T, D>{array},
|
||||
linalg::UnravelIndex<D>(i, array.shape));
|
||||
},
|
||||
ctx->Stream());
|
||||
}
|
||||
|
||||
void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out) {
|
||||
@@ -115,14 +120,13 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Context is not used until we have CUDA stream.
|
||||
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
||||
void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
|
||||
// multi-dim float info
|
||||
if (key == "base_margin") {
|
||||
CopyTensorInfoImpl(array, &base_margin_);
|
||||
CopyTensorInfoImpl(ctx.CUDACtx(), array, &base_margin_);
|
||||
return;
|
||||
} else if (key == "label") {
|
||||
CopyTensorInfoImpl(array, &labels);
|
||||
CopyTensorInfoImpl(ctx.CUDACtx(), array, &labels);
|
||||
auto ptr = labels.Data()->ConstDevicePointer();
|
||||
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{});
|
||||
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
||||
@@ -142,7 +146,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
||||
}
|
||||
// float info
|
||||
linalg::Tensor<float, 1> t;
|
||||
CopyTensorInfoImpl(array, &t);
|
||||
CopyTensorInfoImpl(ctx.CUDACtx(), array, &t);
|
||||
if (key == "weight") {
|
||||
this->weights_ = std::move(*t.Data());
|
||||
auto ptr = weights_.ConstDevicePointer();
|
||||
@@ -156,7 +160,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
||||
this->feature_weights = std::move(*t.Data());
|
||||
auto d_feature_weights = feature_weights.ConstDeviceSpan();
|
||||
auto valid =
|
||||
thrust::none_of(thrust::device, d_feature_weights.data(),
|
||||
thrust::none_of(ctx.CUDACtx()->CTP(), d_feature_weights.data(),
|
||||
d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{});
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user