diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index cb7668f4c..553486dac 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -664,13 +664,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) { template auto MakeVec(HostDeviceVector *data) { - return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(), - data->Device()); + return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(), + data->Size(), data->Device()); } template auto MakeVec(HostDeviceVector const *data) { - return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(), + return MakeVec(data->Device().IsCUDA() ? data->ConstDevicePointer() : data->ConstHostPointer(), data->Size(), data->Device()); } diff --git a/src/common/random.h b/src/common/random.h index 6d7a1bb49..3aed3384a 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -179,14 +179,14 @@ class ColumnSampler { feature_set_tree_->SetDevice(ctx->Device()); feature_set_tree_->Resize(num_col); - if (ctx->IsCPU()) { - std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0); - } else { + if (ctx->IsCUDA()) { #if defined(XGBOOST_USE_CUDA) cuda_impl::InitFeatureSet(ctx, feature_set_tree_); #else AssertGPUSupport(); #endif + } else { + std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0); } feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_); diff --git a/src/common/stats.cc b/src/common/stats.cc index bbf969fcc..72c917bed 100644 --- a/src/common/stats.cc +++ b/src/common/stats.cc @@ -18,7 +18,7 @@ namespace xgboost::common { void Median(Context const* ctx, linalg::Tensor const& t, HostDeviceVector const& weights, linalg::Tensor* out) { - if (!ctx->IsCPU()) { + if (ctx->IsCUDA()) { weights.SetDevice(ctx->Device()); auto opt_weights = OptionalWeights(weights.ConstDeviceSpan()); auto t_v = t.View(ctx->Device()); diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 04960458f..d1f9472df 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -45,17 +45,17 @@ struct EllpackDeviceAccessor { n_rows(n_rows), gidx_iter(gidx_iter), feature_types{feature_types} { - if (device.IsCPU()) { - gidx_fvalue_map = cuts->cut_values_.ConstHostSpan(); - feature_segments = cuts->cut_ptrs_.ConstHostSpan(); - min_fvalue = cuts->min_vals_.ConstHostSpan(); - } else { + if (device.IsCUDA()) { cuts->cut_values_.SetDevice(device); cuts->cut_ptrs_.SetDevice(device); cuts->min_vals_.SetDevice(device); gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan(); feature_segments = cuts->cut_ptrs_.ConstDeviceSpan(); min_fvalue = cuts->min_vals_.ConstDeviceSpan(); + } else { + gidx_fvalue_map = cuts->cut_values_.ConstHostSpan(); + feature_segments = cuts->cut_ptrs_.ConstHostSpan(); + min_fvalue = cuts->min_vals_.ConstHostSpan(); } } // Get a matrix element, uses binary search for look up Return NaN if missing diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index e581e90ca..368aeb2ac 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -41,10 +41,10 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro // hardcoded parameter. BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()}; - if (ctx.IsCPU()) { - this->InitFromCPU(&ctx, p, iter_handle, missing, ref); - } else { + if (ctx.IsCUDA()) { this->InitFromCUDA(&ctx, p, iter_handle, missing, ref); + } else { + this->InitFromCPU(&ctx, p, iter_handle, missing, ref); } this->fmat_ctx_ = ctx; @@ -73,10 +73,10 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr ref, bst_featur if (ref->PageExists() && ref->PageExists()) { // Both exists - if (ctx->IsCPU()) { - csr(); - } else { + if (ctx->IsCUDA()) { ellpack(); + } else { + csr(); } } else if (ref->PageExists()) { csr(); @@ -84,10 +84,10 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr ref, bst_featur ellpack(); } else { // None exist - if (ctx->IsCPU()) { - csr(); - } else { + if (ctx->IsCUDA()) { ellpack(); + } else { + csr(); } } CHECK_EQ(ref->Info().num_col_, n_features) @@ -297,9 +297,9 @@ BatchSet IterativeDMatrix::GetGradientIndex(Context const* ctx } if (!ghist_) { - if (ctx->IsCPU()) { + if (!ctx->IsCUDA()) { ghist_ = std::make_shared(ctx, Info(), *ellpack_, param); - } else if (fmat_ctx_.IsCPU()) { + } else if (!fmat_ctx_.IsCUDA()) { ghist_ = std::make_shared(&fmat_ctx_, Info(), *ellpack_, param); } else { // Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 868875bf7..2e8da2c7e 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -46,7 +46,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, int32_t current_device; dh::safe_cuda(cudaGetDevice(¤t_device)); auto get_device = [&]() { - auto d = (ctx->IsCPU()) ? DeviceOrd::CUDA(current_device) : ctx->Device(); + auto d = (ctx->IsCUDA()) ? ctx->Device() : DeviceOrd::CUDA(current_device); CHECK(!d.IsCPU()); return d; }; diff --git a/src/data/proxy_dmatrix.cc b/src/data/proxy_dmatrix.cc index a28448e3b..bcefb4999 100644 --- a/src/data/proxy_dmatrix.cc +++ b/src/data/proxy_dmatrix.cc @@ -56,7 +56,9 @@ std::shared_ptr CreateDMatrixFromProxy(Context const *ctx, float missing) { bool type_error{false}; std::shared_ptr p_fmat{nullptr}; - if (proxy->Ctx()->IsCPU()) { + if (proxy->Ctx()->IsCUDA()) { + p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing); + } else { p_fmat = data::HostAdapterDispatch( proxy.get(), [&](auto const &adapter) { @@ -65,8 +67,6 @@ std::shared_ptr CreateDMatrixFromProxy(Context const *ctx, return p_fmat; }, &type_error); - } else { - p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing); } CHECK(p_fmat) << "Failed to fallback."; diff --git a/src/data/proxy_dmatrix.cu b/src/data/proxy_dmatrix.cu index cd76e49cf..fb484f5e3 100644 --- a/src/data/proxy_dmatrix.cu +++ b/src/data/proxy_dmatrix.cu @@ -11,7 +11,7 @@ void DMatrixProxy::FromCudaColumnar(StringView interface_str) { this->batch_ = adapter; this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); - if (adapter->Device().IsCPU()) { + if (!adapter->Device().IsCUDA()) { // empty data CHECK_EQ(this->Info().num_row_, 0); ctx_ = ctx_.MakeCUDA(dh::CurrentDevice()); @@ -25,7 +25,7 @@ void DMatrixProxy::FromCudaArray(StringView interface_str) { this->batch_ = adapter; this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); - if (adapter->Device().IsCPU()) { + if (!adapter->Device().IsCUDA()) { // empty data CHECK_EQ(this->Info().num_row_, 0); ctx_ = ctx_.MakeCUDA(dh::CurrentDevice()); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index f54d1c43e..e4b82b7de 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -185,12 +185,12 @@ BatchSet SimpleDMatrix::GetGradientIndex(Context const* ctx, CHECK_GE(param.max_bin, 2); // Used only by approx. auto sorted_sketch = param.regen; - if (ctx->IsCPU()) { + if (!ctx->IsCUDA()) { // The context passed in is on CPU, we pick it first since we prioritize the context // in Booster. gradient_index_.reset(new GHistIndexMatrix{ctx, this, param.max_bin, param.sparse_thresh, sorted_sketch, param.hess}); - } else if (fmat_ctx_.IsCPU()) { + } else if (!fmat_ctx_.IsCUDA()) { // DMatrix was initialized on CPU, we use the context from initialization. gradient_index_.reset(new GHistIndexMatrix{&fmat_ctx_, this, param.max_bin, param.sparse_thresh, sorted_sketch, param.hess}); diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index e5b4d18f7..c177784a3 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -19,7 +19,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr DataSplitMode data_split_mode) { CHECK(data_split_mode != DataSplitMode::kCol) << "Column-wise data split is currently not supported on the GPU."; - auto device = (adapter->Device().IsCPU() || adapter->NumRows() == 0) + auto device = (!adapter->Device().IsCUDA() || adapter->NumRows() == 0) ? DeviceOrd::CUDA(dh::CurrentDevice()) : adapter->Device(); CHECK(device.IsCUDA()); diff --git a/src/data/sparse_page_source.cu b/src/data/sparse_page_source.cu index 40037eedc..99032aeaa 100644 --- a/src/data/sparse_page_source.cu +++ b/src/data/sparse_page_source.cu @@ -20,7 +20,7 @@ std::size_t NFeaturesDevice(DMatrixProxy *proxy) { void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) { auto device = proxy->Device(); - if (device.IsCPU()) { + if (!device.IsCUDA()) { device = DeviceOrd::CUDA(dh::CurrentDevice()); } CHECK(device.IsCUDA()); diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 6de0d1f12..fcb774a4a 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -336,12 +336,12 @@ class EvalROCAUC : public EvalAUC { double auc{0}; uint32_t valid_groups = 0; auto n_threads = ctx_->Threads(); - if (ctx_->IsCPU()) { - std::tie(auc, valid_groups) = - RankingAUC(ctx_, predts.ConstHostVector(), info, n_threads); - } else { + if (ctx_->IsCUDA()) { std::tie(auc, valid_groups) = GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); + } else { + std::tie(auc, valid_groups) = + RankingAUC(ctx_, predts.ConstHostVector(), info, n_threads); } return std::make_pair(auc, valid_groups); } @@ -351,10 +351,10 @@ class EvalROCAUC : public EvalAUC { double auc{0}; auto n_threads = ctx_->Threads(); CHECK_NE(n_classes, 0); - if (ctx_->IsCPU()) { - auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC); - } else { + if (ctx_->IsCUDA()) { auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes); + } else { + auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC); } return auc; } @@ -362,13 +362,13 @@ class EvalROCAUC : public EvalAUC { std::tuple EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { double fp, tp, auc; - if (ctx_->IsCPU()) { + if (ctx_->IsCUDA()) { + std::tie(fp, tp, auc) = + GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); + } else { std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0), common::OptionalWeights{info.weights_.ConstHostSpan()}); - } else { - std::tie(fp, tp, auc) = - GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); } return std::make_tuple(fp, tp, auc); } @@ -413,23 +413,23 @@ class EvalPRAUC : public EvalAUC { std::tuple EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { double pr, re, auc; - if (ctx_->IsCPU()) { + if (ctx_->IsCUDA()) { + std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); + } else { std::tie(pr, re, auc) = BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), common::OptionalWeights{info.weights_.ConstHostSpan()}); - } else { - std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); } return std::make_tuple(pr, re, auc); } double EvalMultiClass(HostDeviceVector const &predts, MetaInfo const &info, size_t n_classes) { - if (ctx_->IsCPU()) { + if (ctx_->IsCUDA()) { + return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes); + } else { auto n_threads = this->ctx_->Threads(); return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC); - } else { - return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes); } } @@ -438,16 +438,16 @@ class EvalPRAUC : public EvalAUC { double auc{0}; uint32_t valid_groups = 0; auto n_threads = ctx_->Threads(); - if (ctx_->IsCPU()) { + if (ctx_->IsCUDA()) { + std::tie(auc, valid_groups) = + GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_); + } else { auto labels = info.labels.Data()->ConstHostSpan(); if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) { InvalidLabels(); } std::tie(auc, valid_groups) = RankingAUC(ctx_, predts.ConstHostVector(), info, n_threads); - } else { - std::tie(auc, valid_groups) = - GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_); } return std::make_pair(auc, valid_groups); } diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index e51509fc7..70738fdf0 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -131,7 +131,7 @@ class MultiClassMetricsReduction { const HostDeviceVector& preds) { PackedReduceResult result; - if (device.IsCPU()) { + if (!device.IsCUDA()) { result = CpuReduceMetrics(weights, labels, preds, n_class, ctx.Threads()); } diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index 9c57be3ab..d8ef7eb95 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -127,7 +127,7 @@ class ElementWiseSurvivalMetricsReduction { const HostDeviceVector& preds) { PackedReduceResult result; - if (ctx.IsCPU()) { + if (!ctx.IsCUDA()) { result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds, ctx.Threads()); }