From ddf2e688219c0e5510a6931b7dac37e34ac14762 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 29 Aug 2023 13:37:29 +0800 Subject: [PATCH] Use the new `DeviceOrd` in the linalg module. (#9527) --- include/xgboost/host_device_vector.h | 8 +++ include/xgboost/learner.h | 2 +- include/xgboost/linalg.h | 73 +++++++--------------- include/xgboost/multi_target_tree_model.h | 4 +- src/c_api/c_api.cu | 2 +- src/common/linalg_op.cuh | 2 +- src/common/ranking_utils.cu | 16 ++--- src/common/ranking_utils.h | 30 ++++----- src/common/stats.cc | 6 +- src/data/data.cc | 11 ++-- src/data/data.cu | 2 +- src/gbm/gblinear.cc | 13 ++-- src/gbm/gbtree.cc | 6 +- src/learner.cc | 16 ++--- src/metric/auc.cc | 28 ++++----- src/metric/auc.cu | 34 +++++----- src/metric/auc.h | 4 +- src/metric/elementwise_metric.cu | 22 +++---- src/metric/rank_metric.cc | 4 +- src/metric/rank_metric.cu | 8 +-- src/objective/adaptive.cu | 14 ++--- src/objective/lambdarank_obj.cc | 26 ++++---- src/objective/lambdarank_obj.cu | 38 +++++------ src/objective/quantile_obj.cu | 4 +- src/objective/regression_obj.cu | 6 +- src/predictor/cpu_predictor.cc | 8 +-- src/predictor/predictor.cc | 2 +- src/tree/fit_stump.cc | 2 +- src/tree/fit_stump.cu | 4 +- src/tree/hist/evaluate_splits.h | 4 +- src/tree/updater_gpu_hist.cu | 5 +- tests/cpp/common/test_linalg.cc | 46 +++++++------- tests/cpp/common/test_linalg.cu | 11 ++-- tests/cpp/common/test_ranking_utils.cu | 2 +- tests/cpp/common/test_stats.cu | 14 ++--- tests/cpp/data/test_array_interface.cc | 2 +- tests/cpp/data/test_metainfo.cc | 10 +-- tests/cpp/data/test_metainfo.cu | 4 +- tests/cpp/data/test_metainfo.h | 12 ++-- tests/cpp/data/test_simple_dmatrix.cc | 8 +-- tests/cpp/objective/test_lambdarank_obj.cu | 8 +-- tests/cpp/predictor/test_gpu_predictor.cu | 2 +- tests/cpp/predictor/test_predictor.cc | 2 +- 43 files changed, 252 insertions(+), 273 deletions(-) diff --git a/include/xgboost/host_device_vector.h b/include/xgboost/host_device_vector.h index b221d7206..ed7117d65 100644 --- a/include/xgboost/host_device_vector.h +++ b/include/xgboost/host_device_vector.h @@ -102,6 +102,14 @@ class HostDeviceVector { bool Empty() const { return Size() == 0; } size_t Size() const; int DeviceIdx() const; + DeviceOrd Device() const { + auto idx = this->DeviceIdx(); + if (idx == DeviceOrd::CPU().ordinal) { + return DeviceOrd::CPU(); + } else { + return DeviceOrd::CUDA(idx); + } + } common::Span DeviceSpan(); common::Span ConstDeviceSpan() const; common::Span DeviceSpan() const { return ConstDeviceSpan(); } diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index cd081a2e8..939324e4a 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -330,7 +330,7 @@ struct LearnerModelParam { multi_strategy{multi_strategy} {} linalg::TensorView BaseScore(Context const* ctx) const; - [[nodiscard]] linalg::TensorView BaseScore(std::int32_t device) const; + [[nodiscard]] linalg::TensorView BaseScore(DeviceOrd device) const; void Copy(LearnerModelParam const& that); [[nodiscard]] bool IsVectorLeaf() const noexcept { diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index d95651ca7..b3ae2f169 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -302,7 +302,7 @@ class TensorView { T *ptr_{nullptr}; // pointer of data_ to avoid bound check. size_t size_{0}; - int32_t device_{-1}; + DeviceOrd device_; // Unlike `Tensor`, the data_ can have arbitrary size since this is just a view. LINALG_HD void CalcSize() { @@ -401,15 +401,11 @@ class TensorView { * \param device Device ordinal */ template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device) + LINALG_HD TensorView(common::Span data, I const (&shape)[D], DeviceOrd device) : TensorView{data, shape, device, Order::kC} {} - template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], DeviceOrd device) - : TensorView{data, shape, device.ordinal, Order::kC} {} - template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device, Order order) + LINALG_HD TensorView(common::Span data, I const (&shape)[D], DeviceOrd device, Order order) : data_{data}, ptr_{data_.data()}, device_{device} { static_assert(D > 0 && D <= kDim, "Invalid shape."); // shape @@ -441,7 +437,7 @@ class TensorView { */ template LINALG_HD TensorView(common::Span data, I const (&shape)[D], I const (&stride)[D], - std::int32_t device) + DeviceOrd device) : data_{data}, ptr_{data_.data()}, device_{device} { static_assert(D == kDim, "Invalid shape & stride."); detail::UnrollLoop([&](auto i) { @@ -450,16 +446,12 @@ class TensorView { }); this->CalcSize(); } - template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], I const (&stride)[D], - DeviceOrd device) - : TensorView{data, shape, stride, device.ordinal} {} template < typename U, std::enable_if_t::value> * = nullptr> LINALG_HD TensorView(TensorView const &that) // NOLINT - : data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.DeviceIdx()} { + : data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.Device()} { detail::UnrollLoop([&](auto i) { stride_[i] = that.Stride(i); shape_[i] = that.Shape(i); @@ -572,7 +564,7 @@ class TensorView { /** * \brief Obtain the CUDA device ordinal. */ - LINALG_HD auto DeviceIdx() const { return device_; } + LINALG_HD auto Device() const { return device_; } }; /** @@ -587,11 +579,11 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL typename Container::value_type>; std::size_t in_shape[sizeof...(S)]; detail::IndexToArr(in_shape, std::forward(shape)...); - return TensorView{data, in_shape, ctx->gpu_id}; + return TensorView{data, in_shape, ctx->Device()}; } template -LINALG_HD auto MakeTensorView(std::int32_t device, common::Span data, S &&...shape) { +LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span data, S &&...shape) { std::size_t in_shape[sizeof...(S)]; detail::IndexToArr(in_shape, std::forward(shape)...); return TensorView{data, in_shape, device}; @@ -599,26 +591,26 @@ LINALG_HD auto MakeTensorView(std::int32_t device, common::Span data, S &&... template auto MakeTensorView(Context const *ctx, common::Span data, S &&...shape) { - return MakeTensorView(ctx->gpu_id, data, std::forward(shape)...); + return MakeTensorView(ctx->Device(), data, std::forward(shape)...); } template auto MakeTensorView(Context const *ctx, Order order, common::Span data, S &&...shape) { std::size_t in_shape[sizeof...(S)]; detail::IndexToArr(in_shape, std::forward(shape)...); - return TensorView{data, in_shape, ctx->Ordinal(), order}; + return TensorView{data, in_shape, ctx->Device(), order}; } template auto MakeTensorView(Context const *ctx, HostDeviceVector *data, S &&...shape) { auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan(); - return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); + return MakeTensorView(ctx->Device(), span, std::forward(shape)...); } template auto MakeTensorView(Context const *ctx, HostDeviceVector const *data, S &&...shape) { auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan(); - return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); + return MakeTensorView(ctx->Device(), span, std::forward(shape)...); } /** @@ -661,20 +653,20 @@ using VectorView = TensorView; * \param device (optional) Device ordinal, default to be host. */ template -auto MakeVec(T *ptr, size_t s, int32_t device = -1) { +auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) { return linalg::TensorView{{ptr, s}, {s}, device}; } template auto MakeVec(HostDeviceVector *data) { return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(), - data->Size(), data->DeviceIdx()); + data->Size(), data->Device()); } template auto MakeVec(HostDeviceVector const *data) { return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(), - data->Size(), data->DeviceIdx()); + data->Size(), data->Device()); } /** @@ -697,7 +689,7 @@ Json ArrayInterface(TensorView const &t) { array_interface["data"] = std::vector(2); array_interface["data"][0] = Integer{reinterpret_cast(t.Values().data())}; array_interface["data"][1] = Boolean{true}; - if (t.DeviceIdx() >= 0) { + if (t.Device().IsCUDA()) { // Change this once we have different CUDA stream. array_interface["stream"] = Null{}; } @@ -856,49 +848,29 @@ class Tensor { /** * @brief Get a @ref TensorView for this tensor. */ - TensorView View(std::int32_t device) { - if (device >= 0) { - data_.SetDevice(device); - auto span = data_.DeviceSpan(); - return {span, shape_, device, order_}; - } else { - auto span = data_.HostSpan(); - return {span, shape_, device, order_}; - } - } - TensorView View(std::int32_t device) const { - if (device >= 0) { - data_.SetDevice(device); - auto span = data_.ConstDeviceSpan(); - return {span, shape_, device, order_}; - } else { - auto span = data_.ConstHostSpan(); - return {span, shape_, device, order_}; - } - } auto View(DeviceOrd device) { if (device.IsCUDA()) { data_.SetDevice(device); auto span = data_.DeviceSpan(); - return TensorView{span, shape_, device.ordinal, order_}; + return TensorView{span, shape_, device, order_}; } else { auto span = data_.HostSpan(); - return TensorView{span, shape_, device.ordinal, order_}; + return TensorView{span, shape_, device, order_}; } } auto View(DeviceOrd device) const { if (device.IsCUDA()) { data_.SetDevice(device); auto span = data_.ConstDeviceSpan(); - return TensorView{span, shape_, device.ordinal, order_}; + return TensorView{span, shape_, device, order_}; } else { auto span = data_.ConstHostSpan(); - return TensorView{span, shape_, device.ordinal, order_}; + return TensorView{span, shape_, device, order_}; } } - auto HostView() const { return this->View(-1); } - auto HostView() { return this->View(-1); } + auto HostView() { return this->View(DeviceOrd::CPU()); } + auto HostView() const { return this->View(DeviceOrd::CPU()); } [[nodiscard]] size_t Size() const { return data_.Size(); } auto Shape() const { return common::Span{shape_}; } @@ -975,6 +947,7 @@ class Tensor { void SetDevice(int32_t device) const { data_.SetDevice(device); } void SetDevice(DeviceOrd device) const { data_.SetDevice(device); } [[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); } + [[nodiscard]] DeviceOrd Device() const { return data_.Device(); } }; template diff --git a/include/xgboost/multi_target_tree_model.h b/include/xgboost/multi_target_tree_model.h index 1ad7d6bf6..676c43196 100644 --- a/include/xgboost/multi_target_tree_model.h +++ b/include/xgboost/multi_target_tree_model.h @@ -37,12 +37,12 @@ class MultiTargetTree : public Model { [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) const { auto beg = nidx * this->NumTarget(); auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); - return linalg::MakeTensorView(Context::kCpuId, v, v.size()); + return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) { auto beg = nidx * this->NumTarget(); auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); - return linalg::MakeTensorView(Context::kCpuId, v, v.size()); + return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } public: diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 21674f785..1dddb1444 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -68,7 +68,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con auto &gpair = *out_gpair; gpair.SetDevice(grad_dev); gpair.Reshape(grad.Shape(0), grad.Shape(1)); - auto d_gpair = gpair.View(grad_dev); + auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev)); auto cuctx = ctx->CUDACtx(); DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) { diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 037ad1ff3..5d52e4100 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -13,7 +13,7 @@ namespace xgboost { namespace linalg { template void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { - dh::safe_cuda(cudaSetDevice(t.DeviceIdx())); + dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); static_assert(std::is_void>::value, "For function with return, use transform instead."); if (t.Contiguous()) { diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu index 283ccc21d..5ad8a575c 100644 --- a/src/common/ranking_utils.cu +++ b/src/common/ranking_utils.cu @@ -133,7 +133,7 @@ struct WeightOp { void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { CUDAContext const* cuctx = ctx->CUDACtx(); - group_ptr_.SetDevice(ctx->gpu_id); + group_ptr_.SetDevice(ctx->Device()); if (info.group_ptr_.empty()) { group_ptr_.Resize(2, 0); group_ptr_.HostVector()[1] = info.num_row_; @@ -153,7 +153,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { max_group_size_ = thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum{}); - threads_group_ptr_.SetDevice(ctx->gpu_id); + threads_group_ptr_.SetDevice(ctx->Device()); threads_group_ptr_.Resize(n_groups + 1, 0); auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan(); if (param_.HasTruncation()) { @@ -168,7 +168,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { n_cuda_threads_ = info.num_row_ * param_.NumPair(); } - sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.SetDevice(ctx->Device()); sorted_idx_cache_.Resize(info.labels.Size(), 0); auto weight = common::MakeOptionalWeights(ctx, info.weights_); @@ -187,18 +187,18 @@ common::Span RankingCache::MakeRankOnCUDA(Context const* ctx, void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { CUDAContext const* cuctx = ctx->CUDACtx(); - auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto labels = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx}); auto d_group_ptr = this->DataGroupPtr(ctx); std::size_t n_groups = d_group_ptr.size() - 1; inv_idcg_ = linalg::Zeros(ctx, n_groups); - auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id); + auto d_inv_idcg = inv_idcg_.View(ctx->Device()); cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param()); CHECK_GE(this->Param().NumPair(), 1ul); - discounts_.SetDevice(ctx->gpu_id); + discounts_.SetDevice(ctx->Device()); discounts_.Resize(MaxGroupSize()); auto d_discount = discounts_.DeviceSpan(); dh::LaunchN(MaxGroupSize(), cuctx->Stream(), @@ -206,12 +206,12 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { } void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { - auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()}); } void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { - auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()}); } } // namespace xgboost::ltr diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 75622bd84..31531a597 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -217,7 +217,7 @@ class RankingCache { } // Constructed as [1, n_samples] if group ptr is not supplied by the user common::Span DataGroupPtr(Context const* ctx) const { - group_ptr_.SetDevice(ctx->gpu_id); + group_ptr_.SetDevice(ctx->Device()); return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan(); } @@ -228,7 +228,7 @@ class RankingCache { // Create a rank list by model prediction common::Span SortedIdx(Context const* ctx, common::Span predt) { if (sorted_idx_cache_.Empty()) { - sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.SetDevice(ctx->Device()); sorted_idx_cache_.Resize(predt.size()); } if (ctx->IsCPU()) { @@ -242,7 +242,7 @@ class RankingCache { common::Span SortedIdxY(Context const* ctx, std::size_t n_samples) { CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal(); if (y_sorted_idx_cache_.Empty()) { - y_sorted_idx_cache_.SetDevice(ctx->gpu_id); + y_sorted_idx_cache_.SetDevice(ctx->Device()); y_sorted_idx_cache_.Resize(n_samples); } return y_sorted_idx_cache_.DeviceSpan(); @@ -250,7 +250,7 @@ class RankingCache { common::Span RankedY(Context const* ctx, std::size_t n_samples) { CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal(); if (y_ranked_by_model_.Empty()) { - y_ranked_by_model_.SetDevice(ctx->gpu_id); + y_ranked_by_model_.SetDevice(ctx->Device()); y_ranked_by_model_.Resize(n_samples); } return y_ranked_by_model_.DeviceSpan(); @@ -266,21 +266,21 @@ class RankingCache { linalg::VectorView CUDARounding(Context const* ctx) { if (roundings_.Size() == 0) { - roundings_.SetDevice(ctx->gpu_id); + roundings_.SetDevice(ctx->Device()); roundings_.Reshape(Groups()); } - return roundings_.View(ctx->gpu_id); + return roundings_.View(ctx->Device()); } common::Span CUDACostRounding(Context const* ctx) { if (cost_rounding_.Size() == 0) { - cost_rounding_.SetDevice(ctx->gpu_id); + cost_rounding_.SetDevice(ctx->Device()); cost_rounding_.Resize(1); } return cost_rounding_.DeviceSpan(); } template common::Span MaxLambdas(Context const* ctx, std::size_t n) { - max_lambdas_.SetDevice(ctx->gpu_id); + max_lambdas_.SetDevice(ctx->Device()); std::size_t bytes = n * sizeof(Type); if (bytes != max_lambdas_.Size()) { max_lambdas_.Resize(bytes); @@ -315,17 +315,17 @@ class NDCGCache : public RankingCache { } linalg::VectorView InvIDCG(Context const* ctx) const { - return inv_idcg_.View(ctx->gpu_id); + return inv_idcg_.View(ctx->Device()); } common::Span Discount(Context const* ctx) const { return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan(); } linalg::VectorView Dcg(Context const* ctx) { if (dcg_.Size() == 0) { - dcg_.SetDevice(ctx->gpu_id); + dcg_.SetDevice(ctx->Device()); dcg_.Reshape(this->Groups()); } - return dcg_.View(ctx->gpu_id); + return dcg_.View(ctx->Device()); } }; @@ -396,7 +396,7 @@ class PreCache : public RankingCache { common::Span Pre(Context const* ctx) { if (pre_.Empty()) { - pre_.SetDevice(ctx->gpu_id); + pre_.SetDevice(ctx->Device()); pre_.Resize(this->Groups()); } return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan(); @@ -427,21 +427,21 @@ class MAPCache : public RankingCache { common::Span NumRelevant(Context const* ctx) { if (n_rel_.Empty()) { - n_rel_.SetDevice(ctx->gpu_id); + n_rel_.SetDevice(ctx->Device()); n_rel_.Resize(n_samples_); } return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan(); } common::Span Acc(Context const* ctx) { if (acc_.Empty()) { - acc_.SetDevice(ctx->gpu_id); + acc_.SetDevice(ctx->Device()); acc_.Resize(n_samples_); } return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan(); } common::Span Map(Context const* ctx) { if (map_.Empty()) { - map_.SetDevice(ctx->gpu_id); + map_.SetDevice(ctx->Device()); map_.Resize(this->Groups()); } return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan(); diff --git a/src/common/stats.cc b/src/common/stats.cc index 80fc2c50d..03ee00b87 100644 --- a/src/common/stats.cc +++ b/src/common/stats.cc @@ -20,9 +20,9 @@ namespace common { void Median(Context const* ctx, linalg::Tensor const& t, HostDeviceVector const& weights, linalg::Tensor* out) { if (!ctx->IsCPU()) { - weights.SetDevice(ctx->gpu_id); + weights.SetDevice(ctx->Device()); auto opt_weights = OptionalWeights(weights.ConstDeviceSpan()); - auto t_v = t.View(ctx->gpu_id); + auto t_v = t.View(ctx->Device()); cuda_impl::Median(ctx, t_v, opt_weights, out); } @@ -59,7 +59,7 @@ void Mean(Context const* ctx, linalg::Vector const& v, linalg::VectorHostView()(0) = ret; } else { - cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id)); + cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device())); } } } // namespace common diff --git a/src/data/data.cc b/src/data/data.cc index 467770715..f143faf97 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -366,7 +366,7 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { // Groups is maintained by a higher level Python function. We should aim at deprecating // the slice function. if (this->labels.Size() != this->num_row_) { - auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx()); + auto t_labels = this->labels.View(this->labels.Data()->Device()); out.labels.Reshape(ridxs.size(), labels.Shape(1)); out.labels.Data()->HostVector() = Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0)); @@ -394,7 +394,7 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; - auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); + auto t_margin = this->base_margin_.View(this->base_margin_.Data()->Device()); out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1)); out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0)); @@ -445,7 +445,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::TensorReshape(array.shape); - auto t_out = p_out->View(Context::kCpuId); + auto t_out = p_out->View(DeviceOrd::CPU()); CHECK(t_out.CContiguous()); auto const shape = t_out.Shape(); DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) { @@ -564,7 +564,7 @@ void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, Da CHECK(key); auto proc = [&](auto cast_d_ptr) { using T = std::remove_pointer_t; - auto t = linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, Context::kCpuId); + auto t = linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, DeviceOrd::CPU()); CHECK(t.CContiguous()); Json interface { linalg::ArrayInterface(t) @@ -739,8 +739,7 @@ void MetaInfo::SynchronizeNumberOfColumns() { namespace { template void CheckDevice(std::int32_t device, HostDeviceVector const& v) { - bool valid = - v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device; + bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device; if (!valid) { LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than " "the booster. The device ordinal of the data is: " diff --git a/src/data/data.cu b/src/data/data.cu index 0f1fda661..74db2b28c 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -50,7 +50,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens return; } p_out->Reshape(array.shape); - auto t = p_out->View(ptr_device); + auto t = p_out->View(DeviceOrd::CUDA(ptr_device)); linalg::ElementWiseTransformDevice( t, [=] __device__(size_t i, T) { diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index bf4f6b92f..4b05d55f3 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -183,7 +183,7 @@ class GBLinear : public GradientBooster { bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override { model_.LazyInitModel(); LinearCheckLayer(layer_begin); - auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId); + auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU()); const int ngroup = model_.learner_model_param->num_output_group; const size_t ncolumns = model_.learner_model_param->num_feature + 1; // allocate space for (#features + bias) times #groups times #rows @@ -250,10 +250,9 @@ class GBLinear : public GradientBooster { // The bias is the last weight out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0); auto n_groups = learner_model_param_->num_output_group; - linalg::TensorView scores{ - *out_scores, - {learner_model_param_->num_feature, n_groups}, - Context::kCpuId}; + auto scores = linalg::MakeTensorView(DeviceOrd::CPU(), + common::Span{out_scores->data(), out_scores->size()}, + learner_model_param_->num_feature, n_groups); for (size_t i = 0; i < learner_model_param_->num_feature; ++i) { for (bst_group_t g = 0; g < n_groups; ++g) { scores(i, g) = model_[i][g]; @@ -275,12 +274,12 @@ class GBLinear : public GradientBooster { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); std::vector &preds = *out_preds; - auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId); + auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU()); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); - auto base_score = learner_model_param_->BaseScore(Context::kCpuId); + auto base_score = learner_model_param_->BaseScore(DeviceOrd::CPU()); for (const auto &page : p_fmat->GetBatches()) { auto const& batch = page.GetView(); // output convention: nrow * k, where nrow is number of rows diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index e9c5be003..50dfe9262 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -754,7 +754,7 @@ class Dart : public GBTree { auto n_groups = model_.learner_model_param->num_output_group; PredictionCacheEntry predts; // temporary storage for prediction - if (ctx_->gpu_id != Context::kCpuId) { + if (ctx_->IsCUDA()) { predts.predictions.SetDevice(ctx_->gpu_id); } predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); @@ -859,12 +859,12 @@ class Dart : public GBTree { size_t n_rows = p_fmat->Info().num_row_; if (predts.predictions.DeviceIdx() != Context::kCpuId) { p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx()); - auto base_score = model_.learner_model_param->BaseScore(predts.predictions.DeviceIdx()); + auto base_score = model_.learner_model_param->BaseScore(predts.predictions.Device()); GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(), predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups, group); } else { - auto base_score = model_.learner_model_param->BaseScore(Context::kCpuId); + auto base_score = model_.learner_model_param->BaseScore(DeviceOrd::CPU()); auto& h_predts = predts.predictions.HostVector(); auto& h_out_predts = p_out_preds->predictions.HostVector(); common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) { diff --git a/src/learner.cc b/src/learner.cc index be562f972..33725b612 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -279,15 +279,15 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy // Make sure read access everywhere for thread-safe prediction. std::as_const(base_score_).HostView(); if (!ctx->IsCPU()) { - std::as_const(base_score_).View(ctx->gpu_id); + std::as_const(base_score_).View(ctx->Device()); } CHECK(std::as_const(base_score_).Data()->HostCanRead()); } -linalg::TensorView LearnerModelParam::BaseScore(int32_t device) const { +linalg::TensorView LearnerModelParam::BaseScore(DeviceOrd device) const { // multi-class is not yet supported. CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted(); - if (device == Context::kCpuId) { + if (device.IsCPU()) { // Make sure that we won't run into race condition. CHECK(base_score_.Data()->HostCanRead()); return base_score_.HostView(); @@ -300,7 +300,7 @@ linalg::TensorView LearnerModelParam::BaseScore(int32_t device) } linalg::TensorView LearnerModelParam::BaseScore(Context const* ctx) const { - return this->BaseScore(ctx->gpu_id); + return this->BaseScore(ctx->Device()); } void LearnerModelParam::Copy(LearnerModelParam const& that) { @@ -309,7 +309,7 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) { base_score_.Data()->Copy(*that.base_score_.Data()); std::as_const(base_score_).HostView(); if (that.base_score_.DeviceIdx() != Context::kCpuId) { - std::as_const(base_score_).View(that.base_score_.DeviceIdx()); + std::as_const(base_score_).View(that.base_score_.Device()); } CHECK_EQ(base_score_.Data()->DeviceCanRead(), that.base_score_.Data()->DeviceCanRead()); CHECK(base_score_.Data()->HostCanRead()); @@ -388,7 +388,7 @@ class LearnerConfiguration : public Learner { this->ConfigureTargets(); auto task = UsePtr(obj_)->Task(); - linalg::Tensor base_score({1}, Ctx()->gpu_id); + linalg::Tensor base_score({1}, Ctx()->Device()); auto h_base_score = base_score.HostView(); // transform to margin @@ -424,7 +424,7 @@ class LearnerConfiguration : public Learner { if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) { if (p_fmat) { auto const& info = p_fmat->Info(); - info.Validate(Ctx()->gpu_id); + info.Validate(Ctx()->Ordinal()); // We estimate it from input data. linalg::Tensor base_score; InitEstimation(info, &base_score); @@ -1369,7 +1369,7 @@ class LearnerImpl : public LearnerIO { auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id); this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); // Copy the prediction cache to output prediction. out_preds comes from C API - out_preds->SetDevice(ctx_.gpu_id); + out_preds->SetDevice(ctx_.Device()); out_preds->Resize(prediction.predictions.Size()); out_preds->Copy(prediction.predictions); if (!output_margin) { diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 473f5b02c..a2e7372fb 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -82,22 +82,19 @@ template double MultiClassOVR(Context const *ctx, common::Span predts, MetaInfo const &info, size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) { CHECK_NE(n_classes, 0); - auto const labels = info.labels.View(Context::kCpuId); + auto const labels = info.labels.HostView(); if (labels.Shape(0) != 0) { CHECK_EQ(labels.Shape(1), 1) << "AUC doesn't support multi-target model."; } std::vector results_storage(n_classes * 3, 0); - linalg::TensorView results(results_storage, {n_classes, static_cast(3)}, - Context::kCpuId); + auto results = linalg::MakeTensorView(ctx, results_storage, n_classes, 3); auto local_area = results.Slice(linalg::All(), 0); auto tp = results.Slice(linalg::All(), 1); auto auc = results.Slice(linalg::All(), 2); auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()}; - auto predts_t = linalg::TensorView( - predts, {static_cast(info.num_row_), n_classes}, - Context::kCpuId); + auto predts_t = linalg::MakeTensorView(ctx, predts, info.num_row_, n_classes); if (info.labels.Size() != 0) { common::ParallelFor(n_classes, n_threads, [&](auto c) { @@ -108,8 +105,8 @@ double MultiClassOVR(Context const *ctx, common::Span predts, MetaI response[i] = labels(i) == c ? 1.0f : 0.0; } double fp; - std::tie(fp, tp(c), auc(c)) = - binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights); + std::tie(fp, tp(c), auc(c)) = binary_auc( + ctx, proba, linalg::MakeVec(response.data(), response.size(), ctx->Device()), weights); local_area(c) = fp * tp(c); }); } @@ -220,7 +217,7 @@ std::pair RankingAUC(Context const *ctx, std::vector co CHECK_GE(info.group_ptr_.size(), 2); uint32_t n_groups = info.group_ptr_.size() - 1; auto s_predts = common::Span{predts}; - auto labels = info.labels.View(Context::kCpuId); + auto labels = info.labels.View(ctx->Device()); auto s_weights = info.weights_.ConstHostSpan(); std::atomic invalid_groups{0}; @@ -363,8 +360,8 @@ class EvalROCAUC : public EvalAUC { info.labels.HostView().Slice(linalg::All(), 0), common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { - std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, - ctx_->gpu_id, &this->d_cache_); + std::tie(fp, tp, auc) = + GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_); } return std::make_tuple(fp, tp, auc); } @@ -381,8 +378,7 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc") #if !defined(XGBOOST_USE_CUDA) std::tuple GPUBinaryROCAUC(common::Span, MetaInfo const &, - std::int32_t, - std::shared_ptr *) { + DeviceOrd, std::shared_ptr *) { common::AssertGPUSupport(); return {}; } @@ -414,8 +410,8 @@ class EvalPRAUC : public EvalAUC { BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { - std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, - ctx_->gpu_id, &this->d_cache_); + std::tie(pr, re, auc) = + GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_); } return std::make_tuple(pr, re, auc); } @@ -459,7 +455,7 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr") #if !defined(XGBOOST_USE_CUDA) std::tuple GPUBinaryPRAUC(common::Span, MetaInfo const &, - std::int32_t, std::shared_ptr *) { + DeviceOrd, std::shared_ptr *) { common::AssertGPUSupport(); return {}; } diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 6e3032e42..dd9e4483f 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -85,11 +85,11 @@ void InitCacheOnce(common::Span predts, std::shared_ptr std::tuple GPUBinaryAUC(common::Span predts, MetaInfo const &info, - int32_t device, common::Span d_sorted_idx, + DeviceOrd device, common::Span d_sorted_idx, Fn area_fn, std::shared_ptr cache) { auto labels = info.labels.View(device); auto weights = info.weights_.ConstDeviceSpan(); - dh::safe_cuda(cudaSetDevice(device)); + dh::safe_cuda(cudaSetDevice(device.ordinal)); CHECK_NE(labels.Size(), 0); CHECK_EQ(labels.Size(), predts.size()); @@ -168,7 +168,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, } std::tuple GPUBinaryROCAUC(common::Span predts, - MetaInfo const &info, std::int32_t device, + MetaInfo const &info, DeviceOrd device, std::shared_ptr *p_cache) { auto &cache = *p_cache; InitCacheOnce(predts, p_cache); @@ -309,9 +309,10 @@ void SegmentedReduceAUC(common::Span d_unique_idx, * up each class in all kernels. */ template -double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span d_class_ptr, - size_t n_classes, std::shared_ptr cache, Fn area_fn) { - dh::safe_cuda(cudaSetDevice(device)); +double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device, + common::Span d_class_ptr, size_t n_classes, + std::shared_ptr cache, Fn area_fn) { + dh::safe_cuda(cudaSetDevice(device.ordinal)); /** * Sorted idx */ @@ -467,11 +468,12 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span predts, dh::TemporaryArray class_ptr(n_classes + 1, 0); MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache); - auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, - double tp, size_t /*class_id*/) { + auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, double tp, + size_t /*class_id*/) { return TrapezoidArea(fp_prev, fp, tp_prev, tp); }; - return GPUMultiClassAUCOVR(info, ctx->gpu_id, dh::ToSpan(class_ptr), n_classes, cache, fn); + return GPUMultiClassAUCOVR(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache, + fn); } namespace { @@ -512,7 +514,7 @@ std::pair GPURankingAUC(Context const *ctx, common::Span< /** * Sort the labels */ - auto d_labels = info.labels.View(ctx->gpu_id); + auto d_labels = info.labels.View(ctx->Device()); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); common::SegmentedArgSort(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx); @@ -604,7 +606,7 @@ std::pair GPURankingAUC(Context const *ctx, common::Span< } std::tuple GPUBinaryPRAUC(common::Span predts, - MetaInfo const &info, std::int32_t device, + MetaInfo const &info, DeviceOrd device, std::shared_ptr *p_cache) { auto& cache = *p_cache; InitCacheOnce(predts, p_cache); @@ -662,7 +664,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span predts, /** * Get total positive/negative */ - auto labels = info.labels.View(ctx->gpu_id); + auto labels = info.labels.View(ctx->Device()); auto n_samples = info.num_row_; dh::caching_device_vector totals(n_classes); auto key_it = @@ -695,13 +697,13 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span predts, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[class_id].first); }; - return GPUMultiClassAUCOVR(info, ctx->gpu_id, d_class_ptr, n_classes, cache, fn); + return GPUMultiClassAUCOVR(info, ctx->Device(), d_class_ptr, n_classes, cache, fn); } template std::pair GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, - common::Span d_group_ptr, int32_t device, + common::Span d_group_ptr, DeviceOrd device, std::shared_ptr cache, Fn area_fn) { /** * Sorted idx @@ -843,7 +845,7 @@ std::pair GPURankingPRAUC(Context const *ctx, common::SegmentedArgSort(ctx, predts, d_group_ptr, d_sorted_idx); dh::XGBDeviceAllocator alloc; - auto labels = info.labels.View(ctx->gpu_id); + auto labels = info.labels.View(ctx->Device()); if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()), dh::tend(labels.Values()), PRAUCLabelInvalid{})) { InvalidLabels(); @@ -882,7 +884,7 @@ std::pair GPURankingPRAUC(Context const *ctx, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[group_id].first); }; - return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->gpu_id, cache, fn); + return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn); } } // namespace metric } // namespace xgboost diff --git a/src/metric/auc.h b/src/metric/auc.h index d8e7f4344..fce1cc757 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -30,7 +30,7 @@ XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, doub struct DeviceAUCCache; std::tuple GPUBinaryROCAUC(common::Span predts, - MetaInfo const &info, std::int32_t device, + MetaInfo const &info, DeviceOrd, std::shared_ptr *p_cache); double GPUMultiClassROCAUC(Context const *ctx, common::Span predts, @@ -45,7 +45,7 @@ std::pair GPURankingAUC(Context const *ctx, common::Span< * PR AUC * **********/ std::tuple GPUBinaryPRAUC(common::Span predts, - MetaInfo const &info, std::int32_t device, + MetaInfo const &info, DeviceOrd, std::shared_ptr *p_cache); double GPUMultiClassPRAUC(Context const *ctx, common::Span predts, diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index b6888610b..e16f9f8cc 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -45,7 +45,7 @@ namespace { template PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) { PackedReduceResult result; - auto labels = info.labels.View(ctx->gpu_id); + auto labels = info.labels.View(ctx->Device()); if (ctx->IsCPU()) { auto n_threads = ctx->Threads(); std::vector score_tloc(n_threads, 0.0); @@ -183,10 +183,10 @@ class PseudoErrorLoss : public MetricNoCache { double Eval(const HostDeviceVector& preds, const MetaInfo& info) override { CHECK_EQ(info.labels.Shape(0), info.num_row_); - auto labels = info.labels.View(ctx_->gpu_id); - preds.SetDevice(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->Device()); + preds.SetDevice(ctx_->Device()); auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); - info.weights_.SetDevice(ctx_->gpu_id); + info.weights_.SetDevice(ctx_->Device()); common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()); float slope = this->param_.huber_slope; @@ -349,11 +349,11 @@ struct EvalEWiseBase : public MetricNoCache { if (info.labels.Size() != 0) { CHECK_NE(info.labels.Shape(1), 0); } - auto labels = info.labels.View(ctx_->gpu_id); - info.weights_.SetDevice(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->Device()); + info.weights_.SetDevice(ctx_->Device()); common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()); - preds.SetDevice(ctx_->gpu_id); + preds.SetDevice(ctx_->Device()); auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); auto d_policy = policy_; @@ -444,16 +444,16 @@ class QuantileError : public MetricNoCache { } auto const* ctx = ctx_; - auto y_true = info.labels.View(ctx->gpu_id); - preds.SetDevice(ctx->gpu_id); - alpha_.SetDevice(ctx->gpu_id); + auto y_true = info.labels.View(ctx->Device()); + preds.SetDevice(ctx->Device()); + alpha_.SetDevice(ctx->Device()); auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan(); std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size(); CHECK_NE(n_targets, 0); auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast(info.num_row_), alpha_.Size(), n_targets); - info.weights_.SetDevice(ctx->gpu_id); + info.weights_.SetDevice(ctx->Device()); common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()}; diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 8df6e585f..41495164c 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -75,7 +75,7 @@ struct EvalAMS : public MetricNoCache { const double br = 10.0; unsigned thresindex = 0; double s_tp = 0.0, b_fp = 0.0, tams = 0.0; - const auto& labels = info.labels.View(Context::kCpuId); + const auto& labels = info.labels.View(DeviceOrd::CPU()); for (unsigned i = 0; i < static_cast(ndata-1) && i < ntop; ++i) { const unsigned ridx = rec[i].second; const bst_float wt = info.GetWeight(ridx); @@ -134,7 +134,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { std::vector sum_tloc(ctx_->Threads(), 0.0); { - const auto& labels = info.labels.View(Context::kCpuId); + const auto& labels = info.labels.HostView(); const auto &h_preds = preds.ConstHostVector(); dmlc::OMPException exc; diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 9ba1baf8f..f79d52742 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -33,7 +33,7 @@ PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, std::shared_ptr p_cache) { auto d_gptr = p_cache->DataGroupPtr(ctx); - auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); predt.SetDevice(ctx->gpu_id); auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); @@ -89,7 +89,7 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, if (!d_weight.Empty()) { CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); } - auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); predt.SetDevice(ctx->gpu_id); auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size()); @@ -119,9 +119,9 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache) { auto d_group_ptr = p_cache->DataGroupPtr(ctx); - auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); - predt.SetDevice(ctx->gpu_id); + predt.SetDevice(ctx->Device()); auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index bba8b85ad..29f70a8d8 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -19,7 +19,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span pos dh::device_vector* p_ridx, HostDeviceVector* p_nptr, HostDeviceVector* p_nidx, RegTree const& tree) { // copy position to buffer - dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); auto cuctx = ctx->CUDACtx(); size_t n_samples = position.size(); dh::device_vector sorted_position(position.size()); @@ -86,11 +86,11 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span pos */ auto& nidx = *p_nidx; auto& nptr = *p_nptr; - nidx.SetDevice(ctx->gpu_id); + nidx.SetDevice(ctx->Device()); nidx.Resize(n_leaf); auto d_node_idx = nidx.DeviceSpan(); - nptr.SetDevice(ctx->gpu_id); + nptr.SetDevice(ctx->Device()); nptr.Resize(n_leaf + 1, 0); auto d_node_ptr = nptr.DeviceSpan(); @@ -142,7 +142,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span pos void UpdateTreeLeafDevice(Context const* ctx, common::Span position, std::int32_t group_idx, MetaInfo const& info, float learning_rate, HostDeviceVector const& predt, float alpha, RegTree* p_tree) { - dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); dh::device_vector ridx; HostDeviceVector nptr; HostDeviceVector nidx; @@ -155,13 +155,13 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos } HostDeviceVector quantiles; - predt.SetDevice(ctx->gpu_id); + predt.SetDevice(ctx->Device()); auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_, predt.Size() / info.num_row_); CHECK_LT(group_idx, d_predt.Shape(1)); auto t_predt = d_predt.Slice(linalg::All(), group_idx); - auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx)); + auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx)); auto d_row_index = dh::ToSpan(ridx); auto seg_beg = nptr.DevicePointer(); @@ -178,7 +178,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos if (info.weights_.Empty()) { common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles); } else { - info.weights_.SetDevice(ctx->gpu_id); + info.weights_.SetDevice(ctx->Device()); auto d_weights = info.weights_.ConstDeviceSpan(); CHECK_EQ(d_weights.size(), d_row_index.size()); auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index)); diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 46fd77705..5a3a38fdf 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -109,12 +109,12 @@ class LambdaRankObj : public FitIntercept { lj_.SetDevice(ctx_->gpu_id); if (ctx_->IsCPU()) { - cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), - lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()), + lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_, &li_, &lj_, p_cache_); } else { - cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), - lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()), + lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_, &li_, &lj_, p_cache_); } @@ -354,9 +354,9 @@ class LambdaRankNDCG : public LambdaRankObj { const MetaInfo& info, linalg::Matrix* out_gpair) { if (ctx_->IsCUDA()) { cuda_impl::LambdaRankGetGradientNDCG( - ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), - tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), - out_gpair); + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()), + tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()), + lj_full_.View(ctx_->Device()), out_gpair); return; } @@ -477,9 +477,9 @@ class LambdaRankMAP : public LambdaRankObj { CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; if (ctx_->IsCUDA()) { return cuda_impl::LambdaRankGetGradientMAP( - ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), - tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), - out_gpair); + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()), + tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()), + lj_full_.View(ctx_->Device()), out_gpair); } auto gptr = p_cache_->DataGroupPtr(ctx_).data(); @@ -567,9 +567,9 @@ class LambdaRankPairwise : public LambdaRankObjIsCUDA()) { return cuda_impl::LambdaRankGetGradientPairwise( - ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), - tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), - out_gpair); + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()), + tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()), + lj_full_.View(ctx_->Device()), out_gpair); } auto gptr = p_cache_->DataGroupPtr(ctx_); diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 0f57fce48..ac31a2c79 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -306,7 +306,7 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const CHECK_NE(d_rounding.Size(), 0); - auto label = info.labels.View(ctx->gpu_id); + auto label = info.labels.View(ctx->Device()); auto predts = preds.ConstDeviceSpan(); auto gpairs = out_gpair->View(ctx->Device()); thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(), @@ -348,7 +348,7 @@ common::Span SortY(Context const* ctx, MetaInfo const& info, common::Span d_rank, std::shared_ptr p_cache) { auto const d_group_ptr = p_cache->DataGroupPtr(ctx); - auto label = info.labels.View(ctx->gpu_id); + auto label = info.labels.View(ctx->Device()); // The buffer for ranked y is necessary as cub segmented sort accepts only pointer. auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_); thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(), @@ -374,13 +374,13 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, linalg::VectorView li, linalg::VectorView lj, linalg::Matrix* out_gpair) { // boilerplate - std::int32_t device_id = ctx->gpu_id; - dh::safe_cuda(cudaSetDevice(device_id)); + auto device = ctx->Device(); + dh::safe_cuda(cudaSetDevice(device.ordinal)); auto const d_inv_IDCG = p_cache->InvIDCG(ctx); auto const discount = p_cache->Discount(ctx); - info.labels.SetDevice(device_id); - preds.SetDevice(device_id); + info.labels.SetDevice(device); + preds.SetDevice(device); auto const exp_gain = p_cache->Param().ndcg_exp_gain; auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high, @@ -403,7 +403,7 @@ void MAPStat(Context const* ctx, MetaInfo const& info, common::Span( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); }); - auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0); auto const* cuctx = ctx->CUDACtx(); { @@ -442,11 +442,11 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, linalg::Matrix* out_gpair) { - std::int32_t device_id = ctx->gpu_id; - dh::safe_cuda(cudaSetDevice(device_id)); + auto device = ctx->Device(); + dh::safe_cuda(cudaSetDevice(device.ordinal)); - info.labels.SetDevice(device_id); - predt.SetDevice(device_id); + info.labels.SetDevice(device); + predt.SetDevice(device); CHECK(p_cache); @@ -481,11 +481,11 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, linalg::Matrix* out_gpair) { - std::int32_t device_id = ctx->gpu_id; - dh::safe_cuda(cudaSetDevice(device_id)); + auto device = ctx->Device(); + dh::safe_cuda(cudaSetDevice(device.ordinal)); - info.labels.SetDevice(device_id); - predt.SetDevice(device_id); + info.labels.SetDevice(device); + predt.SetDevice(device); auto d_predt = predt.ConstDeviceSpan(); auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); @@ -517,11 +517,11 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorViewDataGroupPtr(ctx); auto n_groups = d_group_ptr.size() - 1; - auto ti_plus = p_ti_plus->View(ctx->gpu_id); - auto tj_minus = p_tj_minus->View(ctx->gpu_id); + auto ti_plus = p_ti_plus->View(ctx->Device()); + auto tj_minus = p_tj_minus->View(ctx->Device()); - auto li = p_li->View(ctx->gpu_id); - auto lj = p_lj->View(ctx->gpu_id); + auto li = p_li->View(ctx->Device()); + auto lj = p_lj->View(ctx->Device()); CHECK_EQ(li.Size(), ti_plus.Size()); auto const& param = p_cache->Param(); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 0774223e7..8d83b829b 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -62,7 +62,7 @@ class QuantileRegression : public ObjFunction { CHECK_GE(n_targets, n_alphas); CHECK_EQ(preds.Size(), info.num_row_ * n_targets); - auto labels = info.labels.View(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->Device()); out_gpair->SetDevice(ctx_->Device()); CHECK_EQ(info.labels.Shape(1), 1) @@ -131,7 +131,7 @@ class QuantileRegression : public ObjFunction { #if defined(XGBOOST_USE_CUDA) alpha_.SetDevice(ctx_->gpu_id); auto d_alpha = alpha_.ConstDeviceSpan(); - auto d_labels = info.labels.View(ctx_->gpu_id); + auto d_labels = info.labels.View(ctx_->Device()); auto seg_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); }); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 5751d6102..4f099a537 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -69,7 +69,7 @@ class RegLossObj : public FitIntercept { public: void ValidateLabel(MetaInfo const& info) { - auto label = info.labels.View(ctx_->Ordinal()); + auto label = info.labels.View(ctx_->Device()); auto valid = ctx_->DispatchDevice( [&] { return std::all_of(linalg::cbegin(label), linalg::cend(label), @@ -244,7 +244,7 @@ class PseudoHuberRegression : public FitIntercept { CheckRegInputs(info, preds); auto slope = param_.huber_slope; CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; - auto labels = info.labels.View(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->Device()); out_gpair->SetDevice(ctx_->gpu_id); out_gpair->Reshape(info.num_row_, this->Targets(info)); @@ -698,7 +698,7 @@ class MeanAbsoluteError : public ObjFunction { void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t /*iter*/, linalg::Matrix* out_gpair) override { CheckRegInputs(info, preds); - auto labels = info.labels.View(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->Device()); out_gpair->SetDevice(ctx_->Device()); out_gpair->Reshape(info.num_row_, this->Targets(info)); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index c092c0b04..26d8f3440 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -663,7 +663,7 @@ class CPUPredictor : public Predictor { std::size_t n_samples = p_fmat->Info().num_row_; std::size_t n_groups = model.learner_model_param->OutputLength(); CHECK_EQ(out_preds->size(), n_samples * n_groups); - linalg::TensorView out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id}; + auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups); if (!p_fmat->PageExists()) { std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); @@ -732,7 +732,7 @@ class CPUPredictor : public Predictor { std::vector thread_temp; InitThreadTemp(n_threads * kBlockSize, &thread_temp); std::size_t n_groups = model.learner_model_param->OutputLength(); - linalg::TensorView out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId}; + auto out_predt = linalg::MakeTensorView(ctx_, predictions, m->NumRows(), n_groups); PredictBatchByBlockOfRowsKernel, kBlockSize>( AdapterView(m.get(), missing, common::Span{workspace}, n_threads), model, tree_begin, tree_end, &thread_temp, n_threads, out_predt); @@ -878,8 +878,8 @@ class CPUPredictor : public Predictor { common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); - auto base_margin = info.base_margin_.View(Context::kCpuId); - auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0); + auto base_margin = info.base_margin_.View(ctx_->Device()); + auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0); // start collecting the contributions for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 2559447f3..4d7fc598f 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -60,7 +60,7 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVectorResize(n); - auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0); + auto base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU())(0); out_preds->Fill(base_score); } } diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index ec654a1b2..ec1b6fe18 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::MatrixDevice()); auto gpair_t = gpair.View(ctx->Device()); ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) - : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); + : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device())); } } // namespace tree } // namespace xgboost diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 33f92014e..40b2a0c96 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -1,5 +1,5 @@ /** - * Copyright 2022 by XGBoost Contributors + * Copyright 2022-2023 by XGBoost Contributors * * \brief Utilities for estimating initial score. */ @@ -41,7 +41,7 @@ void FitStump(Context const* ctx, linalg::TensorView gpai auto sample = i % gpair.Shape(0); return GradientPairPrecise{gpair(sample, target)}; }); - auto d_sum = sum.View(ctx->gpu_id); + auto d_sum = sum.View(ctx->Device()); CHECK(d_sum.CContiguous()); dh::XGBCachingDeviceAllocator alloc; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 82dc99b12..d0267b0ed 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -774,7 +774,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, std::vector const &partitioner, linalg::VectorView out_preds) { auto const &tree = *p_last_tree; - CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId); + CHECK(out_preds.Device().IsCPU()); size_t n_nodes = p_last_tree->GetNodes().size(); for (auto &part : partitioner) { CHECK_EQ(part.Size(), n_nodes); @@ -809,7 +809,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, auto n_nodes = mttree->Size(); auto n_targets = tree.NumTargets(); CHECK_EQ(out_preds.Shape(1), n_targets); - CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId); + CHECK(out_preds.Device().IsCPU()); for (auto &part : partitioner) { CHECK_EQ(part.Size(), n_nodes); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 10fb913b3..0e42f1562 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -516,9 +516,10 @@ struct GPUHistMakerDevice { } CHECK(p_tree); - dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); - CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); + CHECK(out_preds_d.Device().IsCUDA()); + CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal()); + dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); auto d_position = dh::ToSpan(positions); CHECK_EQ(out_preds_d.Size(), d_position.size()); diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index b1a90d773..f345b3a78 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -3,7 +3,7 @@ */ #include #include -#include +#include // for HostDeviceVector #include #include // size_t @@ -14,8 +14,8 @@ namespace xgboost::linalg { namespace { -auto kCpuId = Context::kCpuId; -} +DeviceOrd CPU() { return DeviceOrd::CPU(); } +} // namespace auto MakeMatrixFromTest(HostDeviceVector *storage, std::size_t n_rows, std::size_t n_cols) { storage->Resize(n_rows * n_cols); @@ -23,7 +23,7 @@ auto MakeMatrixFromTest(HostDeviceVector *storage, std::size_t n_rows, st std::iota(h_storage.begin(), h_storage.end(), 0); - auto m = linalg::TensorView{h_storage, {n_rows, static_cast(n_cols)}, -1}; + auto m = linalg::TensorView{h_storage, {n_rows, static_cast(n_cols)}, CPU()}; return m; } @@ -31,7 +31,7 @@ TEST(Linalg, MatrixView) { size_t kRows = 31, kCols = 77; HostDeviceVector storage; auto m = MakeMatrixFromTest(&storage, kRows, kCols); - ASSERT_EQ(m.DeviceIdx(), kCpuId); + ASSERT_EQ(m.Device(), CPU()); ASSERT_EQ(m(0, 0), 0); ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1); } @@ -76,7 +76,7 @@ TEST(Linalg, TensorView) { { // as vector - TensorView vec{data, {data.size()}, -1}; + TensorView vec{data, {data.size()}, CPU()}; ASSERT_EQ(vec.Size(), data.size()); ASSERT_EQ(vec.Shape(0), data.size()); ASSERT_EQ(vec.Shape().size(), 1); @@ -87,7 +87,7 @@ TEST(Linalg, TensorView) { { // as matrix - TensorView mat(data, {6, 4}, -1); + TensorView mat(data, {6, 4}, CPU()); auto s = mat.Slice(2, All()); ASSERT_EQ(s.Shape().size(), 1); s = mat.Slice(All(), 1); @@ -96,7 +96,7 @@ TEST(Linalg, TensorView) { { // assignment - TensorView t{data, {2, 3, 4}, 0}; + TensorView t{data, {2, 3, 4}, CPU()}; double pi = 3.14159; auto old = t(1, 2, 3); t(1, 2, 3) = pi; @@ -201,7 +201,7 @@ TEST(Linalg, TensorView) { } { // f-contiguous - TensorView t{data, {4, 3, 2}, {1, 4, 12}, kCpuId}; + TensorView t{data, {4, 3, 2}, {1, 4, 12}, CPU()}; ASSERT_TRUE(t.Contiguous()); ASSERT_TRUE(t.FContiguous()); ASSERT_FALSE(t.CContiguous()); @@ -210,11 +210,11 @@ TEST(Linalg, TensorView) { TEST(Linalg, Tensor) { { - Tensor t{{2, 3, 4}, kCpuId, Order::kC}; - auto view = t.View(kCpuId); + Tensor t{{2, 3, 4}, CPU(), Order::kC}; + auto view = t.View(CPU()); auto const &as_const = t; - auto k_view = as_const.View(kCpuId); + auto k_view = as_const.View(CPU()); size_t n = 2 * 3 * 4; ASSERT_EQ(t.Size(), n); @@ -229,7 +229,7 @@ TEST(Linalg, Tensor) { } { // Reshape - Tensor t{{2, 3, 4}, kCpuId, Order::kC}; + Tensor t{{2, 3, 4}, CPU(), Order::kC}; t.Reshape(4, 3, 2); ASSERT_EQ(t.Size(), 24); ASSERT_EQ(t.Shape(2), 2); @@ -247,7 +247,7 @@ TEST(Linalg, Tensor) { TEST(Linalg, Empty) { { - auto t = TensorView{{}, {0, 3}, kCpuId, Order::kC}; + auto t = TensorView{{}, {0, 3}, CPU(), Order::kC}; for (int32_t i : {0, 1, 2}) { auto s = t.Slice(All(), i); ASSERT_EQ(s.Size(), 0); @@ -256,9 +256,9 @@ TEST(Linalg, Empty) { } } { - auto t = Tensor{{0, 3}, kCpuId, Order::kC}; + auto t = Tensor{{0, 3}, CPU(), Order::kC}; ASSERT_EQ(t.Size(), 0); - auto view = t.View(kCpuId); + auto view = t.View(CPU()); for (int32_t i : {0, 1, 2}) { auto s = view.Slice(All(), i); @@ -270,7 +270,7 @@ TEST(Linalg, Empty) { } TEST(Linalg, ArrayInterface) { - auto cpu = kCpuId; + auto cpu = CPU(); auto t = Tensor{{3, 3}, cpu, Order::kC}; auto v = t.View(cpu); std::iota(v.Values().begin(), v.Values().end(), 0); @@ -315,16 +315,16 @@ TEST(Linalg, Popc) { } TEST(Linalg, Stack) { - Tensor l{{2, 3, 4}, kCpuId, Order::kC}; - ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), + Tensor l{{2, 3, 4}, CPU(), Order::kC}; + ElementWiseTransformHost(l.View(CPU()), omp_get_max_threads(), [=](size_t i, float) { return i; }); - Tensor r_0{{2, 3, 4}, kCpuId, Order::kC}; - ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), + Tensor r_0{{2, 3, 4}, CPU(), Order::kC}; + ElementWiseTransformHost(r_0.View(CPU()), omp_get_max_threads(), [=](size_t i, float) { return i; }); Stack(&l, r_0); - Tensor r_1{{0, 3, 4}, kCpuId, Order::kC}; + Tensor r_1{{0, 3, 4}, CPU(), Order::kC}; Stack(&l, r_1); ASSERT_EQ(l.Shape(0), 4); @@ -335,7 +335,7 @@ TEST(Linalg, Stack) { TEST(Linalg, FOrder) { std::size_t constexpr kRows = 16, kCols = 3; std::vector data(kRows * kCols); - MatrixView mat{data, {kRows, kCols}, Context::kCpuId, Order::kF}; + MatrixView mat{data, {kRows, kCols}, CPU(), Order::kF}; float k{0}; for (std::size_t i = 0; i < kRows; ++i) { for (std::size_t j = 0; j < kCols; ++j) { diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index be89d51bc..b88b8e127 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -11,17 +11,18 @@ namespace xgboost::linalg { namespace { void TestElementWiseKernel() { + auto device = DeviceOrd::CUDA(0); Tensor l{{2, 3, 4}, 0}; { /** * Non-contiguous */ // GPU view - auto t = l.View(0).Slice(linalg::All(), 1, linalg::All()); + auto t = l.View(device).Slice(linalg::All(), 1, linalg::All()); ASSERT_FALSE(t.CContiguous()); ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; }); // CPU view - t = l.View(Context::kCpuId).Slice(linalg::All(), 1, linalg::All()); + t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All()); size_t k = 0; for (size_t i = 0; i < l.Shape(0); ++i) { for (size_t j = 0; j < l.Shape(2); ++j) { @@ -29,7 +30,7 @@ void TestElementWiseKernel() { } } - t = l.View(0).Slice(linalg::All(), 1, linalg::All()); + t = l.View(device).Slice(linalg::All(), 1, linalg::All()); ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); }); } @@ -37,11 +38,11 @@ void TestElementWiseKernel() { /** * Contiguous */ - auto t = l.View(0); + auto t = l.View(device); ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; }); ASSERT_TRUE(t.CContiguous()); // CPU view - t = l.View(Context::kCpuId); + t = l.View(DeviceOrd::CPU()); size_t ind = 0; for (size_t i = 0; i < l.Shape(0); ++i) { diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu index d62f5f171..86ce4b6d0 100644 --- a/tests/cpp/common/test_ranking_utils.cu +++ b/tests/cpp/common/test_ranking_utils.cu @@ -41,7 +41,7 @@ void TestCalcQueriesInvIDCG() { p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}}); cuda_impl::CalcQueriesInvIDCG(&ctx, linalg::MakeTensorView(&ctx, d_scores, d_scores.size()), - dh::ToSpan(group_ptr), inv_IDCG.View(ctx.gpu_id), p); + dh::ToSpan(group_ptr), inv_IDCG.View(ctx.Device()), p); for (std::size_t i = 0; i < n_groups; ++i) { double inv_idcg = inv_IDCG(i); ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps); diff --git a/tests/cpp/common/test_stats.cu b/tests/cpp/common/test_stats.cu index 08877ac8d..3dc90e069 100644 --- a/tests/cpp/common/test_stats.cu +++ b/tests/cpp/common/test_stats.cu @@ -47,7 +47,7 @@ class StatsGPU : public ::testing::Test { data.insert(data.cend(), seg.begin(), seg.end()); data.insert(data.cend(), seg.begin(), seg.end()); linalg::Tensor arr{data.cbegin(), data.cend(), {data.size()}, 0}; - auto d_arr = arr.View(0); + auto d_arr = arr.View(DeviceOrd::CUDA(0)); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), @@ -71,8 +71,8 @@ class StatsGPU : public ::testing::Test { } void Weighted() { - auto d_arr = arr_.View(0); - auto d_key = indptr_.View(0); + auto d_arr = arr_.View(DeviceOrd::CUDA(0)); + auto d_key = indptr_.View(DeviceOrd::CUDA(0)); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), @@ -81,7 +81,7 @@ class StatsGPU : public ::testing::Test { dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); }); linalg::Tensor weights{{10}, 0}; - linalg::ElementWiseTransformDevice(weights.View(0), + linalg::ElementWiseTransformDevice(weights.View(DeviceOrd::CUDA(0)), [=] XGBOOST_DEVICE(std::size_t, float) { return 1.0; }); auto w_it = weights.Data()->ConstDevicePointer(); for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) { @@ -102,7 +102,7 @@ class StatsGPU : public ::testing::Test { data.insert(data.cend(), seg.begin(), seg.end()); data.insert(data.cend(), seg.begin(), seg.end()); linalg::Tensor arr{data.cbegin(), data.cend(), {data.size()}, 0}; - auto d_arr = arr.View(0); + auto d_arr = arr.View(DeviceOrd::CUDA(0)); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), @@ -125,8 +125,8 @@ class StatsGPU : public ::testing::Test { } void NonWeighted() { - auto d_arr = arr_.View(0); - auto d_key = indptr_.View(0); + auto d_arr = arr_.View(DeviceOrd::CUDA(0)); + auto d_key = indptr_.View(DeviceOrd::CUDA(0)); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] __device__(std::size_t i) { return d_key(i); }); diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc index 7e0484842..b692a2aa5 100644 --- a/tests/cpp/data/test_array_interface.cc +++ b/tests/cpp/data/test_array_interface.cc @@ -22,7 +22,7 @@ TEST(ArrayInterface, Initialize) { HostDeviceVector u64_storage(storage.Size()); std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView{ - u64_storage.ConstHostSpan(), {kRows, kCols}, Context::kCpuId})}; + u64_storage.ConstHostSpan(), {kRows, kCols}, DeviceOrd::CPU()})}; std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(), u64_storage.HostSpan().begin()); auto u64_arr = ArrayInterface<2>{u64_arr_str}; diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 5ebe1c6bd..dbaffb7cd 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -129,8 +129,8 @@ TEST(MetaInfo, SaveLoadBinary) { EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); - auto orig_margin = info.base_margin_.View(xgboost::Context::kCpuId); - auto read_margin = inforead.base_margin_.View(xgboost::Context::kCpuId); + auto orig_margin = info.base_margin_.View(xgboost::DeviceOrd::CPU()); + auto read_margin = inforead.base_margin_.View(xgboost::DeviceOrd::CPU()); EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(), read_margin.Values().cbegin())); @@ -267,8 +267,8 @@ TEST(MetaInfo, Validate) { xgboost::HostDeviceVector d_groups{groups}; d_groups.SetDevice(0); d_groups.DevicePointer(); // pull to device - std::string arr_interface_str{ArrayInterfaceStr( - xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))}; + std::string arr_interface_str{ArrayInterfaceStr(xgboost::linalg::MakeVec( + d_groups.ConstDevicePointer(), d_groups.Size(), xgboost::DeviceOrd::CUDA(0)))}; EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } @@ -307,5 +307,5 @@ TEST(MetaInfo, HostExtend) { } namespace xgboost { -TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); } +TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); } } // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 95c8f5f39..4f02dfddc 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -65,7 +65,7 @@ TEST(MetaInfo, FromInterface) { } info.SetInfo(ctx, "base_margin", str.c_str()); - auto const h_base_margin = info.base_margin_.View(Context::kCpuId); + auto const h_base_margin = info.base_margin_.View(DeviceOrd::CPU()); ASSERT_EQ(h_base_margin.Size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { ASSERT_EQ(h_base_margin(i), d_data[i]); @@ -83,7 +83,7 @@ TEST(MetaInfo, FromInterface) { } TEST(MetaInfo, GPUStridedData) { - TestMetaInfoStridedData(0); + TestMetaInfoStridedData(DeviceOrd::CUDA(0)); } TEST(MetaInfo, Group) { diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index 6e45b5062..fba882e0e 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -14,10 +14,10 @@ #include "../../../src/data/array_interface.h" namespace xgboost { -inline void TestMetaInfoStridedData(int32_t device) { +inline void TestMetaInfoStridedData(DeviceOrd device) { MetaInfo info; Context ctx; - ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}}); + ctx.UpdateAllowUnknown(Args{{"device", device.Name()}}); { // labels linalg::Tensor labels; @@ -28,9 +28,9 @@ inline void TestMetaInfoStridedData(int32_t device) { ASSERT_EQ(t_labels.Shape().size(), 2); info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)}); - auto const& h_result = info.labels.View(-1); + auto const& h_result = info.labels.View(DeviceOrd::CPU()); ASSERT_EQ(h_result.Shape().size(), 2); - auto in_labels = labels.View(-1); + auto in_labels = labels.View(DeviceOrd::CPU()); linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) { auto tup = linalg::UnravelIndex(i, h_result.Shape()); auto i0 = std::get<0>(tup); @@ -62,9 +62,9 @@ inline void TestMetaInfoStridedData(int32_t device) { ASSERT_EQ(t_margin.Shape().size(), 2); info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)}); - auto const& h_result = info.base_margin_.View(-1); + auto const& h_result = info.base_margin_.View(DeviceOrd::CPU()); ASSERT_EQ(h_result.Shape().size(), 2); - auto in_margin = base_margin.View(-1); + auto in_margin = base_margin.View(DeviceOrd::CPU()); linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { auto tup = linalg::UnravelIndex(i, h_result.Shape()); auto i0 = std::get<0>(tup); diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 43d0877d3..f1d588196 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -298,8 +298,8 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx), out->Info().weights_.HostVector().at(i)); - auto out_margin = out->Info().base_margin_.View(Context::kCpuId); - auto in_margin = margin.View(Context::kCpuId); + auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU()); + auto in_margin = margin.View(DeviceOrd::CPU()); for (size_t j = 0; j < kClasses; ++j) { ASSERT_EQ(out_margin(i, j), in_margin(ridx, j)); } @@ -372,8 +372,8 @@ TEST(SimpleDMatrix, SliceCol) { out->Info().labels_upper_bound_.HostVector().at(i)); ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i)); - auto out_margin = out->Info().base_margin_.View(Context::kCpuId); - auto in_margin = margin.View(Context::kCpuId); + auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU()); + auto in_margin = margin.View(DeviceOrd::CPU()); for (size_t j = 0; j < kClasses; ++j) { ASSERT_EQ(out_margin(i, j), in_margin(i, j)); } diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu index 1c13665fc..c80ec20fc 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cu +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -39,9 +39,9 @@ void TestGPUMakePair() { auto make_args = [&](std::shared_ptr p_cache, auto rank_idx, common::Span y_sorted_idx) { linalg::Vector dummy; - auto d = dummy.View(ctx.gpu_id); + auto d = dummy.View(ctx.Device()); linalg::Vector dgpair; - auto dg = dgpair.View(ctx.gpu_id); + auto dg = dgpair.View(ctx.Device()); cuda_impl::KernelInputs args{ d, d, @@ -50,9 +50,9 @@ void TestGPUMakePair() { p_cache->DataGroupPtr(&ctx), p_cache->CUDAThreadsGroupPtr(), rank_idx, - info.labels.View(ctx.gpu_id), + info.labels.View(ctx.Device()), predt.ConstDeviceSpan(), - linalg::MatrixView{common::Span{}, {0}, 0}, + linalg::MatrixView{common::Span{}, {0}, DeviceOrd::CUDA(0)}, dg, nullptr, y_sorted_idx, diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 3a65e3e06..f31158482 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -226,7 +226,7 @@ TEST(GPUPredictor, ShapStump) { auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); gpu_predictor->PredictContribution(dmat.get(), &predictions, model); auto& phis = predictions.HostVector(); - auto base_score = mparam.BaseScore(Context::kCpuId)(0); + auto base_score = mparam.BaseScore(DeviceOrd::CPU())(0); EXPECT_EQ(phis[0], 0.0); EXPECT_EQ(phis[1], base_score); EXPECT_EQ(phis[2], 0.0); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 993504c57..a9f218c0c 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -287,7 +287,7 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) { predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); - auto score = mparam.BaseScore(Context::kCpuId)(0); + auto score = mparam.BaseScore(DeviceOrd::CPU())(0); ASSERT_EQ(out_predictions.predictions.Size(), 1ul); ASSERT_EQ(out_predictions.predictions.HostVector()[0], right_weight + score); // go to right for matching cat