Implement typed storage for tensor. (#7429)
* Add `Tensor` class. * Add elementwise kernel for CPU and GPU. * Add unravel index. * Move some computation to compile time.
This commit is contained in:
parent
d27a11ff87
commit
a7057fa64c
@ -6,12 +6,16 @@
|
||||
#ifndef XGBOOST_LINALG_H_
|
||||
#define XGBOOST_LINALG_H_
|
||||
|
||||
#include <dmlc/endian.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/span.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -19,16 +23,35 @@
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace detail {
|
||||
template <typename S, typename Head, size_t D>
|
||||
constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head) {
|
||||
assert(dim < D);
|
||||
|
||||
struct ArrayInterfaceHandler {
|
||||
template <typename T>
|
||||
static constexpr char TypeChar() {
|
||||
return (std::is_floating_point<T>::value
|
||||
? 'f'
|
||||
: (std::is_integral<T>::value ? (std::is_signed<T>::value ? 'i' : 'u') : '\0'));
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t dim, typename S, typename Head, size_t D>
|
||||
constexpr size_t Offset(S (&strides)[D], size_t n, Head head) {
|
||||
static_assert(dim < D, "");
|
||||
return n + head * strides[dim];
|
||||
}
|
||||
|
||||
template <typename S, size_t D, typename Head, typename... Tail>
|
||||
constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head, Tail &&...rest) {
|
||||
assert(dim < D);
|
||||
return Offset(strides, n + (head * strides[dim]), dim + 1, rest...);
|
||||
template <size_t dim, typename S, size_t D, typename Head, typename... Tail>
|
||||
constexpr std::enable_if_t<sizeof...(Tail) != 0, size_t> Offset(S (&strides)[D], size_t n,
|
||||
Head head, Tail &&...rest) {
|
||||
static_assert(dim < D, "");
|
||||
return Offset<dim + 1>(strides, n + (head * strides[dim]), std::forward<Tail>(rest)...);
|
||||
}
|
||||
|
||||
template <int32_t D>
|
||||
constexpr void CalcStride(size_t (&shape)[D], size_t (&stride)[D]) {
|
||||
stride[D - 1] = 1;
|
||||
for (int32_t s = D - 2; s >= 0; --s) {
|
||||
stride[s] = shape[s + 1] * stride[s + 1];
|
||||
}
|
||||
}
|
||||
|
||||
struct AllTag {};
|
||||
@ -71,6 +94,101 @@ XGBOOST_DEVICE constexpr auto UnrollLoop(Fn fn) {
|
||||
fn(i);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int32_t NativePopc(T v) {
|
||||
int c = 0;
|
||||
for (; v != 0; v &= v - 1) c++;
|
||||
return c;
|
||||
}
|
||||
|
||||
inline XGBOOST_DEVICE int Popc(uint32_t v) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popc(v);
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcount(v);
|
||||
#elif defined(_MSC_VER)
|
||||
return __popcnt(v);
|
||||
#else
|
||||
return NativePopc(v);
|
||||
#endif // compiler
|
||||
}
|
||||
|
||||
inline XGBOOST_DEVICE int Popc(uint64_t v) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popcll(v);
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcountll(v);
|
||||
#elif defined(_MSC_VER)
|
||||
return __popcnt64(v);
|
||||
#else
|
||||
return NativePopc(v);
|
||||
#endif // compiler
|
||||
}
|
||||
|
||||
template <class T, std::size_t N, std::size_t... Idx>
|
||||
constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence<Idx...>) {
|
||||
return std::make_tuple(arr[Idx]...);
|
||||
}
|
||||
|
||||
template <class T, std::size_t N>
|
||||
constexpr auto Arr2Tup(T (&arr)[N]) {
|
||||
return Arr2Tup(arr, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
// uint division optimization inspired by the CIndexer in cupy. Division operation is
|
||||
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
|
||||
// bit when the index is smaller, then try to avoid division when it's exp of 2.
|
||||
template <typename I, int32_t D>
|
||||
XGBOOST_DEVICE auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
|
||||
size_t index[D]{0};
|
||||
static_assert(std::is_signed<decltype(D)>::value,
|
||||
"Don't change the type without changing the for loop.");
|
||||
for (int32_t dim = D; --dim > 0;) {
|
||||
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(shape[dim]);
|
||||
if (s & (s - 1)) {
|
||||
auto t = idx / s;
|
||||
index[dim] = idx - t * s;
|
||||
idx = t;
|
||||
} else { // exp of 2
|
||||
index[dim] = idx & (s - 1);
|
||||
idx >>= Popc(s - 1);
|
||||
}
|
||||
}
|
||||
index[0] = idx;
|
||||
return Arr2Tup(index);
|
||||
}
|
||||
|
||||
template <size_t dim, typename I, int32_t D>
|
||||
void ReshapeImpl(size_t (&out_shape)[D], I s) {
|
||||
static_assert(dim < D, "");
|
||||
out_shape[dim] = s;
|
||||
}
|
||||
|
||||
template <size_t dim, int32_t D, typename... S, typename I,
|
||||
std::enable_if_t<sizeof...(S) != 0> * = nullptr>
|
||||
void ReshapeImpl(size_t (&out_shape)[D], I &&s, S &&...rest) {
|
||||
static_assert(dim < D, "");
|
||||
out_shape[dim] = s;
|
||||
ReshapeImpl<dim + 1>(out_shape, std::forward<S>(rest)...);
|
||||
}
|
||||
|
||||
template <typename Fn, typename Tup, size_t... I>
|
||||
XGBOOST_DEVICE decltype(auto) constexpr Apply(Fn &&f, Tup &&t, std::index_sequence<I...>) {
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* C++ 17 style apply.
|
||||
*
|
||||
* \param f function to apply
|
||||
* \param t tuple of arguments
|
||||
*/
|
||||
template <typename Fn, typename Tup>
|
||||
XGBOOST_DEVICE decltype(auto) constexpr Apply(Fn &&f, Tup &&t) {
|
||||
constexpr auto kSize = std::tuple_size<Tup>::value;
|
||||
return Apply(std::forward<Fn>(f), std::forward<Tup>(t), std::make_index_sequence<kSize>{});
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@ -85,6 +203,11 @@ constexpr detail::AllTag All() { return {}; }
|
||||
* much linear algebra routines, this class is a helper intended to ease some high level
|
||||
* operations like indexing into prediction tensor or gradient matrix. It can be passed
|
||||
* into CUDA kernel as normal argument for GPU algorithms.
|
||||
*
|
||||
* Ideally we should add a template parameter `bool on_host` so that the compiler can
|
||||
* prevent passing/accessing the wrong view, but inheritance is heavily used in XGBoost so
|
||||
* some functions expect data types that can be used in everywhere (update prediction
|
||||
* cache for example).
|
||||
*/
|
||||
template <typename T, int32_t kDim = 5>
|
||||
class TensorView {
|
||||
@ -96,7 +219,7 @@ class TensorView {
|
||||
StrideT stride_{1};
|
||||
ShapeT shape_{0};
|
||||
common::Span<T> data_;
|
||||
T* ptr_{nullptr}; // pointer of data_ to avoid bound check.
|
||||
T *ptr_{nullptr}; // pointer of data_ to avoid bound check.
|
||||
|
||||
size_t size_{0};
|
||||
int32_t device_{-1};
|
||||
@ -110,42 +233,45 @@ class TensorView {
|
||||
}
|
||||
}
|
||||
|
||||
struct SliceHelper {
|
||||
size_t old_dim;
|
||||
size_t new_dim;
|
||||
size_t offset;
|
||||
};
|
||||
|
||||
template <int32_t D, typename... S>
|
||||
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
|
||||
size_t new_stride[D], detail::AllTag) const {
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename... S>
|
||||
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
|
||||
detail::AllTag) const {
|
||||
static_assert(new_dim < D, "");
|
||||
static_assert(old_dim < kDim, "");
|
||||
new_stride[new_dim] = stride_[old_dim];
|
||||
new_shape[new_dim] = shape_[old_dim];
|
||||
return {old_dim + 1, new_dim + 1, 0};
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <int32_t D, typename... S>
|
||||
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
|
||||
size_t new_stride[D], detail::AllTag,
|
||||
/**
|
||||
* \brief Slice dimension for All tag.
|
||||
*/
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename... S>
|
||||
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], detail::AllTag,
|
||||
S &&...slices) const {
|
||||
static_assert(new_dim < D, "");
|
||||
static_assert(old_dim < kDim, "");
|
||||
new_stride[new_dim] = stride_[old_dim];
|
||||
new_shape[new_dim] = shape_[old_dim];
|
||||
return MakeSliceDim<D>(old_dim + 1, new_dim + 1, new_shape, new_stride, slices...);
|
||||
return MakeSliceDim<old_dim + 1, new_dim + 1, D>(new_shape, new_stride,
|
||||
std::forward<S>(slices)...);
|
||||
}
|
||||
|
||||
template <int32_t D, typename Index>
|
||||
XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D],
|
||||
size_t new_stride[D], Index i) const {
|
||||
return {old_dim + 1, new_dim, stride_[old_dim] * i};
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename Index>
|
||||
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], Index i) const {
|
||||
static_assert(old_dim < kDim, "");
|
||||
return stride_[old_dim] * i;
|
||||
}
|
||||
|
||||
template <int32_t D, typename Index, typename... S>
|
||||
XGBOOST_DEVICE std::enable_if_t<std::is_integral<Index>::value, SliceHelper> MakeSliceDim(
|
||||
size_t old_dim, size_t new_dim, size_t new_shape[D], size_t new_stride[D], Index i,
|
||||
S &&...slices) const {
|
||||
/**
|
||||
* \brief Slice dimension for Index tag.
|
||||
*/
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename Index, typename... S>
|
||||
XGBOOST_DEVICE std::enable_if_t<std::is_integral<Index>::value, size_t> MakeSliceDim(
|
||||
size_t new_shape[D], size_t new_stride[D], Index i, S &&...slices) const {
|
||||
static_assert(old_dim < kDim, "");
|
||||
auto offset = stride_[old_dim] * i;
|
||||
auto res = MakeSliceDim<D>(old_dim + 1, new_dim, new_shape, new_stride, slices...);
|
||||
return {res.old_dim, res.new_dim, res.offset + offset};
|
||||
auto res =
|
||||
MakeSliceDim<old_dim + 1, new_dim, D>(new_shape, new_stride, std::forward<S>(slices)...);
|
||||
return res + offset;
|
||||
}
|
||||
|
||||
public:
|
||||
@ -174,12 +300,10 @@ class TensorView {
|
||||
shape_[i] = 1;
|
||||
}
|
||||
// stride
|
||||
stride_[kDim - 1] = 1;
|
||||
for (int32_t s = kDim - 2; s >= 0; --s) {
|
||||
stride_[s] = shape_[s + 1] * stride_[s + 1];
|
||||
}
|
||||
detail::CalcStride(shape_, stride_);
|
||||
// size
|
||||
this->CalcSize();
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create a tensor with data, shape and strides. Don't use this constructor if
|
||||
@ -195,7 +319,7 @@ class TensorView {
|
||||
stride_[i] = stride[i];
|
||||
});
|
||||
this->CalcSize();
|
||||
};
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE TensorView(TensorView const &that)
|
||||
: data_{that.data_}, ptr_{data_.data()}, size_{that.size_}, device_{that.device_} {
|
||||
@ -221,7 +345,8 @@ class TensorView {
|
||||
template <typename... Index>
|
||||
XGBOOST_DEVICE T &operator()(Index &&...index) {
|
||||
static_assert(sizeof...(index) <= kDim, "Invalid index.");
|
||||
size_t offset = detail::Offset(stride_, 0ul, 0ul, index...);
|
||||
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
|
||||
assert(offset < data_.size() && "Out of bound access.");
|
||||
return ptr_[offset];
|
||||
}
|
||||
/**
|
||||
@ -230,12 +355,14 @@ class TensorView {
|
||||
template <typename... Index>
|
||||
XGBOOST_DEVICE T const &operator()(Index &&...index) const {
|
||||
static_assert(sizeof...(index) <= kDim, "Invalid index.");
|
||||
size_t offset = detail::Offset(stride_, 0ul, 0ul, index...);
|
||||
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
|
||||
assert(offset < data_.size() && "Out of bound access.");
|
||||
return ptr_[offset];
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Slice the tensor. The returned tensor has inferred dim and shape.
|
||||
* \brief Slice the tensor. The returned tensor has inferred dim and shape. Scalar
|
||||
* result is not supported.
|
||||
*
|
||||
* \code
|
||||
*
|
||||
@ -252,10 +379,10 @@ class TensorView {
|
||||
int32_t constexpr kNewDim{detail::CalcSliceDim<detail::IndexToTag<S>...>()};
|
||||
size_t new_shape[kNewDim];
|
||||
size_t new_stride[kNewDim];
|
||||
auto res = MakeSliceDim<kNewDim>(size_t(0), size_t(0), new_shape, new_stride, slices...);
|
||||
auto offset = MakeSliceDim<0, 0, kNewDim>(new_shape, new_stride, std::forward<S>(slices)...);
|
||||
// ret is a different type due to changed dimension, so we can not access its private
|
||||
// fields.
|
||||
TensorView<T, kNewDim> ret{data_.subspan(data_.empty() ? 0 : res.offset), new_shape, new_stride,
|
||||
TensorView<T, kNewDim> ret{data_.subspan(data_.empty() ? 0 : offset), new_shape, new_stride,
|
||||
device_};
|
||||
return ret;
|
||||
}
|
||||
@ -275,12 +402,91 @@ class TensorView {
|
||||
XGBOOST_DEVICE auto cend() const { return data_.cend(); } // NOLINT
|
||||
XGBOOST_DEVICE auto begin() { return data_.begin(); } // NOLINT
|
||||
XGBOOST_DEVICE auto end() { return data_.end(); } // NOLINT
|
||||
|
||||
/**
|
||||
* \brief Number of items in the tensor.
|
||||
*/
|
||||
XGBOOST_DEVICE size_t Size() const { return size_; }
|
||||
/**
|
||||
* \brief Whether it's a contiguous array. (c and f contiguous are both contiguous)
|
||||
*/
|
||||
XGBOOST_DEVICE bool Contiguous() const { return size_ == data_.size(); }
|
||||
/**
|
||||
* \brief Obtain the raw data.
|
||||
*/
|
||||
XGBOOST_DEVICE auto Values() const { return data_; }
|
||||
/**
|
||||
* \brief Obtain the CUDA device ordinal.
|
||||
*/
|
||||
XGBOOST_DEVICE auto DeviceIdx() const { return device_; }
|
||||
|
||||
/**
|
||||
* \brief Array Interface defined by
|
||||
* <a href="https://numpy.org/doc/stable/reference/arrays.interface.html">numpy</a>.
|
||||
*
|
||||
* `stream` is optionally included when data is on CUDA device.
|
||||
*/
|
||||
Json ArrayInterface() const {
|
||||
Json array_interface{Object{}};
|
||||
array_interface["data"] = std::vector<Json>(2);
|
||||
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(data_.data()));
|
||||
array_interface["data"][1] = Boolean{true};
|
||||
if (this->DeviceIdx() >= 0) {
|
||||
// Change this once we have different CUDA stream.
|
||||
array_interface["stream"] = Null{};
|
||||
}
|
||||
std::vector<Json> shape(Shape().size());
|
||||
std::vector<Json> stride(Stride().size());
|
||||
for (size_t i = 0; i < Shape().size(); ++i) {
|
||||
shape[i] = Integer(Shape(i));
|
||||
stride[i] = Integer(Stride(i) * sizeof(T));
|
||||
}
|
||||
array_interface["shape"] = Array{shape};
|
||||
array_interface["strides"] = Array{stride};
|
||||
array_interface["version"] = 3;
|
||||
|
||||
char constexpr kT = detail::ArrayInterfaceHandler::TypeChar<T>();
|
||||
static_assert(kT != '\0', "");
|
||||
if (DMLC_LITTLE_ENDIAN) {
|
||||
array_interface["typestr"] = String{"<" + (kT + std::to_string(sizeof(T)))};
|
||||
} else {
|
||||
array_interface["typestr"] = String{">" + (kT + std::to_string(sizeof(T)))};
|
||||
}
|
||||
return array_interface;
|
||||
}
|
||||
/**
|
||||
* \brief Same as const version, but returns non-readonly data pointer.
|
||||
*/
|
||||
Json ArrayInterface() {
|
||||
auto const &as_const = *this;
|
||||
auto res = as_const.ArrayInterface();
|
||||
res["data"][1] = Boolean{false};
|
||||
return res;
|
||||
}
|
||||
|
||||
auto ArrayInterfaceStr() const {
|
||||
std::string str;
|
||||
Json::Dump(this->ArrayInterface(), &str);
|
||||
return str;
|
||||
}
|
||||
auto ArrayInterfaceStr() {
|
||||
std::string str;
|
||||
Json::Dump(this->ArrayInterface(), &str);
|
||||
return str;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Turns linear index into multi-dimension index. Similar to numpy unravel.
|
||||
*/
|
||||
template <size_t D>
|
||||
XGBOOST_DEVICE auto UnravelIndex(size_t idx, common::Span<size_t const, D> shape) {
|
||||
if (idx > std::numeric_limits<uint32_t>::max()) {
|
||||
return detail::UnravelImpl<uint64_t, D>(static_cast<uint64_t>(idx), shape);
|
||||
} else {
|
||||
return detail::UnravelImpl<uint32_t, D>(static_cast<uint32_t>(idx), shape);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief A view over a vector, specialization of Tensor
|
||||
*
|
||||
@ -289,6 +495,19 @@ class TensorView {
|
||||
template <typename T>
|
||||
using VectorView = TensorView<T, 1>;
|
||||
|
||||
/**
|
||||
* \brief Create a vector view from contigious memory.
|
||||
*
|
||||
* \param ptr Pointer to the contigious memory.
|
||||
* \param s Size of the vector.
|
||||
* \param device (optional) Device ordinal, default to be host.
|
||||
*/
|
||||
template <typename T>
|
||||
auto MakeVec(T *ptr, size_t s, int32_t device = -1) {
|
||||
using U = std::remove_const_t<std::remove_pointer_t<decltype(ptr)>> const;
|
||||
return linalg::TensorView<U, 1>{{ptr, s}, {s}, device};
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief A view over a matrix, specialization of Tensor.
|
||||
*
|
||||
@ -296,6 +515,163 @@ using VectorView = TensorView<T, 1>;
|
||||
*/
|
||||
template <typename T>
|
||||
using MatrixView = TensorView<T, 2>;
|
||||
|
||||
/**
|
||||
* \brief A tensor storage. To use it for other functionality like slicing one needs to
|
||||
* obtain a view first. This way we can use it on both host and device.
|
||||
*/
|
||||
template <typename T, int32_t kDim = 5>
|
||||
class Tensor {
|
||||
public:
|
||||
using ShapeT = size_t[kDim];
|
||||
using StrideT = ShapeT;
|
||||
|
||||
private:
|
||||
HostDeviceVector<T> data_;
|
||||
ShapeT shape_{0};
|
||||
|
||||
public:
|
||||
Tensor() = default;
|
||||
|
||||
/**
|
||||
* \brief Create a tensor with shape and device ordinal. The storage is initialized
|
||||
* automatically.
|
||||
*
|
||||
* See \ref TensorView for parameters of this constructor.
|
||||
*/
|
||||
template <typename I, int32_t D>
|
||||
explicit Tensor(I const (&shape)[D], int32_t device) {
|
||||
// No device unroll as this is a host only function.
|
||||
std::copy(shape, shape + D, shape_);
|
||||
for (auto i = D; i < kDim; ++i) {
|
||||
shape_[i] = 1;
|
||||
}
|
||||
auto size = detail::CalcSize(shape_);
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
}
|
||||
data_.Resize(size);
|
||||
if (device >= 0) {
|
||||
data_.DevicePointer(); // Pull to device
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Initialize from 2 host iterators.
|
||||
*/
|
||||
template <typename It, typename I, int32_t D>
|
||||
explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) {
|
||||
// shape
|
||||
static_assert(D <= kDim, "Invalid shape.");
|
||||
std::copy(shape, shape + D, shape_);
|
||||
for (auto i = D; i < kDim; ++i) {
|
||||
shape_[i] = 1;
|
||||
}
|
||||
auto &h_vec = data_.HostVector();
|
||||
h_vec.insert(h_vec.begin(), begin, end);
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
data_.DevicePointer(); // Pull to device;
|
||||
}
|
||||
CHECK_EQ(data_.Size(), detail::CalcSize(shape_));
|
||||
}
|
||||
/**
|
||||
* \brief Get a \ref TensorView for this tensor.
|
||||
*/
|
||||
TensorView<T, kDim> View(int32_t device) {
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
auto span = data_.DeviceSpan();
|
||||
return {span, shape_, device};
|
||||
} else {
|
||||
auto span = data_.HostSpan();
|
||||
return {span, shape_, device};
|
||||
}
|
||||
}
|
||||
TensorView<T const, kDim> View(int32_t device) const {
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
auto span = data_.ConstDeviceSpan();
|
||||
return {span, shape_, device};
|
||||
} else {
|
||||
auto span = data_.ConstHostSpan();
|
||||
return {span, shape_, device};
|
||||
}
|
||||
}
|
||||
|
||||
size_t Size() const { return data_.Size(); }
|
||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||
auto Shape(size_t i) const { return shape_[i]; }
|
||||
|
||||
HostDeviceVector<T> *Data() { return &data_; }
|
||||
HostDeviceVector<T> const *Data() const { return &data_; }
|
||||
|
||||
/**
|
||||
* \brief Visitor function for modification that changes shape and data.
|
||||
*
|
||||
* \tparam Fn function that takes a pointer to `HostDeviceVector` and a static sized
|
||||
* span as parameters.
|
||||
*/
|
||||
template <typename Fn>
|
||||
void ModifyInplace(Fn &&fn) {
|
||||
fn(this->Data(), common::Span<size_t, kDim>{this->shape_});
|
||||
CHECK_EQ(this->Data()->Size(), detail::CalcSize(this->shape_))
|
||||
<< "Inconsistent size after modification.";
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Reshape the tensor.
|
||||
*
|
||||
* If the total size is changed, then data in this tensor is no longer valid.
|
||||
*/
|
||||
template <typename... S>
|
||||
void Reshape(S &&...s) {
|
||||
static_assert(sizeof...(S) <= kDim, "Invalid shape.");
|
||||
detail::ReshapeImpl<0>(shape_, std::forward<S>(s)...);
|
||||
auto constexpr kEnd = sizeof...(S);
|
||||
static_assert(kEnd <= kDim, "Invalid shape.");
|
||||
std::fill(shape_ + kEnd, shape_ + kDim, 1);
|
||||
auto n = detail::CalcSize(shape_);
|
||||
data_.Resize(n);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Reshape the tensor.
|
||||
*
|
||||
* If the total size is changed, then data in this tensor is no longer valid.
|
||||
*/
|
||||
template <int32_t D>
|
||||
void Reshape(size_t (&shape)[D]) {
|
||||
static_assert(D <= kDim, "Invalid shape.");
|
||||
std::copy(shape, shape + D, this->shape_);
|
||||
std::fill(shape_ + D, shape_ + kDim, 1);
|
||||
auto n = detail::CalcSize(shape_);
|
||||
data_.Resize(n);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Set device ordinal for this tensor.
|
||||
*/
|
||||
void SetDevice(int32_t device) { data_.SetDevice(device); }
|
||||
};
|
||||
|
||||
// Only first axis is supported for now.
|
||||
template <typename T, int32_t D>
|
||||
void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
|
||||
if (r.Data()->DeviceIdx() >= 0) {
|
||||
l->Data()->SetDevice(r.Data()->DeviceIdx());
|
||||
}
|
||||
l->ModifyInplace([&](HostDeviceVector<T> *data, common::Span<size_t, D> shape) {
|
||||
for (size_t i = 1; i < D; ++i) {
|
||||
if (shape[i] == 0) {
|
||||
shape[i] = r.Shape(i);
|
||||
} else {
|
||||
CHECK_EQ(shape[i], r.Shape(i));
|
||||
}
|
||||
}
|
||||
data->Extend(*r.Data());
|
||||
shape[0] = l->Shape(0) + r.Shape(0);
|
||||
});
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LINALG_H_
|
||||
|
||||
@ -170,6 +170,7 @@ void HostDeviceVector<T>::SetDevice(int) const {}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
|
||||
@ -398,6 +398,7 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
|
||||
25
src/common/linalg_op.cuh
Normal file
25
src/common/linalg_op.cuh
Normal file
@ -0,0 +1,25 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#define XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#include "device_helpers.cuh"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||
} else {
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
|
||||
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
v = fn(i, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
25
src/common/linalg_op.h
Normal file
25
src/common/linalg_op.h
Normal file
@ -0,0 +1,25 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_H_
|
||||
#define XGBOOST_COMMON_LINALG_OP_H_
|
||||
#include "threading_utils.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||
} else {
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
|
||||
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
v = fn(i, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_LINALG_OP_H_
|
||||
@ -1,11 +1,21 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "../../../src/common/linalg_op.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace {
|
||||
auto kCpuId = GenericParameter::kCpuId;
|
||||
}
|
||||
|
||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) {
|
||||
storage->Resize(n_rows * n_cols);
|
||||
auto &h_storage = storage->HostVector();
|
||||
@ -16,16 +26,16 @@ auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t
|
||||
return m;
|
||||
}
|
||||
|
||||
TEST(Linalg, Matrix) {
|
||||
TEST(Linalg, MatrixView) {
|
||||
size_t kRows = 31, kCols = 77;
|
||||
HostDeviceVector<float> storage;
|
||||
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
||||
ASSERT_EQ(m.DeviceIdx(), GenericParameter::kCpuId);
|
||||
ASSERT_EQ(m.DeviceIdx(), kCpuId);
|
||||
ASSERT_EQ(m(0, 0), 0);
|
||||
ASSERT_EQ(m(kRows - 1, kCols - 1), storage.Size() - 1);
|
||||
}
|
||||
|
||||
TEST(Linalg, Vector) {
|
||||
TEST(Linalg, VectorView) {
|
||||
size_t kRows = 31, kCols = 77;
|
||||
HostDeviceVector<float> storage;
|
||||
auto m = MakeMatrixFromTest(&storage, kRows, kCols);
|
||||
@ -37,7 +47,7 @@ TEST(Linalg, Vector) {
|
||||
ASSERT_EQ(v(0), 3);
|
||||
}
|
||||
|
||||
TEST(Linalg, Tensor) {
|
||||
TEST(Linalg, TensorView) {
|
||||
std::vector<double> data(2 * 3 * 4, 0);
|
||||
std::iota(data.begin(), data.end(), 0);
|
||||
|
||||
@ -99,14 +109,123 @@ TEST(Linalg, Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Linalg, Tensor) {
|
||||
{
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId};
|
||||
auto view = t.View(kCpuId);
|
||||
|
||||
auto const &as_const = t;
|
||||
auto k_view = as_const.View(kCpuId);
|
||||
|
||||
size_t n = 2 * 3 * 4;
|
||||
ASSERT_EQ(t.Size(), n);
|
||||
ASSERT_TRUE(std::equal(k_view.cbegin(), k_view.cbegin(), view.begin()));
|
||||
|
||||
Tensor<float, 3> t_0{std::move(t)};
|
||||
ASSERT_EQ(t_0.Size(), n);
|
||||
ASSERT_EQ(t_0.Shape(0), 2);
|
||||
ASSERT_EQ(t_0.Shape(1), 3);
|
||||
ASSERT_EQ(t_0.Shape(2), 4);
|
||||
}
|
||||
{
|
||||
// Reshape
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId};
|
||||
t.Reshape(4, 3, 2);
|
||||
ASSERT_EQ(t.Size(), 24);
|
||||
ASSERT_EQ(t.Shape(2), 2);
|
||||
t.Reshape(1);
|
||||
ASSERT_EQ(t.Size(), 1);
|
||||
t.Reshape(0, 0, 0);
|
||||
ASSERT_EQ(t.Size(), 0);
|
||||
t.Reshape(0, 3, 0);
|
||||
ASSERT_EQ(t.Size(), 0);
|
||||
ASSERT_EQ(t.Shape(1), 3);
|
||||
t.Reshape(3, 3, 3);
|
||||
ASSERT_EQ(t.Size(), 27);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Linalg, Empty) {
|
||||
auto t = TensorView<double, 2>{{}, {0, 3}, GenericParameter::kCpuId};
|
||||
{
|
||||
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId};
|
||||
for (int32_t i : {0, 1, 2}) {
|
||||
auto s = t.Slice(All(), i);
|
||||
ASSERT_EQ(s.Size(), 0);
|
||||
ASSERT_EQ(s.Shape().size(), 1);
|
||||
ASSERT_EQ(s.Shape(0), 0);
|
||||
}
|
||||
}
|
||||
{
|
||||
auto t = Tensor<double, 2>{{0, 3}, kCpuId};
|
||||
ASSERT_EQ(t.Size(), 0);
|
||||
auto view = t.View(kCpuId);
|
||||
|
||||
for (int32_t i : {0, 1, 2}) {
|
||||
auto s = view.Slice(All(), i);
|
||||
ASSERT_EQ(s.Size(), 0);
|
||||
ASSERT_EQ(s.Shape().size(), 1);
|
||||
ASSERT_EQ(s.Shape(0), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Linalg, ArrayInterface) {
|
||||
auto cpu = kCpuId;
|
||||
auto t = Tensor<double, 2>{{3, 3}, cpu};
|
||||
auto v = t.View(cpu);
|
||||
std::iota(v.begin(), v.end(), 0);
|
||||
auto arr = Json::Load(StringView{v.ArrayInterfaceStr()});
|
||||
ASSERT_EQ(get<Integer>(arr["shape"][0]), 3);
|
||||
ASSERT_EQ(get<Integer>(arr["strides"][0]), 3 * sizeof(double));
|
||||
|
||||
ASSERT_FALSE(get<Boolean>(arr["data"][1]));
|
||||
ASSERT_EQ(reinterpret_cast<double *>(get<Integer>(arr["data"][0])), v.Values().data());
|
||||
}
|
||||
|
||||
TEST(Linalg, Popc) {
|
||||
{
|
||||
uint32_t v{0};
|
||||
ASSERT_EQ(detail::NativePopc(v), 0);
|
||||
ASSERT_EQ(detail::Popc(v), 0);
|
||||
v = 1;
|
||||
ASSERT_EQ(detail::NativePopc(v), 1);
|
||||
ASSERT_EQ(detail::Popc(v), 1);
|
||||
v = 0xffffffff;
|
||||
ASSERT_EQ(detail::NativePopc(v), 32);
|
||||
ASSERT_EQ(detail::Popc(v), 32);
|
||||
}
|
||||
{
|
||||
uint64_t v{0};
|
||||
ASSERT_EQ(detail::NativePopc(v), 0);
|
||||
ASSERT_EQ(detail::Popc(v), 0);
|
||||
v = 1;
|
||||
ASSERT_EQ(detail::NativePopc(v), 1);
|
||||
ASSERT_EQ(detail::Popc(v), 1);
|
||||
v = 0xffffffff;
|
||||
ASSERT_EQ(detail::NativePopc(v), 32);
|
||||
ASSERT_EQ(detail::Popc(v), 32);
|
||||
v = 0xffffffffffffffff;
|
||||
ASSERT_EQ(detail::NativePopc(v), 64);
|
||||
ASSERT_EQ(detail::Popc(v), 64);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Linalg, Stack) {
|
||||
Tensor<float, 3> l{{2, 3, 4}, kCpuId};
|
||||
ElementWiseKernelHost(l.View(kCpuId), omp_get_max_threads(),
|
||||
[=](size_t i, float v) { return i; });
|
||||
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId};
|
||||
ElementWiseKernelHost(r_0.View(kCpuId), omp_get_max_threads(),
|
||||
[=](size_t i, float v) { return i; });
|
||||
|
||||
Stack(&l, r_0);
|
||||
|
||||
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId};
|
||||
Stack(&l, r_1);
|
||||
ASSERT_EQ(l.Shape(0), 4);
|
||||
|
||||
Stack(&r_1, l);
|
||||
ASSERT_EQ(r_1.Shape(0), l.Shape(0));
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
|
||||
62
tests/cpp/common/test_linalg.cu
Normal file
62
tests/cpp/common/test_linalg.cu
Normal file
@ -0,0 +1,62 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/common/linalg_op.cuh"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace {
|
||||
void TestElementWiseKernel() {
|
||||
Tensor<float, 3> l{{2, 3, 4}, 0};
|
||||
{
|
||||
/**
|
||||
* Non-contiguous
|
||||
*/
|
||||
// GPU view
|
||||
auto t = l.View(0).Slice(linalg::All(), 1, linalg::All());
|
||||
ASSERT_FALSE(t.Contiguous());
|
||||
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; });
|
||||
// CPU view
|
||||
t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All());
|
||||
size_t k = 0;
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||
ASSERT_EQ(k++, t(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
t = l.View(0).Slice(linalg::All(), 1, linalg::All());
|
||||
ElementWiseKernelDevice(t, [] __device__(size_t i, float v) {
|
||||
SPAN_CHECK(v == i);
|
||||
return v;
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
/**
|
||||
* Contiguous
|
||||
*/
|
||||
auto t = l.View(0);
|
||||
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; });
|
||||
ASSERT_TRUE(t.Contiguous());
|
||||
// CPU view
|
||||
t = l.View(GenericParameter::kCpuId);
|
||||
|
||||
size_t ind = 0;
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(1); ++j) {
|
||||
for (size_t k = 0; k < l.Shape(2); ++k) {
|
||||
ASSERT_EQ(ind++, t(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user