[sycl] Reorder if-else statements to allow using of cpu branches for sycl-devices (#10543)

* reoder if-else statements for sycl compatibility

* trigger check

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin 2024-07-05 10:31:48 +02:00 committed by GitHub
parent 620b2b155a
commit 513d7a7d84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 56 additions and 56 deletions

View File

@ -664,13 +664,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
template <typename T> template <typename T>
auto MakeVec(HostDeviceVector<T> *data) { auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(), return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(),
data->Device()); data->Size(), data->Device());
} }
template <typename T> template <typename T>
auto MakeVec(HostDeviceVector<T> const *data) { auto MakeVec(HostDeviceVector<T> const *data) {
return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(), return MakeVec(data->Device().IsCUDA() ? data->ConstDevicePointer() : data->ConstHostPointer(),
data->Size(), data->Device()); data->Size(), data->Device());
} }

View File

@ -179,14 +179,14 @@ class ColumnSampler {
feature_set_tree_->SetDevice(ctx->Device()); feature_set_tree_->SetDevice(ctx->Device());
feature_set_tree_->Resize(num_col); feature_set_tree_->Resize(num_col);
if (ctx->IsCPU()) { if (ctx->IsCUDA()) {
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
} else {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
cuda_impl::InitFeatureSet(ctx, feature_set_tree_); cuda_impl::InitFeatureSet(ctx, feature_set_tree_);
#else #else
AssertGPUSupport(); AssertGPUSupport();
#endif #endif
} else {
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
} }
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_); feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);

View File

@ -18,7 +18,7 @@
namespace xgboost::common { namespace xgboost::common {
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t, void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) { HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
if (!ctx->IsCPU()) { if (ctx->IsCUDA()) {
weights.SetDevice(ctx->Device()); weights.SetDevice(ctx->Device());
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan()); auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
auto t_v = t.View(ctx->Device()); auto t_v = t.View(ctx->Device());

View File

@ -45,17 +45,17 @@ struct EllpackDeviceAccessor {
n_rows(n_rows), n_rows(n_rows),
gidx_iter(gidx_iter), gidx_iter(gidx_iter),
feature_types{feature_types} { feature_types{feature_types} {
if (device.IsCPU()) { if (device.IsCUDA()) {
gidx_fvalue_map = cuts->cut_values_.ConstHostSpan();
feature_segments = cuts->cut_ptrs_.ConstHostSpan();
min_fvalue = cuts->min_vals_.ConstHostSpan();
} else {
cuts->cut_values_.SetDevice(device); cuts->cut_values_.SetDevice(device);
cuts->cut_ptrs_.SetDevice(device); cuts->cut_ptrs_.SetDevice(device);
cuts->min_vals_.SetDevice(device); cuts->min_vals_.SetDevice(device);
gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan(); gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan();
feature_segments = cuts->cut_ptrs_.ConstDeviceSpan(); feature_segments = cuts->cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts->min_vals_.ConstDeviceSpan(); min_fvalue = cuts->min_vals_.ConstDeviceSpan();
} else {
gidx_fvalue_map = cuts->cut_values_.ConstHostSpan();
feature_segments = cuts->cut_ptrs_.ConstHostSpan();
min_fvalue = cuts->min_vals_.ConstHostSpan();
} }
} }
// Get a matrix element, uses binary search for look up Return NaN if missing // Get a matrix element, uses binary search for look up Return NaN if missing

View File

@ -41,10 +41,10 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
// hardcoded parameter. // hardcoded parameter.
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()}; BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
if (ctx.IsCPU()) { if (ctx.IsCUDA()) {
this->InitFromCPU(&ctx, p, iter_handle, missing, ref);
} else {
this->InitFromCUDA(&ctx, p, iter_handle, missing, ref); this->InitFromCUDA(&ctx, p, iter_handle, missing, ref);
} else {
this->InitFromCPU(&ctx, p, iter_handle, missing, ref);
} }
this->fmat_ctx_ = ctx; this->fmat_ctx_ = ctx;
@ -73,10 +73,10 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) { if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
// Both exists // Both exists
if (ctx->IsCPU()) { if (ctx->IsCUDA()) {
csr();
} else {
ellpack(); ellpack();
} else {
csr();
} }
} else if (ref->PageExists<GHistIndexMatrix>()) { } else if (ref->PageExists<GHistIndexMatrix>()) {
csr(); csr();
@ -84,10 +84,10 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
ellpack(); ellpack();
} else { } else {
// None exist // None exist
if (ctx->IsCPU()) { if (ctx->IsCUDA()) {
csr();
} else {
ellpack(); ellpack();
} else {
csr();
} }
} }
CHECK_EQ(ref->Info().num_col_, n_features) CHECK_EQ(ref->Info().num_col_, n_features)
@ -297,9 +297,9 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(Context const* ctx
} }
if (!ghist_) { if (!ghist_) {
if (ctx->IsCPU()) { if (!ctx->IsCUDA()) {
ghist_ = std::make_shared<GHistIndexMatrix>(ctx, Info(), *ellpack_, param); ghist_ = std::make_shared<GHistIndexMatrix>(ctx, Info(), *ellpack_, param);
} else if (fmat_ctx_.IsCPU()) { } else if (!fmat_ctx_.IsCUDA()) {
ghist_ = std::make_shared<GHistIndexMatrix>(&fmat_ctx_, Info(), *ellpack_, param); ghist_ = std::make_shared<GHistIndexMatrix>(&fmat_ctx_, Info(), *ellpack_, param);
} else { } else {
// Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM // Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM

View File

@ -46,7 +46,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
int32_t current_device; int32_t current_device;
dh::safe_cuda(cudaGetDevice(&current_device)); dh::safe_cuda(cudaGetDevice(&current_device));
auto get_device = [&]() { auto get_device = [&]() {
auto d = (ctx->IsCPU()) ? DeviceOrd::CUDA(current_device) : ctx->Device(); auto d = (ctx->IsCUDA()) ? ctx->Device() : DeviceOrd::CUDA(current_device);
CHECK(!d.IsCPU()); CHECK(!d.IsCPU());
return d; return d;
}; };

View File

@ -56,7 +56,9 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
float missing) { float missing) {
bool type_error{false}; bool type_error{false};
std::shared_ptr<DMatrix> p_fmat{nullptr}; std::shared_ptr<DMatrix> p_fmat{nullptr};
if (proxy->Ctx()->IsCPU()) { if (proxy->Ctx()->IsCUDA()) {
p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing);
} else {
p_fmat = data::HostAdapterDispatch<false>( p_fmat = data::HostAdapterDispatch<false>(
proxy.get(), proxy.get(),
[&](auto const &adapter) { [&](auto const &adapter) {
@ -65,8 +67,6 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
return p_fmat; return p_fmat;
}, },
&type_error); &type_error);
} else {
p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing);
} }
CHECK(p_fmat) << "Failed to fallback."; CHECK(p_fmat) << "Failed to fallback.";

View File

@ -11,7 +11,7 @@ void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
this->batch_ = adapter; this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns(); this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows(); this->Info().num_row_ = adapter->NumRows();
if (adapter->Device().IsCPU()) { if (!adapter->Device().IsCUDA()) {
// empty data // empty data
CHECK_EQ(this->Info().num_row_, 0); CHECK_EQ(this->Info().num_row_, 0);
ctx_ = ctx_.MakeCUDA(dh::CurrentDevice()); ctx_ = ctx_.MakeCUDA(dh::CurrentDevice());
@ -25,7 +25,7 @@ void DMatrixProxy::FromCudaArray(StringView interface_str) {
this->batch_ = adapter; this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns(); this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows(); this->Info().num_row_ = adapter->NumRows();
if (adapter->Device().IsCPU()) { if (!adapter->Device().IsCUDA()) {
// empty data // empty data
CHECK_EQ(this->Info().num_row_, 0); CHECK_EQ(this->Info().num_row_, 0);
ctx_ = ctx_.MakeCUDA(dh::CurrentDevice()); ctx_ = ctx_.MakeCUDA(dh::CurrentDevice());

View File

@ -185,12 +185,12 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(Context const* ctx,
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
// Used only by approx. // Used only by approx.
auto sorted_sketch = param.regen; auto sorted_sketch = param.regen;
if (ctx->IsCPU()) { if (!ctx->IsCUDA()) {
// The context passed in is on CPU, we pick it first since we prioritize the context // The context passed in is on CPU, we pick it first since we prioritize the context
// in Booster. // in Booster.
gradient_index_.reset(new GHistIndexMatrix{ctx, this, param.max_bin, param.sparse_thresh, gradient_index_.reset(new GHistIndexMatrix{ctx, this, param.max_bin, param.sparse_thresh,
sorted_sketch, param.hess}); sorted_sketch, param.hess});
} else if (fmat_ctx_.IsCPU()) { } else if (!fmat_ctx_.IsCUDA()) {
// DMatrix was initialized on CPU, we use the context from initialization. // DMatrix was initialized on CPU, we use the context from initialization.
gradient_index_.reset(new GHistIndexMatrix{&fmat_ctx_, this, param.max_bin, gradient_index_.reset(new GHistIndexMatrix{&fmat_ctx_, this, param.max_bin,
param.sparse_thresh, sorted_sketch, param.hess}); param.sparse_thresh, sorted_sketch, param.hess});

View File

@ -19,7 +19,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
DataSplitMode data_split_mode) { DataSplitMode data_split_mode) {
CHECK(data_split_mode != DataSplitMode::kCol) CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is currently not supported on the GPU."; << "Column-wise data split is currently not supported on the GPU.";
auto device = (adapter->Device().IsCPU() || adapter->NumRows() == 0) auto device = (!adapter->Device().IsCUDA() || adapter->NumRows() == 0)
? DeviceOrd::CUDA(dh::CurrentDevice()) ? DeviceOrd::CUDA(dh::CurrentDevice())
: adapter->Device(); : adapter->Device();
CHECK(device.IsCUDA()); CHECK(device.IsCUDA());

View File

@ -20,7 +20,7 @@ std::size_t NFeaturesDevice(DMatrixProxy *proxy) {
void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) { void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
auto device = proxy->Device(); auto device = proxy->Device();
if (device.IsCPU()) { if (!device.IsCUDA()) {
device = DeviceOrd::CUDA(dh::CurrentDevice()); device = DeviceOrd::CUDA(dh::CurrentDevice());
} }
CHECK(device.IsCUDA()); CHECK(device.IsCUDA());

View File

@ -336,12 +336,12 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
double auc{0}; double auc{0};
uint32_t valid_groups = 0; uint32_t valid_groups = 0;
auto n_threads = ctx_->Threads(); auto n_threads = ctx_->Threads();
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
std::tie(auc, valid_groups) =
RankingAUC<true>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) = std::tie(auc, valid_groups) =
GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} else {
std::tie(auc, valid_groups) =
RankingAUC<true>(ctx_, predts.ConstHostVector(), info, n_threads);
} }
return std::make_pair(auc, valid_groups); return std::make_pair(auc, valid_groups);
} }
@ -351,10 +351,10 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
double auc{0}; double auc{0};
auto n_threads = ctx_->Threads(); auto n_threads = ctx_->Threads();
CHECK_NE(n_classes, 0); CHECK_NE(n_classes, 0);
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC);
} else {
auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes); auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes);
} else {
auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC);
} }
return auc; return auc;
} }
@ -362,13 +362,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
std::tuple<double, double, double> std::tuple<double, double, double>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) { EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
double fp, tp, auc; double fp, tp, auc;
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
std::tie(fp, tp, auc) =
GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} else {
std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(), std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(),
info.labels.HostView().Slice(linalg::All(), 0), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()}); common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) =
GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} }
return std::make_tuple(fp, tp, auc); return std::make_tuple(fp, tp, auc);
} }
@ -413,23 +413,23 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
std::tuple<double, double, double> std::tuple<double, double, double>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) { EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
double pr, re, auc; double pr, re, auc;
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} else {
std::tie(pr, re, auc) = std::tie(pr, re, auc) =
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()}); common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} }
return std::make_tuple(pr, re, auc); return std::make_tuple(pr, re, auc);
} }
double EvalMultiClass(HostDeviceVector<float> const &predts, MetaInfo const &info, double EvalMultiClass(HostDeviceVector<float> const &predts, MetaInfo const &info,
size_t n_classes) { size_t n_classes) {
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
} else {
auto n_threads = this->ctx_->Threads(); auto n_threads = this->ctx_->Threads();
return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC); return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC);
} else {
return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
} }
} }
@ -438,16 +438,16 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
double auc{0}; double auc{0};
uint32_t valid_groups = 0; uint32_t valid_groups = 0;
auto n_threads = ctx_->Threads(); auto n_threads = ctx_->Threads();
if (ctx_->IsCPU()) { if (ctx_->IsCUDA()) {
std::tie(auc, valid_groups) =
GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);
} else {
auto labels = info.labels.Data()->ConstHostSpan(); auto labels = info.labels.Data()->ConstHostSpan();
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) { if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
InvalidLabels(); InvalidLabels();
} }
std::tie(auc, valid_groups) = std::tie(auc, valid_groups) =
RankingAUC<false>(ctx_, predts.ConstHostVector(), info, n_threads); RankingAUC<false>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) =
GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);
} }
return std::make_pair(auc, valid_groups); return std::make_pair(auc, valid_groups);
} }

View File

@ -131,7 +131,7 @@ class MultiClassMetricsReduction {
const HostDeviceVector<bst_float>& preds) { const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result; PackedReduceResult result;
if (device.IsCPU()) { if (!device.IsCUDA()) {
result = result =
CpuReduceMetrics(weights, labels, preds, n_class, ctx.Threads()); CpuReduceMetrics(weights, labels, preds, n_class, ctx.Threads());
} }

View File

@ -127,7 +127,7 @@ class ElementWiseSurvivalMetricsReduction {
const HostDeviceVector<bst_float>& preds) { const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result; PackedReduceResult result;
if (ctx.IsCPU()) { if (!ctx.IsCUDA()) {
result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound, result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound,
preds, ctx.Threads()); preds, ctx.Threads());
} }