Implement typed storage for tensor. (#7429)
* Add `Tensor` class. * Add elementwise kernel for CPU and GPU. * Add unravel index. * Move some computation to compile time.
This commit is contained in:
@@ -170,6 +170,7 @@ void HostDeviceVector<T>::SetDevice(int) const {}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
|
||||
@@ -398,6 +398,7 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
|
||||
25
src/common/linalg_op.cuh
Normal file
25
src/common/linalg_op.cuh
Normal file
@@ -0,0 +1,25 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#define XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#include "device_helpers.cuh"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||
} else {
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
|
||||
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
v = fn(i, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
25
src/common/linalg_op.h
Normal file
25
src/common/linalg_op.h
Normal file
@@ -0,0 +1,25 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_H_
|
||||
#define XGBOOST_COMMON_LINALG_OP_H_
|
||||
#include "threading_utils.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||
} else {
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
|
||||
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
v = fn(i, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_LINALG_OP_H_
|
||||
Reference in New Issue
Block a user