Reorder if-else statements to allow using of cpu branches for sycl-devices (#9682)
This commit is contained in:
parent
4c0e4422d0
commit
ea9f09716b
@ -603,13 +603,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&.
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
||||
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
||||
auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan();
|
||||
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
|
||||
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
|
||||
auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan();
|
||||
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
||||
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
|
||||
ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
|
||||
@ -55,7 +55,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr)
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
||||
if (!ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
|
||||
@ -11,13 +11,14 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
|
||||
if (ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
return cuda_impl::Reduce(ctx, values);
|
||||
} else {
|
||||
auto const& h_values = values.ConstHostVector();
|
||||
auto result = cpu_impl::Reduce(ctx, h_values.cbegin(), h_values.cend(), 0.0);
|
||||
static_assert(std::is_same<decltype(result), double>::value);
|
||||
return result;
|
||||
}
|
||||
return cuda_impl::Reduce(ctx, values);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -26,7 +26,7 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx,
|
||||
if (ctx->IsCUDA()) {
|
||||
weights.SetDevice(ctx->Device());
|
||||
}
|
||||
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
|
||||
return OptionalWeights{ctx->IsCUDA() ? weights.ConstDeviceSpan() : weights.ConstHostSpan()};
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
|
||||
|
||||
@ -197,10 +197,10 @@ class RankingCache {
|
||||
CHECK_EQ(info.group_ptr_.back(), info.labels.Size())
|
||||
<< error::GroupSize() << "the size of label.";
|
||||
}
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
} else {
|
||||
this->InitOnCPU(ctx, info);
|
||||
}
|
||||
if (!info.weights_.Empty()) {
|
||||
CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight();
|
||||
@ -218,7 +218,7 @@ class RankingCache {
|
||||
// Constructed as [1, n_samples] if group ptr is not supplied by the user
|
||||
common::Span<bst_group_t const> DataGroupPtr(Context const* ctx) const {
|
||||
group_ptr_.SetDevice(ctx->Device());
|
||||
return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
|
||||
return ctx->IsCUDA() ? group_ptr_.ConstDeviceSpan() : group_ptr_.ConstHostSpan();
|
||||
}
|
||||
|
||||
[[nodiscard]] auto const& Param() const { return param_; }
|
||||
@ -231,10 +231,10 @@ class RankingCache {
|
||||
sorted_idx_cache_.SetDevice(ctx->Device());
|
||||
sorted_idx_cache_.Resize(predt.size());
|
||||
}
|
||||
if (ctx->IsCPU()) {
|
||||
return this->MakeRankOnCPU(ctx, predt);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
return this->MakeRankOnCUDA(ctx, predt);
|
||||
} else {
|
||||
return this->MakeRankOnCPU(ctx, predt);
|
||||
}
|
||||
}
|
||||
// The function simply returns a uninitialized buffer as this is only used by the
|
||||
@ -307,10 +307,10 @@ class NDCGCache : public RankingCache {
|
||||
public:
|
||||
NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
|
||||
: RankingCache{ctx, info, p} {
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
} else {
|
||||
this->InitOnCPU(ctx, info);
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,7 +318,7 @@ class NDCGCache : public RankingCache {
|
||||
return inv_idcg_.View(ctx->Device());
|
||||
}
|
||||
common::Span<double const> Discount(Context const* ctx) const {
|
||||
return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
|
||||
return ctx->IsCUDA() ? discounts_.ConstDeviceSpan() : discounts_.ConstHostSpan();
|
||||
}
|
||||
linalg::VectorView<double> Dcg(Context const* ctx) {
|
||||
if (dcg_.Size() == 0) {
|
||||
@ -387,10 +387,10 @@ class PreCache : public RankingCache {
|
||||
public:
|
||||
PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
|
||||
: RankingCache{ctx, info, p} {
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
} else {
|
||||
this->InitOnCPU(ctx, info);
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,7 +399,7 @@ class PreCache : public RankingCache {
|
||||
pre_.SetDevice(ctx->Device());
|
||||
pre_.Resize(this->Groups());
|
||||
}
|
||||
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
|
||||
return ctx->IsCUDA() ? pre_.DeviceSpan() : pre_.HostSpan();
|
||||
}
|
||||
};
|
||||
|
||||
@ -418,10 +418,10 @@ class MAPCache : public RankingCache {
|
||||
public:
|
||||
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
|
||||
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
} else {
|
||||
this->InitOnCPU(ctx, info);
|
||||
}
|
||||
}
|
||||
|
||||
@ -430,21 +430,21 @@ class MAPCache : public RankingCache {
|
||||
n_rel_.SetDevice(ctx->Device());
|
||||
n_rel_.Resize(n_samples_);
|
||||
}
|
||||
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
|
||||
return ctx->IsCUDA() ? n_rel_.DeviceSpan() : n_rel_.HostSpan();
|
||||
}
|
||||
common::Span<double> Acc(Context const* ctx) {
|
||||
if (acc_.Empty()) {
|
||||
acc_.SetDevice(ctx->Device());
|
||||
acc_.Resize(n_samples_);
|
||||
}
|
||||
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
|
||||
return ctx->IsCUDA() ? acc_.DeviceSpan() : acc_.HostSpan();
|
||||
}
|
||||
common::Span<double> Map(Context const* ctx) {
|
||||
if (map_.Empty()) {
|
||||
map_.SetDevice(ctx->Device());
|
||||
map_.Resize(this->Groups());
|
||||
}
|
||||
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
|
||||
return ctx->IsCUDA() ? map_.DeviceSpan() : map_.HostSpan();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -49,7 +49,9 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
|
||||
out->SetDevice(ctx->Device());
|
||||
out->Reshape(1);
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
|
||||
} else {
|
||||
auto h_v = v.HostView();
|
||||
float n = v.Size();
|
||||
MemStackAllocator<float, DefaultMaxThreads()> tloc(ctx->Threads(), 0.0f);
|
||||
@ -57,8 +59,6 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
|
||||
[&](auto i) { tloc[omp_get_thread_num()] += h_v(i) / n; });
|
||||
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
|
||||
out->HostView()(0) = ret;
|
||||
} else {
|
||||
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -278,7 +278,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
|
||||
std::swap(base_score_, base_margin);
|
||||
// Make sure read access everywhere for thread-safe prediction.
|
||||
std::as_const(base_score_).HostView();
|
||||
if (!ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
std::as_const(base_score_).View(ctx->Device());
|
||||
}
|
||||
CHECK(std::as_const(base_score_).Data()->HostCanRead());
|
||||
@ -287,7 +287,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
|
||||
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(DeviceOrd device) const {
|
||||
// multi-class is not yet supported.
|
||||
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
|
||||
if (device.IsCPU()) {
|
||||
if (!device.IsCUDA()) {
|
||||
// Make sure that we won't run into race condition.
|
||||
CHECK(base_score_.Data()->HostCanRead());
|
||||
return base_score_.HostView();
|
||||
|
||||
@ -46,7 +46,26 @@ template <typename Fn>
|
||||
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
||||
PackedReduceResult result;
|
||||
auto labels = info.labels.View(ctx->Device());
|
||||
if (ctx->IsCPU()) {
|
||||
if (ctx->IsCUDA()) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + labels.Size();
|
||||
result = thrust::transform_reduce(
|
||||
thrust::cuda::par(alloc), begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = linalg::UnravelIndex(i, labels.Shape());
|
||||
auto sample_id = std::get<0>(idx);
|
||||
auto target_id = std::get<1>(idx);
|
||||
auto res = loss(i, sample_id, target_id);
|
||||
float v{std::get<0>(res)}, wt{std::get<1>(res)};
|
||||
return PackedReduceResult{v, wt};
|
||||
},
|
||||
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} else {
|
||||
auto n_threads = ctx->Threads();
|
||||
std::vector<double> score_tloc(n_threads, 0.0);
|
||||
std::vector<double> weight_tloc(n_threads, 0.0);
|
||||
@ -69,25 +88,6 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
||||
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
|
||||
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
|
||||
result = PackedReduceResult{residue_sum, weights_sum};
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + labels.Size();
|
||||
result = thrust::transform_reduce(
|
||||
thrust::cuda::par(alloc), begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = linalg::UnravelIndex(i, labels.Shape());
|
||||
auto sample_id = std::get<0>(idx);
|
||||
auto target_id = std::get<1>(idx);
|
||||
auto res = loss(i, sample_id, target_id);
|
||||
float v{std::get<0>(res)}, wt{std::get<1>(res)};
|
||||
return PackedReduceResult{v, wt};
|
||||
},
|
||||
PackedReduceResult{}, thrust::plus<PackedReduceResult>());
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -185,10 +185,10 @@ class PseudoErrorLoss : public MetricNoCache {
|
||||
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
||||
auto labels = info.labels.View(ctx_->Device());
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan());
|
||||
float slope = this->param_.huber_slope;
|
||||
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
||||
PackedReduceResult result =
|
||||
@ -351,10 +351,10 @@ struct EvalEWiseBase : public MetricNoCache {
|
||||
}
|
||||
auto labels = info.labels.View(ctx_->Device());
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan());
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan();
|
||||
|
||||
auto d_policy = policy_;
|
||||
auto result =
|
||||
|
||||
@ -96,13 +96,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||
if (ctx->IsCPU()) {
|
||||
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
|
||||
predt, alpha, p_tree);
|
||||
} else {
|
||||
if (ctx->IsCUDA()) {
|
||||
position.SetDevice(ctx->Device());
|
||||
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
|
||||
predt, alpha, p_tree);
|
||||
} else {
|
||||
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
|
||||
predt, alpha, p_tree);
|
||||
}
|
||||
}
|
||||
} // namespace obj
|
||||
|
||||
@ -108,14 +108,14 @@ class LambdaRankObj : public FitIntercept {
|
||||
li_.SetDevice(ctx_->Device());
|
||||
lj_.SetDevice(ctx_->Device());
|
||||
|
||||
if (ctx_->IsCPU()) {
|
||||
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
|
||||
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
|
||||
&li_, &lj_, p_cache_);
|
||||
} else {
|
||||
if (ctx_->IsCUDA()) {
|
||||
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
|
||||
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
|
||||
&li_, &lj_, p_cache_);
|
||||
} else {
|
||||
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
|
||||
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
|
||||
&li_, &lj_, p_cache_);
|
||||
}
|
||||
|
||||
li_full_.Data()->Fill(0.0);
|
||||
|
||||
@ -71,15 +71,15 @@ class QuantileRegression : public ObjFunction {
|
||||
auto gpair = out_gpair->View(ctx_->Device());
|
||||
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan()};
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
auto n_samples = info.num_row_;
|
||||
|
||||
alpha_.SetDevice(ctx_->Device());
|
||||
auto alpha = ctx_->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
|
||||
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();
|
||||
|
||||
linalg::ElementWiseKernel(
|
||||
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
|
||||
@ -107,27 +107,7 @@ class QuantileRegression : public ObjFunction {
|
||||
base_score->Reshape(n_targets);
|
||||
|
||||
double sw{0};
|
||||
if (ctx_->IsCPU()) {
|
||||
auto quantiles = base_score->HostView();
|
||||
auto h_weights = info.weights_.ConstHostVector();
|
||||
if (info.weights_.Empty()) {
|
||||
sw = info.num_row_;
|
||||
} else {
|
||||
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
|
||||
}
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto alpha = param_.quantile_alpha[t];
|
||||
auto h_labels = info.labels.HostView();
|
||||
if (h_weights.empty()) {
|
||||
quantiles(t) =
|
||||
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
|
||||
} else {
|
||||
CHECK_EQ(h_weights.size(), h_labels.Size());
|
||||
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
|
||||
linalg::cend(h_labels), std::cbegin(h_weights));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (ctx_->IsCUDA()) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
alpha_.SetDevice(ctx_->Device());
|
||||
auto d_alpha = alpha_.ConstDeviceSpan();
|
||||
@ -164,6 +144,26 @@ class QuantileRegression : public ObjFunction {
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} else {
|
||||
auto quantiles = base_score->HostView();
|
||||
auto h_weights = info.weights_.ConstHostVector();
|
||||
if (info.weights_.Empty()) {
|
||||
sw = info.num_row_;
|
||||
} else {
|
||||
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
|
||||
}
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto alpha = param_.quantile_alpha[t];
|
||||
auto h_labels = info.labels.HostView();
|
||||
if (h_weights.empty()) {
|
||||
quantiles(t) =
|
||||
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
|
||||
} else {
|
||||
CHECK_EQ(h_weights.size(), h_labels.Size());
|
||||
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
|
||||
linalg::cend(h_labels), std::cbegin(h_weights));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For multiple quantiles, we should extend the base score to a vector instead of
|
||||
|
||||
@ -254,8 +254,8 @@ class PseudoHuberRegression : public FitIntercept {
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan()};
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
|
||||
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
|
||||
@ -714,8 +714,8 @@ class MeanAbsoluteError : public ObjFunction {
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan()};
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
|
||||
auto sign = [](auto x) {
|
||||
|
||||
@ -72,7 +72,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
|
||||
|
||||
gpair.SetDevice(ctx->Device());
|
||||
auto gpair_t = gpair.View(ctx->Device());
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
|
||||
ctx->IsCUDA() ? cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()))
|
||||
: cpu_impl::FitStump(ctx, info, gpair_t, out->HostView());
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user