Use the new DeviceOrd in the linalg module. (#9527)

This commit is contained in:
Jiaming Yuan 2023-08-29 13:37:29 +08:00 committed by GitHub
parent 942b957eef
commit ddf2e68821
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 252 additions and 273 deletions

View File

@ -102,6 +102,14 @@ class HostDeviceVector {
bool Empty() const { return Size() == 0; }
size_t Size() const;
int DeviceIdx() const;
DeviceOrd Device() const {
auto idx = this->DeviceIdx();
if (idx == DeviceOrd::CPU().ordinal) {
return DeviceOrd::CPU();
} else {
return DeviceOrd::CUDA(idx);
}
}
common::Span<T> DeviceSpan();
common::Span<const T> ConstDeviceSpan() const;
common::Span<const T> DeviceSpan() const { return ConstDeviceSpan(); }

View File

@ -330,7 +330,7 @@ struct LearnerModelParam {
multi_strategy{multi_strategy} {}
linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(std::int32_t device) const;
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(DeviceOrd device) const;
void Copy(LearnerModelParam const& that);
[[nodiscard]] bool IsVectorLeaf() const noexcept {

View File

@ -302,7 +302,7 @@ class TensorView {
T *ptr_{nullptr}; // pointer of data_ to avoid bound check.
size_t size_{0};
int32_t device_{-1};
DeviceOrd device_;
// Unlike `Tensor`, the data_ can have arbitrary size since this is just a view.
LINALG_HD void CalcSize() {
@ -401,15 +401,11 @@ class TensorView {
* \param device Device ordinal
*/
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device)
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device)
: TensorView{data, shape, device, Order::kC} {}
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device)
: TensorView{data, shape, device.ordinal, Order::kC} {}
template <typename I, int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order)
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device, Order order)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D > 0 && D <= kDim, "Invalid shape.");
// shape
@ -441,7 +437,7 @@ class TensorView {
*/
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
std::int32_t device)
DeviceOrd device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D == kDim, "Invalid shape & stride.");
detail::UnrollLoop<D>([&](auto i) {
@ -450,16 +446,12 @@ class TensorView {
});
this->CalcSize();
}
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
DeviceOrd device)
: TensorView{data, shape, stride, device.ordinal} {}
template <
typename U,
std::enable_if_t<common::detail::IsAllowedElementTypeConversion<U, T>::value> * = nullptr>
LINALG_HD TensorView(TensorView<U, kDim> const &that) // NOLINT
: data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.DeviceIdx()} {
: data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.Device()} {
detail::UnrollLoop<kDim>([&](auto i) {
stride_[i] = that.Stride(i);
shape_[i] = that.Shape(i);
@ -572,7 +564,7 @@ class TensorView {
/**
* \brief Obtain the CUDA device ordinal.
*/
LINALG_HD auto DeviceIdx() const { return device_; }
LINALG_HD auto Device() const { return device_; }
};
/**
@ -587,11 +579,11 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
typename Container::value_type>;
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
}
template <typename T, typename... S>
LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...shape) {
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, device};
@ -599,26 +591,26 @@ LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Ordinal(), order};
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}
/**
@ -661,20 +653,20 @@ using VectorView = TensorView<T, 1>;
* \param device (optional) Device ordinal, default to be host.
*/
template <typename T>
auto MakeVec(T *ptr, size_t s, int32_t device = -1) {
auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
}
template <typename T>
auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(),
data->Size(), data->DeviceIdx());
data->Size(), data->Device());
}
template <typename T>
auto MakeVec(HostDeviceVector<T> const *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(),
data->Size(), data->DeviceIdx());
data->Size(), data->Device());
}
/**
@ -697,7 +689,7 @@ Json ArrayInterface(TensorView<T const, D> const &t) {
array_interface["data"] = std::vector<Json>(2);
array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(t.Values().data())};
array_interface["data"][1] = Boolean{true};
if (t.DeviceIdx() >= 0) {
if (t.Device().IsCUDA()) {
// Change this once we have different CUDA stream.
array_interface["stream"] = Null{};
}
@ -856,49 +848,29 @@ class Tensor {
/**
* @brief Get a @ref TensorView for this tensor.
*/
TensorView<T, kDim> View(std::int32_t device) {
if (device >= 0) {
data_.SetDevice(device);
auto span = data_.DeviceSpan();
return {span, shape_, device, order_};
} else {
auto span = data_.HostSpan();
return {span, shape_, device, order_};
}
}
TensorView<T const, kDim> View(std::int32_t device) const {
if (device >= 0) {
data_.SetDevice(device);
auto span = data_.ConstDeviceSpan();
return {span, shape_, device, order_};
} else {
auto span = data_.ConstHostSpan();
return {span, shape_, device, order_};
}
}
auto View(DeviceOrd device) {
if (device.IsCUDA()) {
data_.SetDevice(device);
auto span = data_.DeviceSpan();
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
return TensorView<T, kDim>{span, shape_, device, order_};
} else {
auto span = data_.HostSpan();
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
return TensorView<T, kDim>{span, shape_, device, order_};
}
}
auto View(DeviceOrd device) const {
if (device.IsCUDA()) {
data_.SetDevice(device);
auto span = data_.ConstDeviceSpan();
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
return TensorView<T const, kDim>{span, shape_, device, order_};
} else {
auto span = data_.ConstHostSpan();
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
return TensorView<T const, kDim>{span, shape_, device, order_};
}
}
auto HostView() const { return this->View(-1); }
auto HostView() { return this->View(-1); }
auto HostView() { return this->View(DeviceOrd::CPU()); }
auto HostView() const { return this->View(DeviceOrd::CPU()); }
[[nodiscard]] size_t Size() const { return data_.Size(); }
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
@ -975,6 +947,7 @@ class Tensor {
void SetDevice(int32_t device) const { data_.SetDevice(device); }
void SetDevice(DeviceOrd device) const { data_.SetDevice(device); }
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
[[nodiscard]] DeviceOrd Device() const { return data_.Device(); }
};
template <typename T>

View File

@ -37,12 +37,12 @@ class MultiTargetTree : public Model {
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}
public:

View File

@ -68,7 +68,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
auto &gpair = *out_gpair;
gpair.SetDevice(grad_dev);
gpair.Reshape(grad.Shape(0), grad.Shape(1));
auto d_gpair = gpair.View(grad_dev);
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
auto cuctx = ctx->CUDACtx();
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {

View File

@ -13,7 +13,7 @@ namespace xgboost {
namespace linalg {
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {

View File

@ -133,7 +133,7 @@ struct WeightOp {
void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
CUDAContext const* cuctx = ctx->CUDACtx();
group_ptr_.SetDevice(ctx->gpu_id);
group_ptr_.SetDevice(ctx->Device());
if (info.group_ptr_.empty()) {
group_ptr_.Resize(2, 0);
group_ptr_.HostVector()[1] = info.num_row_;
@ -153,7 +153,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
max_group_size_ =
thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum<std::size_t>{});
threads_group_ptr_.SetDevice(ctx->gpu_id);
threads_group_ptr_.SetDevice(ctx->Device());
threads_group_ptr_.Resize(n_groups + 1, 0);
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
if (param_.HasTruncation()) {
@ -168,7 +168,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
n_cuda_threads_ = info.num_row_ * param_.NumPair();
}
sorted_idx_cache_.SetDevice(ctx->gpu_id);
sorted_idx_cache_.SetDevice(ctx->Device());
sorted_idx_cache_.Resize(info.labels.Size(), 0);
auto weight = common::MakeOptionalWeights(ctx, info.weights_);
@ -187,18 +187,18 @@ common::Span<std::size_t const> RankingCache::MakeRankOnCUDA(Context const* ctx,
void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
CUDAContext const* cuctx = ctx->CUDACtx();
auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto labels = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx});
auto d_group_ptr = this->DataGroupPtr(ctx);
std::size_t n_groups = d_group_ptr.size() - 1;
inv_idcg_ = linalg::Zeros<double>(ctx, n_groups);
auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id);
auto d_inv_idcg = inv_idcg_.View(ctx->Device());
cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param());
CHECK_GE(this->Param().NumPair(), 1ul);
discounts_.SetDevice(ctx->gpu_id);
discounts_.SetDevice(ctx->Device());
discounts_.Resize(MaxGroupSize());
auto d_discount = discounts_.DeviceSpan();
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
@ -206,12 +206,12 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
}
void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()});
}
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()});
}
} // namespace xgboost::ltr

View File

@ -217,7 +217,7 @@ class RankingCache {
}
// Constructed as [1, n_samples] if group ptr is not supplied by the user
common::Span<bst_group_t const> DataGroupPtr(Context const* ctx) const {
group_ptr_.SetDevice(ctx->gpu_id);
group_ptr_.SetDevice(ctx->Device());
return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
}
@ -228,7 +228,7 @@ class RankingCache {
// Create a rank list by model prediction
common::Span<std::size_t const> SortedIdx(Context const* ctx, common::Span<float const> predt) {
if (sorted_idx_cache_.Empty()) {
sorted_idx_cache_.SetDevice(ctx->gpu_id);
sorted_idx_cache_.SetDevice(ctx->Device());
sorted_idx_cache_.Resize(predt.size());
}
if (ctx->IsCPU()) {
@ -242,7 +242,7 @@ class RankingCache {
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_sorted_idx_cache_.Empty()) {
y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
y_sorted_idx_cache_.SetDevice(ctx->Device());
y_sorted_idx_cache_.Resize(n_samples);
}
return y_sorted_idx_cache_.DeviceSpan();
@ -250,7 +250,7 @@ class RankingCache {
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_ranked_by_model_.Empty()) {
y_ranked_by_model_.SetDevice(ctx->gpu_id);
y_ranked_by_model_.SetDevice(ctx->Device());
y_ranked_by_model_.Resize(n_samples);
}
return y_ranked_by_model_.DeviceSpan();
@ -266,21 +266,21 @@ class RankingCache {
linalg::VectorView<GradientPair> CUDARounding(Context const* ctx) {
if (roundings_.Size() == 0) {
roundings_.SetDevice(ctx->gpu_id);
roundings_.SetDevice(ctx->Device());
roundings_.Reshape(Groups());
}
return roundings_.View(ctx->gpu_id);
return roundings_.View(ctx->Device());
}
common::Span<double> CUDACostRounding(Context const* ctx) {
if (cost_rounding_.Size() == 0) {
cost_rounding_.SetDevice(ctx->gpu_id);
cost_rounding_.SetDevice(ctx->Device());
cost_rounding_.Resize(1);
}
return cost_rounding_.DeviceSpan();
}
template <typename Type>
common::Span<Type> MaxLambdas(Context const* ctx, std::size_t n) {
max_lambdas_.SetDevice(ctx->gpu_id);
max_lambdas_.SetDevice(ctx->Device());
std::size_t bytes = n * sizeof(Type);
if (bytes != max_lambdas_.Size()) {
max_lambdas_.Resize(bytes);
@ -315,17 +315,17 @@ class NDCGCache : public RankingCache {
}
linalg::VectorView<double const> InvIDCG(Context const* ctx) const {
return inv_idcg_.View(ctx->gpu_id);
return inv_idcg_.View(ctx->Device());
}
common::Span<double const> Discount(Context const* ctx) const {
return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
}
linalg::VectorView<double> Dcg(Context const* ctx) {
if (dcg_.Size() == 0) {
dcg_.SetDevice(ctx->gpu_id);
dcg_.SetDevice(ctx->Device());
dcg_.Reshape(this->Groups());
}
return dcg_.View(ctx->gpu_id);
return dcg_.View(ctx->Device());
}
};
@ -396,7 +396,7 @@ class PreCache : public RankingCache {
common::Span<double> Pre(Context const* ctx) {
if (pre_.Empty()) {
pre_.SetDevice(ctx->gpu_id);
pre_.SetDevice(ctx->Device());
pre_.Resize(this->Groups());
}
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
@ -427,21 +427,21 @@ class MAPCache : public RankingCache {
common::Span<double> NumRelevant(Context const* ctx) {
if (n_rel_.Empty()) {
n_rel_.SetDevice(ctx->gpu_id);
n_rel_.SetDevice(ctx->Device());
n_rel_.Resize(n_samples_);
}
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
}
common::Span<double> Acc(Context const* ctx) {
if (acc_.Empty()) {
acc_.SetDevice(ctx->gpu_id);
acc_.SetDevice(ctx->Device());
acc_.Resize(n_samples_);
}
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
}
common::Span<double> Map(Context const* ctx) {
if (map_.Empty()) {
map_.SetDevice(ctx->gpu_id);
map_.SetDevice(ctx->Device());
map_.Resize(this->Groups());
}
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();

View File

@ -20,9 +20,9 @@ namespace common {
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
if (!ctx->IsCPU()) {
weights.SetDevice(ctx->gpu_id);
weights.SetDevice(ctx->Device());
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
auto t_v = t.View(ctx->gpu_id);
auto t_v = t.View(ctx->Device());
cuda_impl::Median(ctx, t_v, opt_weights, out);
}
@ -59,7 +59,7 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
out->HostView()(0) = ret;
} else {
cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id));
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
}
}
} // namespace common

View File

@ -366,7 +366,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
if (this->labels.Size() != this->num_row_) {
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
auto t_labels = this->labels.View(this->labels.Data()->Device());
out.labels.Reshape(ridxs.size(), labels.Shape(1));
out.labels.Data()->HostVector() =
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
@ -394,7 +394,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->Device());
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
out.base_margin_.Data()->HostVector() =
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
@ -445,7 +445,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
return;
}
p_out->Reshape(array.shape);
auto t_out = p_out->View(Context::kCpuId);
auto t_out = p_out->View(DeviceOrd::CPU());
CHECK(t_out.CContiguous());
auto const shape = t_out.Shape();
DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
@ -564,7 +564,7 @@ void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, Da
CHECK(key);
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface {
linalg::ArrayInterface(t)
@ -739,8 +739,7 @@ void MetaInfo::SynchronizeNumberOfColumns() {
namespace {
template <typename T>
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
bool valid =
v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device;
bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device;
if (!valid) {
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
"the booster. The device ordinal of the data is: "

View File

@ -50,7 +50,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(ptr_device);
auto t = p_out->View(DeviceOrd::CUDA(ptr_device));
linalg::ElementWiseTransformDevice(
t,
[=] __device__(size_t i, T) {

View File

@ -183,7 +183,7 @@ class GBLinear : public GradientBooster {
bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override {
model_.LazyInitModel();
LinearCheckLayer(layer_begin);
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU());
const int ngroup = model_.learner_model_param->num_output_group;
const size_t ncolumns = model_.learner_model_param->num_feature + 1;
// allocate space for (#features + bias) times #groups times #rows
@ -250,10 +250,9 @@ class GBLinear : public GradientBooster {
// The bias is the last weight
out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0);
auto n_groups = learner_model_param_->num_output_group;
linalg::TensorView<float, 2> scores{
*out_scores,
{learner_model_param_->num_feature, n_groups},
Context::kCpuId};
auto scores = linalg::MakeTensorView(DeviceOrd::CPU(),
common::Span{out_scores->data(), out_scores->size()},
learner_model_param_->num_feature, n_groups);
for (size_t i = 0; i < learner_model_param_->num_feature; ++i) {
for (bst_group_t g = 0; g < n_groups; ++g) {
scores(i, g) = model_[i][g];
@ -275,12 +274,12 @@ class GBLinear : public GradientBooster {
monitor_.Start("PredictBatchInternal");
model_.LazyInitModel();
std::vector<bst_float> &preds = *out_preds;
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU());
// start collecting the prediction
const int ngroup = model_.learner_model_param->num_output_group;
preds.resize(p_fmat->Info().num_row_ * ngroup);
auto base_score = learner_model_param_->BaseScore(Context::kCpuId);
auto base_score = learner_model_param_->BaseScore(DeviceOrd::CPU());
for (const auto &page : p_fmat->GetBatches<SparsePage>()) {
auto const& batch = page.GetView();
// output convention: nrow * k, where nrow is number of rows

View File

@ -754,7 +754,7 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;
PredictionCacheEntry predts; // temporary storage for prediction
if (ctx_->gpu_id != Context::kCpuId) {
if (ctx_->IsCUDA()) {
predts.predictions.SetDevice(ctx_->gpu_id);
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
@ -859,12 +859,12 @@ class Dart : public GBTree {
size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.DeviceIdx());
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.Device());
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,
group);
} else {
auto base_score = model_.learner_model_param->BaseScore(Context::kCpuId);
auto base_score = model_.learner_model_param->BaseScore(DeviceOrd::CPU());
auto& h_predts = predts.predictions.HostVector();
auto& h_out_predts = p_out_preds->predictions.HostVector();
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {

View File

@ -279,15 +279,15 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
// Make sure read access everywhere for thread-safe prediction.
std::as_const(base_score_).HostView();
if (!ctx->IsCPU()) {
std::as_const(base_score_).View(ctx->gpu_id);
std::as_const(base_score_).View(ctx->Device());
}
CHECK(std::as_const(base_score_).Data()->HostCanRead());
}
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const {
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(DeviceOrd device) const {
// multi-class is not yet supported.
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
if (device == Context::kCpuId) {
if (device.IsCPU()) {
// Make sure that we won't run into race condition.
CHECK(base_score_.Data()->HostCanRead());
return base_score_.HostView();
@ -300,7 +300,7 @@ linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device)
}
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(Context const* ctx) const {
return this->BaseScore(ctx->gpu_id);
return this->BaseScore(ctx->Device());
}
void LearnerModelParam::Copy(LearnerModelParam const& that) {
@ -309,7 +309,7 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) {
base_score_.Data()->Copy(*that.base_score_.Data());
std::as_const(base_score_).HostView();
if (that.base_score_.DeviceIdx() != Context::kCpuId) {
std::as_const(base_score_).View(that.base_score_.DeviceIdx());
std::as_const(base_score_).View(that.base_score_.Device());
}
CHECK_EQ(base_score_.Data()->DeviceCanRead(), that.base_score_.Data()->DeviceCanRead());
CHECK(base_score_.Data()->HostCanRead());
@ -388,7 +388,7 @@ class LearnerConfiguration : public Learner {
this->ConfigureTargets();
auto task = UsePtr(obj_)->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
linalg::Tensor<float, 1> base_score({1}, Ctx()->Device());
auto h_base_score = base_score.HostView();
// transform to margin
@ -424,7 +424,7 @@ class LearnerConfiguration : public Learner {
if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) {
if (p_fmat) {
auto const& info = p_fmat->Info();
info.Validate(Ctx()->gpu_id);
info.Validate(Ctx()->Ordinal());
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
InitEstimation(info, &base_score);
@ -1369,7 +1369,7 @@ class LearnerImpl : public LearnerIO {
auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(ctx_.gpu_id);
out_preds->SetDevice(ctx_.Device());
out_preds->Resize(prediction.predictions.Size());
out_preds->Copy(prediction.predictions);
if (!output_margin) {

View File

@ -82,22 +82,19 @@ template <typename BinaryAUC>
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
CHECK_NE(n_classes, 0);
auto const labels = info.labels.View(Context::kCpuId);
auto const labels = info.labels.HostView();
if (labels.Shape(0) != 0) {
CHECK_EQ(labels.Shape(1), 1) << "AUC doesn't support multi-target model.";
}
std::vector<double> results_storage(n_classes * 3, 0);
linalg::TensorView<double, 2> results(results_storage, {n_classes, static_cast<size_t>(3)},
Context::kCpuId);
auto results = linalg::MakeTensorView(ctx, results_storage, n_classes, 3);
auto local_area = results.Slice(linalg::All(), 0);
auto tp = results.Slice(linalg::All(), 1);
auto auc = results.Slice(linalg::All(), 2);
auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()};
auto predts_t = linalg::TensorView<float const, 2>(
predts, {static_cast<size_t>(info.num_row_), n_classes},
Context::kCpuId);
auto predts_t = linalg::MakeTensorView(ctx, predts, info.num_row_, n_classes);
if (info.labels.Size() != 0) {
common::ParallelFor(n_classes, n_threads, [&](auto c) {
@ -108,8 +105,8 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
response[i] = labels(i) == c ? 1.0f : 0.0;
}
double fp;
std::tie(fp, tp(c), auc(c)) =
binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
std::tie(fp, tp(c), auc(c)) = binary_auc(
ctx, proba, linalg::MakeVec(response.data(), response.size(), ctx->Device()), weights);
local_area(c) = fp * tp(c);
});
}
@ -220,7 +217,7 @@ std::pair<double, uint32_t> RankingAUC(Context const *ctx, std::vector<float> co
CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1;
auto s_predts = common::Span<float const>{predts};
auto labels = info.labels.View(Context::kCpuId);
auto labels = info.labels.View(ctx->Device());
auto s_weights = info.weights_.ConstHostSpan();
std::atomic<uint32_t> invalid_groups{0};
@ -363,8 +360,8 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
ctx_->gpu_id, &this->d_cache_);
std::tie(fp, tp, auc) =
GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
}
return std::make_tuple(fp, tp, auc);
}
@ -381,8 +378,7 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
#if !defined(XGBOOST_USE_CUDA)
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, MetaInfo const &,
std::int32_t,
std::shared_ptr<DeviceAUCCache> *) {
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport();
return {};
}
@ -414,8 +410,8 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
ctx_->gpu_id, &this->d_cache_);
std::tie(pr, re, auc) =
GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
}
return std::make_tuple(pr, re, auc);
}
@ -459,7 +455,7 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
#if !defined(XGBOOST_USE_CUDA)
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, MetaInfo const &,
std::int32_t, std::shared_ptr<DeviceAUCCache> *) {
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport();
return {};
}

View File

@ -85,11 +85,11 @@ void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCa
template <typename Fn>
std::tuple<double, double, double>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, common::Span<size_t const> d_sorted_idx,
DeviceOrd device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(device);
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device));
dh::safe_cuda(cudaSetDevice(device.ordinal));
CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size());
@ -168,7 +168,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
}
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
MetaInfo const &info, DeviceOrd device,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@ -309,9 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
* up each class in all kernels.
*/
template <bool scale, typename Fn>
double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<uint32_t> d_class_ptr,
size_t n_classes, std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device));
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device.ordinal));
/**
* Sorted idx
*/
@ -467,11 +468,12 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache);
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
double tp, size_t /*class_id*/) {
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, double tp,
size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(info, ctx->gpu_id, dh::ToSpan(class_ptr), n_classes, cache, fn);
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
fn);
}
namespace {
@ -512,7 +514,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
/**
* Sort the labels
*/
auto d_labels = info.labels.View(ctx->gpu_id);
auto d_labels = info.labels.View(ctx->Device());
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
common::SegmentedArgSort<false, false>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
@ -604,7 +606,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
}
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
MetaInfo const &info, DeviceOrd device,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@ -662,7 +664,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
/**
* Get total positive/negative
*/
auto labels = info.labels.View(ctx->gpu_id);
auto labels = info.labels.View(ctx->Device());
auto n_samples = info.num_row_;
dh::caching_device_vector<Pair> totals(n_classes);
auto key_it =
@ -695,13 +697,13 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(info, ctx->gpu_id, d_class_ptr, n_classes, cache, fn);
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
}
template <typename Fn>
std::pair<double, uint32_t>
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
common::Span<uint32_t> d_group_ptr, int32_t device,
common::Span<uint32_t> d_group_ptr, DeviceOrd device,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
/**
* Sorted idx
@ -843,7 +845,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
common::SegmentedArgSort<false, false>(ctx, predts, d_group_ptr, d_sorted_idx);
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels.View(ctx->gpu_id);
auto labels = info.labels.View(ctx->Device());
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
@ -882,7 +884,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->gpu_id, cache, fn);
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn);
}
} // namespace metric
} // namespace xgboost

View File

@ -30,7 +30,7 @@ XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, doub
struct DeviceAUCCache;
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
MetaInfo const &info, DeviceOrd,
std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
@ -45,7 +45,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
* PR AUC *
**********/
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
MetaInfo const &info, DeviceOrd,
std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,

View File

@ -45,7 +45,7 @@ namespace {
template <typename Fn>
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
PackedReduceResult result;
auto labels = info.labels.View(ctx->gpu_id);
auto labels = info.labels.View(ctx->Device());
if (ctx->IsCPU()) {
auto n_threads = ctx->Threads();
std::vector<double> score_tloc(n_threads, 0.0);
@ -183,10 +183,10 @@ class PseudoErrorLoss : public MetricNoCache {
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
CHECK_EQ(info.labels.Shape(0), info.num_row_);
auto labels = info.labels.View(ctx_->gpu_id);
preds.SetDevice(ctx_->gpu_id);
auto labels = info.labels.View(ctx_->Device());
preds.SetDevice(ctx_->Device());
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
info.weights_.SetDevice(ctx_->gpu_id);
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
float slope = this->param_.huber_slope;
@ -349,11 +349,11 @@ struct EvalEWiseBase : public MetricNoCache {
if (info.labels.Size() != 0) {
CHECK_NE(info.labels.Shape(1), 0);
}
auto labels = info.labels.View(ctx_->gpu_id);
info.weights_.SetDevice(ctx_->gpu_id);
auto labels = info.labels.View(ctx_->Device());
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
preds.SetDevice(ctx_->gpu_id);
preds.SetDevice(ctx_->Device());
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
auto d_policy = policy_;
@ -444,16 +444,16 @@ class QuantileError : public MetricNoCache {
}
auto const* ctx = ctx_;
auto y_true = info.labels.View(ctx->gpu_id);
preds.SetDevice(ctx->gpu_id);
alpha_.SetDevice(ctx->gpu_id);
auto y_true = info.labels.View(ctx->Device());
preds.SetDevice(ctx->Device());
alpha_.SetDevice(ctx->Device());
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
CHECK_NE(n_targets, 0);
auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast<std::size_t>(info.num_row_),
alpha_.Size(), n_targets);
info.weights_.SetDevice(ctx->gpu_id);
info.weights_.SetDevice(ctx->Device());
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};

View File

@ -75,7 +75,7 @@ struct EvalAMS : public MetricNoCache {
const double br = 10.0;
unsigned thresindex = 0;
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
const auto& labels = info.labels.View(Context::kCpuId);
const auto& labels = info.labels.View(DeviceOrd::CPU());
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
const unsigned ridx = rec[i].second;
const bst_float wt = info.GetWeight(ridx);
@ -134,7 +134,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
{
const auto& labels = info.labels.View(Context::kCpuId);
const auto& labels = info.labels.HostView();
const auto &h_preds = preds.ConstHostVector();
dmlc::OMPException exc;

View File

@ -33,7 +33,7 @@ PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt,
std::shared_ptr<ltr::PreCache> p_cache) {
auto d_gptr = p_cache->DataGroupPtr(ctx);
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
@ -89,7 +89,7 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
if (!d_weight.Empty()) {
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
}
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
@ -119,9 +119,9 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache) {
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
predt.SetDevice(ctx->Device());
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),

View File

@ -19,7 +19,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
dh::device_vector<size_t>* p_ridx, HostDeviceVector<size_t>* p_nptr,
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
// copy position to buffer
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
auto cuctx = ctx->CUDACtx();
size_t n_samples = position.size();
dh::device_vector<bst_node_t> sorted_position(position.size());
@ -86,11 +86,11 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
*/
auto& nidx = *p_nidx;
auto& nptr = *p_nptr;
nidx.SetDevice(ctx->gpu_id);
nidx.SetDevice(ctx->Device());
nidx.Resize(n_leaf);
auto d_node_idx = nidx.DeviceSpan();
nptr.SetDevice(ctx->gpu_id);
nptr.SetDevice(ctx->Device());
nptr.Resize(n_leaf + 1, 0);
auto d_node_ptr = nptr.DeviceSpan();
@ -142,7 +142,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
dh::device_vector<size_t> ridx;
HostDeviceVector<size_t> nptr;
HostDeviceVector<bst_node_t> nidx;
@ -155,13 +155,13 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
}
HostDeviceVector<float> quantiles;
predt.SetDevice(ctx->gpu_id);
predt.SetDevice(ctx->Device());
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
predt.Size() / info.num_row_);
CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
@ -178,7 +178,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
} else {
info.weights_.SetDevice(ctx->gpu_id);
info.weights_.SetDevice(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weights.size(), d_row_index.size());
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));

View File

@ -109,12 +109,12 @@ class LambdaRankObj : public FitIntercept {
lj_.SetDevice(ctx_->gpu_id);
if (ctx_->IsCPU()) {
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
} else {
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
}
@ -354,9 +354,9 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
if (ctx_->IsCUDA()) {
cuda_impl::LambdaRankGetGradientNDCG(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), out_gpair);
return;
}
@ -477,9 +477,9 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientMAP(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), out_gpair);
}
auto gptr = p_cache_->DataGroupPtr(ctx_).data();
@ -567,9 +567,9 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientPairwise(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), out_gpair);
}
auto gptr = p_cache_->DataGroupPtr(ctx_);

View File

@ -306,7 +306,7 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
CHECK_NE(d_rounding.Size(), 0);
auto label = info.labels.View(ctx->gpu_id);
auto label = info.labels.View(ctx->Device());
auto predts = preds.ConstDeviceSpan();
auto gpairs = out_gpair->View(ctx->Device());
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(),
@ -348,7 +348,7 @@ common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
common::Span<std::size_t const> d_rank,
std::shared_ptr<ltr::RankingCache> p_cache) {
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
auto label = info.labels.View(ctx->gpu_id);
auto label = info.labels.View(ctx->Device());
// The buffer for ranked y is necessary as cub segmented sort accepts only pointer.
auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_);
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(),
@ -374,13 +374,13 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
linalg::VectorView<double> li, linalg::VectorView<double> lj,
linalg::Matrix<GradientPair>* out_gpair) {
// boilerplate
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
auto device = ctx->Device();
dh::safe_cuda(cudaSetDevice(device.ordinal));
auto const d_inv_IDCG = p_cache->InvIDCG(ctx);
auto const discount = p_cache->Discount(ctx);
info.labels.SetDevice(device_id);
preds.SetDevice(device_id);
info.labels.SetDevice(device);
preds.SetDevice(device);
auto const exp_gain = p_cache->Param().ndcg_exp_gain;
auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
@ -403,7 +403,7 @@ void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
auto const* cuctx = ctx->CUDACtx();
{
@ -442,11 +442,11 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
linalg::Matrix<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
auto device = ctx->Device();
dh::safe_cuda(cudaSetDevice(device.ordinal));
info.labels.SetDevice(device_id);
predt.SetDevice(device_id);
info.labels.SetDevice(device);
predt.SetDevice(device);
CHECK(p_cache);
@ -481,11 +481,11 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
linalg::Matrix<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
auto device = ctx->Device();
dh::safe_cuda(cudaSetDevice(device.ordinal));
info.labels.SetDevice(device_id);
predt.SetDevice(device_id);
info.labels.SetDevice(device);
predt.SetDevice(device);
auto d_predt = predt.ConstDeviceSpan();
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
@ -517,11 +517,11 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = d_group_ptr.size() - 1;
auto ti_plus = p_ti_plus->View(ctx->gpu_id);
auto tj_minus = p_tj_minus->View(ctx->gpu_id);
auto ti_plus = p_ti_plus->View(ctx->Device());
auto tj_minus = p_tj_minus->View(ctx->Device());
auto li = p_li->View(ctx->gpu_id);
auto lj = p_lj->View(ctx->gpu_id);
auto li = p_li->View(ctx->Device());
auto lj = p_lj->View(ctx->Device());
CHECK_EQ(li.Size(), ti_plus.Size());
auto const& param = p_cache->Param();

View File

@ -62,7 +62,7 @@ class QuantileRegression : public ObjFunction {
CHECK_GE(n_targets, n_alphas);
CHECK_EQ(preds.Size(), info.num_row_ * n_targets);
auto labels = info.labels.View(ctx_->gpu_id);
auto labels = info.labels.View(ctx_->Device());
out_gpair->SetDevice(ctx_->Device());
CHECK_EQ(info.labels.Shape(1), 1)
@ -131,7 +131,7 @@ class QuantileRegression : public ObjFunction {
#if defined(XGBOOST_USE_CUDA)
alpha_.SetDevice(ctx_->gpu_id);
auto d_alpha = alpha_.ConstDeviceSpan();
auto d_labels = info.labels.View(ctx_->gpu_id);
auto d_labels = info.labels.View(ctx_->Device());
auto seg_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); });

View File

@ -69,7 +69,7 @@ class RegLossObj : public FitIntercept {
public:
void ValidateLabel(MetaInfo const& info) {
auto label = info.labels.View(ctx_->Ordinal());
auto label = info.labels.View(ctx_->Device());
auto valid = ctx_->DispatchDevice(
[&] {
return std::all_of(linalg::cbegin(label), linalg::cend(label),
@ -244,7 +244,7 @@ class PseudoHuberRegression : public FitIntercept {
CheckRegInputs(info, preds);
auto slope = param_.huber_slope;
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
auto labels = info.labels.View(ctx_->gpu_id);
auto labels = info.labels.View(ctx_->Device());
out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Reshape(info.num_row_, this->Targets(info));
@ -698,7 +698,7 @@ class MeanAbsoluteError : public ObjFunction {
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info,
std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds);
auto labels = info.labels.View(ctx_->gpu_id);
auto labels = info.labels.View(ctx_->Device());
out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));

View File

@ -663,7 +663,7 @@ class CPUPredictor : public Predictor {
std::size_t n_samples = p_fmat->Info().num_row_;
std::size_t n_groups = model.learner_model_param->OutputLength();
CHECK_EQ(out_preds->size(), n_samples * n_groups);
linalg::TensorView<float, 2> out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id};
auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups);
if (!p_fmat->PageExists<SparsePage>()) {
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
@ -732,7 +732,7 @@ class CPUPredictor : public Predictor {
std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(n_threads * kBlockSize, &thread_temp);
std::size_t n_groups = model.learner_model_param->OutputLength();
linalg::TensorView<float, 2> out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId};
auto out_predt = linalg::MakeTensorView(ctx_, predictions, m->NumRows(), n_groups);
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads), model,
tree_begin, tree_end, &thread_temp, n_threads, out_predt);
@ -878,8 +878,8 @@ class CPUPredictor : public Predictor {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
auto base_margin = info.base_margin_.View(Context::kCpuId);
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
auto base_margin = info.base_margin_.View(ctx_->Device());
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);
// start collecting the contributions
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView();

View File

@ -60,7 +60,7 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
} else {
// cannot rely on the Resize to fill as it might skip if the size is already correct.
out_preds->Resize(n);
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
auto base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU())(0);
out_preds->Fill(base_score);
}
}

View File

@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
}
} // namespace tree
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 by XGBoost Contributors
* Copyright 2022-2023 by XGBoost Contributors
*
* \brief Utilities for estimating initial score.
*/
@ -41,7 +41,7 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
auto sample = i % gpair.Shape(0);
return GradientPairPrecise{gpair(sample, target)};
});
auto d_sum = sum.View(ctx->gpu_id);
auto d_sum = sum.View(ctx->Device());
CHECK(d_sum.CContiguous());
dh::XGBCachingDeviceAllocator<char> alloc;

View File

@ -774,7 +774,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::VectorView<float> out_preds) {
auto const &tree = *p_last_tree;
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
CHECK(out_preds.Device().IsCPU());
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
@ -809,7 +809,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
auto n_nodes = mttree->Size();
auto n_targets = tree.NumTargets();
CHECK_EQ(out_preds.Shape(1), n_targets);
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
CHECK(out_preds.Device().IsCPU());
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);

View File

@ -516,9 +516,10 @@ struct GPUHistMakerDevice {
}
CHECK(p_tree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
CHECK(out_preds_d.Device().IsCUDA());
CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal());
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
auto d_position = dh::ToSpan(positions);
CHECK_EQ(out_preds_d.Size(), d_position.size());

View File

@ -3,7 +3,7 @@
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h>
#include <cstddef> // size_t
@ -14,8 +14,8 @@
namespace xgboost::linalg {
namespace {
auto kCpuId = Context::kCpuId;
}
DeviceOrd CPU() { return DeviceOrd::CPU(); }
} // namespace
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
storage->Resize(n_rows * n_cols);
@ -23,7 +23,7 @@ auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, st
std::iota(h_storage.begin(), h_storage.end(), 0);
auto m = linalg::TensorView<float, 2>{h_storage, {n_rows, static_cast<size_t>(n_cols)}, -1};
auto m = linalg::TensorView<float, 2>{h_storage, {n_rows, static_cast<size_t>(n_cols)}, CPU()};
return m;
}
@ -31,7 +31,7 @@ TEST(Linalg, MatrixView) {
size_t kRows = 31, kCols = 77;
HostDeviceVector<float> storage;
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
ASSERT_EQ(m.DeviceIdx(), kCpuId);
ASSERT_EQ(m.Device(), CPU());
ASSERT_EQ(m(0, 0), 0);
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
}
@ -76,7 +76,7 @@ TEST(Linalg, TensorView) {
{
// as vector
TensorView<double, 1> vec{data, {data.size()}, -1};
TensorView<double, 1> vec{data, {data.size()}, CPU()};
ASSERT_EQ(vec.Size(), data.size());
ASSERT_EQ(vec.Shape(0), data.size());
ASSERT_EQ(vec.Shape().size(), 1);
@ -87,7 +87,7 @@ TEST(Linalg, TensorView) {
{
// as matrix
TensorView<double, 2> mat(data, {6, 4}, -1);
TensorView<double, 2> mat(data, {6, 4}, CPU());
auto s = mat.Slice(2, All());
ASSERT_EQ(s.Shape().size(), 1);
s = mat.Slice(All(), 1);
@ -96,7 +96,7 @@ TEST(Linalg, TensorView) {
{
// assignment
TensorView<double, 3> t{data, {2, 3, 4}, 0};
TensorView<double, 3> t{data, {2, 3, 4}, CPU()};
double pi = 3.14159;
auto old = t(1, 2, 3);
t(1, 2, 3) = pi;
@ -201,7 +201,7 @@ TEST(Linalg, TensorView) {
}
{
// f-contiguous
TensorView<double, 3> t{data, {4, 3, 2}, {1, 4, 12}, kCpuId};
TensorView<double, 3> t{data, {4, 3, 2}, {1, 4, 12}, CPU()};
ASSERT_TRUE(t.Contiguous());
ASSERT_TRUE(t.FContiguous());
ASSERT_FALSE(t.CContiguous());
@ -210,11 +210,11 @@ TEST(Linalg, TensorView) {
TEST(Linalg, Tensor) {
{
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
auto view = t.View(kCpuId);
Tensor<float, 3> t{{2, 3, 4}, CPU(), Order::kC};
auto view = t.View(CPU());
auto const &as_const = t;
auto k_view = as_const.View(kCpuId);
auto k_view = as_const.View(CPU());
size_t n = 2 * 3 * 4;
ASSERT_EQ(t.Size(), n);
@ -229,7 +229,7 @@ TEST(Linalg, Tensor) {
}
{
// Reshape
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
Tensor<float, 3> t{{2, 3, 4}, CPU(), Order::kC};
t.Reshape(4, 3, 2);
ASSERT_EQ(t.Size(), 24);
ASSERT_EQ(t.Shape(2), 2);
@ -247,7 +247,7 @@ TEST(Linalg, Tensor) {
TEST(Linalg, Empty) {
{
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId, Order::kC};
auto t = TensorView<double, 2>{{}, {0, 3}, CPU(), Order::kC};
for (int32_t i : {0, 1, 2}) {
auto s = t.Slice(All(), i);
ASSERT_EQ(s.Size(), 0);
@ -256,9 +256,9 @@ TEST(Linalg, Empty) {
}
}
{
auto t = Tensor<double, 2>{{0, 3}, kCpuId, Order::kC};
auto t = Tensor<double, 2>{{0, 3}, CPU(), Order::kC};
ASSERT_EQ(t.Size(), 0);
auto view = t.View(kCpuId);
auto view = t.View(CPU());
for (int32_t i : {0, 1, 2}) {
auto s = view.Slice(All(), i);
@ -270,7 +270,7 @@ TEST(Linalg, Empty) {
}
TEST(Linalg, ArrayInterface) {
auto cpu = kCpuId;
auto cpu = CPU();
auto t = Tensor<double, 2>{{3, 3}, cpu, Order::kC};
auto v = t.View(cpu);
std::iota(v.Values().begin(), v.Values().end(), 0);
@ -315,16 +315,16 @@ TEST(Linalg, Popc) {
}
TEST(Linalg, Stack) {
Tensor<float, 3> l{{2, 3, 4}, kCpuId, Order::kC};
ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(),
Tensor<float, 3> l{{2, 3, 4}, CPU(), Order::kC};
ElementWiseTransformHost(l.View(CPU()), omp_get_max_threads(),
[=](size_t i, float) { return i; });
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId, Order::kC};
ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(),
Tensor<float, 3> r_0{{2, 3, 4}, CPU(), Order::kC};
ElementWiseTransformHost(r_0.View(CPU()), omp_get_max_threads(),
[=](size_t i, float) { return i; });
Stack(&l, r_0);
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId, Order::kC};
Tensor<float, 3> r_1{{0, 3, 4}, CPU(), Order::kC};
Stack(&l, r_1);
ASSERT_EQ(l.Shape(0), 4);
@ -335,7 +335,7 @@ TEST(Linalg, Stack) {
TEST(Linalg, FOrder) {
std::size_t constexpr kRows = 16, kCols = 3;
std::vector<float> data(kRows * kCols);
MatrixView<float> mat{data, {kRows, kCols}, Context::kCpuId, Order::kF};
MatrixView<float> mat{data, {kRows, kCols}, CPU(), Order::kF};
float k{0};
for (std::size_t i = 0; i < kRows; ++i) {
for (std::size_t j = 0; j < kCols; ++j) {

View File

@ -11,17 +11,18 @@
namespace xgboost::linalg {
namespace {
void TestElementWiseKernel() {
auto device = DeviceOrd::CUDA(0);
Tensor<float, 3> l{{2, 3, 4}, 0};
{
/**
* Non-contiguous
*/
// GPU view
auto t = l.View(0).Slice(linalg::All(), 1, linalg::All());
auto t = l.View(device).Slice(linalg::All(), 1, linalg::All());
ASSERT_FALSE(t.CContiguous());
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
// CPU view
t = l.View(Context::kCpuId).Slice(linalg::All(), 1, linalg::All());
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
size_t k = 0;
for (size_t i = 0; i < l.Shape(0); ++i) {
for (size_t j = 0; j < l.Shape(2); ++j) {
@ -29,7 +30,7 @@ void TestElementWiseKernel() {
}
}
t = l.View(0).Slice(linalg::All(), 1, linalg::All());
t = l.View(device).Slice(linalg::All(), 1, linalg::All());
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
}
@ -37,11 +38,11 @@ void TestElementWiseKernel() {
/**
* Contiguous
*/
auto t = l.View(0);
auto t = l.View(device);
ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; });
ASSERT_TRUE(t.CContiguous());
// CPU view
t = l.View(Context::kCpuId);
t = l.View(DeviceOrd::CPU());
size_t ind = 0;
for (size_t i = 0; i < l.Shape(0); ++i) {

View File

@ -41,7 +41,7 @@ void TestCalcQueriesInvIDCG() {
p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}});
cuda_impl::CalcQueriesInvIDCG(&ctx, linalg::MakeTensorView(&ctx, d_scores, d_scores.size()),
dh::ToSpan(group_ptr), inv_IDCG.View(ctx.gpu_id), p);
dh::ToSpan(group_ptr), inv_IDCG.View(ctx.Device()), p);
for (std::size_t i = 0; i < n_groups; ++i) {
double inv_idcg = inv_IDCG(i);
ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps);

View File

@ -47,7 +47,7 @@ class StatsGPU : public ::testing::Test {
data.insert(data.cend(), seg.begin(), seg.end());
data.insert(data.cend(), seg.begin(), seg.end());
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
auto d_arr = arr.View(0);
auto d_arr = arr.View(DeviceOrd::CUDA(0));
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
@ -71,8 +71,8 @@ class StatsGPU : public ::testing::Test {
}
void Weighted() {
auto d_arr = arr_.View(0);
auto d_key = indptr_.View(0);
auto d_arr = arr_.View(DeviceOrd::CUDA(0));
auto d_key = indptr_.View(DeviceOrd::CUDA(0));
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
@ -81,7 +81,7 @@ class StatsGPU : public ::testing::Test {
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
linalg::Tensor<float, 1> weights{{10}, 0};
linalg::ElementWiseTransformDevice(weights.View(0),
linalg::ElementWiseTransformDevice(weights.View(DeviceOrd::CUDA(0)),
[=] XGBOOST_DEVICE(std::size_t, float) { return 1.0; });
auto w_it = weights.Data()->ConstDevicePointer();
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
@ -102,7 +102,7 @@ class StatsGPU : public ::testing::Test {
data.insert(data.cend(), seg.begin(), seg.end());
data.insert(data.cend(), seg.begin(), seg.end());
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
auto d_arr = arr.View(0);
auto d_arr = arr.View(DeviceOrd::CUDA(0));
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
@ -125,8 +125,8 @@ class StatsGPU : public ::testing::Test {
}
void NonWeighted() {
auto d_arr = arr_.View(0);
auto d_key = indptr_.View(0);
auto d_arr = arr_.View(DeviceOrd::CUDA(0));
auto d_key = indptr_.View(DeviceOrd::CUDA(0));
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul), [=] __device__(std::size_t i) { return d_key(i); });

View File

@ -22,7 +22,7 @@ TEST(ArrayInterface, Initialize) {
HostDeviceVector<size_t> u64_storage(storage.Size());
std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView<size_t const, 2>{
u64_storage.ConstHostSpan(), {kRows, kCols}, Context::kCpuId})};
u64_storage.ConstHostSpan(), {kRows, kCols}, DeviceOrd::CPU()})};
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
u64_storage.HostSpan().begin());
auto u64_arr = ArrayInterface<2>{u64_arr_str};

View File

@ -129,8 +129,8 @@ TEST(MetaInfo, SaveLoadBinary) {
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
auto orig_margin = info.base_margin_.View(xgboost::Context::kCpuId);
auto read_margin = inforead.base_margin_.View(xgboost::Context::kCpuId);
auto orig_margin = info.base_margin_.View(xgboost::DeviceOrd::CPU());
auto read_margin = inforead.base_margin_.View(xgboost::DeviceOrd::CPU());
EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(),
read_margin.Values().cbegin()));
@ -267,8 +267,8 @@ TEST(MetaInfo, Validate) {
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
d_groups.SetDevice(0);
d_groups.DevicePointer(); // pull to device
std::string arr_interface_str{ArrayInterfaceStr(
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
std::string arr_interface_str{ArrayInterfaceStr(xgboost::linalg::MakeVec(
d_groups.ConstDevicePointer(), d_groups.Size(), xgboost::DeviceOrd::CUDA(0)))};
EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA)
}
@ -307,5 +307,5 @@ TEST(MetaInfo, HostExtend) {
}
namespace xgboost {
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); }
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); }
} // namespace xgboost

View File

@ -65,7 +65,7 @@ TEST(MetaInfo, FromInterface) {
}
info.SetInfo(ctx, "base_margin", str.c_str());
auto const h_base_margin = info.base_margin_.View(Context::kCpuId);
auto const h_base_margin = info.base_margin_.View(DeviceOrd::CPU());
ASSERT_EQ(h_base_margin.Size(), d_data.size());
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_base_margin(i), d_data[i]);
@ -83,7 +83,7 @@ TEST(MetaInfo, FromInterface) {
}
TEST(MetaInfo, GPUStridedData) {
TestMetaInfoStridedData(0);
TestMetaInfoStridedData(DeviceOrd::CUDA(0));
}
TEST(MetaInfo, Group) {

View File

@ -14,10 +14,10 @@
#include "../../../src/data/array_interface.h"
namespace xgboost {
inline void TestMetaInfoStridedData(int32_t device) {
inline void TestMetaInfoStridedData(DeviceOrd device) {
MetaInfo info;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
ctx.UpdateAllowUnknown(Args{{"device", device.Name()}});
{
// labels
linalg::Tensor<float, 3> labels;
@ -28,9 +28,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
ASSERT_EQ(t_labels.Shape().size(), 2);
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
auto const& h_result = info.labels.View(-1);
auto const& h_result = info.labels.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(-1);
auto in_labels = labels.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
@ -62,9 +62,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
ASSERT_EQ(t_margin.Shape().size(), 2);
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
auto const& h_result = info.base_margin_.View(-1);
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(-1);
auto in_margin = base_margin.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);

View File

@ -298,8 +298,8 @@ TEST(SimpleDMatrix, Slice) {
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
out->Info().weights_.HostVector().at(i));
auto out_margin = out->Info().base_margin_.View(Context::kCpuId);
auto in_margin = margin.View(Context::kCpuId);
auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU());
auto in_margin = margin.View(DeviceOrd::CPU());
for (size_t j = 0; j < kClasses; ++j) {
ASSERT_EQ(out_margin(i, j), in_margin(ridx, j));
}
@ -372,8 +372,8 @@ TEST(SimpleDMatrix, SliceCol) {
out->Info().labels_upper_bound_.HostVector().at(i));
ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i));
auto out_margin = out->Info().base_margin_.View(Context::kCpuId);
auto in_margin = margin.View(Context::kCpuId);
auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU());
auto in_margin = margin.View(DeviceOrd::CPU());
for (size_t j = 0; j < kClasses; ++j) {
ASSERT_EQ(out_margin(i, j), in_margin(i, j));
}

View File

@ -39,9 +39,9 @@ void TestGPUMakePair() {
auto make_args = [&](std::shared_ptr<ltr::RankingCache> p_cache, auto rank_idx,
common::Span<std::size_t const> y_sorted_idx) {
linalg::Vector<double> dummy;
auto d = dummy.View(ctx.gpu_id);
auto d = dummy.View(ctx.Device());
linalg::Vector<GradientPair> dgpair;
auto dg = dgpair.View(ctx.gpu_id);
auto dg = dgpair.View(ctx.Device());
cuda_impl::KernelInputs args{
d,
d,
@ -50,9 +50,9 @@ void TestGPUMakePair() {
p_cache->DataGroupPtr(&ctx),
p_cache->CUDAThreadsGroupPtr(),
rank_idx,
info.labels.View(ctx.gpu_id),
info.labels.View(ctx.Device()),
predt.ConstDeviceSpan(),
linalg::MatrixView<GradientPair>{common::Span<GradientPair>{}, {0}, 0},
linalg::MatrixView<GradientPair>{common::Span<GradientPair>{}, {0}, DeviceOrd::CUDA(0)},
dg,
nullptr,
y_sorted_idx,

View File

@ -226,7 +226,7 @@ TEST(GPUPredictor, ShapStump) {
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
auto& phis = predictions.HostVector();
auto base_score = mparam.BaseScore(Context::kCpuId)(0);
auto base_score = mparam.BaseScore(DeviceOrd::CPU())(0);
EXPECT_EQ(phis[0], 0.0);
EXPECT_EQ(phis[1], base_score);
EXPECT_EQ(phis[2], 0.0);

View File

@ -287,7 +287,7 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
auto score = mparam.BaseScore(Context::kCpuId)(0);
auto score = mparam.BaseScore(DeviceOrd::CPU())(0);
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
right_weight + score); // go to right for matching cat