Reorder if-else statements to allow using of cpu branches for sycl-devices (#9682)

This commit is contained in:
Dmitry Razdoburdin 2023-10-18 04:55:33 +02:00 committed by GitHub
parent 4c0e4422d0
commit ea9f09716b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 99 additions and 98 deletions

View File

@ -603,13 +603,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&.
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan();
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan();
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}

View File

@ -42,7 +42,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn);
}
} // namespace linalg
} // namespace xgboost

View File

@ -55,7 +55,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr)
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
if (!ctx->IsCPU()) {
if (ctx->IsCUDA()) {
common::AssertGPUSupport();
}
ElementWiseKernelHost(t, ctx->Threads(), fn);

View File

@ -11,13 +11,14 @@
namespace xgboost {
namespace common {
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
if (ctx->IsCPU()) {
if (ctx->IsCUDA()) {
return cuda_impl::Reduce(ctx, values);
} else {
auto const& h_values = values.ConstHostVector();
auto result = cpu_impl::Reduce(ctx, h_values.cbegin(), h_values.cend(), 0.0);
static_assert(std::is_same<decltype(result), double>::value);
return result;
}
return cuda_impl::Reduce(ctx, values);
}
} // namespace common
} // namespace xgboost

View File

@ -26,7 +26,7 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx,
if (ctx->IsCUDA()) {
weights.SetDevice(ctx->Device());
}
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
return OptionalWeights{ctx->IsCUDA() ? weights.ConstDeviceSpan() : weights.ConstHostSpan()};
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_

View File

@ -197,10 +197,10 @@ class RankingCache {
CHECK_EQ(info.group_ptr_.back(), info.labels.Size())
<< error::GroupSize() << "the size of label.";
}
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
if (ctx->IsCUDA()) {
this->InitOnCUDA(ctx, info);
} else {
this->InitOnCPU(ctx, info);
}
if (!info.weights_.Empty()) {
CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight();
@ -218,7 +218,7 @@ class RankingCache {
// Constructed as [1, n_samples] if group ptr is not supplied by the user
common::Span<bst_group_t const> DataGroupPtr(Context const* ctx) const {
group_ptr_.SetDevice(ctx->Device());
return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
return ctx->IsCUDA() ? group_ptr_.ConstDeviceSpan() : group_ptr_.ConstHostSpan();
}
[[nodiscard]] auto const& Param() const { return param_; }
@ -231,10 +231,10 @@ class RankingCache {
sorted_idx_cache_.SetDevice(ctx->Device());
sorted_idx_cache_.Resize(predt.size());
}
if (ctx->IsCPU()) {
return this->MakeRankOnCPU(ctx, predt);
} else {
if (ctx->IsCUDA()) {
return this->MakeRankOnCUDA(ctx, predt);
} else {
return this->MakeRankOnCPU(ctx, predt);
}
}
// The function simply returns a uninitialized buffer as this is only used by the
@ -307,10 +307,10 @@ class NDCGCache : public RankingCache {
public:
NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
if (ctx->IsCUDA()) {
this->InitOnCUDA(ctx, info);
} else {
this->InitOnCPU(ctx, info);
}
}
@ -318,7 +318,7 @@ class NDCGCache : public RankingCache {
return inv_idcg_.View(ctx->Device());
}
common::Span<double const> Discount(Context const* ctx) const {
return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
return ctx->IsCUDA() ? discounts_.ConstDeviceSpan() : discounts_.ConstHostSpan();
}
linalg::VectorView<double> Dcg(Context const* ctx) {
if (dcg_.Size() == 0) {
@ -387,10 +387,10 @@ class PreCache : public RankingCache {
public:
PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
if (ctx->IsCUDA()) {
this->InitOnCUDA(ctx, info);
} else {
this->InitOnCPU(ctx, info);
}
}
@ -399,7 +399,7 @@ class PreCache : public RankingCache {
pre_.SetDevice(ctx->Device());
pre_.Resize(this->Groups());
}
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
return ctx->IsCUDA() ? pre_.DeviceSpan() : pre_.HostSpan();
}
};
@ -418,10 +418,10 @@ class MAPCache : public RankingCache {
public:
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
if (ctx->IsCUDA()) {
this->InitOnCUDA(ctx, info);
} else {
this->InitOnCPU(ctx, info);
}
}
@ -430,21 +430,21 @@ class MAPCache : public RankingCache {
n_rel_.SetDevice(ctx->Device());
n_rel_.Resize(n_samples_);
}
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
return ctx->IsCUDA() ? n_rel_.DeviceSpan() : n_rel_.HostSpan();
}
common::Span<double> Acc(Context const* ctx) {
if (acc_.Empty()) {
acc_.SetDevice(ctx->Device());
acc_.Resize(n_samples_);
}
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
return ctx->IsCUDA() ? acc_.DeviceSpan() : acc_.HostSpan();
}
common::Span<double> Map(Context const* ctx) {
if (map_.Empty()) {
map_.SetDevice(ctx->Device());
map_.Resize(this->Groups());
}
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
return ctx->IsCUDA() ? map_.DeviceSpan() : map_.HostSpan();
}
};

View File

@ -49,7 +49,9 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
out->SetDevice(ctx->Device());
out->Reshape(1);
if (ctx->IsCPU()) {
if (ctx->IsCUDA()) {
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
} else {
auto h_v = v.HostView();
float n = v.Size();
MemStackAllocator<float, DefaultMaxThreads()> tloc(ctx->Threads(), 0.0f);
@ -57,8 +59,6 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
[&](auto i) { tloc[omp_get_thread_num()] += h_v(i) / n; });
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
out->HostView()(0) = ret;
} else {
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
}
}
} // namespace xgboost::common

View File

@ -278,7 +278,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
std::swap(base_score_, base_margin);
// Make sure read access everywhere for thread-safe prediction.
std::as_const(base_score_).HostView();
if (!ctx->IsCPU()) {
if (ctx->IsCUDA()) {
std::as_const(base_score_).View(ctx->Device());
}
CHECK(std::as_const(base_score_).Data()->HostCanRead());
@ -287,7 +287,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(DeviceOrd device) const {
// multi-class is not yet supported.
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
if (device.IsCPU()) {
if (!device.IsCUDA()) {
// Make sure that we won't run into race condition.
CHECK(base_score_.Data()->HostCanRead());
return base_score_.HostView();

View File

@ -46,7 +46,26 @@ template <typename Fn>
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
PackedReduceResult result;
auto labels = info.labels.View(ctx->Device());
if (ctx->IsCPU()) {
if (ctx->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce(
thrust::cuda::par(alloc), begin, end,
[=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx);
auto target_id = std::get<1>(idx);
auto res = loss(i, sample_id, target_id);
float v{std::get<0>(res)}, wt{std::get<1>(res)};
return PackedReduceResult{v, wt};
},
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
} else {
auto n_threads = ctx->Threads();
std::vector<double> score_tloc(n_threads, 0.0);
std::vector<double> weight_tloc(n_threads, 0.0);
@ -69,25 +88,6 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
result = PackedReduceResult{residue_sum, weights_sum};
} else {
#if defined(XGBOOST_USE_CUDA)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce(
thrust::cuda::par(alloc), begin, end,
[=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx);
auto target_id = std::get<1>(idx);
auto res = loss(i, sample_id, target_id);
float v{std::get<0>(res)}, wt{std::get<1>(res)};
return PackedReduceResult{v, wt};
},
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
return result;
}
@ -185,10 +185,10 @@ class PseudoErrorLoss : public MetricNoCache {
CHECK_EQ(info.labels.Shape(0), info.num_row_);
auto labels = info.labels.View(ctx_->Device());
preds.SetDevice(ctx_->Device());
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan());
float slope = this->param_.huber_slope;
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
PackedReduceResult result =
@ -351,10 +351,10 @@ struct EvalEWiseBase : public MetricNoCache {
}
auto labels = info.labels.View(ctx_->Device());
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan());
preds.SetDevice(ctx_->Device());
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
auto d_policy = policy_;
auto result =

View File

@ -96,13 +96,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
if (ctx->IsCPU()) {
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
predt, alpha, p_tree);
} else {
if (ctx->IsCUDA()) {
position.SetDevice(ctx->Device());
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
predt, alpha, p_tree);
} else {
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
predt, alpha, p_tree);
}
}
} // namespace obj

View File

@ -108,14 +108,14 @@ class LambdaRankObj : public FitIntercept {
li_.SetDevice(ctx_->Device());
lj_.SetDevice(ctx_->Device());
if (ctx_->IsCPU()) {
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
} else {
if (ctx_->IsCUDA()) {
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
} else {
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
}
li_full_.Data()->Fill(0.0);

View File

@ -71,15 +71,15 @@ class QuantileRegression : public ObjFunction {
auto gpair = out_gpair->View(ctx_->Device());
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto n_samples = info.num_row_;
alpha_.SetDevice(ctx_->Device());
auto alpha = ctx_->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();
linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
@ -107,27 +107,7 @@ class QuantileRegression : public ObjFunction {
base_score->Reshape(n_targets);
double sw{0};
if (ctx_->IsCPU()) {
auto quantiles = base_score->HostView();
auto h_weights = info.weights_.ConstHostVector();
if (info.weights_.Empty()) {
sw = info.num_row_;
} else {
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
}
for (bst_target_t t{0}; t < n_targets; ++t) {
auto alpha = param_.quantile_alpha[t];
auto h_labels = info.labels.HostView();
if (h_weights.empty()) {
quantiles(t) =
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
} else {
CHECK_EQ(h_weights.size(), h_labels.Size());
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
linalg::cend(h_labels), std::cbegin(h_weights));
}
}
} else {
if (ctx_->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
alpha_.SetDevice(ctx_->Device());
auto d_alpha = alpha_.ConstDeviceSpan();
@ -164,6 +144,26 @@ class QuantileRegression : public ObjFunction {
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
} else {
auto quantiles = base_score->HostView();
auto h_weights = info.weights_.ConstHostVector();
if (info.weights_.Empty()) {
sw = info.num_row_;
} else {
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
}
for (bst_target_t t{0}; t < n_targets; ++t) {
auto alpha = param_.quantile_alpha[t];
auto h_labels = info.labels.HostView();
if (h_weights.empty()) {
quantiles(t) =
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
} else {
CHECK_EQ(h_weights.size(), h_labels.Size());
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
linalg::cend(h_labels), std::cbegin(h_weights));
}
}
}
// For multiple quantiles, we should extend the base score to a vector instead of

View File

@ -254,8 +254,8 @@ class PseudoHuberRegression : public FitIntercept {
auto predt = linalg::MakeVec(&preds);
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
@ -714,8 +714,8 @@ class MeanAbsoluteError : public ObjFunction {
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
auto sign = [](auto x) {

View File

@ -72,7 +72,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
ctx->IsCUDA() ? cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()))
: cpu_impl::FitStump(ctx, info, gpair_t, out->HostView());
}
} // namespace xgboost::tree