Implement a general array view. (#7365)

* Replace existing matrix and vector view.

This is to prepare for handling higher dimension data and prediction when we support multi-target models.
This commit is contained in:
Jiaming Yuan 2021-11-05 04:16:11 +08:00 committed by GitHub
parent 232144ca09
commit b06040b6d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 418 additions and 146 deletions

View File

@ -1,113 +1,301 @@
/*!
* Copyright 2021 by Contributors
* Copyright 2021 by XGBoost Contributors
* \file linalg.h
* \brief Linear algebra related utilities.
*/
#ifndef XGBOOST_LINALG_H_
#define XGBOOST_LINALG_H_
#include <xgboost/span.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/base.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/span.h>
#include <array>
#include <algorithm>
#include <cassert>
#include <type_traits>
#include <utility>
#include <vector>
namespace xgboost {
/*!
* \brief A view over a matrix on contiguous storage.
namespace linalg {
namespace detail {
template <typename S, typename Head, size_t D>
constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head) {
assert(dim < D);
return n + head * strides[dim];
}
template <typename S, size_t D, typename Head, typename... Tail>
constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head, Tail &&...rest) {
assert(dim < D);
return Offset(strides, n + (head * strides[dim]), dim + 1, rest...);
}
struct AllTag {};
struct IntTag {};
/**
* \brief Calculate the dimension of sliced tensor.
*/
template <typename T>
constexpr int32_t CalcSliceDim() {
return std::is_same<T, IntTag>::value ? 0 : 1;
}
template <typename T, typename... S>
constexpr std::enable_if_t<sizeof...(S) != 0, int32_t> CalcSliceDim() {
return CalcSliceDim<T>() + CalcSliceDim<S...>();
}
template <int32_t D>
constexpr size_t CalcSize(size_t (&shape)[D]) {
size_t size = 1;
for (auto d : shape) {
size *= d;
}
return size;
}
template <typename S>
using RemoveCRType = std::remove_const_t<std::remove_reference_t<S>>;
template <typename S>
using IndexToTag = std::conditional_t<std::is_integral<RemoveCRType<S>>::value, IntTag, AllTag>;
template <int32_t n, typename Fn>
XGBOOST_DEVICE constexpr auto UnrollLoop(Fn fn) {
#if defined __CUDA_ARCH__
#pragma unroll n
#endif // defined __CUDA_ARCH__
for (int32_t i = 0; i < n; ++i) {
fn(i);
}
}
} // namespace detail
/**
* \brief Specify all elements in the axis is used for slice.
*/
constexpr detail::AllTag All() { return {}; }
/**
* \brief A tensor view with static type and shape. It implements indexing and slicing.
*
* Most of the algorithms in XGBoost are implemented for both CPU and GPU without using
* much linear algebra routines, this class is a helper intended to ease some high level
* operations like indexing into prediction tensor or gradient matrix. It can be passed
* into CUDA kernel as normal argument for GPU algorithms.
*/
template <typename T, int32_t kDim = 5>
class TensorView {
public:
using ShapeT = size_t[kDim];
using StrideT = ShapeT;
private:
StrideT stride_{1};
ShapeT shape_{0};
common::Span<T> data_;
T* ptr_{nullptr}; // pointer of data_ to avoid bound check.
size_t size_{0};
int32_t device_{-1};
// Unlike `Tensor`, the data_ can have arbitrary size since this is just a view.
XGBOOST_DEVICE void CalcSize() {
if (data_.empty()) {
size_ = 0;
} else {
size_ = detail::CalcSize(shape_);
}
}
struct SliceHelper {
size_t old_dim;
size_t new_dim;
size_t offset;
};
template <int32_t D, typename... S>
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
size_t new_stride[D], detail::AllTag) const {
new_stride[new_dim] = stride_[old_dim];
new_shape[new_dim] = shape_[old_dim];
return {old_dim + 1, new_dim + 1, 0};
}
template <int32_t D, typename... S>
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
size_t new_stride[D], detail::AllTag,
S &&...slices) const {
new_stride[new_dim] = stride_[old_dim];
new_shape[new_dim] = shape_[old_dim];
return MakeSliceDim<D>(old_dim + 1, new_dim + 1, new_shape, new_stride, slices...);
}
template <int32_t D, typename Index>
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
size_t new_stride[D], Index i) const {
return {old_dim + 1, new_dim, stride_[old_dim] * i};
}
template <int32_t D, typename Index, typename... S>
XGBOOST_DEVICE std::enable_if_t<std::is_integral<Index>::value, SliceHelper> MakeSliceDim(
size_t old_dim, size_t new_dim, size_t new_shape[D], size_t new_stride[D], Index i,
S &&...slices) const {
auto offset = stride_[old_dim] * i;
auto res = MakeSliceDim<D>(old_dim + 1, new_dim, new_shape, new_stride, slices...);
return {res.old_dim, res.new_dim, res.offset + offset};
}
public:
size_t constexpr static kValueSize = sizeof(T);
size_t constexpr static kDimension = kDim;
public:
/**
* \brief Create a tensor with data and shape.
*
* \tparam I Type of the shape array element.
* \tparam D Size of the shape array, can be lesser than or equal to tensor dimension.
*
* \param data Raw data input, can be const if this tensor has const type in its
* template parameter.
* \param shape shape of the tensor
* \param device Device ordinal
*/
template <typename I, int32_t D>
XGBOOST_DEVICE TensorView(common::Span<T> data, I const (&shape)[D], int32_t device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D > 0 && D <= kDim, "Invalid shape.");
// shape
detail::UnrollLoop<D>([&](auto i) { shape_[i] = shape[i]; });
for (auto i = D; i < kDim; ++i) {
shape_[i] = 1;
}
// stride
stride_[kDim - 1] = 1;
for (int32_t s = kDim - 2; s >= 0; --s) {
stride_[s] = shape_[s + 1] * stride_[s + 1];
}
this->CalcSize();
};
/**
* \brief Create a tensor with data, shape and strides. Don't use this constructor if
* stride can be calculated from shape.
*/
template <typename I, int32_t D>
XGBOOST_DEVICE TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
int32_t device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D == kDim, "Invalid shape & stride.");
detail::UnrollLoop<D>([&](auto i) {
shape_[i] = shape[i];
stride_[i] = stride[i];
});
this->CalcSize();
};
XGBOOST_DEVICE TensorView(TensorView const &that)
: data_{that.data_}, ptr_{data_.data()}, size_{that.size_}, device_{that.device_} {
detail::UnrollLoop<kDim>([&](auto i) {
stride_[i] = that.stride_[i];
shape_[i] = that.shape_[i];
});
}
/**
* \brief Index the tensor to obtain a scalar value.
*
* \code
*
* // Create a 3-dim tensor.
* Tensor<float, 3> t {data, shape, 0};
* float pi = 3.14159;
* t(1, 2, 3) = pi;
* ASSERT_EQ(t(1, 2, 3), pi);
*
* \endcode
*/
template <typename... Index>
XGBOOST_DEVICE T &operator()(Index &&...index) {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset(stride_, 0ul, 0ul, index...);
return ptr_[offset];
}
/**
* \brief Index the tensor to obtain a scalar value.
*/
template <typename... Index>
XGBOOST_DEVICE T const &operator()(Index &&...index) const {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset(stride_, 0ul, 0ul, index...);
return ptr_[offset];
}
/**
* \brief Slice the tensor. The returned tensor has inferred dim and shape.
*
* \code
*
* // Create a 3-dim tensor.
* Tensor<float, 3> t {data, shape, 0};
* // s has 2 dimensions (matrix)
* auto s = t.Slice(1, All(), All());
*
* \endcode
*/
template <typename... S>
XGBOOST_DEVICE auto Slice(S &&...slices) const {
static_assert(sizeof...(slices) <= kDim, "Invalid slice.");
int32_t constexpr kNewDim{detail::CalcSliceDim<detail::IndexToTag<S>...>()};
size_t new_shape[kNewDim];
size_t new_stride[kNewDim];
auto res = MakeSliceDim<kNewDim>(size_t(0), size_t(0), new_shape, new_stride, slices...);
// ret is a different type due to changed dimension, so we can not access its private
// fields.
TensorView<T, kNewDim> ret{data_.subspan(data_.empty() ? 0 : res.offset), new_shape, new_stride,
device_};
return ret;
}
XGBOOST_DEVICE auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
/**
* Get the shape for i^th dimension
*/
XGBOOST_DEVICE auto Shape(size_t i) const { return shape_[i]; }
XGBOOST_DEVICE auto Stride() const { return common::Span<size_t const, kDim>{stride_}; }
/**
* Get the stride for i^th dimension, stride is specified as number of items instead of bytes.
*/
XGBOOST_DEVICE auto Stride(size_t i) const { return stride_[i]; }
XGBOOST_DEVICE auto cbegin() const { return data_.cbegin(); } // NOLINT
XGBOOST_DEVICE auto cend() const { return data_.cend(); } // NOLINT
XGBOOST_DEVICE auto begin() { return data_.begin(); } // NOLINT
XGBOOST_DEVICE auto end() { return data_.end(); } // NOLINT
XGBOOST_DEVICE size_t Size() const { return size_; }
XGBOOST_DEVICE auto Values() const { return data_; }
XGBOOST_DEVICE auto DeviceIdx() const { return device_; }
};
/**
* \brief A view over a vector, specialization of Tensor
*
* \tparam T data type of vector
*/
template <typename T>
using VectorView = TensorView<T, 1>;
/**
* \brief A view over a matrix, specialization of Tensor.
*
* \tparam T data type of matrix
*/
template <typename T> class MatrixView {
int32_t device_;
common::Span<T> values_;
size_t strides_[2];
size_t shape_[2];
template <typename Vec> static auto InferValues(Vec *vec, int32_t device) {
return device == GenericParameter::kCpuId ? vec->HostSpan()
: vec->DeviceSpan();
}
public:
/*!
* \param vec storage.
* \param strides Strides for matrix.
* \param shape Rows and columns.
* \param device Where the data is stored in.
*/
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> strides,
std::array<size_t, 2> shape, int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(strides.cbegin(), strides.cend(), strides_);
std::copy(shape.cbegin(), shape.cend(), shape_);
}
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
std::array<size_t, 2> strides, std::array<size_t, 2> shape,
int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(strides.cbegin(), strides.cend(), strides_);
std::copy(shape.cbegin(), shape.cend(), shape_);
}
/*! \brief Row major constructor. */
MatrixView(HostDeviceVector<T> *vec, std::array<size_t, 2> shape,
int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(shape.cbegin(), shape.cend(), shape_);
strides_[0] = shape[1];
strides_[1] = 1;
}
MatrixView(std::vector<T> *vec, std::array<size_t, 2> shape)
: device_{GenericParameter::kCpuId}, values_{*vec} {
CHECK_EQ(vec->size(), shape[0] * shape[1]);
std::copy(shape.cbegin(), shape.cend(), shape_);
strides_[0] = shape[1];
strides_[1] = 1;
}
MatrixView(HostDeviceVector<std::remove_const_t<T>> const *vec,
std::array<size_t, 2> shape, int32_t device)
: device_{device}, values_{InferValues(vec, device)} {
std::copy(shape.cbegin(), shape.cend(), shape_);
strides_[0] = shape[1];
strides_[1] = 1;
}
XGBOOST_DEVICE T const &operator()(size_t r, size_t c) const {
return values_[strides_[0] * r + strides_[1] * c];
}
XGBOOST_DEVICE T &operator()(size_t r, size_t c) {
return values_[strides_[0] * r + strides_[1] * c];
}
auto Strides() const { return strides_; }
auto Shape() const { return shape_; }
auto Values() const { return values_; }
auto Size() const { return shape_[0] * shape_[1]; }
auto DeviceIdx() const { return device_; }
};
/*! \brief A slice for 1 column of MatrixView. Can be extended to row if needed. */
template <typename T> class VectorView {
MatrixView<T> matrix_;
size_t column_;
public:
explicit VectorView(MatrixView<T> matrix, size_t column)
: matrix_{std::move(matrix)}, column_{column} {}
XGBOOST_DEVICE T &operator[](size_t i) {
return matrix_(i, column_);
}
XGBOOST_DEVICE T const &operator[](size_t i) const {
return matrix_(i, column_);
}
size_t Size() { return matrix_.Shape()[0]; }
int32_t DeviceIdx() const { return matrix_.DeviceIdx(); }
};
template <typename T>
using MatrixView = TensorView<T, 2>;
} // namespace linalg
} // namespace xgboost
#endif // XGBOOST_LINALG_H_

View File

@ -73,7 +73,7 @@ class TreeUpdater : public Configurable {
* updated by the time this function returns.
*/
virtual bool UpdatePredictionCache(const DMatrix * /*data*/,
VectorView<float> /*out_preds*/) {
linalg::VectorView<float> /*out_preds*/) {
return false;
}

View File

@ -243,7 +243,10 @@ class GBLinear : public GradientBooster {
// The bias is the last weight
out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0);
auto n_groups = learner_model_param_->num_output_group;
MatrixView<float> scores{out_scores, {learner_model_param_->num_feature, n_groups}};
linalg::TensorView<float, 2> scores{
*out_scores,
{learner_model_param_->num_feature, n_groups},
GenericParameter::kCpuId};
for (size_t i = 0; i < learner_model_param_->num_feature; ++i) {
for (bst_group_t g = 0; g < n_groups; ++g) {
scores(i, g) = model_[i][g];

View File

@ -229,16 +229,19 @@ void GBTree::DoBoost(DMatrix* p_fmat,
auto device = tparam_.tree_method != TreeMethod::kGPUHist
? GenericParameter::kCpuId
: generic_param_->gpu_id;
auto out = MatrixView<float>(
&predt->predictions,
{static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)}, device);
auto out = linalg::TensorView<float, 2>{
device == GenericParameter::kCpuId ? predt->predictions.HostSpan()
: predt->predictions.DeviceSpan(),
{static_cast<size_t>(p_fmat->Info().num_row_),
static_cast<size_t>(ngroup)},
device};
CHECK_NE(ngroup, 0);
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
auto v_predt = VectorView<float>{out, 0};
auto v_predt = out.Slice(linalg::All(), 0);
if (updaters_.size() > 0 && num_new_trees == 1 &&
predt->predictions.Size() > 0 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) {
@ -257,7 +260,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
BoostNewTrees(&tmp, p_fmat, gid, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
auto v_predt = VectorView<float>{out, static_cast<size_t>(gid)};
auto v_predt = out.Slice(linalg::All(), gid);
if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 &&
num_new_trees == 1 &&
updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) {

View File

@ -12,15 +12,14 @@ namespace gbm {
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id,
HostDeviceVector<GradientPair> *out_gpair) {
MatrixView<GradientPair const> in{
in_gpair,
{n_groups, 1ul},
auto mat = linalg::TensorView<GradientPair const, 2>(
in_gpair->ConstDeviceSpan(),
{in_gpair->Size() / n_groups, static_cast<size_t>(n_groups)},
in_gpair->DeviceIdx()};
auto v_in = VectorView<GradientPair const>{in, group_id};
in_gpair->DeviceIdx());
auto v_in = mat.Slice(linalg::All(), group_id);
out_gpair->Resize(v_in.Size());
auto d_out = out_gpair->DeviceSpan();
dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in[i]; });
dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in(i); });
}
void GPUDartPredictInc(common::Span<float> out_predts,

View File

@ -13,6 +13,7 @@
#include <vector>
#include "rabit/rabit.h"
#include "xgboost/linalg.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/metric.h"
@ -83,41 +84,45 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
CHECK_NE(n_classes, 0);
auto const &labels = info.labels_.ConstHostVector();
std::vector<double> results(n_classes * 3, 0);
auto s_results = common::Span<double>(results);
auto local_area = s_results.subspan(0, n_classes);
auto tp = s_results.subspan(n_classes, n_classes);
auto auc = s_results.subspan(2 * n_classes, n_classes);
std::vector<double> results_storage(n_classes * 3, 0);
linalg::TensorView<double> results(results_storage,
{n_classes, static_cast<size_t>(3)},
GenericParameter::kCpuId);
auto local_area = results.Slice(linalg::All(), 0);
auto tp = results.Slice(linalg::All(), 1);
auto auc = results.Slice(linalg::All(), 2);
auto weights = OptionalWeights{info.weights_.ConstHostSpan()};
auto predts_t = linalg::TensorView<float const, 2>(
predts, {static_cast<size_t>(info.num_row_), n_classes},
GenericParameter::kCpuId);
if (!info.labels_.Empty()) {
common::ParallelFor(n_classes, n_threads, [&](auto c) {
std::vector<float> proba(info.labels_.Size());
std::vector<float> response(info.labels_.Size());
for (size_t i = 0; i < proba.size(); ++i) {
proba[i] = predts[i * n_classes + c];
proba[i] = predts_t(i, c);
response[i] = labels[i] == c ? 1.0f : 0.0;
}
double fp;
std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights);
local_area[c] = fp * tp[c];
std::tie(fp, tp(c), auc(c)) = binary_auc(proba, response, weights);
local_area(c) = fp * tp(c);
});
}
// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
rabit::Allreduce<rabit::op::Sum>(results.Values().data(), results.Values().size());
double auc_sum{0};
double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
if (local_area[c] != 0) {
if (local_area(c) != 0) {
// normalize and weight it by prevalence. After allreduce, `local_area`
// means the total covered area (not area under curve, rather it's the
// accessible area for each worker) for each class.
auc_sum += auc[c] / local_area[c] * tp[c];
tp_sum += tp[c];
auc_sum += auc(c) / local_area(c) * tp(c);
tp_sum += tp(c);
} else {
auc_sum = std::numeric_limits<double>::quiet_NaN();
break;

View File

@ -496,7 +496,7 @@ struct GPUHistMakerDevice {
});
}
void UpdatePredictionCache(VectorView<float> out_preds_d) {
void UpdatePredictionCache(linalg::VectorView<float> out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
auto d_ridx = row_partitioner->GetRows();
@ -512,13 +512,13 @@ struct GPUHistMakerDevice {
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
dh::LaunchN(d_ridx.size(), [=] __device__(int local_idx) {
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(
int local_idx) mutable {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(
pos, param_d, GradStats{d_node_sum_gradients[pos]});
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
auto v_predt = out_preds_d; // for some reason out_preds_d is const by both nvcc and clang.
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
});
row_partitioner.reset();
}
@ -834,7 +834,8 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
}
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
bool UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<bst_float> p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
@ -920,8 +921,9 @@ class GPUHistMaker : public TreeUpdater {
}
}
bool UpdatePredictionCache(const DMatrix *data,
VectorView<bst_float> p_out_preds) override {
bool
UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {

View File

@ -105,7 +105,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
}
bool QuantileHistMaker::UpdatePredictionCache(
const DMatrix* data, VectorView<float> out_preds) {
const DMatrix* data, linalg::VectorView<float> out_preds) {
if (hist_maker_param_.single_precision_histogram && float_builder_) {
return float_builder_->UpdatePredictionCache(data, out_preds);
} else if (double_builder_) {
@ -319,7 +319,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
template<typename GradientSumT>
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
VectorView<float> out_preds) {
linalg::VectorView<float> out_preds) {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
@ -352,7 +352,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
leaf_value = (*p_last_tree_)[nid].LeafValue();
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds[*it] += leaf_value;
out_preds(*it) += leaf_value;
}
}
});

View File

@ -105,7 +105,7 @@ class QuantileHistMaker: public TreeUpdater {
const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix *data,
VectorView<float> out_preds) override;
linalg::VectorView<float> out_preds) override;
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
@ -174,7 +174,7 @@ class QuantileHistMaker: public TreeUpdater {
RegTree* p_tree);
bool UpdatePredictionCache(const DMatrix* data,
VectorView<float> out_preds);
linalg::VectorView<float> out_preds);
protected:
// initialize temp data structure

View File

@ -1,18 +1,19 @@
#include <gtest/gtest.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h>
#include <numeric>
namespace xgboost {
namespace linalg {
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) {
storage->Resize(n_rows * n_cols);
auto& h_storage = storage->HostVector();
auto &h_storage = storage->HostVector();
std::iota(h_storage.begin(), h_storage.end(), 0);
auto m = MatrixView<float>{storage, {n_cols, 1}, {n_rows, n_cols}, -1};
auto m = linalg::TensorView<float, 2>{h_storage, {n_rows, static_cast<size_t>(n_cols)}, -1};
return m;
}
TEST(Linalg, Matrix) {
@ -28,11 +29,84 @@ TEST(Linalg, Vector) {
size_t kRows = 31, kCols = 77;
HostDeviceVector<float> storage;
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
auto v = VectorView<float>(m, 3);
auto v = m.Slice(linalg::All(), 3);
for (size_t i = 0; i < v.Size(); ++i) {
ASSERT_EQ(v[i], m(i, 3));
ASSERT_EQ(v(i), m(i, 3));
}
ASSERT_EQ(v[0], 3);
ASSERT_EQ(v(0), 3);
}
TEST(Linalg, Tensor) {
std::vector<double> data(2 * 3 * 4, 0);
std::iota(data.begin(), data.end(), 0);
TensorView<double> t{data, {2, 3, 4}, -1};
ASSERT_EQ(t.Shape()[0], 2);
ASSERT_EQ(t.Shape()[1], 3);
ASSERT_EQ(t.Shape()[2], 4);
float v = t(0, 1, 2);
ASSERT_EQ(v, 6);
auto s = t.Slice(1, All(), All());
ASSERT_EQ(s.Shape().size(), 2);
ASSERT_EQ(s.Shape()[0], 3);
ASSERT_EQ(s.Shape()[1], 4);
std::vector<std::vector<double>> sol{
{12.0, 13.0, 14.0, 15.0}, {16.0, 17.0, 18.0, 19.0}, {20.0, 21.0, 22.0, 23.0}};
for (size_t i = 0; i < s.Shape()[0]; ++i) {
for (size_t j = 0; j < s.Shape()[1]; ++j) {
ASSERT_EQ(s(i, j), sol[i][j]);
}
}
{
// as vector
TensorView<double, 1> vec{data, {data.size()}, -1};
ASSERT_EQ(vec.Size(), data.size());
ASSERT_EQ(vec.Shape(0), data.size());
ASSERT_EQ(vec.Shape().size(), 1);
for (size_t i = 0; i < data.size(); ++i) {
ASSERT_EQ(vec(i), data[i]);
}
}
{
// as matrix
TensorView<double, 2> mat(data, {6, 4}, -1);
auto s = mat.Slice(2, All());
ASSERT_EQ(s.Shape().size(), 1);
s = mat.Slice(All(), 1);
ASSERT_EQ(s.Shape().size(), 1);
}
{
// assignment
TensorView<double, 3> t{data, {2, 3, 4}, 0};
double pi = 3.14159;
t(1, 2, 3) = pi;
ASSERT_EQ(t(1, 2, 3), pi);
}
{
// Don't assign the initial dimension, tensor should be able to deduce the correct dim
// for Slice.
TensorView<double> t{data, {2, 3, 4}, 0};
auto s = t.Slice(1, 2, All());
static_assert(decltype(s)::kDimension == 1, "");
}
}
TEST(Linalg, Empty) {
auto t = TensorView<double, 2>{{}, {0, 3}, GenericParameter::kCpuId};
for (int32_t i : {0, 1, 2}) {
auto s = t.Slice(All(), i);
ASSERT_EQ(s.Size(), 0);
ASSERT_EQ(s.Shape().size(), 1);
ASSERT_EQ(s.Shape(0), 0);
}
}
} // namespace linalg
} // namespace xgboost

View File

@ -400,10 +400,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
hist_maker.Configure(args, &generic_param);
hist_maker.Update(gpair, dmat, {tree});
hist_maker.UpdatePredictionCache(
dmat,
VectorView<float>{
MatrixView<float>(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0});
auto cache = linalg::VectorView<float>{preds->DeviceSpan(), {preds->Size()}, 0};
hist_maker.UpdatePredictionCache(dmat, cache);
}
TEST(GpuHist, UniformSampling) {