Use the new DeviceOrd in the linalg module. (#9527)
This commit is contained in:
parent
942b957eef
commit
ddf2e68821
@ -102,6 +102,14 @@ class HostDeviceVector {
|
|||||||
bool Empty() const { return Size() == 0; }
|
bool Empty() const { return Size() == 0; }
|
||||||
size_t Size() const;
|
size_t Size() const;
|
||||||
int DeviceIdx() const;
|
int DeviceIdx() const;
|
||||||
|
DeviceOrd Device() const {
|
||||||
|
auto idx = this->DeviceIdx();
|
||||||
|
if (idx == DeviceOrd::CPU().ordinal) {
|
||||||
|
return DeviceOrd::CPU();
|
||||||
|
} else {
|
||||||
|
return DeviceOrd::CUDA(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
common::Span<T> DeviceSpan();
|
common::Span<T> DeviceSpan();
|
||||||
common::Span<const T> ConstDeviceSpan() const;
|
common::Span<const T> ConstDeviceSpan() const;
|
||||||
common::Span<const T> DeviceSpan() const { return ConstDeviceSpan(); }
|
common::Span<const T> DeviceSpan() const { return ConstDeviceSpan(); }
|
||||||
|
|||||||
@ -330,7 +330,7 @@ struct LearnerModelParam {
|
|||||||
multi_strategy{multi_strategy} {}
|
multi_strategy{multi_strategy} {}
|
||||||
|
|
||||||
linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
|
linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
|
||||||
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(std::int32_t device) const;
|
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(DeviceOrd device) const;
|
||||||
|
|
||||||
void Copy(LearnerModelParam const& that);
|
void Copy(LearnerModelParam const& that);
|
||||||
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
||||||
|
|||||||
@ -302,7 +302,7 @@ class TensorView {
|
|||||||
T *ptr_{nullptr}; // pointer of data_ to avoid bound check.
|
T *ptr_{nullptr}; // pointer of data_ to avoid bound check.
|
||||||
|
|
||||||
size_t size_{0};
|
size_t size_{0};
|
||||||
int32_t device_{-1};
|
DeviceOrd device_;
|
||||||
|
|
||||||
// Unlike `Tensor`, the data_ can have arbitrary size since this is just a view.
|
// Unlike `Tensor`, the data_ can have arbitrary size since this is just a view.
|
||||||
LINALG_HD void CalcSize() {
|
LINALG_HD void CalcSize() {
|
||||||
@ -401,15 +401,11 @@ class TensorView {
|
|||||||
* \param device Device ordinal
|
* \param device Device ordinal
|
||||||
*/
|
*/
|
||||||
template <typename I, std::int32_t D>
|
template <typename I, std::int32_t D>
|
||||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device)
|
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device)
|
||||||
: TensorView{data, shape, device, Order::kC} {}
|
: TensorView{data, shape, device, Order::kC} {}
|
||||||
|
|
||||||
template <typename I, std::int32_t D>
|
|
||||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device)
|
|
||||||
: TensorView{data, shape, device.ordinal, Order::kC} {}
|
|
||||||
|
|
||||||
template <typename I, int32_t D>
|
template <typename I, int32_t D>
|
||||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order)
|
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device, Order order)
|
||||||
: data_{data}, ptr_{data_.data()}, device_{device} {
|
: data_{data}, ptr_{data_.data()}, device_{device} {
|
||||||
static_assert(D > 0 && D <= kDim, "Invalid shape.");
|
static_assert(D > 0 && D <= kDim, "Invalid shape.");
|
||||||
// shape
|
// shape
|
||||||
@ -441,7 +437,7 @@ class TensorView {
|
|||||||
*/
|
*/
|
||||||
template <typename I, std::int32_t D>
|
template <typename I, std::int32_t D>
|
||||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
|
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
|
||||||
std::int32_t device)
|
DeviceOrd device)
|
||||||
: data_{data}, ptr_{data_.data()}, device_{device} {
|
: data_{data}, ptr_{data_.data()}, device_{device} {
|
||||||
static_assert(D == kDim, "Invalid shape & stride.");
|
static_assert(D == kDim, "Invalid shape & stride.");
|
||||||
detail::UnrollLoop<D>([&](auto i) {
|
detail::UnrollLoop<D>([&](auto i) {
|
||||||
@ -450,16 +446,12 @@ class TensorView {
|
|||||||
});
|
});
|
||||||
this->CalcSize();
|
this->CalcSize();
|
||||||
}
|
}
|
||||||
template <typename I, std::int32_t D>
|
|
||||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
|
|
||||||
DeviceOrd device)
|
|
||||||
: TensorView{data, shape, stride, device.ordinal} {}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename U,
|
typename U,
|
||||||
std::enable_if_t<common::detail::IsAllowedElementTypeConversion<U, T>::value> * = nullptr>
|
std::enable_if_t<common::detail::IsAllowedElementTypeConversion<U, T>::value> * = nullptr>
|
||||||
LINALG_HD TensorView(TensorView<U, kDim> const &that) // NOLINT
|
LINALG_HD TensorView(TensorView<U, kDim> const &that) // NOLINT
|
||||||
: data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.DeviceIdx()} {
|
: data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.Device()} {
|
||||||
detail::UnrollLoop<kDim>([&](auto i) {
|
detail::UnrollLoop<kDim>([&](auto i) {
|
||||||
stride_[i] = that.Stride(i);
|
stride_[i] = that.Stride(i);
|
||||||
shape_[i] = that.Shape(i);
|
shape_[i] = that.Shape(i);
|
||||||
@ -572,7 +564,7 @@ class TensorView {
|
|||||||
/**
|
/**
|
||||||
* \brief Obtain the CUDA device ordinal.
|
* \brief Obtain the CUDA device ordinal.
|
||||||
*/
|
*/
|
||||||
LINALG_HD auto DeviceIdx() const { return device_; }
|
LINALG_HD auto Device() const { return device_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -587,11 +579,11 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
|
|||||||
typename Container::value_type>;
|
typename Container::value_type>;
|
||||||
std::size_t in_shape[sizeof...(S)];
|
std::size_t in_shape[sizeof...(S)];
|
||||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
|
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...shape) {
|
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
|
||||||
std::size_t in_shape[sizeof...(S)];
|
std::size_t in_shape[sizeof...(S)];
|
||||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
||||||
@ -599,26 +591,26 @@ LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...
|
|||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
||||||
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
|
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
||||||
std::size_t in_shape[sizeof...(S)];
|
std::size_t in_shape[sizeof...(S)];
|
||||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Ordinal(), order};
|
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
||||||
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
||||||
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
|
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
|
||||||
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
|
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
|
||||||
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
|
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -661,20 +653,20 @@ using VectorView = TensorView<T, 1>;
|
|||||||
* \param device (optional) Device ordinal, default to be host.
|
* \param device (optional) Device ordinal, default to be host.
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
auto MakeVec(T *ptr, size_t s, int32_t device = -1) {
|
auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
|
||||||
return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
|
return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
auto MakeVec(HostDeviceVector<T> *data) {
|
auto MakeVec(HostDeviceVector<T> *data) {
|
||||||
return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(),
|
return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(),
|
||||||
data->Size(), data->DeviceIdx());
|
data->Size(), data->Device());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
auto MakeVec(HostDeviceVector<T> const *data) {
|
auto MakeVec(HostDeviceVector<T> const *data) {
|
||||||
return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(),
|
return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(),
|
||||||
data->Size(), data->DeviceIdx());
|
data->Size(), data->Device());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -697,7 +689,7 @@ Json ArrayInterface(TensorView<T const, D> const &t) {
|
|||||||
array_interface["data"] = std::vector<Json>(2);
|
array_interface["data"] = std::vector<Json>(2);
|
||||||
array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(t.Values().data())};
|
array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(t.Values().data())};
|
||||||
array_interface["data"][1] = Boolean{true};
|
array_interface["data"][1] = Boolean{true};
|
||||||
if (t.DeviceIdx() >= 0) {
|
if (t.Device().IsCUDA()) {
|
||||||
// Change this once we have different CUDA stream.
|
// Change this once we have different CUDA stream.
|
||||||
array_interface["stream"] = Null{};
|
array_interface["stream"] = Null{};
|
||||||
}
|
}
|
||||||
@ -856,49 +848,29 @@ class Tensor {
|
|||||||
/**
|
/**
|
||||||
* @brief Get a @ref TensorView for this tensor.
|
* @brief Get a @ref TensorView for this tensor.
|
||||||
*/
|
*/
|
||||||
TensorView<T, kDim> View(std::int32_t device) {
|
|
||||||
if (device >= 0) {
|
|
||||||
data_.SetDevice(device);
|
|
||||||
auto span = data_.DeviceSpan();
|
|
||||||
return {span, shape_, device, order_};
|
|
||||||
} else {
|
|
||||||
auto span = data_.HostSpan();
|
|
||||||
return {span, shape_, device, order_};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TensorView<T const, kDim> View(std::int32_t device) const {
|
|
||||||
if (device >= 0) {
|
|
||||||
data_.SetDevice(device);
|
|
||||||
auto span = data_.ConstDeviceSpan();
|
|
||||||
return {span, shape_, device, order_};
|
|
||||||
} else {
|
|
||||||
auto span = data_.ConstHostSpan();
|
|
||||||
return {span, shape_, device, order_};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto View(DeviceOrd device) {
|
auto View(DeviceOrd device) {
|
||||||
if (device.IsCUDA()) {
|
if (device.IsCUDA()) {
|
||||||
data_.SetDevice(device);
|
data_.SetDevice(device);
|
||||||
auto span = data_.DeviceSpan();
|
auto span = data_.DeviceSpan();
|
||||||
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
|
return TensorView<T, kDim>{span, shape_, device, order_};
|
||||||
} else {
|
} else {
|
||||||
auto span = data_.HostSpan();
|
auto span = data_.HostSpan();
|
||||||
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
|
return TensorView<T, kDim>{span, shape_, device, order_};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto View(DeviceOrd device) const {
|
auto View(DeviceOrd device) const {
|
||||||
if (device.IsCUDA()) {
|
if (device.IsCUDA()) {
|
||||||
data_.SetDevice(device);
|
data_.SetDevice(device);
|
||||||
auto span = data_.ConstDeviceSpan();
|
auto span = data_.ConstDeviceSpan();
|
||||||
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
|
return TensorView<T const, kDim>{span, shape_, device, order_};
|
||||||
} else {
|
} else {
|
||||||
auto span = data_.ConstHostSpan();
|
auto span = data_.ConstHostSpan();
|
||||||
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
|
return TensorView<T const, kDim>{span, shape_, device, order_};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto HostView() const { return this->View(-1); }
|
auto HostView() { return this->View(DeviceOrd::CPU()); }
|
||||||
auto HostView() { return this->View(-1); }
|
auto HostView() const { return this->View(DeviceOrd::CPU()); }
|
||||||
|
|
||||||
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
||||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||||
@ -975,6 +947,7 @@ class Tensor {
|
|||||||
void SetDevice(int32_t device) const { data_.SetDevice(device); }
|
void SetDevice(int32_t device) const { data_.SetDevice(device); }
|
||||||
void SetDevice(DeviceOrd device) const { data_.SetDevice(device); }
|
void SetDevice(DeviceOrd device) const { data_.SetDevice(device); }
|
||||||
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
||||||
|
[[nodiscard]] DeviceOrd Device() const { return data_.Device(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -37,12 +37,12 @@ class MultiTargetTree : public Model {
|
|||||||
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
|
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
|
||||||
auto beg = nidx * this->NumTarget();
|
auto beg = nidx * this->NumTarget();
|
||||||
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
|
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
|
||||||
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
|
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
|
||||||
}
|
}
|
||||||
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
|
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
|
||||||
auto beg = nidx * this->NumTarget();
|
auto beg = nidx * this->NumTarget();
|
||||||
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
|
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
|
||||||
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
|
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -68,7 +68,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
|
|||||||
auto &gpair = *out_gpair;
|
auto &gpair = *out_gpair;
|
||||||
gpair.SetDevice(grad_dev);
|
gpair.SetDevice(grad_dev);
|
||||||
gpair.Reshape(grad.Shape(0), grad.Shape(1));
|
gpair.Reshape(grad.Shape(0), grad.Shape(1));
|
||||||
auto d_gpair = gpair.View(grad_dev);
|
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
|
||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
|
||||||
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {
|
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {
|
||||||
|
|||||||
@ -13,7 +13,7 @@ namespace xgboost {
|
|||||||
namespace linalg {
|
namespace linalg {
|
||||||
template <typename T, int32_t D, typename Fn>
|
template <typename T, int32_t D, typename Fn>
|
||||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||||
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
|
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
|
||||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||||
"For function with return, use transform instead.");
|
"For function with return, use transform instead.");
|
||||||
if (t.Contiguous()) {
|
if (t.Contiguous()) {
|
||||||
|
|||||||
@ -133,7 +133,7 @@ struct WeightOp {
|
|||||||
void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
|
|
||||||
group_ptr_.SetDevice(ctx->gpu_id);
|
group_ptr_.SetDevice(ctx->Device());
|
||||||
if (info.group_ptr_.empty()) {
|
if (info.group_ptr_.empty()) {
|
||||||
group_ptr_.Resize(2, 0);
|
group_ptr_.Resize(2, 0);
|
||||||
group_ptr_.HostVector()[1] = info.num_row_;
|
group_ptr_.HostVector()[1] = info.num_row_;
|
||||||
@ -153,7 +153,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
|||||||
max_group_size_ =
|
max_group_size_ =
|
||||||
thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum<std::size_t>{});
|
thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum<std::size_t>{});
|
||||||
|
|
||||||
threads_group_ptr_.SetDevice(ctx->gpu_id);
|
threads_group_ptr_.SetDevice(ctx->Device());
|
||||||
threads_group_ptr_.Resize(n_groups + 1, 0);
|
threads_group_ptr_.Resize(n_groups + 1, 0);
|
||||||
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
|
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
|
||||||
if (param_.HasTruncation()) {
|
if (param_.HasTruncation()) {
|
||||||
@ -168,7 +168,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
|||||||
n_cuda_threads_ = info.num_row_ * param_.NumPair();
|
n_cuda_threads_ = info.num_row_ * param_.NumPair();
|
||||||
}
|
}
|
||||||
|
|
||||||
sorted_idx_cache_.SetDevice(ctx->gpu_id);
|
sorted_idx_cache_.SetDevice(ctx->Device());
|
||||||
sorted_idx_cache_.Resize(info.labels.Size(), 0);
|
sorted_idx_cache_.Resize(info.labels.Size(), 0);
|
||||||
|
|
||||||
auto weight = common::MakeOptionalWeights(ctx, info.weights_);
|
auto weight = common::MakeOptionalWeights(ctx, info.weights_);
|
||||||
@ -187,18 +187,18 @@ common::Span<std::size_t const> RankingCache::MakeRankOnCUDA(Context const* ctx,
|
|||||||
|
|
||||||
void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto labels = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx});
|
CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx});
|
||||||
|
|
||||||
auto d_group_ptr = this->DataGroupPtr(ctx);
|
auto d_group_ptr = this->DataGroupPtr(ctx);
|
||||||
|
|
||||||
std::size_t n_groups = d_group_ptr.size() - 1;
|
std::size_t n_groups = d_group_ptr.size() - 1;
|
||||||
inv_idcg_ = linalg::Zeros<double>(ctx, n_groups);
|
inv_idcg_ = linalg::Zeros<double>(ctx, n_groups);
|
||||||
auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id);
|
auto d_inv_idcg = inv_idcg_.View(ctx->Device());
|
||||||
cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param());
|
cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param());
|
||||||
CHECK_GE(this->Param().NumPair(), 1ul);
|
CHECK_GE(this->Param().NumPair(), 1ul);
|
||||||
|
|
||||||
discounts_.SetDevice(ctx->gpu_id);
|
discounts_.SetDevice(ctx->Device());
|
||||||
discounts_.Resize(MaxGroupSize());
|
discounts_.Resize(MaxGroupSize());
|
||||||
auto d_discount = discounts_.DeviceSpan();
|
auto d_discount = discounts_.DeviceSpan();
|
||||||
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
|
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
|
||||||
@ -206,12 +206,12 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||||
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()});
|
CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||||
}
|
}
|
||||||
|
|
||||||
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||||
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto const d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()});
|
CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||||
}
|
}
|
||||||
} // namespace xgboost::ltr
|
} // namespace xgboost::ltr
|
||||||
|
|||||||
@ -217,7 +217,7 @@ class RankingCache {
|
|||||||
}
|
}
|
||||||
// Constructed as [1, n_samples] if group ptr is not supplied by the user
|
// Constructed as [1, n_samples] if group ptr is not supplied by the user
|
||||||
common::Span<bst_group_t const> DataGroupPtr(Context const* ctx) const {
|
common::Span<bst_group_t const> DataGroupPtr(Context const* ctx) const {
|
||||||
group_ptr_.SetDevice(ctx->gpu_id);
|
group_ptr_.SetDevice(ctx->Device());
|
||||||
return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
|
return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ class RankingCache {
|
|||||||
// Create a rank list by model prediction
|
// Create a rank list by model prediction
|
||||||
common::Span<std::size_t const> SortedIdx(Context const* ctx, common::Span<float const> predt) {
|
common::Span<std::size_t const> SortedIdx(Context const* ctx, common::Span<float const> predt) {
|
||||||
if (sorted_idx_cache_.Empty()) {
|
if (sorted_idx_cache_.Empty()) {
|
||||||
sorted_idx_cache_.SetDevice(ctx->gpu_id);
|
sorted_idx_cache_.SetDevice(ctx->Device());
|
||||||
sorted_idx_cache_.Resize(predt.size());
|
sorted_idx_cache_.Resize(predt.size());
|
||||||
}
|
}
|
||||||
if (ctx->IsCPU()) {
|
if (ctx->IsCPU()) {
|
||||||
@ -242,7 +242,7 @@ class RankingCache {
|
|||||||
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
|
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
|
||||||
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
|
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
|
||||||
if (y_sorted_idx_cache_.Empty()) {
|
if (y_sorted_idx_cache_.Empty()) {
|
||||||
y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
|
y_sorted_idx_cache_.SetDevice(ctx->Device());
|
||||||
y_sorted_idx_cache_.Resize(n_samples);
|
y_sorted_idx_cache_.Resize(n_samples);
|
||||||
}
|
}
|
||||||
return y_sorted_idx_cache_.DeviceSpan();
|
return y_sorted_idx_cache_.DeviceSpan();
|
||||||
@ -250,7 +250,7 @@ class RankingCache {
|
|||||||
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
|
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
|
||||||
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
|
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
|
||||||
if (y_ranked_by_model_.Empty()) {
|
if (y_ranked_by_model_.Empty()) {
|
||||||
y_ranked_by_model_.SetDevice(ctx->gpu_id);
|
y_ranked_by_model_.SetDevice(ctx->Device());
|
||||||
y_ranked_by_model_.Resize(n_samples);
|
y_ranked_by_model_.Resize(n_samples);
|
||||||
}
|
}
|
||||||
return y_ranked_by_model_.DeviceSpan();
|
return y_ranked_by_model_.DeviceSpan();
|
||||||
@ -266,21 +266,21 @@ class RankingCache {
|
|||||||
|
|
||||||
linalg::VectorView<GradientPair> CUDARounding(Context const* ctx) {
|
linalg::VectorView<GradientPair> CUDARounding(Context const* ctx) {
|
||||||
if (roundings_.Size() == 0) {
|
if (roundings_.Size() == 0) {
|
||||||
roundings_.SetDevice(ctx->gpu_id);
|
roundings_.SetDevice(ctx->Device());
|
||||||
roundings_.Reshape(Groups());
|
roundings_.Reshape(Groups());
|
||||||
}
|
}
|
||||||
return roundings_.View(ctx->gpu_id);
|
return roundings_.View(ctx->Device());
|
||||||
}
|
}
|
||||||
common::Span<double> CUDACostRounding(Context const* ctx) {
|
common::Span<double> CUDACostRounding(Context const* ctx) {
|
||||||
if (cost_rounding_.Size() == 0) {
|
if (cost_rounding_.Size() == 0) {
|
||||||
cost_rounding_.SetDevice(ctx->gpu_id);
|
cost_rounding_.SetDevice(ctx->Device());
|
||||||
cost_rounding_.Resize(1);
|
cost_rounding_.Resize(1);
|
||||||
}
|
}
|
||||||
return cost_rounding_.DeviceSpan();
|
return cost_rounding_.DeviceSpan();
|
||||||
}
|
}
|
||||||
template <typename Type>
|
template <typename Type>
|
||||||
common::Span<Type> MaxLambdas(Context const* ctx, std::size_t n) {
|
common::Span<Type> MaxLambdas(Context const* ctx, std::size_t n) {
|
||||||
max_lambdas_.SetDevice(ctx->gpu_id);
|
max_lambdas_.SetDevice(ctx->Device());
|
||||||
std::size_t bytes = n * sizeof(Type);
|
std::size_t bytes = n * sizeof(Type);
|
||||||
if (bytes != max_lambdas_.Size()) {
|
if (bytes != max_lambdas_.Size()) {
|
||||||
max_lambdas_.Resize(bytes);
|
max_lambdas_.Resize(bytes);
|
||||||
@ -315,17 +315,17 @@ class NDCGCache : public RankingCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
linalg::VectorView<double const> InvIDCG(Context const* ctx) const {
|
linalg::VectorView<double const> InvIDCG(Context const* ctx) const {
|
||||||
return inv_idcg_.View(ctx->gpu_id);
|
return inv_idcg_.View(ctx->Device());
|
||||||
}
|
}
|
||||||
common::Span<double const> Discount(Context const* ctx) const {
|
common::Span<double const> Discount(Context const* ctx) const {
|
||||||
return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
|
return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
|
||||||
}
|
}
|
||||||
linalg::VectorView<double> Dcg(Context const* ctx) {
|
linalg::VectorView<double> Dcg(Context const* ctx) {
|
||||||
if (dcg_.Size() == 0) {
|
if (dcg_.Size() == 0) {
|
||||||
dcg_.SetDevice(ctx->gpu_id);
|
dcg_.SetDevice(ctx->Device());
|
||||||
dcg_.Reshape(this->Groups());
|
dcg_.Reshape(this->Groups());
|
||||||
}
|
}
|
||||||
return dcg_.View(ctx->gpu_id);
|
return dcg_.View(ctx->Device());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -396,7 +396,7 @@ class PreCache : public RankingCache {
|
|||||||
|
|
||||||
common::Span<double> Pre(Context const* ctx) {
|
common::Span<double> Pre(Context const* ctx) {
|
||||||
if (pre_.Empty()) {
|
if (pre_.Empty()) {
|
||||||
pre_.SetDevice(ctx->gpu_id);
|
pre_.SetDevice(ctx->Device());
|
||||||
pre_.Resize(this->Groups());
|
pre_.Resize(this->Groups());
|
||||||
}
|
}
|
||||||
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
|
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
|
||||||
@ -427,21 +427,21 @@ class MAPCache : public RankingCache {
|
|||||||
|
|
||||||
common::Span<double> NumRelevant(Context const* ctx) {
|
common::Span<double> NumRelevant(Context const* ctx) {
|
||||||
if (n_rel_.Empty()) {
|
if (n_rel_.Empty()) {
|
||||||
n_rel_.SetDevice(ctx->gpu_id);
|
n_rel_.SetDevice(ctx->Device());
|
||||||
n_rel_.Resize(n_samples_);
|
n_rel_.Resize(n_samples_);
|
||||||
}
|
}
|
||||||
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
|
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
|
||||||
}
|
}
|
||||||
common::Span<double> Acc(Context const* ctx) {
|
common::Span<double> Acc(Context const* ctx) {
|
||||||
if (acc_.Empty()) {
|
if (acc_.Empty()) {
|
||||||
acc_.SetDevice(ctx->gpu_id);
|
acc_.SetDevice(ctx->Device());
|
||||||
acc_.Resize(n_samples_);
|
acc_.Resize(n_samples_);
|
||||||
}
|
}
|
||||||
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
|
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
|
||||||
}
|
}
|
||||||
common::Span<double> Map(Context const* ctx) {
|
common::Span<double> Map(Context const* ctx) {
|
||||||
if (map_.Empty()) {
|
if (map_.Empty()) {
|
||||||
map_.SetDevice(ctx->gpu_id);
|
map_.SetDevice(ctx->Device());
|
||||||
map_.Resize(this->Groups());
|
map_.Resize(this->Groups());
|
||||||
}
|
}
|
||||||
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
|
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
|
||||||
|
|||||||
@ -20,9 +20,9 @@ namespace common {
|
|||||||
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||||
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
|
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
|
||||||
if (!ctx->IsCPU()) {
|
if (!ctx->IsCPU()) {
|
||||||
weights.SetDevice(ctx->gpu_id);
|
weights.SetDevice(ctx->Device());
|
||||||
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
||||||
auto t_v = t.View(ctx->gpu_id);
|
auto t_v = t.View(ctx->Device());
|
||||||
cuda_impl::Median(ctx, t_v, opt_weights, out);
|
cuda_impl::Median(ctx, t_v, opt_weights, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
|
|||||||
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
|
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
|
||||||
out->HostView()(0) = ret;
|
out->HostView()(0) = ret;
|
||||||
} else {
|
} else {
|
||||||
cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id));
|
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -366,7 +366,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||||
// the slice function.
|
// the slice function.
|
||||||
if (this->labels.Size() != this->num_row_) {
|
if (this->labels.Size() != this->num_row_) {
|
||||||
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
|
auto t_labels = this->labels.View(this->labels.Data()->Device());
|
||||||
out.labels.Reshape(ridxs.size(), labels.Shape(1));
|
out.labels.Reshape(ridxs.size(), labels.Shape(1));
|
||||||
out.labels.Data()->HostVector() =
|
out.labels.Data()->HostVector() =
|
||||||
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
|
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
|
||||||
@ -394,7 +394,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
if (this->base_margin_.Size() != this->num_row_) {
|
if (this->base_margin_.Size() != this->num_row_) {
|
||||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||||
<< "Incorrect size of base margin vector.";
|
<< "Incorrect size of base margin vector.";
|
||||||
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
|
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->Device());
|
||||||
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
|
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
|
||||||
out.base_margin_.Data()->HostVector() =
|
out.base_margin_.Data()->HostVector() =
|
||||||
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
|
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
|
||||||
@ -445,7 +445,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
p_out->Reshape(array.shape);
|
p_out->Reshape(array.shape);
|
||||||
auto t_out = p_out->View(Context::kCpuId);
|
auto t_out = p_out->View(DeviceOrd::CPU());
|
||||||
CHECK(t_out.CContiguous());
|
CHECK(t_out.CContiguous());
|
||||||
auto const shape = t_out.Shape();
|
auto const shape = t_out.Shape();
|
||||||
DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
|
DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
|
||||||
@ -564,7 +564,7 @@ void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, Da
|
|||||||
CHECK(key);
|
CHECK(key);
|
||||||
auto proc = [&](auto cast_d_ptr) {
|
auto proc = [&](auto cast_d_ptr) {
|
||||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
|
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
|
||||||
CHECK(t.CContiguous());
|
CHECK(t.CContiguous());
|
||||||
Json interface {
|
Json interface {
|
||||||
linalg::ArrayInterface(t)
|
linalg::ArrayInterface(t)
|
||||||
@ -739,8 +739,7 @@ void MetaInfo::SynchronizeNumberOfColumns() {
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
||||||
bool valid =
|
bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device;
|
||||||
v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device;
|
|
||||||
if (!valid) {
|
if (!valid) {
|
||||||
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
|
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
|
||||||
"the booster. The device ordinal of the data is: "
|
"the booster. The device ordinal of the data is: "
|
||||||
|
|||||||
@ -50,7 +50,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
p_out->Reshape(array.shape);
|
p_out->Reshape(array.shape);
|
||||||
auto t = p_out->View(ptr_device);
|
auto t = p_out->View(DeviceOrd::CUDA(ptr_device));
|
||||||
linalg::ElementWiseTransformDevice(
|
linalg::ElementWiseTransformDevice(
|
||||||
t,
|
t,
|
||||||
[=] __device__(size_t i, T) {
|
[=] __device__(size_t i, T) {
|
||||||
|
|||||||
@ -183,7 +183,7 @@ class GBLinear : public GradientBooster {
|
|||||||
bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override {
|
bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override {
|
||||||
model_.LazyInitModel();
|
model_.LazyInitModel();
|
||||||
LinearCheckLayer(layer_begin);
|
LinearCheckLayer(layer_begin);
|
||||||
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
|
auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU());
|
||||||
const int ngroup = model_.learner_model_param->num_output_group;
|
const int ngroup = model_.learner_model_param->num_output_group;
|
||||||
const size_t ncolumns = model_.learner_model_param->num_feature + 1;
|
const size_t ncolumns = model_.learner_model_param->num_feature + 1;
|
||||||
// allocate space for (#features + bias) times #groups times #rows
|
// allocate space for (#features + bias) times #groups times #rows
|
||||||
@ -250,10 +250,9 @@ class GBLinear : public GradientBooster {
|
|||||||
// The bias is the last weight
|
// The bias is the last weight
|
||||||
out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0);
|
out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0);
|
||||||
auto n_groups = learner_model_param_->num_output_group;
|
auto n_groups = learner_model_param_->num_output_group;
|
||||||
linalg::TensorView<float, 2> scores{
|
auto scores = linalg::MakeTensorView(DeviceOrd::CPU(),
|
||||||
*out_scores,
|
common::Span{out_scores->data(), out_scores->size()},
|
||||||
{learner_model_param_->num_feature, n_groups},
|
learner_model_param_->num_feature, n_groups);
|
||||||
Context::kCpuId};
|
|
||||||
for (size_t i = 0; i < learner_model_param_->num_feature; ++i) {
|
for (size_t i = 0; i < learner_model_param_->num_feature; ++i) {
|
||||||
for (bst_group_t g = 0; g < n_groups; ++g) {
|
for (bst_group_t g = 0; g < n_groups; ++g) {
|
||||||
scores(i, g) = model_[i][g];
|
scores(i, g) = model_[i][g];
|
||||||
@ -275,12 +274,12 @@ class GBLinear : public GradientBooster {
|
|||||||
monitor_.Start("PredictBatchInternal");
|
monitor_.Start("PredictBatchInternal");
|
||||||
model_.LazyInitModel();
|
model_.LazyInitModel();
|
||||||
std::vector<bst_float> &preds = *out_preds;
|
std::vector<bst_float> &preds = *out_preds;
|
||||||
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
|
auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU());
|
||||||
// start collecting the prediction
|
// start collecting the prediction
|
||||||
const int ngroup = model_.learner_model_param->num_output_group;
|
const int ngroup = model_.learner_model_param->num_output_group;
|
||||||
preds.resize(p_fmat->Info().num_row_ * ngroup);
|
preds.resize(p_fmat->Info().num_row_ * ngroup);
|
||||||
|
|
||||||
auto base_score = learner_model_param_->BaseScore(Context::kCpuId);
|
auto base_score = learner_model_param_->BaseScore(DeviceOrd::CPU());
|
||||||
for (const auto &page : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
auto const& batch = page.GetView();
|
auto const& batch = page.GetView();
|
||||||
// output convention: nrow * k, where nrow is number of rows
|
// output convention: nrow * k, where nrow is number of rows
|
||||||
|
|||||||
@ -754,7 +754,7 @@ class Dart : public GBTree {
|
|||||||
auto n_groups = model_.learner_model_param->num_output_group;
|
auto n_groups = model_.learner_model_param->num_output_group;
|
||||||
|
|
||||||
PredictionCacheEntry predts; // temporary storage for prediction
|
PredictionCacheEntry predts; // temporary storage for prediction
|
||||||
if (ctx_->gpu_id != Context::kCpuId) {
|
if (ctx_->IsCUDA()) {
|
||||||
predts.predictions.SetDevice(ctx_->gpu_id);
|
predts.predictions.SetDevice(ctx_->gpu_id);
|
||||||
}
|
}
|
||||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||||
@ -859,12 +859,12 @@ class Dart : public GBTree {
|
|||||||
size_t n_rows = p_fmat->Info().num_row_;
|
size_t n_rows = p_fmat->Info().num_row_;
|
||||||
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
|
if (predts.predictions.DeviceIdx() != Context::kCpuId) {
|
||||||
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
|
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
|
||||||
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.DeviceIdx());
|
auto base_score = model_.learner_model_param->BaseScore(predts.predictions.Device());
|
||||||
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
|
GPUDartInplacePredictInc(p_out_preds->predictions.DeviceSpan(),
|
||||||
predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,
|
predts.predictions.DeviceSpan(), w, n_rows, base_score, n_groups,
|
||||||
group);
|
group);
|
||||||
} else {
|
} else {
|
||||||
auto base_score = model_.learner_model_param->BaseScore(Context::kCpuId);
|
auto base_score = model_.learner_model_param->BaseScore(DeviceOrd::CPU());
|
||||||
auto& h_predts = predts.predictions.HostVector();
|
auto& h_predts = predts.predictions.HostVector();
|
||||||
auto& h_out_predts = p_out_preds->predictions.HostVector();
|
auto& h_out_predts = p_out_preds->predictions.HostVector();
|
||||||
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {
|
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {
|
||||||
|
|||||||
@ -279,15 +279,15 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
|
|||||||
// Make sure read access everywhere for thread-safe prediction.
|
// Make sure read access everywhere for thread-safe prediction.
|
||||||
std::as_const(base_score_).HostView();
|
std::as_const(base_score_).HostView();
|
||||||
if (!ctx->IsCPU()) {
|
if (!ctx->IsCPU()) {
|
||||||
std::as_const(base_score_).View(ctx->gpu_id);
|
std::as_const(base_score_).View(ctx->Device());
|
||||||
}
|
}
|
||||||
CHECK(std::as_const(base_score_).Data()->HostCanRead());
|
CHECK(std::as_const(base_score_).Data()->HostCanRead());
|
||||||
}
|
}
|
||||||
|
|
||||||
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const {
|
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(DeviceOrd device) const {
|
||||||
// multi-class is not yet supported.
|
// multi-class is not yet supported.
|
||||||
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
|
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
|
||||||
if (device == Context::kCpuId) {
|
if (device.IsCPU()) {
|
||||||
// Make sure that we won't run into race condition.
|
// Make sure that we won't run into race condition.
|
||||||
CHECK(base_score_.Data()->HostCanRead());
|
CHECK(base_score_.Data()->HostCanRead());
|
||||||
return base_score_.HostView();
|
return base_score_.HostView();
|
||||||
@ -300,7 +300,7 @@ linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device)
|
|||||||
}
|
}
|
||||||
|
|
||||||
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(Context const* ctx) const {
|
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(Context const* ctx) const {
|
||||||
return this->BaseScore(ctx->gpu_id);
|
return this->BaseScore(ctx->Device());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LearnerModelParam::Copy(LearnerModelParam const& that) {
|
void LearnerModelParam::Copy(LearnerModelParam const& that) {
|
||||||
@ -309,7 +309,7 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) {
|
|||||||
base_score_.Data()->Copy(*that.base_score_.Data());
|
base_score_.Data()->Copy(*that.base_score_.Data());
|
||||||
std::as_const(base_score_).HostView();
|
std::as_const(base_score_).HostView();
|
||||||
if (that.base_score_.DeviceIdx() != Context::kCpuId) {
|
if (that.base_score_.DeviceIdx() != Context::kCpuId) {
|
||||||
std::as_const(base_score_).View(that.base_score_.DeviceIdx());
|
std::as_const(base_score_).View(that.base_score_.Device());
|
||||||
}
|
}
|
||||||
CHECK_EQ(base_score_.Data()->DeviceCanRead(), that.base_score_.Data()->DeviceCanRead());
|
CHECK_EQ(base_score_.Data()->DeviceCanRead(), that.base_score_.Data()->DeviceCanRead());
|
||||||
CHECK(base_score_.Data()->HostCanRead());
|
CHECK(base_score_.Data()->HostCanRead());
|
||||||
@ -388,7 +388,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
this->ConfigureTargets();
|
this->ConfigureTargets();
|
||||||
|
|
||||||
auto task = UsePtr(obj_)->Task();
|
auto task = UsePtr(obj_)->Task();
|
||||||
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
|
linalg::Tensor<float, 1> base_score({1}, Ctx()->Device());
|
||||||
auto h_base_score = base_score.HostView();
|
auto h_base_score = base_score.HostView();
|
||||||
|
|
||||||
// transform to margin
|
// transform to margin
|
||||||
@ -424,7 +424,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) {
|
if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) {
|
||||||
if (p_fmat) {
|
if (p_fmat) {
|
||||||
auto const& info = p_fmat->Info();
|
auto const& info = p_fmat->Info();
|
||||||
info.Validate(Ctx()->gpu_id);
|
info.Validate(Ctx()->Ordinal());
|
||||||
// We estimate it from input data.
|
// We estimate it from input data.
|
||||||
linalg::Tensor<float, 1> base_score;
|
linalg::Tensor<float, 1> base_score;
|
||||||
InitEstimation(info, &base_score);
|
InitEstimation(info, &base_score);
|
||||||
@ -1369,7 +1369,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
|
auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
|
||||||
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
||||||
// Copy the prediction cache to output prediction. out_preds comes from C API
|
// Copy the prediction cache to output prediction. out_preds comes from C API
|
||||||
out_preds->SetDevice(ctx_.gpu_id);
|
out_preds->SetDevice(ctx_.Device());
|
||||||
out_preds->Resize(prediction.predictions.Size());
|
out_preds->Resize(prediction.predictions.Size());
|
||||||
out_preds->Copy(prediction.predictions);
|
out_preds->Copy(prediction.predictions);
|
||||||
if (!output_margin) {
|
if (!output_margin) {
|
||||||
|
|||||||
@ -82,22 +82,19 @@ template <typename BinaryAUC>
|
|||||||
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
|
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
|
||||||
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
|
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
|
||||||
CHECK_NE(n_classes, 0);
|
CHECK_NE(n_classes, 0);
|
||||||
auto const labels = info.labels.View(Context::kCpuId);
|
auto const labels = info.labels.HostView();
|
||||||
if (labels.Shape(0) != 0) {
|
if (labels.Shape(0) != 0) {
|
||||||
CHECK_EQ(labels.Shape(1), 1) << "AUC doesn't support multi-target model.";
|
CHECK_EQ(labels.Shape(1), 1) << "AUC doesn't support multi-target model.";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<double> results_storage(n_classes * 3, 0);
|
std::vector<double> results_storage(n_classes * 3, 0);
|
||||||
linalg::TensorView<double, 2> results(results_storage, {n_classes, static_cast<size_t>(3)},
|
auto results = linalg::MakeTensorView(ctx, results_storage, n_classes, 3);
|
||||||
Context::kCpuId);
|
|
||||||
auto local_area = results.Slice(linalg::All(), 0);
|
auto local_area = results.Slice(linalg::All(), 0);
|
||||||
auto tp = results.Slice(linalg::All(), 1);
|
auto tp = results.Slice(linalg::All(), 1);
|
||||||
auto auc = results.Slice(linalg::All(), 2);
|
auto auc = results.Slice(linalg::All(), 2);
|
||||||
|
|
||||||
auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()};
|
auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()};
|
||||||
auto predts_t = linalg::TensorView<float const, 2>(
|
auto predts_t = linalg::MakeTensorView(ctx, predts, info.num_row_, n_classes);
|
||||||
predts, {static_cast<size_t>(info.num_row_), n_classes},
|
|
||||||
Context::kCpuId);
|
|
||||||
|
|
||||||
if (info.labels.Size() != 0) {
|
if (info.labels.Size() != 0) {
|
||||||
common::ParallelFor(n_classes, n_threads, [&](auto c) {
|
common::ParallelFor(n_classes, n_threads, [&](auto c) {
|
||||||
@ -108,8 +105,8 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
|
|||||||
response[i] = labels(i) == c ? 1.0f : 0.0;
|
response[i] = labels(i) == c ? 1.0f : 0.0;
|
||||||
}
|
}
|
||||||
double fp;
|
double fp;
|
||||||
std::tie(fp, tp(c), auc(c)) =
|
std::tie(fp, tp(c), auc(c)) = binary_auc(
|
||||||
binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
|
ctx, proba, linalg::MakeVec(response.data(), response.size(), ctx->Device()), weights);
|
||||||
local_area(c) = fp * tp(c);
|
local_area(c) = fp * tp(c);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -220,7 +217,7 @@ std::pair<double, uint32_t> RankingAUC(Context const *ctx, std::vector<float> co
|
|||||||
CHECK_GE(info.group_ptr_.size(), 2);
|
CHECK_GE(info.group_ptr_.size(), 2);
|
||||||
uint32_t n_groups = info.group_ptr_.size() - 1;
|
uint32_t n_groups = info.group_ptr_.size() - 1;
|
||||||
auto s_predts = common::Span<float const>{predts};
|
auto s_predts = common::Span<float const>{predts};
|
||||||
auto labels = info.labels.View(Context::kCpuId);
|
auto labels = info.labels.View(ctx->Device());
|
||||||
auto s_weights = info.weights_.ConstHostSpan();
|
auto s_weights = info.weights_.ConstHostSpan();
|
||||||
|
|
||||||
std::atomic<uint32_t> invalid_groups{0};
|
std::atomic<uint32_t> invalid_groups{0};
|
||||||
@ -363,8 +360,8 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
|||||||
info.labels.HostView().Slice(linalg::All(), 0),
|
info.labels.HostView().Slice(linalg::All(), 0),
|
||||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||||
} else {
|
} else {
|
||||||
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
|
std::tie(fp, tp, auc) =
|
||||||
ctx_->gpu_id, &this->d_cache_);
|
GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_tuple(fp, tp, auc);
|
return std::make_tuple(fp, tp, auc);
|
||||||
}
|
}
|
||||||
@ -381,8 +378,7 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
|
|||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, MetaInfo const &,
|
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, MetaInfo const &,
|
||||||
std::int32_t,
|
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
|
||||||
std::shared_ptr<DeviceAUCCache> *) {
|
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
@ -414,8 +410,8 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
|||||||
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||||
} else {
|
} else {
|
||||||
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
|
std::tie(pr, re, auc) =
|
||||||
ctx_->gpu_id, &this->d_cache_);
|
GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_tuple(pr, re, auc);
|
return std::make_tuple(pr, re, auc);
|
||||||
}
|
}
|
||||||
@ -459,7 +455,7 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
|
|||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, MetaInfo const &,
|
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, MetaInfo const &,
|
||||||
std::int32_t, std::shared_ptr<DeviceAUCCache> *) {
|
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,11 +85,11 @@ void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCa
|
|||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
std::tuple<double, double, double>
|
std::tuple<double, double, double>
|
||||||
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
|
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||||
int32_t device, common::Span<size_t const> d_sorted_idx,
|
DeviceOrd device, common::Span<size_t const> d_sorted_idx,
|
||||||
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
|
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
|
||||||
auto labels = info.labels.View(device);
|
auto labels = info.labels.View(device);
|
||||||
auto weights = info.weights_.ConstDeviceSpan();
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
|
|
||||||
CHECK_NE(labels.Size(), 0);
|
CHECK_NE(labels.Size(), 0);
|
||||||
CHECK_EQ(labels.Size(), predts.size());
|
CHECK_EQ(labels.Size(), predts.size());
|
||||||
@ -168,7 +168,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
|
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
|
||||||
MetaInfo const &info, std::int32_t device,
|
MetaInfo const &info, DeviceOrd device,
|
||||||
std::shared_ptr<DeviceAUCCache> *p_cache) {
|
std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
auto &cache = *p_cache;
|
auto &cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
@ -309,9 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
|
|||||||
* up each class in all kernels.
|
* up each class in all kernels.
|
||||||
*/
|
*/
|
||||||
template <bool scale, typename Fn>
|
template <bool scale, typename Fn>
|
||||||
double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<uint32_t> d_class_ptr,
|
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||||
size_t n_classes, std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
common::Span<uint32_t> d_class_ptr, size_t n_classes,
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
/**
|
/**
|
||||||
* Sorted idx
|
* Sorted idx
|
||||||
*/
|
*/
|
||||||
@ -467,11 +468,12 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
|||||||
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
|
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
|
||||||
MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache);
|
MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache);
|
||||||
|
|
||||||
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
|
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, double tp,
|
||||||
double tp, size_t /*class_id*/) {
|
size_t /*class_id*/) {
|
||||||
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
||||||
};
|
};
|
||||||
return GPUMultiClassAUCOVR<true>(info, ctx->gpu_id, dh::ToSpan(class_ptr), n_classes, cache, fn);
|
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
|
||||||
|
fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -512,7 +514,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
|
|||||||
/**
|
/**
|
||||||
* Sort the labels
|
* Sort the labels
|
||||||
*/
|
*/
|
||||||
auto d_labels = info.labels.View(ctx->gpu_id);
|
auto d_labels = info.labels.View(ctx->Device());
|
||||||
|
|
||||||
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
||||||
common::SegmentedArgSort<false, false>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
|
common::SegmentedArgSort<false, false>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
|
||||||
@ -604,7 +606,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
|
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
|
||||||
MetaInfo const &info, std::int32_t device,
|
MetaInfo const &info, DeviceOrd device,
|
||||||
std::shared_ptr<DeviceAUCCache> *p_cache) {
|
std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
@ -662,7 +664,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
|||||||
/**
|
/**
|
||||||
* Get total positive/negative
|
* Get total positive/negative
|
||||||
*/
|
*/
|
||||||
auto labels = info.labels.View(ctx->gpu_id);
|
auto labels = info.labels.View(ctx->Device());
|
||||||
auto n_samples = info.num_row_;
|
auto n_samples = info.num_row_;
|
||||||
dh::caching_device_vector<Pair> totals(n_classes);
|
dh::caching_device_vector<Pair> totals(n_classes);
|
||||||
auto key_it =
|
auto key_it =
|
||||||
@ -695,13 +697,13 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
|||||||
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
||||||
d_totals[class_id].first);
|
d_totals[class_id].first);
|
||||||
};
|
};
|
||||||
return GPUMultiClassAUCOVR<false>(info, ctx->gpu_id, d_class_ptr, n_classes, cache, fn);
|
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
std::pair<double, uint32_t>
|
std::pair<double, uint32_t>
|
||||||
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
|
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
|
||||||
common::Span<uint32_t> d_group_ptr, int32_t device,
|
common::Span<uint32_t> d_group_ptr, DeviceOrd device,
|
||||||
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
||||||
/**
|
/**
|
||||||
* Sorted idx
|
* Sorted idx
|
||||||
@ -843,7 +845,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
|
|||||||
common::SegmentedArgSort<false, false>(ctx, predts, d_group_ptr, d_sorted_idx);
|
common::SegmentedArgSort<false, false>(ctx, predts, d_group_ptr, d_sorted_idx);
|
||||||
|
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
dh::XGBDeviceAllocator<char> alloc;
|
||||||
auto labels = info.labels.View(ctx->gpu_id);
|
auto labels = info.labels.View(ctx->Device());
|
||||||
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
|
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
|
||||||
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
|
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
|
||||||
InvalidLabels();
|
InvalidLabels();
|
||||||
@ -882,7 +884,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
|
|||||||
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
||||||
d_totals[group_id].first);
|
d_totals[group_id].first);
|
||||||
};
|
};
|
||||||
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->gpu_id, cache, fn);
|
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn);
|
||||||
}
|
}
|
||||||
} // namespace metric
|
} // namespace metric
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -30,7 +30,7 @@ XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, doub
|
|||||||
struct DeviceAUCCache;
|
struct DeviceAUCCache;
|
||||||
|
|
||||||
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
|
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
|
||||||
MetaInfo const &info, std::int32_t device,
|
MetaInfo const &info, DeviceOrd,
|
||||||
std::shared_ptr<DeviceAUCCache> *p_cache);
|
std::shared_ptr<DeviceAUCCache> *p_cache);
|
||||||
|
|
||||||
double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
||||||
@ -45,7 +45,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
|
|||||||
* PR AUC *
|
* PR AUC *
|
||||||
**********/
|
**********/
|
||||||
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
|
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
|
||||||
MetaInfo const &info, std::int32_t device,
|
MetaInfo const &info, DeviceOrd,
|
||||||
std::shared_ptr<DeviceAUCCache> *p_cache);
|
std::shared_ptr<DeviceAUCCache> *p_cache);
|
||||||
|
|
||||||
double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
||||||
|
|||||||
@ -45,7 +45,7 @@ namespace {
|
|||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
|
||||||
PackedReduceResult result;
|
PackedReduceResult result;
|
||||||
auto labels = info.labels.View(ctx->gpu_id);
|
auto labels = info.labels.View(ctx->Device());
|
||||||
if (ctx->IsCPU()) {
|
if (ctx->IsCPU()) {
|
||||||
auto n_threads = ctx->Threads();
|
auto n_threads = ctx->Threads();
|
||||||
std::vector<double> score_tloc(n_threads, 0.0);
|
std::vector<double> score_tloc(n_threads, 0.0);
|
||||||
@ -183,10 +183,10 @@ class PseudoErrorLoss : public MetricNoCache {
|
|||||||
|
|
||||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||||
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
||||||
auto labels = info.labels.View(ctx_->gpu_id);
|
auto labels = info.labels.View(ctx_->Device());
|
||||||
preds.SetDevice(ctx_->gpu_id);
|
preds.SetDevice(ctx_->Device());
|
||||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||||
info.weights_.SetDevice(ctx_->gpu_id);
|
info.weights_.SetDevice(ctx_->Device());
|
||||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
: info.weights_.ConstDeviceSpan());
|
: info.weights_.ConstDeviceSpan());
|
||||||
float slope = this->param_.huber_slope;
|
float slope = this->param_.huber_slope;
|
||||||
@ -349,11 +349,11 @@ struct EvalEWiseBase : public MetricNoCache {
|
|||||||
if (info.labels.Size() != 0) {
|
if (info.labels.Size() != 0) {
|
||||||
CHECK_NE(info.labels.Shape(1), 0);
|
CHECK_NE(info.labels.Shape(1), 0);
|
||||||
}
|
}
|
||||||
auto labels = info.labels.View(ctx_->gpu_id);
|
auto labels = info.labels.View(ctx_->Device());
|
||||||
info.weights_.SetDevice(ctx_->gpu_id);
|
info.weights_.SetDevice(ctx_->Device());
|
||||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
: info.weights_.ConstDeviceSpan());
|
: info.weights_.ConstDeviceSpan());
|
||||||
preds.SetDevice(ctx_->gpu_id);
|
preds.SetDevice(ctx_->Device());
|
||||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||||
|
|
||||||
auto d_policy = policy_;
|
auto d_policy = policy_;
|
||||||
@ -444,16 +444,16 @@ class QuantileError : public MetricNoCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto const* ctx = ctx_;
|
auto const* ctx = ctx_;
|
||||||
auto y_true = info.labels.View(ctx->gpu_id);
|
auto y_true = info.labels.View(ctx->Device());
|
||||||
preds.SetDevice(ctx->gpu_id);
|
preds.SetDevice(ctx->Device());
|
||||||
alpha_.SetDevice(ctx->gpu_id);
|
alpha_.SetDevice(ctx->Device());
|
||||||
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
|
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
|
||||||
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
|
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
|
||||||
CHECK_NE(n_targets, 0);
|
CHECK_NE(n_targets, 0);
|
||||||
auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast<std::size_t>(info.num_row_),
|
auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast<std::size_t>(info.num_row_),
|
||||||
alpha_.Size(), n_targets);
|
alpha_.Size(), n_targets);
|
||||||
|
|
||||||
info.weights_.SetDevice(ctx->gpu_id);
|
info.weights_.SetDevice(ctx->Device());
|
||||||
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
|
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
: info.weights_.ConstDeviceSpan()};
|
: info.weights_.ConstDeviceSpan()};
|
||||||
|
|
||||||
|
|||||||
@ -75,7 +75,7 @@ struct EvalAMS : public MetricNoCache {
|
|||||||
const double br = 10.0;
|
const double br = 10.0;
|
||||||
unsigned thresindex = 0;
|
unsigned thresindex = 0;
|
||||||
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
|
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
|
||||||
const auto& labels = info.labels.View(Context::kCpuId);
|
const auto& labels = info.labels.View(DeviceOrd::CPU());
|
||||||
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
|
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
|
||||||
const unsigned ridx = rec[i].second;
|
const unsigned ridx = rec[i].second;
|
||||||
const bst_float wt = info.GetWeight(ridx);
|
const bst_float wt = info.GetWeight(ridx);
|
||||||
@ -134,7 +134,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
|||||||
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
|
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto& labels = info.labels.View(Context::kCpuId);
|
const auto& labels = info.labels.HostView();
|
||||||
const auto &h_preds = preds.ConstHostVector();
|
const auto &h_preds = preds.ConstHostVector();
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
dmlc::OMPException exc;
|
||||||
|
|||||||
@ -33,7 +33,7 @@ PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
|
|||||||
HostDeviceVector<float> const &predt,
|
HostDeviceVector<float> const &predt,
|
||||||
std::shared_ptr<ltr::PreCache> p_cache) {
|
std::shared_ptr<ltr::PreCache> p_cache) {
|
||||||
auto d_gptr = p_cache->DataGroupPtr(ctx);
|
auto d_gptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
|
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->gpu_id);
|
||||||
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
||||||
@ -89,7 +89,7 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
|||||||
if (!d_weight.Empty()) {
|
if (!d_weight.Empty()) {
|
||||||
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
|
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
|
||||||
}
|
}
|
||||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->gpu_id);
|
||||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
|
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
|
||||||
|
|
||||||
@ -119,9 +119,9 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
|
|||||||
HostDeviceVector<float> const &predt, bool minus,
|
HostDeviceVector<float> const &predt, bool minus,
|
||||||
std::shared_ptr<ltr::MAPCache> p_cache) {
|
std::shared_ptr<ltr::MAPCache> p_cache) {
|
||||||
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto d_label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
|
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->Device());
|
||||||
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
|
|||||||
@ -19,7 +19,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
dh::device_vector<size_t>* p_ridx, HostDeviceVector<size_t>* p_nptr,
|
dh::device_vector<size_t>* p_ridx, HostDeviceVector<size_t>* p_nptr,
|
||||||
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
|
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
|
||||||
// copy position to buffer
|
// copy position to buffer
|
||||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||||
auto cuctx = ctx->CUDACtx();
|
auto cuctx = ctx->CUDACtx();
|
||||||
size_t n_samples = position.size();
|
size_t n_samples = position.size();
|
||||||
dh::device_vector<bst_node_t> sorted_position(position.size());
|
dh::device_vector<bst_node_t> sorted_position(position.size());
|
||||||
@ -86,11 +86,11 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
*/
|
*/
|
||||||
auto& nidx = *p_nidx;
|
auto& nidx = *p_nidx;
|
||||||
auto& nptr = *p_nptr;
|
auto& nptr = *p_nptr;
|
||||||
nidx.SetDevice(ctx->gpu_id);
|
nidx.SetDevice(ctx->Device());
|
||||||
nidx.Resize(n_leaf);
|
nidx.Resize(n_leaf);
|
||||||
auto d_node_idx = nidx.DeviceSpan();
|
auto d_node_idx = nidx.DeviceSpan();
|
||||||
|
|
||||||
nptr.SetDevice(ctx->gpu_id);
|
nptr.SetDevice(ctx->Device());
|
||||||
nptr.Resize(n_leaf + 1, 0);
|
nptr.Resize(n_leaf + 1, 0);
|
||||||
auto d_node_ptr = nptr.DeviceSpan();
|
auto d_node_ptr = nptr.DeviceSpan();
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||||
dh::device_vector<size_t> ridx;
|
dh::device_vector<size_t> ridx;
|
||||||
HostDeviceVector<size_t> nptr;
|
HostDeviceVector<size_t> nptr;
|
||||||
HostDeviceVector<bst_node_t> nidx;
|
HostDeviceVector<bst_node_t> nidx;
|
||||||
@ -155,13 +155,13 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<float> quantiles;
|
HostDeviceVector<float> quantiles;
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->Device());
|
||||||
|
|
||||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
||||||
predt.Size() / info.num_row_);
|
predt.Size() / info.num_row_);
|
||||||
CHECK_LT(group_idx, d_predt.Shape(1));
|
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||||
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||||
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));
|
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||||
|
|
||||||
auto d_row_index = dh::ToSpan(ridx);
|
auto d_row_index = dh::ToSpan(ridx);
|
||||||
auto seg_beg = nptr.DevicePointer();
|
auto seg_beg = nptr.DevicePointer();
|
||||||
@ -178,7 +178,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
if (info.weights_.Empty()) {
|
if (info.weights_.Empty()) {
|
||||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||||
} else {
|
} else {
|
||||||
info.weights_.SetDevice(ctx->gpu_id);
|
info.weights_.SetDevice(ctx->Device());
|
||||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||||
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||||
|
|||||||
@ -109,12 +109,12 @@ class LambdaRankObj : public FitIntercept {
|
|||||||
lj_.SetDevice(ctx_->gpu_id);
|
lj_.SetDevice(ctx_->gpu_id);
|
||||||
|
|
||||||
if (ctx_->IsCPU()) {
|
if (ctx_->IsCPU()) {
|
||||||
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
|
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
|
||||||
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
|
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
|
||||||
&li_, &lj_, p_cache_);
|
&li_, &lj_, p_cache_);
|
||||||
} else {
|
} else {
|
||||||
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
|
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
|
||||||
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
|
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
|
||||||
&li_, &lj_, p_cache_);
|
&li_, &lj_, p_cache_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,9 +354,9 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
|
|||||||
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
|
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
cuda_impl::LambdaRankGetGradientNDCG(
|
cuda_impl::LambdaRankGetGradientNDCG(
|
||||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
|
||||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
|
||||||
out_gpair);
|
lj_full_.View(ctx_->Device()), out_gpair);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -477,9 +477,9 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
|
|||||||
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
|
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
return cuda_impl::LambdaRankGetGradientMAP(
|
return cuda_impl::LambdaRankGetGradientMAP(
|
||||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
|
||||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
|
||||||
out_gpair);
|
lj_full_.View(ctx_->Device()), out_gpair);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gptr = p_cache_->DataGroupPtr(ctx_).data();
|
auto gptr = p_cache_->DataGroupPtr(ctx_).data();
|
||||||
@ -567,9 +567,9 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
|
|||||||
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
|
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
return cuda_impl::LambdaRankGetGradientPairwise(
|
return cuda_impl::LambdaRankGetGradientPairwise(
|
||||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()),
|
||||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
tj_minus_.View(ctx_->Device()), li_full_.View(ctx_->Device()),
|
||||||
out_gpair);
|
lj_full_.View(ctx_->Device()), out_gpair);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gptr = p_cache_->DataGroupPtr(ctx_);
|
auto gptr = p_cache_->DataGroupPtr(ctx_);
|
||||||
|
|||||||
@ -306,7 +306,7 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
|
|||||||
|
|
||||||
CHECK_NE(d_rounding.Size(), 0);
|
CHECK_NE(d_rounding.Size(), 0);
|
||||||
|
|
||||||
auto label = info.labels.View(ctx->gpu_id);
|
auto label = info.labels.View(ctx->Device());
|
||||||
auto predts = preds.ConstDeviceSpan();
|
auto predts = preds.ConstDeviceSpan();
|
||||||
auto gpairs = out_gpair->View(ctx->Device());
|
auto gpairs = out_gpair->View(ctx->Device());
|
||||||
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(),
|
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(),
|
||||||
@ -348,7 +348,7 @@ common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
|
|||||||
common::Span<std::size_t const> d_rank,
|
common::Span<std::size_t const> d_rank,
|
||||||
std::shared_ptr<ltr::RankingCache> p_cache) {
|
std::shared_ptr<ltr::RankingCache> p_cache) {
|
||||||
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
|
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto label = info.labels.View(ctx->gpu_id);
|
auto label = info.labels.View(ctx->Device());
|
||||||
// The buffer for ranked y is necessary as cub segmented sort accepts only pointer.
|
// The buffer for ranked y is necessary as cub segmented sort accepts only pointer.
|
||||||
auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_);
|
auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_);
|
||||||
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(),
|
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(),
|
||||||
@ -374,13 +374,13 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
|
|||||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||||
linalg::Matrix<GradientPair>* out_gpair) {
|
linalg::Matrix<GradientPair>* out_gpair) {
|
||||||
// boilerplate
|
// boilerplate
|
||||||
std::int32_t device_id = ctx->gpu_id;
|
auto device = ctx->Device();
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
auto const d_inv_IDCG = p_cache->InvIDCG(ctx);
|
auto const d_inv_IDCG = p_cache->InvIDCG(ctx);
|
||||||
auto const discount = p_cache->Discount(ctx);
|
auto const discount = p_cache->Discount(ctx);
|
||||||
|
|
||||||
info.labels.SetDevice(device_id);
|
info.labels.SetDevice(device);
|
||||||
preds.SetDevice(device_id);
|
preds.SetDevice(device);
|
||||||
|
|
||||||
auto const exp_gain = p_cache->Param().ndcg_exp_gain;
|
auto const exp_gain = p_cache->Param().ndcg_exp_gain;
|
||||||
auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
|
auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
|
||||||
@ -403,7 +403,7 @@ void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t
|
|||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
|
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
|
||||||
auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto label = info.labels.View(ctx->Device()).Slice(linalg::All(), 0);
|
||||||
auto const* cuctx = ctx->CUDACtx();
|
auto const* cuctx = ctx->CUDACtx();
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -442,11 +442,11 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
|
|||||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||||
linalg::Matrix<GradientPair>* out_gpair) {
|
linalg::Matrix<GradientPair>* out_gpair) {
|
||||||
std::int32_t device_id = ctx->gpu_id;
|
auto device = ctx->Device();
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
|
|
||||||
info.labels.SetDevice(device_id);
|
info.labels.SetDevice(device);
|
||||||
predt.SetDevice(device_id);
|
predt.SetDevice(device);
|
||||||
|
|
||||||
CHECK(p_cache);
|
CHECK(p_cache);
|
||||||
|
|
||||||
@ -481,11 +481,11 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
|
|||||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||||
linalg::Matrix<GradientPair>* out_gpair) {
|
linalg::Matrix<GradientPair>* out_gpair) {
|
||||||
std::int32_t device_id = ctx->gpu_id;
|
auto device = ctx->Device();
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
|
|
||||||
info.labels.SetDevice(device_id);
|
info.labels.SetDevice(device);
|
||||||
predt.SetDevice(device_id);
|
predt.SetDevice(device);
|
||||||
|
|
||||||
auto d_predt = predt.ConstDeviceSpan();
|
auto d_predt = predt.ConstDeviceSpan();
|
||||||
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
|
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
|
||||||
@ -517,11 +517,11 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
|
|||||||
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
|
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto n_groups = d_group_ptr.size() - 1;
|
auto n_groups = d_group_ptr.size() - 1;
|
||||||
|
|
||||||
auto ti_plus = p_ti_plus->View(ctx->gpu_id);
|
auto ti_plus = p_ti_plus->View(ctx->Device());
|
||||||
auto tj_minus = p_tj_minus->View(ctx->gpu_id);
|
auto tj_minus = p_tj_minus->View(ctx->Device());
|
||||||
|
|
||||||
auto li = p_li->View(ctx->gpu_id);
|
auto li = p_li->View(ctx->Device());
|
||||||
auto lj = p_lj->View(ctx->gpu_id);
|
auto lj = p_lj->View(ctx->Device());
|
||||||
CHECK_EQ(li.Size(), ti_plus.Size());
|
CHECK_EQ(li.Size(), ti_plus.Size());
|
||||||
|
|
||||||
auto const& param = p_cache->Param();
|
auto const& param = p_cache->Param();
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class QuantileRegression : public ObjFunction {
|
|||||||
CHECK_GE(n_targets, n_alphas);
|
CHECK_GE(n_targets, n_alphas);
|
||||||
CHECK_EQ(preds.Size(), info.num_row_ * n_targets);
|
CHECK_EQ(preds.Size(), info.num_row_ * n_targets);
|
||||||
|
|
||||||
auto labels = info.labels.View(ctx_->gpu_id);
|
auto labels = info.labels.View(ctx_->Device());
|
||||||
|
|
||||||
out_gpair->SetDevice(ctx_->Device());
|
out_gpair->SetDevice(ctx_->Device());
|
||||||
CHECK_EQ(info.labels.Shape(1), 1)
|
CHECK_EQ(info.labels.Shape(1), 1)
|
||||||
@ -131,7 +131,7 @@ class QuantileRegression : public ObjFunction {
|
|||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
alpha_.SetDevice(ctx_->gpu_id);
|
alpha_.SetDevice(ctx_->gpu_id);
|
||||||
auto d_alpha = alpha_.ConstDeviceSpan();
|
auto d_alpha = alpha_.ConstDeviceSpan();
|
||||||
auto d_labels = info.labels.View(ctx_->gpu_id);
|
auto d_labels = info.labels.View(ctx_->Device());
|
||||||
auto seg_it = dh::MakeTransformIterator<std::size_t>(
|
auto seg_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
[=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); });
|
[=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); });
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class RegLossObj : public FitIntercept {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void ValidateLabel(MetaInfo const& info) {
|
void ValidateLabel(MetaInfo const& info) {
|
||||||
auto label = info.labels.View(ctx_->Ordinal());
|
auto label = info.labels.View(ctx_->Device());
|
||||||
auto valid = ctx_->DispatchDevice(
|
auto valid = ctx_->DispatchDevice(
|
||||||
[&] {
|
[&] {
|
||||||
return std::all_of(linalg::cbegin(label), linalg::cend(label),
|
return std::all_of(linalg::cbegin(label), linalg::cend(label),
|
||||||
@ -244,7 +244,7 @@ class PseudoHuberRegression : public FitIntercept {
|
|||||||
CheckRegInputs(info, preds);
|
CheckRegInputs(info, preds);
|
||||||
auto slope = param_.huber_slope;
|
auto slope = param_.huber_slope;
|
||||||
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
||||||
auto labels = info.labels.View(ctx_->gpu_id);
|
auto labels = info.labels.View(ctx_->Device());
|
||||||
|
|
||||||
out_gpair->SetDevice(ctx_->gpu_id);
|
out_gpair->SetDevice(ctx_->gpu_id);
|
||||||
out_gpair->Reshape(info.num_row_, this->Targets(info));
|
out_gpair->Reshape(info.num_row_, this->Targets(info));
|
||||||
@ -698,7 +698,7 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info,
|
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info,
|
||||||
std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
|
std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
|
||||||
CheckRegInputs(info, preds);
|
CheckRegInputs(info, preds);
|
||||||
auto labels = info.labels.View(ctx_->gpu_id);
|
auto labels = info.labels.View(ctx_->Device());
|
||||||
|
|
||||||
out_gpair->SetDevice(ctx_->Device());
|
out_gpair->SetDevice(ctx_->Device());
|
||||||
out_gpair->Reshape(info.num_row_, this->Targets(info));
|
out_gpair->Reshape(info.num_row_, this->Targets(info));
|
||||||
|
|||||||
@ -663,7 +663,7 @@ class CPUPredictor : public Predictor {
|
|||||||
std::size_t n_samples = p_fmat->Info().num_row_;
|
std::size_t n_samples = p_fmat->Info().num_row_;
|
||||||
std::size_t n_groups = model.learner_model_param->OutputLength();
|
std::size_t n_groups = model.learner_model_param->OutputLength();
|
||||||
CHECK_EQ(out_preds->size(), n_samples * n_groups);
|
CHECK_EQ(out_preds->size(), n_samples * n_groups);
|
||||||
linalg::TensorView<float, 2> out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id};
|
auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups);
|
||||||
|
|
||||||
if (!p_fmat->PageExists<SparsePage>()) {
|
if (!p_fmat->PageExists<SparsePage>()) {
|
||||||
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
std::vector<Entry> workspace(p_fmat->Info().num_col_ * kUnroll * n_threads);
|
||||||
@ -732,7 +732,7 @@ class CPUPredictor : public Predictor {
|
|||||||
std::vector<RegTree::FVec> thread_temp;
|
std::vector<RegTree::FVec> thread_temp;
|
||||||
InitThreadTemp(n_threads * kBlockSize, &thread_temp);
|
InitThreadTemp(n_threads * kBlockSize, &thread_temp);
|
||||||
std::size_t n_groups = model.learner_model_param->OutputLength();
|
std::size_t n_groups = model.learner_model_param->OutputLength();
|
||||||
linalg::TensorView<float, 2> out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId};
|
auto out_predt = linalg::MakeTensorView(ctx_, predictions, m->NumRows(), n_groups);
|
||||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
|
||||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads), model,
|
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads), model,
|
||||||
tree_begin, tree_end, &thread_temp, n_threads, out_predt);
|
tree_begin, tree_end, &thread_temp, n_threads, out_predt);
|
||||||
@ -878,8 +878,8 @@ class CPUPredictor : public Predictor {
|
|||||||
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
|
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
|
||||||
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
|
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
|
||||||
});
|
});
|
||||||
auto base_margin = info.base_margin_.View(Context::kCpuId);
|
auto base_margin = info.base_margin_.View(ctx_->Device());
|
||||||
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
|
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);
|
||||||
// start collecting the contributions
|
// start collecting the contributions
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
|
|||||||
@ -60,7 +60,7 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
|
|||||||
} else {
|
} else {
|
||||||
// cannot rely on the Resize to fill as it might skip if the size is already correct.
|
// cannot rely on the Resize to fill as it might skip if the size is already correct.
|
||||||
out_preds->Resize(n);
|
out_preds->Resize(n);
|
||||||
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
|
auto base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU())(0);
|
||||||
out_preds->Fill(base_score);
|
out_preds->Fill(base_score);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
|
|||||||
gpair.SetDevice(ctx->Device());
|
gpair.SetDevice(ctx->Device());
|
||||||
auto gpair_t = gpair.View(ctx->Device());
|
auto gpair_t = gpair.View(ctx->Device());
|
||||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*
|
*
|
||||||
* \brief Utilities for estimating initial score.
|
* \brief Utilities for estimating initial score.
|
||||||
*/
|
*/
|
||||||
@ -41,7 +41,7 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
|||||||
auto sample = i % gpair.Shape(0);
|
auto sample = i % gpair.Shape(0);
|
||||||
return GradientPairPrecise{gpair(sample, target)};
|
return GradientPairPrecise{gpair(sample, target)};
|
||||||
});
|
});
|
||||||
auto d_sum = sum.View(ctx->gpu_id);
|
auto d_sum = sum.View(ctx->Device());
|
||||||
CHECK(d_sum.CContiguous());
|
CHECK(d_sum.CContiguous());
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|||||||
@ -774,7 +774,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
|||||||
std::vector<Partitioner> const &partitioner,
|
std::vector<Partitioner> const &partitioner,
|
||||||
linalg::VectorView<float> out_preds) {
|
linalg::VectorView<float> out_preds) {
|
||||||
auto const &tree = *p_last_tree;
|
auto const &tree = *p_last_tree;
|
||||||
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
|
CHECK(out_preds.Device().IsCPU());
|
||||||
size_t n_nodes = p_last_tree->GetNodes().size();
|
size_t n_nodes = p_last_tree->GetNodes().size();
|
||||||
for (auto &part : partitioner) {
|
for (auto &part : partitioner) {
|
||||||
CHECK_EQ(part.Size(), n_nodes);
|
CHECK_EQ(part.Size(), n_nodes);
|
||||||
@ -809,7 +809,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
|
|||||||
auto n_nodes = mttree->Size();
|
auto n_nodes = mttree->Size();
|
||||||
auto n_targets = tree.NumTargets();
|
auto n_targets = tree.NumTargets();
|
||||||
CHECK_EQ(out_preds.Shape(1), n_targets);
|
CHECK_EQ(out_preds.Shape(1), n_targets);
|
||||||
CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
|
CHECK(out_preds.Device().IsCPU());
|
||||||
|
|
||||||
for (auto &part : partitioner) {
|
for (auto &part : partitioner) {
|
||||||
CHECK_EQ(part.Size(), n_nodes);
|
CHECK_EQ(part.Size(), n_nodes);
|
||||||
|
|||||||
@ -516,9 +516,10 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK(p_tree);
|
CHECK(p_tree);
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
CHECK(out_preds_d.Device().IsCUDA());
|
||||||
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
|
CHECK_EQ(out_preds_d.Device().ordinal, ctx_->Ordinal());
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||||
auto d_position = dh::ToSpan(positions);
|
auto d_position = dh::ToSpan(positions);
|
||||||
CHECK_EQ(out_preds_d.Size(), d_position.size());
|
CHECK_EQ(out_preds_d.Size(), d_position.size());
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h>
|
#include <xgboost/context.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||||
#include <xgboost/linalg.h>
|
#include <xgboost/linalg.h>
|
||||||
|
|
||||||
#include <cstddef> // size_t
|
#include <cstddef> // size_t
|
||||||
@ -14,8 +14,8 @@
|
|||||||
|
|
||||||
namespace xgboost::linalg {
|
namespace xgboost::linalg {
|
||||||
namespace {
|
namespace {
|
||||||
auto kCpuId = Context::kCpuId;
|
DeviceOrd CPU() { return DeviceOrd::CPU(); }
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
||||||
storage->Resize(n_rows * n_cols);
|
storage->Resize(n_rows * n_cols);
|
||||||
@ -23,7 +23,7 @@ auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, st
|
|||||||
|
|
||||||
std::iota(h_storage.begin(), h_storage.end(), 0);
|
std::iota(h_storage.begin(), h_storage.end(), 0);
|
||||||
|
|
||||||
auto m = linalg::TensorView<float, 2>{h_storage, {n_rows, static_cast<size_t>(n_cols)}, -1};
|
auto m = linalg::TensorView<float, 2>{h_storage, {n_rows, static_cast<size_t>(n_cols)}, CPU()};
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ TEST(Linalg, MatrixView) {
|
|||||||
size_t kRows = 31, kCols = 77;
|
size_t kRows = 31, kCols = 77;
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
||||||
ASSERT_EQ(m.DeviceIdx(), kCpuId);
|
ASSERT_EQ(m.Device(), CPU());
|
||||||
ASSERT_EQ(m(0, 0), 0);
|
ASSERT_EQ(m(0, 0), 0);
|
||||||
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
|
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ TEST(Linalg, TensorView) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// as vector
|
// as vector
|
||||||
TensorView<double, 1> vec{data, {data.size()}, -1};
|
TensorView<double, 1> vec{data, {data.size()}, CPU()};
|
||||||
ASSERT_EQ(vec.Size(), data.size());
|
ASSERT_EQ(vec.Size(), data.size());
|
||||||
ASSERT_EQ(vec.Shape(0), data.size());
|
ASSERT_EQ(vec.Shape(0), data.size());
|
||||||
ASSERT_EQ(vec.Shape().size(), 1);
|
ASSERT_EQ(vec.Shape().size(), 1);
|
||||||
@ -87,7 +87,7 @@ TEST(Linalg, TensorView) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// as matrix
|
// as matrix
|
||||||
TensorView<double, 2> mat(data, {6, 4}, -1);
|
TensorView<double, 2> mat(data, {6, 4}, CPU());
|
||||||
auto s = mat.Slice(2, All());
|
auto s = mat.Slice(2, All());
|
||||||
ASSERT_EQ(s.Shape().size(), 1);
|
ASSERT_EQ(s.Shape().size(), 1);
|
||||||
s = mat.Slice(All(), 1);
|
s = mat.Slice(All(), 1);
|
||||||
@ -96,7 +96,7 @@ TEST(Linalg, TensorView) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// assignment
|
// assignment
|
||||||
TensorView<double, 3> t{data, {2, 3, 4}, 0};
|
TensorView<double, 3> t{data, {2, 3, 4}, CPU()};
|
||||||
double pi = 3.14159;
|
double pi = 3.14159;
|
||||||
auto old = t(1, 2, 3);
|
auto old = t(1, 2, 3);
|
||||||
t(1, 2, 3) = pi;
|
t(1, 2, 3) = pi;
|
||||||
@ -201,7 +201,7 @@ TEST(Linalg, TensorView) {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
// f-contiguous
|
// f-contiguous
|
||||||
TensorView<double, 3> t{data, {4, 3, 2}, {1, 4, 12}, kCpuId};
|
TensorView<double, 3> t{data, {4, 3, 2}, {1, 4, 12}, CPU()};
|
||||||
ASSERT_TRUE(t.Contiguous());
|
ASSERT_TRUE(t.Contiguous());
|
||||||
ASSERT_TRUE(t.FContiguous());
|
ASSERT_TRUE(t.FContiguous());
|
||||||
ASSERT_FALSE(t.CContiguous());
|
ASSERT_FALSE(t.CContiguous());
|
||||||
@ -210,11 +210,11 @@ TEST(Linalg, TensorView) {
|
|||||||
|
|
||||||
TEST(Linalg, Tensor) {
|
TEST(Linalg, Tensor) {
|
||||||
{
|
{
|
||||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
|
Tensor<float, 3> t{{2, 3, 4}, CPU(), Order::kC};
|
||||||
auto view = t.View(kCpuId);
|
auto view = t.View(CPU());
|
||||||
|
|
||||||
auto const &as_const = t;
|
auto const &as_const = t;
|
||||||
auto k_view = as_const.View(kCpuId);
|
auto k_view = as_const.View(CPU());
|
||||||
|
|
||||||
size_t n = 2 * 3 * 4;
|
size_t n = 2 * 3 * 4;
|
||||||
ASSERT_EQ(t.Size(), n);
|
ASSERT_EQ(t.Size(), n);
|
||||||
@ -229,7 +229,7 @@ TEST(Linalg, Tensor) {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Reshape
|
// Reshape
|
||||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
|
Tensor<float, 3> t{{2, 3, 4}, CPU(), Order::kC};
|
||||||
t.Reshape(4, 3, 2);
|
t.Reshape(4, 3, 2);
|
||||||
ASSERT_EQ(t.Size(), 24);
|
ASSERT_EQ(t.Size(), 24);
|
||||||
ASSERT_EQ(t.Shape(2), 2);
|
ASSERT_EQ(t.Shape(2), 2);
|
||||||
@ -247,7 +247,7 @@ TEST(Linalg, Tensor) {
|
|||||||
|
|
||||||
TEST(Linalg, Empty) {
|
TEST(Linalg, Empty) {
|
||||||
{
|
{
|
||||||
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId, Order::kC};
|
auto t = TensorView<double, 2>{{}, {0, 3}, CPU(), Order::kC};
|
||||||
for (int32_t i : {0, 1, 2}) {
|
for (int32_t i : {0, 1, 2}) {
|
||||||
auto s = t.Slice(All(), i);
|
auto s = t.Slice(All(), i);
|
||||||
ASSERT_EQ(s.Size(), 0);
|
ASSERT_EQ(s.Size(), 0);
|
||||||
@ -256,9 +256,9 @@ TEST(Linalg, Empty) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto t = Tensor<double, 2>{{0, 3}, kCpuId, Order::kC};
|
auto t = Tensor<double, 2>{{0, 3}, CPU(), Order::kC};
|
||||||
ASSERT_EQ(t.Size(), 0);
|
ASSERT_EQ(t.Size(), 0);
|
||||||
auto view = t.View(kCpuId);
|
auto view = t.View(CPU());
|
||||||
|
|
||||||
for (int32_t i : {0, 1, 2}) {
|
for (int32_t i : {0, 1, 2}) {
|
||||||
auto s = view.Slice(All(), i);
|
auto s = view.Slice(All(), i);
|
||||||
@ -270,7 +270,7 @@ TEST(Linalg, Empty) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Linalg, ArrayInterface) {
|
TEST(Linalg, ArrayInterface) {
|
||||||
auto cpu = kCpuId;
|
auto cpu = CPU();
|
||||||
auto t = Tensor<double, 2>{{3, 3}, cpu, Order::kC};
|
auto t = Tensor<double, 2>{{3, 3}, cpu, Order::kC};
|
||||||
auto v = t.View(cpu);
|
auto v = t.View(cpu);
|
||||||
std::iota(v.Values().begin(), v.Values().end(), 0);
|
std::iota(v.Values().begin(), v.Values().end(), 0);
|
||||||
@ -315,16 +315,16 @@ TEST(Linalg, Popc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Linalg, Stack) {
|
TEST(Linalg, Stack) {
|
||||||
Tensor<float, 3> l{{2, 3, 4}, kCpuId, Order::kC};
|
Tensor<float, 3> l{{2, 3, 4}, CPU(), Order::kC};
|
||||||
ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(),
|
ElementWiseTransformHost(l.View(CPU()), omp_get_max_threads(),
|
||||||
[=](size_t i, float) { return i; });
|
[=](size_t i, float) { return i; });
|
||||||
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId, Order::kC};
|
Tensor<float, 3> r_0{{2, 3, 4}, CPU(), Order::kC};
|
||||||
ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(),
|
ElementWiseTransformHost(r_0.View(CPU()), omp_get_max_threads(),
|
||||||
[=](size_t i, float) { return i; });
|
[=](size_t i, float) { return i; });
|
||||||
|
|
||||||
Stack(&l, r_0);
|
Stack(&l, r_0);
|
||||||
|
|
||||||
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId, Order::kC};
|
Tensor<float, 3> r_1{{0, 3, 4}, CPU(), Order::kC};
|
||||||
Stack(&l, r_1);
|
Stack(&l, r_1);
|
||||||
ASSERT_EQ(l.Shape(0), 4);
|
ASSERT_EQ(l.Shape(0), 4);
|
||||||
|
|
||||||
@ -335,7 +335,7 @@ TEST(Linalg, Stack) {
|
|||||||
TEST(Linalg, FOrder) {
|
TEST(Linalg, FOrder) {
|
||||||
std::size_t constexpr kRows = 16, kCols = 3;
|
std::size_t constexpr kRows = 16, kCols = 3;
|
||||||
std::vector<float> data(kRows * kCols);
|
std::vector<float> data(kRows * kCols);
|
||||||
MatrixView<float> mat{data, {kRows, kCols}, Context::kCpuId, Order::kF};
|
MatrixView<float> mat{data, {kRows, kCols}, CPU(), Order::kF};
|
||||||
float k{0};
|
float k{0};
|
||||||
for (std::size_t i = 0; i < kRows; ++i) {
|
for (std::size_t i = 0; i < kRows; ++i) {
|
||||||
for (std::size_t j = 0; j < kCols; ++j) {
|
for (std::size_t j = 0; j < kCols; ++j) {
|
||||||
|
|||||||
@ -11,17 +11,18 @@
|
|||||||
namespace xgboost::linalg {
|
namespace xgboost::linalg {
|
||||||
namespace {
|
namespace {
|
||||||
void TestElementWiseKernel() {
|
void TestElementWiseKernel() {
|
||||||
|
auto device = DeviceOrd::CUDA(0);
|
||||||
Tensor<float, 3> l{{2, 3, 4}, 0};
|
Tensor<float, 3> l{{2, 3, 4}, 0};
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
* Non-contiguous
|
* Non-contiguous
|
||||||
*/
|
*/
|
||||||
// GPU view
|
// GPU view
|
||||||
auto t = l.View(0).Slice(linalg::All(), 1, linalg::All());
|
auto t = l.View(device).Slice(linalg::All(), 1, linalg::All());
|
||||||
ASSERT_FALSE(t.CContiguous());
|
ASSERT_FALSE(t.CContiguous());
|
||||||
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
|
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
|
||||||
// CPU view
|
// CPU view
|
||||||
t = l.View(Context::kCpuId).Slice(linalg::All(), 1, linalg::All());
|
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
|
||||||
size_t k = 0;
|
size_t k = 0;
|
||||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||||
@ -29,7 +30,7 @@ void TestElementWiseKernel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t = l.View(0).Slice(linalg::All(), 1, linalg::All());
|
t = l.View(device).Slice(linalg::All(), 1, linalg::All());
|
||||||
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
|
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,11 +38,11 @@ void TestElementWiseKernel() {
|
|||||||
/**
|
/**
|
||||||
* Contiguous
|
* Contiguous
|
||||||
*/
|
*/
|
||||||
auto t = l.View(0);
|
auto t = l.View(device);
|
||||||
ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; });
|
ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; });
|
||||||
ASSERT_TRUE(t.CContiguous());
|
ASSERT_TRUE(t.CContiguous());
|
||||||
// CPU view
|
// CPU view
|
||||||
t = l.View(Context::kCpuId);
|
t = l.View(DeviceOrd::CPU());
|
||||||
|
|
||||||
size_t ind = 0;
|
size_t ind = 0;
|
||||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||||
|
|||||||
@ -41,7 +41,7 @@ void TestCalcQueriesInvIDCG() {
|
|||||||
p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}});
|
p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}});
|
||||||
|
|
||||||
cuda_impl::CalcQueriesInvIDCG(&ctx, linalg::MakeTensorView(&ctx, d_scores, d_scores.size()),
|
cuda_impl::CalcQueriesInvIDCG(&ctx, linalg::MakeTensorView(&ctx, d_scores, d_scores.size()),
|
||||||
dh::ToSpan(group_ptr), inv_IDCG.View(ctx.gpu_id), p);
|
dh::ToSpan(group_ptr), inv_IDCG.View(ctx.Device()), p);
|
||||||
for (std::size_t i = 0; i < n_groups; ++i) {
|
for (std::size_t i = 0; i < n_groups; ++i) {
|
||||||
double inv_idcg = inv_IDCG(i);
|
double inv_idcg = inv_IDCG(i);
|
||||||
ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps);
|
ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps);
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class StatsGPU : public ::testing::Test {
|
|||||||
data.insert(data.cend(), seg.begin(), seg.end());
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
data.insert(data.cend(), seg.begin(), seg.end());
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
||||||
auto d_arr = arr.View(0);
|
auto d_arr = arr.View(DeviceOrd::CUDA(0));
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
@ -71,8 +71,8 @@ class StatsGPU : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Weighted() {
|
void Weighted() {
|
||||||
auto d_arr = arr_.View(0);
|
auto d_arr = arr_.View(DeviceOrd::CUDA(0));
|
||||||
auto d_key = indptr_.View(0);
|
auto d_key = indptr_.View(DeviceOrd::CUDA(0));
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
@ -81,7 +81,7 @@ class StatsGPU : public ::testing::Test {
|
|||||||
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||||
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
[=] XGBOOST_DEVICE(std::size_t i) { return d_arr(i); });
|
||||||
linalg::Tensor<float, 1> weights{{10}, 0};
|
linalg::Tensor<float, 1> weights{{10}, 0};
|
||||||
linalg::ElementWiseTransformDevice(weights.View(0),
|
linalg::ElementWiseTransformDevice(weights.View(DeviceOrd::CUDA(0)),
|
||||||
[=] XGBOOST_DEVICE(std::size_t, float) { return 1.0; });
|
[=] XGBOOST_DEVICE(std::size_t, float) { return 1.0; });
|
||||||
auto w_it = weights.Data()->ConstDevicePointer();
|
auto w_it = weights.Data()->ConstDevicePointer();
|
||||||
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) {
|
||||||
@ -102,7 +102,7 @@ class StatsGPU : public ::testing::Test {
|
|||||||
data.insert(data.cend(), seg.begin(), seg.end());
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
data.insert(data.cend(), seg.begin(), seg.end());
|
data.insert(data.cend(), seg.begin(), seg.end());
|
||||||
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
linalg::Tensor<float, 1> arr{data.cbegin(), data.cend(), {data.size()}, 0};
|
||||||
auto d_arr = arr.View(0);
|
auto d_arr = arr.View(DeviceOrd::CUDA(0));
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul),
|
thrust::make_counting_iterator(0ul),
|
||||||
@ -125,8 +125,8 @@ class StatsGPU : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void NonWeighted() {
|
void NonWeighted() {
|
||||||
auto d_arr = arr_.View(0);
|
auto d_arr = arr_.View(DeviceOrd::CUDA(0));
|
||||||
auto d_key = indptr_.View(0);
|
auto d_key = indptr_.View(DeviceOrd::CUDA(0));
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||||
thrust::make_counting_iterator(0ul), [=] __device__(std::size_t i) { return d_key(i); });
|
thrust::make_counting_iterator(0ul), [=] __device__(std::size_t i) { return d_key(i); });
|
||||||
|
|||||||
@ -22,7 +22,7 @@ TEST(ArrayInterface, Initialize) {
|
|||||||
|
|
||||||
HostDeviceVector<size_t> u64_storage(storage.Size());
|
HostDeviceVector<size_t> u64_storage(storage.Size());
|
||||||
std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView<size_t const, 2>{
|
std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView<size_t const, 2>{
|
||||||
u64_storage.ConstHostSpan(), {kRows, kCols}, Context::kCpuId})};
|
u64_storage.ConstHostSpan(), {kRows, kCols}, DeviceOrd::CPU()})};
|
||||||
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
|
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
|
||||||
u64_storage.HostSpan().begin());
|
u64_storage.HostSpan().begin());
|
||||||
auto u64_arr = ArrayInterface<2>{u64_arr_str};
|
auto u64_arr = ArrayInterface<2>{u64_arr_str};
|
||||||
|
|||||||
@ -129,8 +129,8 @@ TEST(MetaInfo, SaveLoadBinary) {
|
|||||||
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
||||||
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
||||||
|
|
||||||
auto orig_margin = info.base_margin_.View(xgboost::Context::kCpuId);
|
auto orig_margin = info.base_margin_.View(xgboost::DeviceOrd::CPU());
|
||||||
auto read_margin = inforead.base_margin_.View(xgboost::Context::kCpuId);
|
auto read_margin = inforead.base_margin_.View(xgboost::DeviceOrd::CPU());
|
||||||
EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(),
|
EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(),
|
||||||
read_margin.Values().cbegin()));
|
read_margin.Values().cbegin()));
|
||||||
|
|
||||||
@ -267,8 +267,8 @@ TEST(MetaInfo, Validate) {
|
|||||||
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||||
d_groups.SetDevice(0);
|
d_groups.SetDevice(0);
|
||||||
d_groups.DevicePointer(); // pull to device
|
d_groups.DevicePointer(); // pull to device
|
||||||
std::string arr_interface_str{ArrayInterfaceStr(
|
std::string arr_interface_str{ArrayInterfaceStr(xgboost::linalg::MakeVec(
|
||||||
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
|
d_groups.ConstDevicePointer(), d_groups.Size(), xgboost::DeviceOrd::CUDA(0)))};
|
||||||
EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
@ -307,5 +307,5 @@ TEST(MetaInfo, HostExtend) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); }
|
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); }
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -65,7 +65,7 @@ TEST(MetaInfo, FromInterface) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
info.SetInfo(ctx, "base_margin", str.c_str());
|
info.SetInfo(ctx, "base_margin", str.c_str());
|
||||||
auto const h_base_margin = info.base_margin_.View(Context::kCpuId);
|
auto const h_base_margin = info.base_margin_.View(DeviceOrd::CPU());
|
||||||
ASSERT_EQ(h_base_margin.Size(), d_data.size());
|
ASSERT_EQ(h_base_margin.Size(), d_data.size());
|
||||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||||
ASSERT_EQ(h_base_margin(i), d_data[i]);
|
ASSERT_EQ(h_base_margin(i), d_data[i]);
|
||||||
@ -83,7 +83,7 @@ TEST(MetaInfo, FromInterface) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MetaInfo, GPUStridedData) {
|
TEST(MetaInfo, GPUStridedData) {
|
||||||
TestMetaInfoStridedData(0);
|
TestMetaInfoStridedData(DeviceOrd::CUDA(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MetaInfo, Group) {
|
TEST(MetaInfo, Group) {
|
||||||
|
|||||||
@ -14,10 +14,10 @@
|
|||||||
#include "../../../src/data/array_interface.h"
|
#include "../../../src/data/array_interface.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
inline void TestMetaInfoStridedData(int32_t device) {
|
inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
|
ctx.UpdateAllowUnknown(Args{{"device", device.Name()}});
|
||||||
{
|
{
|
||||||
// labels
|
// labels
|
||||||
linalg::Tensor<float, 3> labels;
|
linalg::Tensor<float, 3> labels;
|
||||||
@ -28,9 +28,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
|||||||
ASSERT_EQ(t_labels.Shape().size(), 2);
|
ASSERT_EQ(t_labels.Shape().size(), 2);
|
||||||
|
|
||||||
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
|
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
|
||||||
auto const& h_result = info.labels.View(-1);
|
auto const& h_result = info.labels.View(DeviceOrd::CPU());
|
||||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||||
auto in_labels = labels.View(-1);
|
auto in_labels = labels.View(DeviceOrd::CPU());
|
||||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
|
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
|
||||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||||
auto i0 = std::get<0>(tup);
|
auto i0 = std::get<0>(tup);
|
||||||
@ -62,9 +62,9 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
|||||||
ASSERT_EQ(t_margin.Shape().size(), 2);
|
ASSERT_EQ(t_margin.Shape().size(), 2);
|
||||||
|
|
||||||
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||||
auto const& h_result = info.base_margin_.View(-1);
|
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
|
||||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||||
auto in_margin = base_margin.View(-1);
|
auto in_margin = base_margin.View(DeviceOrd::CPU());
|
||||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
||||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||||
auto i0 = std::get<0>(tup);
|
auto i0 = std::get<0>(tup);
|
||||||
|
|||||||
@ -298,8 +298,8 @@ TEST(SimpleDMatrix, Slice) {
|
|||||||
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
|
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
|
||||||
out->Info().weights_.HostVector().at(i));
|
out->Info().weights_.HostVector().at(i));
|
||||||
|
|
||||||
auto out_margin = out->Info().base_margin_.View(Context::kCpuId);
|
auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU());
|
||||||
auto in_margin = margin.View(Context::kCpuId);
|
auto in_margin = margin.View(DeviceOrd::CPU());
|
||||||
for (size_t j = 0; j < kClasses; ++j) {
|
for (size_t j = 0; j < kClasses; ++j) {
|
||||||
ASSERT_EQ(out_margin(i, j), in_margin(ridx, j));
|
ASSERT_EQ(out_margin(i, j), in_margin(ridx, j));
|
||||||
}
|
}
|
||||||
@ -372,8 +372,8 @@ TEST(SimpleDMatrix, SliceCol) {
|
|||||||
out->Info().labels_upper_bound_.HostVector().at(i));
|
out->Info().labels_upper_bound_.HostVector().at(i));
|
||||||
ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i));
|
ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i));
|
||||||
|
|
||||||
auto out_margin = out->Info().base_margin_.View(Context::kCpuId);
|
auto out_margin = out->Info().base_margin_.View(DeviceOrd::CPU());
|
||||||
auto in_margin = margin.View(Context::kCpuId);
|
auto in_margin = margin.View(DeviceOrd::CPU());
|
||||||
for (size_t j = 0; j < kClasses; ++j) {
|
for (size_t j = 0; j < kClasses; ++j) {
|
||||||
ASSERT_EQ(out_margin(i, j), in_margin(i, j));
|
ASSERT_EQ(out_margin(i, j), in_margin(i, j));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,9 +39,9 @@ void TestGPUMakePair() {
|
|||||||
auto make_args = [&](std::shared_ptr<ltr::RankingCache> p_cache, auto rank_idx,
|
auto make_args = [&](std::shared_ptr<ltr::RankingCache> p_cache, auto rank_idx,
|
||||||
common::Span<std::size_t const> y_sorted_idx) {
|
common::Span<std::size_t const> y_sorted_idx) {
|
||||||
linalg::Vector<double> dummy;
|
linalg::Vector<double> dummy;
|
||||||
auto d = dummy.View(ctx.gpu_id);
|
auto d = dummy.View(ctx.Device());
|
||||||
linalg::Vector<GradientPair> dgpair;
|
linalg::Vector<GradientPair> dgpair;
|
||||||
auto dg = dgpair.View(ctx.gpu_id);
|
auto dg = dgpair.View(ctx.Device());
|
||||||
cuda_impl::KernelInputs args{
|
cuda_impl::KernelInputs args{
|
||||||
d,
|
d,
|
||||||
d,
|
d,
|
||||||
@ -50,9 +50,9 @@ void TestGPUMakePair() {
|
|||||||
p_cache->DataGroupPtr(&ctx),
|
p_cache->DataGroupPtr(&ctx),
|
||||||
p_cache->CUDAThreadsGroupPtr(),
|
p_cache->CUDAThreadsGroupPtr(),
|
||||||
rank_idx,
|
rank_idx,
|
||||||
info.labels.View(ctx.gpu_id),
|
info.labels.View(ctx.Device()),
|
||||||
predt.ConstDeviceSpan(),
|
predt.ConstDeviceSpan(),
|
||||||
linalg::MatrixView<GradientPair>{common::Span<GradientPair>{}, {0}, 0},
|
linalg::MatrixView<GradientPair>{common::Span<GradientPair>{}, {0}, DeviceOrd::CUDA(0)},
|
||||||
dg,
|
dg,
|
||||||
nullptr,
|
nullptr,
|
||||||
y_sorted_idx,
|
y_sorted_idx,
|
||||||
|
|||||||
@ -226,7 +226,7 @@ TEST(GPUPredictor, ShapStump) {
|
|||||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||||
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
|
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
|
||||||
auto& phis = predictions.HostVector();
|
auto& phis = predictions.HostVector();
|
||||||
auto base_score = mparam.BaseScore(Context::kCpuId)(0);
|
auto base_score = mparam.BaseScore(DeviceOrd::CPU())(0);
|
||||||
EXPECT_EQ(phis[0], 0.0);
|
EXPECT_EQ(phis[0], 0.0);
|
||||||
EXPECT_EQ(phis[1], base_score);
|
EXPECT_EQ(phis[1], base_score);
|
||||||
EXPECT_EQ(phis[2], 0.0);
|
EXPECT_EQ(phis[2], 0.0);
|
||||||
|
|||||||
@ -287,7 +287,7 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
|||||||
|
|
||||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
auto score = mparam.BaseScore(Context::kCpuId)(0);
|
auto score = mparam.BaseScore(DeviceOrd::CPU())(0);
|
||||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||||
right_weight + score); // go to right for matching cat
|
right_weight + score); // go to right for matching cat
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user