Reorder if-else statements to allow using of cpu branches for sycl-devices (#9682)
This commit is contained in:
committed by
GitHub
parent
4c0e4422d0
commit
ea9f09716b
@@ -46,7 +46,26 @@ template <typename Fn>
|
||||
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
||||
PackedReduceResult result;
|
||||
auto labels = info.labels.View(ctx->Device());
|
||||
if (ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + labels.Size();
|
||||
result = thrust::transform_reduce(
|
||||
thrust::cuda::par(alloc), begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = linalg::UnravelIndex(i, labels.Shape());
|
||||
auto sample_id = std::get<0>(idx);
|
||||
auto target_id = std::get<1>(idx);
|
||||
auto res = loss(i, sample_id, target_id);
|
||||
float v{std::get<0>(res)}, wt{std::get<1>(res)};
|
||||
return PackedReduceResult{v, wt};
|
||||
},
|
||||
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} else {
|
||||
auto n_threads = ctx->Threads();
|
||||
std::vector<double> score_tloc(n_threads, 0.0);
|
||||
std::vector<double> weight_tloc(n_threads, 0.0);
|
||||
@@ -69,25 +88,6 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
||||
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
|
||||
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
|
||||
result = PackedReduceResult{residue_sum, weights_sum};
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + labels.Size();
|
||||
result = thrust::transform_reduce(
|
||||
thrust::cuda::par(alloc), begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = linalg::UnravelIndex(i, labels.Shape());
|
||||
auto sample_id = std::get<0>(idx);
|
||||
auto target_id = std::get<1>(idx);
|
||||
auto res = loss(i, sample_id, target_id);
|
||||
float v{std::get<0>(res)}, wt{std::get<1>(res)};
|
||||
return PackedReduceResult{v, wt};
|
||||
},
|
||||
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -185,10 +185,10 @@ class PseudoErrorLoss : public MetricNoCache {
|
||||
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
||||
auto labels = info.labels.View(ctx_->Device());
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan());
|
||||
float slope = this->param_.huber_slope;
|
||||
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
||||
PackedReduceResult result =
|
||||
@@ -351,10 +351,10 @@ struct EvalEWiseBase : public MetricNoCache {
|
||||
}
|
||||
auto labels = info.labels.View(ctx_->Device());
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan());
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
|
||||
|
||||
auto d_policy = policy_;
|
||||
auto result =
|
||||
|
||||
Reference in New Issue
Block a user