Retire DVec class in favour of c++20 style span for device memory. (#4293)
This commit is contained in:
parent
c85181dd8a
commit
3f312e30db
@ -227,179 +227,79 @@ inline void LaunchN(int device_idx, size_t n, L lambda) {
|
||||
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
|
||||
}
|
||||
|
||||
/*
|
||||
* Memory
|
||||
|
||||
/**
|
||||
* \brief A double buffer, useful for algorithms like sort.
|
||||
*/
|
||||
|
||||
enum MemoryType { kDevice, kDeviceManaged };
|
||||
|
||||
template <MemoryType MemoryT>
|
||||
class BulkAllocator;
|
||||
template <typename T>
|
||||
class DVec2;
|
||||
|
||||
template <typename T>
|
||||
class DVec {
|
||||
friend class DVec2<T>;
|
||||
|
||||
private:
|
||||
T *ptr_;
|
||||
size_t size_;
|
||||
int device_idx_;
|
||||
|
||||
class DoubleBuffer {
|
||||
public:
|
||||
void ExternalAllocate(int device_idx, void *ptr, size_t size) {
|
||||
if (!Empty()) {
|
||||
throw std::runtime_error("Tried to allocate DVec but already allocated");
|
||||
}
|
||||
ptr_ = static_cast<T *>(ptr);
|
||||
size_ = size;
|
||||
device_idx_ = device_idx;
|
||||
safe_cuda(cudaSetDevice(device_idx_));
|
||||
cub::DoubleBuffer<T> buff;
|
||||
xgboost::common::Span<T> a, b;
|
||||
DoubleBuffer() = default;
|
||||
|
||||
size_t Size() const {
|
||||
CHECK_EQ(a.size(), b.size());
|
||||
return a.size();
|
||||
}
|
||||
cub::DoubleBuffer<T> &CubBuffer() { return buff; }
|
||||
|
||||
T *Current() { return buff.Current(); }
|
||||
xgboost::common::Span<T> CurrentSpan() {
|
||||
return xgboost::common::Span<T>{
|
||||
buff.Current(),
|
||||
static_cast<typename xgboost::common::Span<T>::index_type>(Size())};
|
||||
}
|
||||
|
||||
DVec() : ptr_(NULL), size_(0), device_idx_(-1) {}
|
||||
size_t Size() const { return size_; }
|
||||
int DeviceIdx() const { return device_idx_; }
|
||||
bool Empty() const { return ptr_ == NULL || size_ == 0; }
|
||||
|
||||
T *Data() { return ptr_; }
|
||||
|
||||
const T *Data() const { return ptr_; }
|
||||
|
||||
xgboost::common::Span<const T> GetSpan() const {
|
||||
return xgboost::common::Span<const T>(ptr_, this->Size());
|
||||
}
|
||||
|
||||
xgboost::common::Span<T> GetSpan() {
|
||||
return xgboost::common::Span<T>(ptr_, this->Size());
|
||||
}
|
||||
|
||||
std::vector<T> AsVector() const {
|
||||
std::vector<T> h_vector(Size());
|
||||
safe_cuda(cudaSetDevice(device_idx_));
|
||||
safe_cuda(cudaMemcpy(h_vector.data(), ptr_, Size() * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return h_vector;
|
||||
}
|
||||
|
||||
void Fill(T value) {
|
||||
auto d_ptr = ptr_;
|
||||
LaunchN(device_idx_, Size(),
|
||||
[=] __device__(size_t idx) { d_ptr[idx] = value; });
|
||||
}
|
||||
|
||||
void Print() {
|
||||
auto h_vector = this->AsVector();
|
||||
for (auto e : h_vector) {
|
||||
std::cout << e << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
thrust::device_ptr<T> tbegin() { return thrust::device_pointer_cast(ptr_); }
|
||||
|
||||
thrust::device_ptr<T> tend() {
|
||||
return thrust::device_pointer_cast(ptr_ + Size());
|
||||
}
|
||||
|
||||
template <typename T2>
|
||||
DVec &operator=(const std::vector<T2> &other) {
|
||||
this->copy(other.begin(), other.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
DVec &operator=(DVec<T> &other) {
|
||||
if (other.Size() != Size()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot copy assign DVec to DVec, sizes are different");
|
||||
}
|
||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||
if (other.DeviceIdx() == this->DeviceIdx()) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(),
|
||||
other.Size() * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
} else {
|
||||
std::cout << "deviceother: " << other.DeviceIdx()
|
||||
<< " devicethis: " << this->DeviceIdx() << std::endl;
|
||||
std::cout << "size deviceother: " << other.Size()
|
||||
<< " devicethis: " << this->DeviceIdx() << std::endl;
|
||||
throw std::runtime_error("Cannot copy to/from different devices");
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename IterT>
|
||||
void copy(IterT begin, IterT end) {
|
||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||
if (end - begin != Size()) {
|
||||
LOG(FATAL) << "Cannot copy assign vector to DVec, sizes are different" <<
|
||||
" vector::Size(): " << end - begin << " DVec::Size(): " << Size();
|
||||
}
|
||||
thrust::copy(begin, end, this->tbegin());
|
||||
}
|
||||
|
||||
void copy(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) {
|
||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||
if (end - begin != Size()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot copy assign vector to dvec, sizes are different");
|
||||
}
|
||||
safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
T *other() { return buff.Alternate(); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @class DVec2 device_helpers.cuh
|
||||
* @brief wrapper for storing 2 DVec's which are needed for cub::DoubleBuffer
|
||||
* \brief Copies device span to std::vector.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param [in,out] dst Copy destination.
|
||||
* \param src Copy source. Must be device memory.
|
||||
*/
|
||||
template <typename T>
|
||||
class DVec2 {
|
||||
private:
|
||||
DVec<T> d1_, d2_;
|
||||
cub::DoubleBuffer<T> buff_;
|
||||
int device_idx_;
|
||||
void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<T> src) {
|
||||
CHECK_EQ(dst->size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst->data(), src.data(), dst->size() * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
public:
|
||||
void ExternalAllocate(int device_idx, void *ptr1, void *ptr2, size_t size) {
|
||||
if (!Empty()) {
|
||||
throw std::runtime_error("Tried to allocate DVec2 but already allocated");
|
||||
}
|
||||
device_idx_ = device_idx;
|
||||
d1_.ExternalAllocate(device_idx_, ptr1, size);
|
||||
d2_.ExternalAllocate(device_idx_, ptr2, size);
|
||||
buff_.d_buffers[0] = static_cast<T *>(ptr1);
|
||||
buff_.d_buffers[1] = static_cast<T *>(ptr2);
|
||||
buff_.selector = 0;
|
||||
}
|
||||
DVec2() : d1_(), d2_(), buff_(), device_idx_(-1) {}
|
||||
/**
|
||||
* \brief Copies std::vector to device span.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param dst Copy destination. Must be device memory.
|
||||
* \param src Copy source.
|
||||
*/
|
||||
template <typename T>
|
||||
void CopyVectorToDeviceSpan(xgboost::common::Span<T> dst ,const std::vector<T>&src)
|
||||
{
|
||||
CHECK_EQ(dst.size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
size_t Size() const { return d1_.Size(); }
|
||||
int DeviceIdx() const { return device_idx_; }
|
||||
bool Empty() const { return d1_.Empty() || d2_.Empty(); }
|
||||
|
||||
cub::DoubleBuffer<T> &buff() { return buff_; }
|
||||
|
||||
DVec<T> &D1() { return d1_; }
|
||||
|
||||
DVec<T> &D2() { return d2_; }
|
||||
|
||||
T *Current() { return buff_.Current(); }
|
||||
xgboost::common::Span<T> CurrentSpan() {
|
||||
return xgboost::common::Span<T>{
|
||||
buff_.Current(),
|
||||
static_cast<typename xgboost::common::Span<T>::index_type>(Size())};
|
||||
}
|
||||
|
||||
DVec<T> &CurrentDVec() { return buff_.selector == 0 ? D1() : D2(); }
|
||||
|
||||
T *other() { return buff_.Alternate(); }
|
||||
};
|
||||
/**
|
||||
* \brief Device to device memory copy from src to dst. Spans must be the same size. Use subspan to
|
||||
* copy from a smaller array to a larger array.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param dst Copy destination. Must be device memory.
|
||||
* \param src Copy source. Must be device memory.
|
||||
*/
|
||||
template <typename T>
|
||||
void CopyDeviceSpan(xgboost::common::Span<T> dst,
|
||||
xgboost::common::Span<T> src) {
|
||||
CHECK_EQ(dst.size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
/*! \brief Helper for allocating large block of memory. */
|
||||
template <MemoryType MemoryT>
|
||||
class BulkAllocator {
|
||||
std::vector<char *> d_ptr_;
|
||||
std::vector<size_t> size_;
|
||||
@ -413,70 +313,73 @@ class BulkAllocator {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t GetSizeBytes(DVec<T> *first_vec, size_t first_size) {
|
||||
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size) {
|
||||
return AlignRoundUp(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
size_t GetSizeBytes(DVec<T> *first_vec, size_t first_size, Args... args) {
|
||||
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size, Args... args) {
|
||||
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateDVec(int device_idx, char *ptr, DVec<T> *first_vec,
|
||||
size_t first_size) {
|
||||
first_vec->ExternalAllocate(device_idx, static_cast<void *>(ptr),
|
||||
first_size);
|
||||
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
|
||||
size_t first_size) {
|
||||
*first_vec = xgboost::common::Span<T>(reinterpret_cast<T *>(ptr), first_size);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AllocateDVec(int device_idx, char *ptr, DVec<T> *first_vec,
|
||||
size_t first_size, Args... args) {
|
||||
AllocateDVec<T>(device_idx, ptr, first_vec, first_size);
|
||||
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
|
||||
size_t first_size, Args... args) {
|
||||
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += AlignRoundUp(first_size * sizeof(T));
|
||||
AllocateDVec(device_idx, ptr, args...);
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
}
|
||||
|
||||
char *AllocateDevice(int device_idx, size_t bytes, MemoryType t) {
|
||||
char *AllocateDevice(int device_idx, size_t bytes) {
|
||||
char *ptr;
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
safe_cuda(cudaMalloc(&ptr, bytes));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t GetSizeBytes(DVec2<T> *first_vec, size_t first_size) {
|
||||
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size) {
|
||||
return 2 * AlignRoundUp(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
size_t GetSizeBytes(DVec2<T> *first_vec, size_t first_size, Args... args) {
|
||||
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size, Args... args) {
|
||||
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateDVec(int device_idx, char *ptr, DVec2<T> *first_vec,
|
||||
size_t first_size) {
|
||||
first_vec->ExternalAllocate(
|
||||
device_idx, static_cast<void *>(ptr),
|
||||
static_cast<void *>(ptr + AlignRoundUp(first_size * sizeof(T))),
|
||||
first_size);
|
||||
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
|
||||
size_t first_size) {
|
||||
auto ptr1 = reinterpret_cast<T *>(ptr);
|
||||
auto ptr2 = ptr1 + first_size;
|
||||
first_vec->a = xgboost::common::Span<T>(ptr1, first_size);
|
||||
first_vec->b = xgboost::common::Span<T>(ptr2, first_size);
|
||||
first_vec->buff.d_buffers[0] = ptr1;
|
||||
first_vec->buff.d_buffers[1] = ptr2;
|
||||
first_vec->buff.selector = 0;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AllocateDVec(int device_idx, char *ptr, DVec2<T> *first_vec,
|
||||
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
|
||||
size_t first_size, Args... args) {
|
||||
AllocateDVec<T>(device_idx, ptr, first_vec, first_size);
|
||||
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += (AlignRoundUp(first_size * sizeof(T)) * 2);
|
||||
AllocateDVec(device_idx, ptr, args...);
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
}
|
||||
|
||||
public:
|
||||
BulkAllocator() = default;
|
||||
// prevent accidental copying, moving or assignment of this object
|
||||
BulkAllocator(const BulkAllocator<MemoryT>&) = delete;
|
||||
BulkAllocator(BulkAllocator<MemoryT>&&) = delete;
|
||||
void operator=(const BulkAllocator<MemoryT>&) = delete;
|
||||
void operator=(BulkAllocator<MemoryT>&&) = delete;
|
||||
BulkAllocator(const BulkAllocator&) = delete;
|
||||
BulkAllocator(BulkAllocator&&) = delete;
|
||||
void operator=(const BulkAllocator&) = delete;
|
||||
void operator=(BulkAllocator&&) = delete;
|
||||
|
||||
~BulkAllocator() {
|
||||
for (size_t i = 0; i < d_ptr_.size(); i++) {
|
||||
@ -497,9 +400,9 @@ class BulkAllocator {
|
||||
void Allocate(int device_idx, Args... args) {
|
||||
size_t size = GetSizeBytes(args...);
|
||||
|
||||
char *ptr = AllocateDevice(device_idx, size, MemoryT);
|
||||
char *ptr = AllocateDevice(device_idx, size);
|
||||
|
||||
AllocateDVec(device_idx, ptr, args...);
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
|
||||
d_ptr_.push_back(ptr);
|
||||
size_.push_back(size);
|
||||
@ -582,28 +485,6 @@ struct CubMemory {
|
||||
* Utility functions
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
void Print(const DVec<T> &v, size_t max_items = 10) {
|
||||
std::vector<T> h = v.as_vector();
|
||||
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper macro to measure timing on GPU
|
||||
* @param call the GPU call
|
||||
* @param name name used to track later
|
||||
* @param stream cuda stream where to measure time
|
||||
*/
|
||||
#define TIMEIT(call, name) \
|
||||
do { \
|
||||
dh::Timer t1234; \
|
||||
call; \
|
||||
t1234.printElapsed(name); \
|
||||
} while (0)
|
||||
|
||||
// Load balancing search
|
||||
|
||||
template <typename CoordinateT, typename SegmentT, typename OffsetT>
|
||||
@ -762,18 +643,18 @@ void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count,
|
||||
* @param offsets the segments
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
void SegmentedSort(dh::CubMemory *tmp_mem, dh::DVec2<T1> *keys,
|
||||
dh::DVec2<T2> *vals, int nVals, int nSegs,
|
||||
const dh::DVec<int> &offsets, int start = 0,
|
||||
void SegmentedSort(dh::CubMemory *tmp_mem, dh::DoubleBuffer<T1> *keys,
|
||||
dh::DoubleBuffer<T2> *vals, int nVals, int nSegs,
|
||||
xgboost::common::Span<int> offsets, int start = 0,
|
||||
int end = sizeof(T1) * 8) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.Data(),
|
||||
offsets.Data() + 1, start, end));
|
||||
NULL, tmpSize, keys->CubBuffer(), vals->CubBuffer(), nVals, nSegs,
|
||||
offsets.data(), offsets.data() + 1, start, end));
|
||||
tmp_mem->LazyAllocate(tmpSize);
|
||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals,
|
||||
nSegs, offsets.Data(), offsets.Data() + 1, start, end));
|
||||
tmp_mem->d_temp_storage, tmpSize, keys->CubBuffer(), vals->CubBuffer(),
|
||||
nVals, nSegs, offsets.data(), offsets.data() + 1, start, end));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -784,14 +665,14 @@ void SegmentedSort(dh::CubMemory *tmp_mem, dh::DVec2<T1> *keys,
|
||||
* @param nVals number of elements in the input array
|
||||
*/
|
||||
template <typename T>
|
||||
void SumReduction(dh::CubMemory &tmp_mem, dh::DVec<T> &in, dh::DVec<T> &out,
|
||||
void SumReduction(dh::CubMemory &tmp_mem, xgboost::common::Span<T> in, xgboost::common::Span<T> out,
|
||||
int nVals) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(
|
||||
cub::DeviceReduce::Sum(NULL, tmpSize, in.Data(), out.Data(), nVals));
|
||||
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
|
||||
tmp_mem.LazyAllocate(tmpSize);
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize,
|
||||
in.Data(), out.Data(), nVals));
|
||||
in.data(), out.data(), nVals));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -19,18 +19,18 @@ namespace linear {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
||||
|
||||
void RescaleIndices(size_t ridx_begin, dh::DVec<xgboost::Entry> *data) {
|
||||
auto d_data = data->GetSpan();
|
||||
dh::LaunchN(data->DeviceIdx(), data->Size(),
|
||||
[=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; });
|
||||
void RescaleIndices(int device_idx, size_t ridx_begin,
|
||||
common::Span<xgboost::Entry> data) {
|
||||
dh::LaunchN(device_idx, data.size(),
|
||||
[=] __device__(size_t idx) { data[idx].index -= ridx_begin; });
|
||||
}
|
||||
|
||||
class DeviceShard {
|
||||
int device_id_;
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
||||
dh::BulkAllocator ba_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
dh::DVec<xgboost::Entry> data_;
|
||||
dh::DVec<GradientPair> gpair_;
|
||||
common::Span<xgboost::Entry> data_;
|
||||
common::Span<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t ridx_begin_;
|
||||
size_t ridx_end_;
|
||||
@ -73,12 +73,12 @@ class DeviceShard {
|
||||
auto col = batch[fidx];
|
||||
auto seg = column_segments[fidx];
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
data_.GetSpan().subspan(row_ptr_[fidx]).data(),
|
||||
data_.subspan(row_ptr_[fidx]).data(),
|
||||
col.data() + seg.first,
|
||||
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
|
||||
}
|
||||
// Rescale indices with respect to current shard
|
||||
RescaleIndices(ridx_begin_, &data_);
|
||||
RescaleIndices(device_id_, ridx_begin_, data_);
|
||||
}
|
||||
|
||||
bool IsEmpty() {
|
||||
@ -87,8 +87,10 @@ class DeviceShard {
|
||||
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group,
|
||||
host_gpair.begin() + ridx_end_ * model_param.num_output_group);
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair_.data(),
|
||||
host_gpair.data() + ridx_begin_ * model_param.num_output_group,
|
||||
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||
@ -99,14 +101,14 @@ class DeviceShard {
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), size_t> skip(
|
||||
counting, f);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.tbegin(), skip);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
||||
|
||||
return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_);
|
||||
}
|
||||
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_.GetSpan();
|
||||
auto d_gpair = gpair_;
|
||||
dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
@ -115,9 +117,9 @@ class DeviceShard {
|
||||
|
||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
common::Span<xgboost::Entry> d_col = data_.GetSpan().subspan(row_ptr_[fidx]);
|
||||
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
common::Span<GradientPair> d_gpair = gpair_.GetSpan();
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
@ -131,8 +133,8 @@ class DeviceShard {
|
||||
}
|
||||
|
||||
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||
common::Span<GradientPair> d_gpair = gpair_.GetSpan();
|
||||
common::Span<Entry> d_col = data_.GetSpan().subspan(row_ptr_[fidx]);
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
|
||||
@ -545,21 +545,21 @@ class GPUMaker : public TreeUpdater {
|
||||
/** whether we have initialized memory already (so as not to repeat!) */
|
||||
bool allocated_;
|
||||
/** feature values stored in column-major compressed format */
|
||||
dh::DVec2<float> vals_;
|
||||
dh::DVec<float> vals_cached_;
|
||||
dh::DoubleBuffer<float> vals_;
|
||||
common::Span<float> vals_cached_;
|
||||
/** corresponding instance id's of these featutre values */
|
||||
dh::DVec2<int> instIds_;
|
||||
dh::DVec<int> inst_ids_cached_;
|
||||
dh::DoubleBuffer<int> instIds_;
|
||||
common::Span<int> inst_ids_cached_;
|
||||
/** column offsets for these feature values */
|
||||
dh::DVec<int> colOffsets_;
|
||||
dh::DVec<GradientPair> gradsInst_;
|
||||
dh::DVec2<NodeIdT> nodeAssigns_;
|
||||
dh::DVec2<int> nodeLocations_;
|
||||
dh::DVec<DeviceNodeStats> nodes_;
|
||||
dh::DVec<NodeIdT> node_assigns_per_inst_;
|
||||
dh::DVec<GradientPair> gradsums_;
|
||||
dh::DVec<GradientPair> gradscans_;
|
||||
dh::DVec<ExactSplitCandidate> nodeSplits_;
|
||||
common::Span<int> colOffsets_;
|
||||
common::Span<GradientPair> gradsInst_;
|
||||
dh::DoubleBuffer<NodeIdT> nodeAssigns_;
|
||||
dh::DoubleBuffer<int> nodeLocations_;
|
||||
common::Span<DeviceNodeStats> nodes_;
|
||||
common::Span<NodeIdT> node_assigns_per_inst_;
|
||||
common::Span<GradientPair> gradsums_;
|
||||
common::Span<GradientPair> gradscans_;
|
||||
common::Span<ExactSplitCandidate> nodeSplits_;
|
||||
int n_vals_;
|
||||
int n_rows_;
|
||||
int n_cols_;
|
||||
@ -571,10 +571,10 @@ class GPUMaker : public TreeUpdater {
|
||||
GPUSet devices_;
|
||||
|
||||
dh::CubMemory tmp_mem_;
|
||||
dh::DVec<GradientPair> tmpScanGradBuff_;
|
||||
dh::DVec<int> tmp_scan_key_buff_;
|
||||
dh::DVec<int> colIds_;
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
||||
common::Span<GradientPair> tmpScanGradBuff_;
|
||||
common::Span<int> tmp_scan_key_buff_;
|
||||
common::Span<int> colIds_;
|
||||
dh::BulkAllocator ba_;
|
||||
|
||||
public:
|
||||
GPUMaker() : allocated_{false} {}
|
||||
@ -615,8 +615,8 @@ class GPUMaker : public TreeUpdater {
|
||||
for (int i = 0; i < param_.max_depth; ++i) {
|
||||
if (i == 0) {
|
||||
// make sure to start on a fresh tree with sorted values!
|
||||
vals_.CurrentDVec() = vals_cached_;
|
||||
instIds_.CurrentDVec() = inst_ids_cached_;
|
||||
dh::CopyDeviceSpan(vals_.CurrentSpan(), vals_cached_);
|
||||
dh::CopyDeviceSpan(instIds_.CurrentSpan(), inst_ids_cached_);
|
||||
TransferGrads(gpair);
|
||||
}
|
||||
int nNodes = 1 << i;
|
||||
@ -630,13 +630,13 @@ class GPUMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
void Split2Node(int nNodes, NodeIdT nodeStart) {
|
||||
auto d_nodes = nodes_.GetSpan();
|
||||
auto d_gradScans = gradscans_.GetSpan();
|
||||
auto d_gradsums = gradsums_.GetSpan();
|
||||
auto d_nodes = nodes_;
|
||||
auto d_gradScans = gradscans_;
|
||||
auto d_gradsums = gradsums_;
|
||||
auto d_nodeAssigns = nodeAssigns_.CurrentSpan();
|
||||
auto d_colIds = colIds_.GetSpan();
|
||||
auto d_colIds = colIds_;
|
||||
auto d_vals = vals_.Current();
|
||||
auto d_nodeSplits = nodeSplits_.Data();
|
||||
auto d_nodeSplits = nodeSplits_.data();
|
||||
int nUniqKeys = nNodes;
|
||||
float min_split_loss = param_.min_split_loss;
|
||||
auto gpu_param = GPUTrainingParam(param_);
|
||||
@ -679,13 +679,13 @@ class GPUMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
void FindSplit(int level, NodeIdT nodeStart, int nNodes) {
|
||||
ReduceScanByKey(gradsums_.GetSpan(), gradscans_.GetSpan(), gradsInst_.GetSpan(),
|
||||
ReduceScanByKey(gradsums_, gradscans_, gradsInst_,
|
||||
instIds_.CurrentSpan(), nodeAssigns_.CurrentSpan(), n_vals_, nNodes,
|
||||
n_cols_, tmpScanGradBuff_.GetSpan(), tmp_scan_key_buff_.GetSpan(),
|
||||
colIds_.GetSpan(), nodeStart);
|
||||
ArgMaxByKey(nodeSplits_.GetSpan(), gradscans_.GetSpan(), gradsums_.GetSpan(),
|
||||
vals_.CurrentSpan(), colIds_.GetSpan(), nodeAssigns_.CurrentSpan(),
|
||||
nodes_.GetSpan(), nNodes, nodeStart, n_vals_, param_,
|
||||
n_cols_, tmpScanGradBuff_, tmp_scan_key_buff_,
|
||||
colIds_, nodeStart);
|
||||
ArgMaxByKey(nodeSplits_, gradscans_, gradsums_,
|
||||
vals_.CurrentSpan(), colIds_, nodeAssigns_.CurrentSpan(),
|
||||
nodes_, nNodes, nodeStart, n_vals_, param_,
|
||||
level <= kMaxAbkLevels ? kAbkSmem : kAbkGmem);
|
||||
Split2Node(nNodes, nodeStart);
|
||||
}
|
||||
@ -707,7 +707,7 @@ class GPUMaker : public TreeUpdater {
|
||||
}
|
||||
std::vector<float> fval;
|
||||
std::vector<int> fId;
|
||||
std::vector<size_t> offset;
|
||||
std::vector<int> offset;
|
||||
ConvertToCsc(dmat, &fval, &fId, &offset);
|
||||
AllocateAllData(static_cast<int>(offset.size()));
|
||||
TransferAndSortData(fval, fId, offset);
|
||||
@ -715,7 +715,7 @@ class GPUMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
void ConvertToCsc(DMatrix* dmat, std::vector<float>* fval,
|
||||
std::vector<int>* fId, std::vector<size_t>* offset) {
|
||||
std::vector<int>* fId, std::vector<int>* offset) {
|
||||
const MetaInfo& info = dmat->Info();
|
||||
CHECK(info.num_col_ < std::numeric_limits<int>::max());
|
||||
CHECK(info.num_row_ < std::numeric_limits<int>::max());
|
||||
@ -735,7 +735,7 @@ class GPUMaker : public TreeUpdater {
|
||||
fval->push_back(e.fvalue);
|
||||
fId->push_back(inst_id);
|
||||
}
|
||||
offset->push_back(fval->size());
|
||||
offset->push_back(static_cast<int>(fval->size()));
|
||||
}
|
||||
}
|
||||
CHECK(fval->size() < std::numeric_limits<int>::max());
|
||||
@ -744,19 +744,21 @@ class GPUMaker : public TreeUpdater {
|
||||
|
||||
void TransferAndSortData(const std::vector<float>& fval,
|
||||
const std::vector<int>& fId,
|
||||
const std::vector<size_t>& offset) {
|
||||
vals_.CurrentDVec() = fval;
|
||||
instIds_.CurrentDVec() = fId;
|
||||
colOffsets_ = offset;
|
||||
const std::vector<int>& offset) {
|
||||
dh::CopyVectorToDeviceSpan(vals_.CurrentSpan(), fval);
|
||||
dh::CopyVectorToDeviceSpan(instIds_.CurrentSpan(), fId);
|
||||
dh::CopyVectorToDeviceSpan(colOffsets_, offset);
|
||||
dh::SegmentedSort<float, int>(&tmp_mem_, &vals_, &instIds_, n_vals_, n_cols_,
|
||||
colOffsets_);
|
||||
vals_cached_ = vals_.CurrentDVec();
|
||||
inst_ids_cached_ = instIds_.CurrentDVec();
|
||||
AssignColIds<<<n_cols_, 512>>>(colIds_.Data(), colOffsets_.Data());
|
||||
dh::CopyDeviceSpan(vals_cached_, vals_.CurrentSpan());
|
||||
dh::CopyDeviceSpan(inst_ids_cached_, instIds_.CurrentSpan());
|
||||
AssignColIds<<<n_cols_, 512>>>(colIds_.data(), colOffsets_.data());
|
||||
}
|
||||
|
||||
void TransferGrads(HostDeviceVector<GradientPair>* gpair) {
|
||||
gpair->GatherTo(gradsInst_.tbegin(), gradsInst_.tend());
|
||||
gpair->GatherTo(
|
||||
thrust::device_pointer_cast(gradsInst_.data()),
|
||||
thrust::device_pointer_cast(gradsInst_.data() + gradsInst_.size()));
|
||||
// evaluate the full-grad reduction for the root node
|
||||
dh::SumReduction<GradientPair>(tmp_mem_, gradsInst_, gradsums_, n_rows_);
|
||||
}
|
||||
@ -764,14 +766,22 @@ class GPUMaker : public TreeUpdater {
|
||||
void InitNodeData(int level, NodeIdT nodeStart, int nNodes) {
|
||||
// all instances belong to root node at the beginning!
|
||||
if (level == 0) {
|
||||
nodes_.Fill(DeviceNodeStats());
|
||||
nodeAssigns_.CurrentDVec().Fill(0);
|
||||
node_assigns_per_inst_.Fill(0);
|
||||
thrust::fill(thrust::device_pointer_cast(nodes_.data()),
|
||||
thrust::device_pointer_cast(nodes_.data() + nodes_.size()),
|
||||
DeviceNodeStats());
|
||||
thrust::fill(thrust::device_pointer_cast(nodeAssigns_.Current()),
|
||||
thrust::device_pointer_cast(nodeAssigns_.Current() +
|
||||
nodeAssigns_.Size()),
|
||||
0);
|
||||
thrust::fill(thrust::device_pointer_cast(node_assigns_per_inst_.data()),
|
||||
thrust::device_pointer_cast(node_assigns_per_inst_.data() +
|
||||
node_assigns_per_inst_.size()),
|
||||
0);
|
||||
// for root node, just update the gradient/score/weight/id info
|
||||
// before splitting it! Currently all data is on GPU, hence this
|
||||
// stupid little kernel
|
||||
auto d_nodes = nodes_.Data();
|
||||
auto d_sums = gradsums_.Data();
|
||||
auto d_nodes = nodes_;
|
||||
auto d_sums = gradsums_;
|
||||
auto gpu_params = GPUTrainingParam(param_);
|
||||
dh::LaunchN(param_.gpu_id, 1, [=] __device__(int idx) {
|
||||
d_nodes[0] = DeviceNodeStats(d_sums[0], 0, gpu_params);
|
||||
@ -781,17 +791,17 @@ class GPUMaker : public TreeUpdater {
|
||||
const int ItemsPerThread = 4;
|
||||
// assign default node ids first
|
||||
int nBlks = dh::DivRoundUp(n_rows_, BlkDim);
|
||||
FillDefaultNodeIds<<<nBlks, BlkDim>>>(node_assigns_per_inst_.Data(),
|
||||
nodes_.Data(), n_rows_);
|
||||
FillDefaultNodeIds<<<nBlks, BlkDim>>>(node_assigns_per_inst_.data(),
|
||||
nodes_.data(), n_rows_);
|
||||
// evaluate the correct child indices of non-missing values next
|
||||
nBlks = dh::DivRoundUp(n_vals_, BlkDim * ItemsPerThread);
|
||||
AssignNodeIds<<<nBlks, BlkDim>>>(
|
||||
node_assigns_per_inst_.Data(), nodeLocations_.Current(),
|
||||
nodeAssigns_.Current(), instIds_.Current(), nodes_.Data(),
|
||||
colOffsets_.Data(), vals_.Current(), n_vals_, n_cols_);
|
||||
node_assigns_per_inst_.data(), nodeLocations_.Current(),
|
||||
nodeAssigns_.Current(), instIds_.Current(), nodes_.data(),
|
||||
colOffsets_.data(), vals_.Current(), n_vals_, n_cols_);
|
||||
// gather the node assignments across all other columns too
|
||||
dh::Gather(param_.gpu_id, nodeAssigns_.Current(),
|
||||
node_assigns_per_inst_.Data(), instIds_.Current(), n_vals_);
|
||||
node_assigns_per_inst_.data(), instIds_.Current(), n_vals_);
|
||||
SortKeys(level);
|
||||
}
|
||||
}
|
||||
@ -804,14 +814,14 @@ class GPUMaker : public TreeUpdater {
|
||||
dh::Gather<float, int>(param_.gpu_id, vals_.other(),
|
||||
vals_.Current(), instIds_.other(), instIds_.Current(),
|
||||
nodeLocations_.Current(), n_vals_);
|
||||
vals_.buff().selector ^= 1;
|
||||
instIds_.buff().selector ^= 1;
|
||||
vals_.buff.selector ^= 1;
|
||||
instIds_.buff.selector ^= 1;
|
||||
}
|
||||
|
||||
void MarkLeaves() {
|
||||
const int BlkDim = 128;
|
||||
int nBlks = dh::DivRoundUp(maxNodes_, BlkDim);
|
||||
MarkLeavesKernel<<<nBlks, BlkDim>>>(nodes_.Data(), maxNodes_);
|
||||
MarkLeavesKernel<<<nBlks, BlkDim>>>(nodes_.data(), maxNodes_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -254,10 +254,13 @@ XGBOOST_DEVICE inline bool IsLeftChild(int nidx) {
|
||||
|
||||
// Copy gpu dense representation of tree to xgboost sparse representation
|
||||
inline void Dense2SparseTree(RegTree* p_tree,
|
||||
const dh::DVec<DeviceNodeStats>& nodes,
|
||||
common::Span<DeviceNodeStats> nodes,
|
||||
const TrainParam& param) {
|
||||
RegTree& tree = *p_tree;
|
||||
std::vector<DeviceNodeStats> h_nodes = nodes.AsVector();
|
||||
std::vector<DeviceNodeStats> h_nodes(nodes.size());
|
||||
dh::safe_cuda(cudaMemcpy(h_nodes.data(), nodes.data(),
|
||||
nodes.size() * sizeof(DeviceNodeStats),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
int nid = 0;
|
||||
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
||||
@ -298,18 +301,16 @@ struct BernoulliRng {
|
||||
};
|
||||
|
||||
// Set gradient pair to 0 with p = 1 - subsample
|
||||
inline void SubsampleGradientPair(dh::DVec<GradientPair>* p_gpair, float subsample,
|
||||
int offset = 0) {
|
||||
inline void SubsampleGradientPair(int device_idx,
|
||||
common::Span<GradientPair> d_gpair,
|
||||
float subsample, int offset = 0) {
|
||||
if (subsample == 1.0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::DVec<GradientPair>& gpair = *p_gpair;
|
||||
|
||||
auto d_gpair = gpair.Data();
|
||||
BernoulliRng rng(subsample, common::GlobalRandom()());
|
||||
|
||||
dh::LaunchN(gpair.DeviceIdx(), gpair.Size(), [=] XGBOOST_DEVICE(int i) {
|
||||
dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) {
|
||||
if (!rng(i + offset)) {
|
||||
d_gpair[i] = GradientPair();
|
||||
}
|
||||
|
||||
@ -601,7 +601,7 @@ struct DeviceShard {
|
||||
int n_bins;
|
||||
int device_id;
|
||||
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
dh::BulkAllocator ba;
|
||||
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
|
||||
@ -610,27 +610,26 @@ struct DeviceShard {
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief row_ptr form HistCutMatrix. */
|
||||
dh::DVec<uint32_t> feature_segments;
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief minimum value for each feature. */
|
||||
dh::DVec<bst_float> min_fvalue;
|
||||
common::Span<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
dh::DVec<bst_float> gidx_fvalue_map;
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
common::Span<common::CompressedByteT> gidx_buffer;
|
||||
|
||||
/*! \brief Row indices relative to this shard, necessary for sorting rows. */
|
||||
dh::DVec2<bst_uint> ridx;
|
||||
dh::DoubleBuffer<bst_uint> ridx;
|
||||
dh::DoubleBuffer<int> position;
|
||||
/*! \brief Gradient pair for each row. */
|
||||
dh::DVec<GradientPair> gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::DVec2<int> position;
|
||||
|
||||
dh::DVec<int> monotone_constraints;
|
||||
dh::DVec<bst_float> prediction_cache;
|
||||
common::Span<int> monotone_constraints;
|
||||
common::Span<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
@ -718,7 +717,9 @@ struct DeviceShard {
|
||||
// Reset values for each update iteration
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
position.CurrentDVec().Fill(0);
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(position.Current()),
|
||||
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
if (left_counts.size() < 256) {
|
||||
@ -727,13 +728,16 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0,
|
||||
sizeof(int64_t) * left_counts.size()));
|
||||
}
|
||||
thrust::sequence(ridx.CurrentDVec().tbegin(), ridx.CurrentDVec().tend());
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.Size());
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_id),
|
||||
dh_gpair->tcend(device_id));
|
||||
SubsampleGradientPair(&gpair, param.subsample, row_begin_idx);
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair.data(), dh_gpair->ConstDevicePointer(device_id),
|
||||
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
|
||||
SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx);
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
@ -788,7 +792,7 @@ struct DeviceShard {
|
||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
|
||||
gpu_param, d_split_candidates, value_constraints[nidx],
|
||||
monotone_constraints.GetSpan());
|
||||
monotone_constraints);
|
||||
|
||||
// Reduce over features to find best feature
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
@ -943,8 +947,8 @@ struct DeviceShard {
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.Data(), out_preds_d,
|
||||
prediction_cache.Size() * sizeof(bst_float),
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data(), out_preds_d,
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
prediction_cache_initialised = true;
|
||||
@ -952,16 +956,16 @@ struct DeviceShard {
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(),
|
||||
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.Data();
|
||||
auto d_prediction_cache = prediction_cache.Data();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
@ -969,8 +973,8 @@ struct DeviceShard {
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
out_preds_d, prediction_cache.Data(),
|
||||
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
out_preds_d, prediction_cache.data(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
};
|
||||
|
||||
@ -981,7 +985,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
auto segment_begin = segment.begin;
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = shard->ridx.Current();
|
||||
auto d_gpair = shard->gpair.Data();
|
||||
auto d_gpair = shard->gpair.data();
|
||||
|
||||
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
|
||||
@ -1006,7 +1010,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
Segment segment = shard->ridx_segments[nidx];
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
bst_uint* d_ridx = shard->ridx.Current();
|
||||
GradientPair* d_gpair = shard->gpair.Data();
|
||||
GradientPair* d_gpair = shard->gpair.data();
|
||||
|
||||
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
auto d_matrix = shard->ellpack_matrix;
|
||||
@ -1043,10 +1047,11 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
&gidx_fvalue_map, hmat.cut.size(),
|
||||
&min_fvalue, hmat.min_val.size(),
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
gidx_fvalue_map = hmat.cut;
|
||||
min_fvalue = hmat.min_val;
|
||||
feature_segments = hmat.row_ptr;
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
|
||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut);
|
||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val);
|
||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr);
|
||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
@ -1063,14 +1068,16 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(gidx_buffer.data()),
|
||||
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
||||
|
||||
this->CreateHistIndices(row_batch, row_stride, null_gidx_value);
|
||||
|
||||
ellpack_matrix.Init(
|
||||
feature_segments.GetSpan(), min_fvalue.GetSpan(),
|
||||
gidx_fvalue_map.GetSpan(), row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols),
|
||||
feature_segments, min_fvalue,
|
||||
gidx_fvalue_map, row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
||||
is_dense, null_gidx_value);
|
||||
|
||||
// check if we can use shared memory for building histograms
|
||||
@ -1121,10 +1128,10 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols),
|
||||
gidx_buffer.Data(),
|
||||
gidx_buffer.data(),
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(),
|
||||
gidx_fvalue_map.Data(), feature_segments.Data(),
|
||||
gidx_fvalue_map.data(), feature_segments.data(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
@ -1355,7 +1362,7 @@ class GPUHistMakerSpecialised{
|
||||
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
tmp_sums[i] = dh::SumReduction(
|
||||
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
|
||||
shard->temp_memory, shard->gpair.data(), shard->gpair.size());
|
||||
});
|
||||
|
||||
GradientPair sum_gradient =
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using xgboost::common::Span;
|
||||
|
||||
struct Shard { int id; };
|
||||
|
||||
TEST(DeviceHelpers, Basic) {
|
||||
@ -71,3 +73,20 @@ TEST(sumReduce, Test) {
|
||||
auto sum = dh::SumReduction(temp, dh::Raw(data), data.size());
|
||||
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
||||
}
|
||||
|
||||
void TestAllocator() {
|
||||
int n = 10;
|
||||
Span<float> a;
|
||||
Span<int> b;
|
||||
Span<size_t> c;
|
||||
dh::BulkAllocator ba;
|
||||
ba.Allocate(0, &a, n, &b, n, &c, n);
|
||||
|
||||
// Should be no illegal memory accesses
|
||||
dh::LaunchN(0, n, [=] __device__(size_t idx) { c[idx] = a[idx] + b[idx]; });
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// Define the test in a function so we can use device lambda
|
||||
TEST(bulkAllocator, Test) { TestAllocator(); }
|
||||
|
||||
@ -56,8 +56,8 @@ TEST(GpuHist, BuildGidxDense) {
|
||||
DeviceShard<GradientPairPrecise> shard(0, 0, kNRows, param);
|
||||
BuildGidx(&shard, kNRows, kNCols);
|
||||
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||
h_gidx_buffer = shard.gidx_buffer.AsVector();
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols);
|
||||
@ -95,8 +95,8 @@ TEST(GpuHist, BuildGidxSparse) {
|
||||
DeviceShard<GradientPairPrecise> shard(0, 0, kNRows, param);
|
||||
BuildGidx(&shard, kNRows, kNCols, 0.9f);
|
||||
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||
h_gidx_buffer = shard.gidx_buffer.AsVector();
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, shard.gidx_buffer);
|
||||
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
|
||||
|
||||
ASSERT_LE(shard.ellpack_matrix.row_stride, 3);
|
||||
@ -149,17 +149,14 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
gpair = GradientPair(grad, hess);
|
||||
}
|
||||
|
||||
thrust::device_vector<GradientPair> gpair (kNRows);
|
||||
gpair = h_gpair;
|
||||
|
||||
int num_symbols = shard.n_bins + 1;
|
||||
|
||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (
|
||||
shard.gidx_buffer.Size());
|
||||
shard.gidx_buffer.size());
|
||||
|
||||
common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.Data();
|
||||
common::CompressedByteT* d_gidx_buffer_ptr = shard.gidx_buffer.data();
|
||||
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
|
||||
sizeof(common::CompressedByteT) * shard.gidx_buffer.Size(),
|
||||
sizeof(common::CompressedByteT) * shard.gidx_buffer.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
auto gidx = common::CompressedIterator<uint32_t>(h_gidx_buffer.data(),
|
||||
num_symbols);
|
||||
@ -167,9 +164,10 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
shard.ridx_segments.resize(1);
|
||||
shard.ridx_segments[0] = Segment(0, kNRows);
|
||||
shard.hist.AllocateHistogram(0);
|
||||
shard.gpair.copy(gpair.begin(), gpair.end());
|
||||
thrust::sequence(shard.ridx.CurrentDVec().tbegin(),
|
||||
shard.ridx.CurrentDVec().tend());
|
||||
dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair);
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(shard.ridx.Current()),
|
||||
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
||||
|
||||
builder.Build(&shard, 0);
|
||||
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
||||
@ -262,14 +260,14 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
&(shard->min_fvalue), cmat.min_val.size(),
|
||||
&(shard->gidx_fvalue_map), 24,
|
||||
&(shard->monotone_constraints), kNCols);
|
||||
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
|
||||
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
|
||||
shard->monotone_constraints.copy(param.monotone_constraints.begin(),
|
||||
param.monotone_constraints.end());
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
|
||||
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
|
||||
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
|
||||
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
|
||||
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
|
||||
param.monotone_constraints);
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
||||
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
||||
|
||||
// Initialize DeviceShard::hist
|
||||
shard->hist.Init(0, (max_bins - 1) * kNCols);
|
||||
@ -344,8 +342,9 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->ba.Allocate(0, &(shard->ridx), kNRows,
|
||||
&(shard->position), kNRows);
|
||||
shard->ellpack_matrix.row_stride = kNCols;
|
||||
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
|
||||
shard->ridx.CurrentDVec().tend());
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(shard->ridx.Current()),
|
||||
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
|
||||
// Initialize GPUHistMaker
|
||||
hist_maker.param_ = param;
|
||||
RegTree tree;
|
||||
@ -378,12 +377,12 @@ TEST(GpuHist, ApplySplit) {
|
||||
&(shard->feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->min_fvalue), cmat.min_val.size(),
|
||||
&(shard->gidx_fvalue_map), 24);
|
||||
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
|
||||
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
|
||||
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
|
||||
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
|
||||
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
||||
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
||||
shard->ellpack_matrix.is_dense = true;
|
||||
|
||||
common::CompressedBufferWriter wr(num_symbols);
|
||||
@ -394,10 +393,10 @@ TEST(GpuHist, ApplySplit) {
|
||||
std::vector<common::CompressedByteT> h_gidx_compressed (compressed_size_bytes);
|
||||
|
||||
wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end());
|
||||
shard->gidx_buffer.copy(h_gidx_compressed.begin(), h_gidx_compressed.end());
|
||||
dh::CopyVectorToDeviceSpan(shard->gidx_buffer, h_gidx_compressed);
|
||||
|
||||
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
|
||||
shard->gidx_buffer.Data(), num_symbols);
|
||||
shard->gidx_buffer.data(), num_symbols);
|
||||
|
||||
hist_maker.info_ = &info;
|
||||
hist_maker.ApplySplit(candidate_entry, &tree);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user