Purge device_helpers.cuh (#5534)

* Simplifications with caching_device_vector

* Purge device helpers
This commit is contained in:
Rory Mitchell 2020-04-15 21:51:56 +12:00 committed by GitHub
parent a2f54963b6
commit ca4e05660e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 182 additions and 733 deletions

View File

@ -85,19 +85,6 @@ inline int32_t CudaGetPointerDevice(void* ptr) {
return device;
}
inline void CudaCheckPointerDevice(void* ptr) {
auto ptr_device = CudaGetPointerDevice(ptr);
int cur_device = -1;
dh::safe_cuda(cudaGetDevice(&cur_device));
CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device
<< "current device: " << cur_device;
}
template <typename T>
const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
inline size_t AvailableMemory(int device_idx) {
size_t device_free = 0;
size_t device_total = 0;
@ -552,161 +539,6 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}
/**
* \brief Copies std::vector to device span.
*
* \tparam T Generic type parameter.
* \param dst Copy destination. Must be device memory.
* \param src Copy source.
*/
template <typename T>
void CopyVectorToDeviceSpan(xgboost::common::Span<T> dst ,const std::vector<T>&src)
{
CHECK_EQ(dst.size(), src.size());
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
cudaMemcpyHostToDevice));
}
/**
* \brief Device to device memory copy from src to dst. Spans must be the same size. Use subspan to
* copy from a smaller array to a larger array.
*
* \tparam T Generic type parameter.
* \param dst Copy destination. Must be device memory.
* \param src Copy source. Must be device memory.
*/
template <typename T>
void CopyDeviceSpan(xgboost::common::Span<T> dst,
xgboost::common::Span<T> src) {
CHECK_EQ(dst.size(), src.size());
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
cudaMemcpyDeviceToDevice));
}
/*! \brief Helper for allocating large block of memory. */
class BulkAllocator {
std::vector<char *> d_ptr_;
std::vector<size_t> size_;
int device_idx_{-1};
static const int kAlign = 256;
size_t AlignRoundUp(size_t n) const {
n = (n + kAlign - 1) / kAlign;
return n * kAlign;
}
template <typename T>
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size) {
return AlignRoundUp(first_size * sizeof(T));
}
template <typename T, typename... Args>
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size, Args... args) {
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
}
template <typename T>
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
size_t first_size) {
*first_vec = xgboost::common::Span<T>(reinterpret_cast<T *>(ptr), first_size);
}
template <typename T, typename... Args>
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
size_t first_size, Args... args) {
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
ptr += AlignRoundUp(first_size * sizeof(T));
AllocateSpan(device_idx, ptr, args...);
}
char *AllocateDevice(int device_idx, size_t bytes) {
safe_cuda(cudaSetDevice(device_idx));
XGBDeviceAllocator<char> allocator;
return allocator.allocate(bytes).get();
}
template <typename T>
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size) {
return 2 * AlignRoundUp(first_size * sizeof(T));
}
template <typename T, typename... Args>
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size, Args... args) {
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
}
template <typename T>
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
size_t first_size) {
auto ptr1 = reinterpret_cast<T *>(ptr);
auto ptr2 = ptr1 + first_size;
first_vec->a = xgboost::common::Span<T>(ptr1, first_size);
first_vec->b = xgboost::common::Span<T>(ptr2, first_size);
first_vec->buff.d_buffers[0] = ptr1;
first_vec->buff.d_buffers[1] = ptr2;
first_vec->buff.selector = 0;
}
template <typename T, typename... Args>
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
size_t first_size, Args... args) {
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
ptr += (AlignRoundUp(first_size * sizeof(T)) * 2);
AllocateSpan(device_idx, ptr, args...);
}
public:
BulkAllocator() = default;
// prevent accidental copying, moving or assignment of this object
BulkAllocator(const BulkAllocator&) = delete;
BulkAllocator(BulkAllocator&&) = delete;
void operator=(const BulkAllocator&) = delete;
void operator=(BulkAllocator&&) = delete;
/*!
* \brief Clear the bulk allocator.
*
* This frees the GPU memory managed by this allocator.
*/
void Clear() {
if (d_ptr_.empty()) return;
safe_cuda(cudaSetDevice(device_idx_));
size_t idx = 0;
std::for_each(d_ptr_.begin(), d_ptr_.end(), [&](char *dptr) {
XGBDeviceAllocator<char>().deallocate(thrust::device_ptr<char>(dptr), size_[idx++]);
});
d_ptr_.clear();
size_.clear();
}
~BulkAllocator() {
Clear();
}
// returns sum of bytes for all allocations
size_t Size() {
return std::accumulate(size_.begin(), size_.end(), static_cast<size_t>(0));
}
template <typename... Args>
void Allocate(int device_idx, Args... args) {
if (device_idx_ == -1) {
device_idx_ = device_idx;
}
else CHECK(device_idx_ == device_idx);
size_t size = GetSizeBytes(args...);
char *ptr = AllocateDevice(device_idx, size);
AllocateSpan(device_idx, ptr, args...);
d_ptr_.push_back(ptr);
size_.push_back(size);
}
};
// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};
@ -787,196 +619,6 @@ struct CubMemory {
* Utility functions
*/
// Load balancing search
template <typename CoordinateT, typename SegmentT, typename OffsetT>
void FindMergePartitions(int device_idx, CoordinateT *d_tile_coordinates,
size_t num_tiles, int tile_size, SegmentT segments,
OffsetT num_rows, OffsetT num_elements) {
dh::LaunchN(device_idx, num_tiles + 1, [=] __device__(int idx) {
OffsetT diagonal = idx * tile_size;
CoordinateT tile_coordinate;
cub::CountingInputIterator<OffsetT> nonzero_indices(0);
// Search the merge path
// Cast to signed integer as this function can have negatives
cub::MergePathSearch(static_cast<int64_t>(diagonal), segments + 1,
nonzero_indices, static_cast<int64_t>(num_rows),
static_cast<int64_t>(num_elements), tile_coordinate);
// Output starting offset
d_tile_coordinates[idx] = tile_coordinate;
});
}
template <int TILE_SIZE, int ITEMS_PER_THREAD, int BLOCK_THREADS,
typename OffsetT, typename CoordinateT, typename FunctionT,
typename SegmentIterT>
__global__ void LbsKernel(CoordinateT *d_coordinates,
SegmentIterT segment_end_offsets, FunctionT f,
OffsetT num_segments) {
int tile = blockIdx.x;
CoordinateT tile_start_coord = d_coordinates[tile];
CoordinateT tile_end_coord = d_coordinates[tile + 1];
int64_t tile_num_rows = tile_end_coord.x - tile_start_coord.x;
int64_t tile_num_elements = tile_end_coord.y - tile_start_coord.y;
cub::CountingInputIterator<OffsetT> tile_element_indices(tile_start_coord.y);
CoordinateT thread_start_coord;
using SegmentT = typename std::iterator_traits<SegmentIterT>::value_type;
__shared__ struct {
SegmentT tile_segment_end_offsets[TILE_SIZE + 1];
SegmentT output_segment[TILE_SIZE];
} temp_storage;
for (auto item : dh::BlockStrideRange(int(0), int(tile_num_rows + 1))) {
temp_storage.tile_segment_end_offsets[item] =
segment_end_offsets[min(static_cast<size_t>(tile_start_coord.x + item),
static_cast<size_t>(num_segments - 1))];
}
__syncthreads();
int64_t diag = threadIdx.x * ITEMS_PER_THREAD;
// Cast to signed integer as this function can have negatives
cub::MergePathSearch(diag, // Diagonal
temp_storage.tile_segment_end_offsets, // List A
tile_element_indices, // List B
tile_num_rows, tile_num_elements, thread_start_coord);
CoordinateT thread_current_coord = thread_start_coord;
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) {
if (tile_element_indices[thread_current_coord.y] <
temp_storage.tile_segment_end_offsets[thread_current_coord.x]) {
temp_storage.output_segment[thread_current_coord.y] =
thread_current_coord.x + tile_start_coord.x;
++thread_current_coord.y;
} else {
++thread_current_coord.x;
}
}
__syncthreads();
for (auto item : dh::BlockStrideRange(int(0), int(tile_num_elements))) {
f(tile_start_coord.y + item, temp_storage.output_segment[item]);
}
}
template <typename FunctionT, typename SegmentIterT, typename OffsetT>
void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
OffsetT count, SegmentIterT segments,
OffsetT num_segments, FunctionT f) {
using CoordinateT = typename cub::CubVector<OffsetT, 2>::Type;
dh::safe_cuda(cudaSetDevice(device_idx));
const int BLOCK_THREADS = 256;
const int ITEMS_PER_THREAD = 1;
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
auto num_tiles = xgboost::common::DivRoundUp(count + num_segments, BLOCK_THREADS);
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1));
CoordinateT *tmp_tile_coordinates =
reinterpret_cast<CoordinateT *>(temp_memory->d_temp_storage);
FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles,
BLOCK_THREADS, segments, num_segments, count);
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT>
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, // NOLINT
segments + 1, f, num_segments);
}
template <typename FunctionT, typename OffsetT>
void DenseTransformLbs(int device_idx, OffsetT count, OffsetT num_segments,
FunctionT f) {
CHECK(count % num_segments == 0) << "Data is not dense.";
LaunchN(device_idx, count, [=] __device__(OffsetT idx) {
OffsetT segment = idx / (count / num_segments);
f(idx, segment);
});
}
/**
* \fn template <typename FunctionT, typename SegmentIterT, typename OffsetT>
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count,
* SegmentIterT segments, OffsetT num_segments, bool is_dense, FunctionT f)
*
* \brief Load balancing search function. Reads a CSR type matrix description
* and allows a function to be executed on each element. Search 'modern GPU load
* balancing search' for more information.
*
* \author Rory
* \date 7/9/2017
*
* \tparam FunctionT Type of the function t.
* \tparam SegmentIterT Type of the segments iterator.
* \tparam OffsetT Type of the offset.
* \param device_idx Zero-based index of the device.
* \param [in,out] temp_memory Temporary memory allocator.
* \param count Number of elements.
* \param segments Device pointer to segments.
* \param num_segments Number of segments.
* \param is_dense True if this object is dense.
* \param f Lambda to be executed on matrix elements.
*/
template <typename FunctionT, typename SegmentIterT, typename OffsetT>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count,
SegmentIterT segments, OffsetT num_segments, bool is_dense,
FunctionT f) {
if (is_dense) {
DenseTransformLbs(device_idx, count, num_segments, f);
} else {
SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments,
f);
}
}
/**
* @brief Helper function to sort the pairs using cub's segmented RadixSortPairs
* @param tmp_mem cub temporary memory info
* @param keys keys double-buffer array
* @param vals the values double-buffer array
* @param nVals number of elements in the array
* @param nSegs number of segments
* @param offsets the segments
*/
template <typename T1, typename T2>
void SegmentedSort(dh::CubMemory *tmp_mem, dh::DoubleBuffer<T1> *keys,
dh::DoubleBuffer<T2> *vals, int nVals, int nSegs,
xgboost::common::Span<int> offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys->CubBuffer(), vals->CubBuffer(), nVals, nSegs,
offsets.data(), offsets.data() + 1, start, end));
tmp_mem->LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem->d_temp_storage, tmpSize, keys->CubBuffer(), vals->CubBuffer(),
nVals, nSegs, offsets.data(), offsets.data() + 1, start, end));
}
/**
* @brief Helper function to perform device-wide sum-reduction
* @param tmp_mem cub temporary memory info
* @param in the input array to be reduced
* @param out the output reduced value
* @param nVals number of elements in the input array
*/
template <typename T>
void SumReduction(dh::CubMemory* tmp_mem, xgboost::common::Span<T> in, xgboost::common::Span<T> out,
int nVals) {
size_t tmpSize;
dh::safe_cuda(
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
tmp_mem->LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem->d_temp_storage, tmpSize,
in.data(), out.data(), nVals));
}
/**
* @brief Helper function to perform device-wide sum-reduction, returns to the
* host
@ -1004,79 +646,6 @@ typename std::iterator_traits<T>::value_type SumReduction(
return sum;
}
/**
* @brief Fill a given constant value across all elements in the buffer
* @param out the buffer to be filled
* @param len number of elements i the buffer
* @param def default value to be filled
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void FillConst(int device_idx, T *out, int len, T def) {
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, len,
[=] __device__(int i) { out[i] = def; });
}
/**
* @brief gather elements
* @param out1 output gathered array for the first buffer
* @param in1 first input buffer
* @param out2 output gathered array for the second buffer
* @param in2 second input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
void Gather(int device_idx, T1 *out1, const T1 *in1, T2 *out2, const T2 *in2,
const int *instId, int nVals) {
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
out1[i] = v1;
out2[i] = v2;
});
}
/**
* @brief gather elements
* @param out output gathered array
* @param in input buffer
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});
}
class SaveCudaContext {
private:
int saved_device_;
public:
template <typename Functor>
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
// When compiled with CUDA but running on CPU only device,
// cudaGetDevice will fail.
try {
safe_cuda(cudaGetDevice(&saved_device_));
} catch (const dmlc::Error &except) {
saved_device_ = -1;
}
func();
}
~SaveCudaContext() {
if (saved_device_ != -1) {
safe_cuda(cudaSetDevice(saved_device_));
}
}
};
/**
* \class AllReducer
*
@ -1200,50 +769,12 @@ class AllReducer {
return id;
}
#endif
/** \brief Perform max all reduce operation on the host. This function first
* reduces over omp threads then over nodes using rabit (which is not thread
* safe) using the master thread. Uses naive reduce algorithm for local
* threads, don't expect this to scale.*/
void HostMaxAllReduce(std::vector<size_t> *p_data) {
#ifdef XGBOOST_USE_NCCL
auto &data = *p_data;
// Wait in case some other thread is accessing host_data_
#pragma omp barrier
// Reset shared buffer
#pragma omp single
{
host_data_.resize(data.size());
std::fill(host_data_.begin(), host_data_.end(), size_t(0));
}
// Threads update shared array
for (auto i = 0ull; i < data.size(); i++) {
#pragma omp critical
{ host_data_[i] = std::max(host_data_[i], data[i]); }
}
// Wait until all threads are finished
#pragma omp barrier
// One thread performs all reduce across distributed nodes
#pragma omp master
{
rabit::Allreduce<rabit::op::Max, size_t>(host_data_.data(),
host_data_.size());
}
#pragma omp barrier
// Threads can now read back all reduced values
for (auto i = 0ull; i < data.size(); i++) {
data[i] = host_data_[i];
}
#endif
}
};
template <typename T,
template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
device_vector<T>& vec,
VectorT &vec,
IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
@ -1467,6 +998,26 @@ class SegmentSorter {
}
};
// Atomic add function for gradients
template <typename OutputGradientT, typename InputGradientT>
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
const InputGradientT& gpair) {
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
atomicAdd(dst_ptr,
static_cast<typename OutputGradientT::ValueT>(gpair.GetGrad()));
atomicAdd(dst_ptr + 1,
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
template <typename FunctionT>
class LauncherItr {
public:
@ -1481,35 +1032,35 @@ public:
};
/**
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
* with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*
* \tparam FunctionT Type of the function t.
*/
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
* with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*
* \tparam FunctionT Type of the function t.
*/
template <typename FunctionT>
class DiscardLambdaItr {
public:
// Required iterator traits
using self_type = DiscardLambdaItr; // NOLINT
using difference_type = ptrdiff_t; // NOLINT
using value_type = void; // NOLINT
using pointer = value_type *; // NOLINT
using reference = LauncherItr<FunctionT>; // NOLINT
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type; // NOLINT
// Required iterator traits
using self_type = DiscardLambdaItr; // NOLINT
using difference_type = ptrdiff_t; // NOLINT
using value_type = void; // NOLINT
using pointer = value_type *; // NOLINT
using reference = LauncherItr<FunctionT>; // NOLINT
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type; // NOLINT
private:
difference_type offset_;
FunctionT f_;
public:
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
: offset_(offset), f_(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset_ + b, f_);
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
: offset_(offset), f_(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset_ + b, f_);
}
XGBOOST_DEVICE self_type operator++() {
offset_++;
@ -1533,24 +1084,4 @@ public:
}
};
// Atomic add function for gradients
template <typename OutputGradientT, typename InputGradientT>
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
const InputGradientT& gpair) {
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
atomicAdd(dst_ptr,
static_cast<typename OutputGradientT::ValueT>(gpair.GetGrad()));
atomicAdd(dst_ptr + 1,
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
} // namespace dh

View File

@ -86,14 +86,13 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
}
ba_.Allocate(learner_param_->gpu_id, &data_, row_ptr_.back(), &gpair_,
num_row_ * model_param.num_output_group);
data_.resize(row_ptr_.back());
gpair_.resize(num_row_ * model_param.num_output_group);
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
auto col = batch[fidx];
auto seg = column_segments[fidx];
dh::safe_cuda(cudaMemcpy(
data_.subspan(row_ptr_[fidx]).data(),
data_.data().get() + row_ptr_[fidx],
col.data() + seg.first,
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
}
@ -192,7 +191,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// This needs to be public because of the __device__ lambda.
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
if (dbias == 0.0f) return;
auto d_gpair = gpair_;
auto d_gpair = dh::ToSpan(gpair_);
dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) {
auto &g = d_gpair[idx * num_groups + group_idx];
g += GradientPair(g.GetHess() * dbias, 0);
@ -202,9 +201,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// This needs to be public because of the __device__ lambda.
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
common::Span<xgboost::Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
common::Span<GradientPair> d_gpair = gpair_;
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
auto counting = thrust::make_counting_iterator(0ull);
auto f = [=] __device__(size_t idx) {
auto entry = d_col[idx];
@ -219,8 +218,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// This needs to be public because of the __device__ lambda.
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
common::Span<GradientPair> d_gpair = gpair_;
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
common::Span<Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) {
auto entry = d_col[idx];
@ -236,7 +235,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
void UpdateGpair(const std::vector<GradientPair> &host_gpair) {
dh::safe_cuda(cudaMemcpyAsync(
gpair_.data(),
gpair_.data().get(),
host_gpair.data(),
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
}
@ -247,10 +246,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
std::unique_ptr<FeatureSelector> selector_;
common::Monitor monitor_;
dh::BulkAllocator ba_;
std::vector<size_t> row_ptr_;
common::Span<xgboost::Entry> data_;
common::Span<GradientPair> gpair_;
dh::device_vector<xgboost::Entry> data_;
dh::caching_device_vector<GradientPair> gpair_;
dh::CubMemory temp_;
size_t num_row_;
};

View File

@ -187,9 +187,10 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* pa
size_t n_rows,
const BatchParam& batch_param,
float subsample)
: original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows);
}
: original_page_(page),
batch_param_(batch_param),
subsample_(subsample),
sample_row_index_(n_rows) {}
GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
@ -207,12 +208,12 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
sample_row_index_.begin(),
sample_row_index_.begin(),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
@ -224,7 +225,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
@ -233,23 +234,23 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : page_(page), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows);
}
float subsample)
: page_(page),
subsample_(subsample),
threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f) {}
GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
PoissonSampling(dh::ToSpan(threshold_),
threshold_index,
RandomWeight(common::GlobalRandom()())));
return {n_rows, page_, gpair};
@ -259,24 +260,25 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows,
&sample_row_index_, n_rows);
}
float subsample)
: original_page_(page),
batch_param_(batch_param),
subsample_(subsample),
threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f),
sample_row_index_(n_rows) {}
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
PoissonSampling(dh::ToSpan(threshold_),
threshold_index,
RandomWeight(common::GlobalRandom()())));
@ -288,12 +290,12 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
sample_row_index_.begin(),
sample_row_index_.begin(),
ClearEmptyRows());
// Create a new ELLPACK page with empty rows.
@ -305,7 +307,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
@ -358,21 +360,21 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpai
return sample;
}
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(threshold),
size_t GradientBasedSampler::CalculateThresholdIndex(
common::Span<GradientPair> gpair, common::Span<float> threshold,
common::Span<float> grad_sum, size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold),
std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum));
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
dh::tbegin(grad_sum));
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::counting_iterator<size_t>(0),
dh::tbegin(grad_sum),
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
SampleRateDelta(threshold, gpair.size(), sample_rows));
thrust::device_ptr<float> min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
thrust::device_ptr<float> min =
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
}

View File

@ -73,13 +73,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_{};
common::Span<size_t> sample_row_index_;
dh::caching_device_vector<size_t> sample_row_index_;
};
/*! \brief Gradient-based sampling in in-memory mode.. */
@ -94,9 +93,8 @@ class GradientBasedSampling : public SamplingStrategy {
private:
EllpackPageImpl* page_;
float subsample_;
dh::BulkAllocator ba_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
dh::caching_device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_;
};
/*! \brief Gradient-based sampling in external memory mode.. */
@ -109,15 +107,14 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
dh::caching_device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_;
common::Span<size_t> sample_row_index_;
dh::caching_device_vector<size_t> sample_row_index_;
};
/*! \brief Draw a sample of rows from a DMatrix.

View File

@ -408,25 +408,22 @@ struct GPUHistMakerDevice {
EllpackPageImpl* page;
BatchParam batch_param;
dh::BulkAllocator ba;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist{};
/*! \brief Gradient pair for each row. */
common::Span<GradientPair> gpair;
common::Span<int> monotone_constraints;
common::Span<bst_float> prediction_cache;
dh::caching_device_vector<int> monotone_constraints;
dh::caching_device_vector<bst_float> prediction_cache;
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
std::vector<GradientPair> host_node_sum_gradients;
dh::caching_device_vector<GradientPair> node_sum_gradients;
bst_uint n_rows;
TrainParam param;
bool deterministic_histogram;
bool prediction_cache_initialised;
bool use_shared_memory_histograms {false};
GradientSumT histogram_rounding;
@ -460,7 +457,6 @@ struct GPUHistMakerDevice {
page(_page),
n_rows(_n_rows),
param(std::move(_param)),
prediction_cache_initialised(false),
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
deterministic_histogram{deterministic_histogram},
@ -513,7 +509,7 @@ struct GPUHistMakerDevice {
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
std::fill(host_node_sum_gradients.begin(), host_node_sum_gradients.end(),
GradientPair());
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
@ -541,44 +537,26 @@ struct GPUHistMakerDevice {
// Work out cub temporary memory requirement
GPUTrainingParam gpu_param(param);
DeviceSplitCandidateReduceOp op(gpu_param);
size_t temp_storage_bytes = 0;
DeviceSplitCandidate*dummy = nullptr;
cub::DeviceReduce::Reduce(
nullptr, temp_storage_bytes, dummy,
dummy, num_columns, op,
DeviceSplitCandidate());
// size in terms of DeviceSplitCandidate
size_t cub_memory_size =
std::ceil(static_cast<double>(temp_storage_bytes) /
sizeof(DeviceSplitCandidate));
// Allocate enough temporary memory
// Result for each nidx
// + intermediate result for each column
// + cub reduce memory
auto temp_span = temp_memory.GetSpan<DeviceSplitCandidate>(
nidxs.size() + nidxs.size() * num_columns +cub_memory_size*nidxs.size());
auto d_result_all = temp_span.subspan(0, nidxs.size());
auto d_split_candidates_all =
temp_span.subspan(d_result_all.size(), nidxs.size() * num_columns);
auto d_cub_memory_all =
temp_span.subspan(d_result_all.size() + d_split_candidates_all.size(),
cub_memory_size * nidxs.size());
dh::caching_device_vector<DeviceSplitCandidate> d_result_all(nidxs.size());
dh::caching_device_vector<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
auto& streams = this->GetStreams(nidxs.size());
for (auto i = 0ull; i < nidxs.size(); i++) {
auto nidx = nidxs[i];
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
p_feature_set->SetDevice(device_id);
common::Span<bst_feature_t> d_sampled_features = p_feature_set->DeviceSpan();
common::Span<bst_feature_t> d_sampled_features =
p_feature_set->DeviceSpan();
common::Span<bst_feature_t> d_feature_set =
interaction_constraints.Query(d_sampled_features, nidx);
auto d_split_candidates =
d_split_candidates_all.subspan(i * num_columns, d_feature_set.size());
common::Span<DeviceSplitCandidate> d_split_candidates(
split_candidates_all.data().get() + i * num_columns,
d_feature_set.size());
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
DeviceNodeStats node(host_node_sum_gradients[nidx], nidx, param);
auto d_result = d_result_all.subspan(i, 1);
common::Span<DeviceSplitCandidate> d_result(d_result_all.data().get() + i, 1);
if (d_feature_set.empty()) {
// Acting as a device side constructor for DeviceSplitCandidate.
// DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this
@ -596,19 +574,22 @@ struct GPUHistMakerDevice {
EvaluateSplitKernel<kBlockThreads, GradientSumT>,
hist.GetNodeHistogram(nidx), d_feature_set, node, page->GetDeviceAccessor(device_id),
gpu_param, d_split_candidates, node_value_constraints[nidx],
monotone_constraints);
dh::ToSpan(monotone_constraints));
// Reduce over features to find best feature
auto d_cub_memory =
d_cub_memory_all.subspan(i * cub_memory_size, cub_memory_size);
size_t cub_bytes = d_cub_memory.size() * sizeof(DeviceSplitCandidate);
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(d_cub_memory.data()),
size_t cub_bytes = 0;
cub::DeviceReduce::Reduce(nullptr,
cub_bytes, d_split_candidates.data(),
d_result.data(), d_split_candidates.size(), op,
DeviceSplitCandidate(), streams[i]);
dh::caching_device_vector<char> cub_temp(cub_bytes);
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
cub_bytes, d_split_candidates.data(),
d_result.data(), d_split_candidates.size(), op,
DeviceSplitCandidate(), streams[i]);
}
dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data(),
dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data().get(),
sizeof(DeviceSplitCandidate) * d_result_all.size(),
cudaMemcpyDeviceToHost));
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
@ -718,23 +699,23 @@ struct GPUHistMakerDevice {
void UpdatePredictionCache(bst_float* out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
if (!prediction_cache_initialised) {
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data(), out_preds_d,
auto d_ridx = row_partitioner->GetRows();
if (prediction_cache.size() != d_ridx.size()) {
prediction_cache.resize(d_ridx.size());
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d,
prediction_cache.size() * sizeof(bst_float),
cudaMemcpyDefault));
}
prediction_cache_initialised = true;
CalcWeightTrainParam param_d(param);
dh::safe_cuda(
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
sizeof(GradientPair) * node_sum_gradients.size(),
cudaMemcpyAsync(node_sum_gradients.data().get(), host_node_sum_gradients.data(),
sizeof(GradientPair) * host_node_sum_gradients.size(),
cudaMemcpyHostToDevice));
auto d_position = row_partitioner->GetPosition();
auto d_ridx = row_partitioner->GetRows();
auto d_node_sum_gradients = node_sum_gradients_d.data();
auto d_prediction_cache = prediction_cache.data();
auto d_node_sum_gradients = node_sum_gradients.data().get();
auto d_prediction_cache = prediction_cache.data().get();
dh::LaunchN(
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
@ -745,7 +726,7 @@ struct GPUHistMakerDevice {
});
dh::safe_cuda(cudaMemcpy(
out_preds_d, prediction_cache.data(),
out_preds_d, prediction_cache.data().get(),
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
row_partitioner.reset();
}
@ -822,9 +803,9 @@ struct GPUHistMakerDevice {
param, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
&node_value_constraints[tree[candidate.nid].LeftChild()],
&node_value_constraints[tree[candidate.nid].RightChild()]);
node_sum_gradients[tree[candidate.nid].LeftChild()] =
host_node_sum_gradients[tree[candidate.nid].LeftChild()] =
candidate.split.left_sum;
node_sum_gradients[tree[candidate.nid].RightChild()] =
host_node_sum_gradients[tree[candidate.nid].RightChild()] =
candidate.split.right_sum;
interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
@ -839,22 +820,22 @@ struct GPUHistMakerDevice {
thrust::cuda::par(alloc),
thrust::device_ptr<GradientPair const>(gpair.data()),
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()));
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.data(), &root_sum, sizeof(root_sum),
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data().get(), &root_sum, sizeof(root_sum),
cudaMemcpyHostToDevice));
reducer->AllReduceSum(
reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
reinterpret_cast<float*>(node_sum_gradients.data().get()),
reinterpret_cast<float*>(node_sum_gradients.data().get()), 2);
reducer->Synchronize();
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data(),
node_sum_gradients_d.data(), sizeof(GradientPair),
dh::safe_cuda(cudaMemcpyAsync(host_node_sum_gradients.data(),
node_sum_gradients.data().get(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer);
// Remember root stats
p_tree->Stat(kRootNIdx).sum_hess = node_sum_gradients[kRootNIdx].GetHess();
auto weight = CalcWeight(param, node_sum_gradients[kRootNIdx]);
p_tree->Stat(kRootNIdx).sum_hess = host_node_sum_gradients[kRootNIdx].GetHess();
auto weight = CalcWeight(param, host_node_sum_gradients[kRootNIdx]);
p_tree->Stat(kRootNIdx).base_weight = weight;
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
@ -927,15 +908,12 @@ struct GPUHistMakerDevice {
template <typename GradientSumT>
inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
bst_node_t max_nodes { param.MaxNodes() };
ba.Allocate(device_id,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&monotone_constraints, param.monotone_constraints.size());
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
node_sum_gradients.resize(max_nodes);
if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints;
}
host_node_sum_gradients.resize(param.MaxNodes());
node_sum_gradients.resize(param.MaxNodes());
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency

View File

@ -27,64 +27,13 @@ void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
}
}
void TestLbs() {
srand(17);
dh::CubMemory temp_memory;
std::vector<int> test_rows = {4, 100, 1000};
std::vector<int> test_max_row_sizes = {4, 100, 1300};
for (auto num_rows : test_rows) {
for (auto max_row_size : test_max_row_sizes) {
thrust::host_vector<int> h_row_ptr;
thrust::host_vector<xgboost::bst_uint> h_rows;
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
thrust::device_vector<size_t> row_ptr = h_row_ptr;
thrust::device_vector<int> output_row(h_rows.size());
auto d_output_row = output_row.data();
dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::Raw(row_ptr),
row_ptr.size() - 1, false,
[=] __device__(size_t idx, size_t ridx) {
d_output_row[idx] = ridx;
});
dh::safe_cuda(cudaDeviceSynchronize());
ASSERT_TRUE(h_rows == output_row);
}
}
}
TEST(CubLBS, Test) {
TestLbs();
}
TEST(SumReduce, Test) {
thrust::device_vector<float> data(100, 1.0f);
dh::CubMemory temp;
auto sum = dh::SumReduction(&temp, dh::Raw(data), data.size());
auto sum = dh::SumReduction(&temp, data.data().get(), data.size());
ASSERT_NEAR(sum, 100.0f, 1e-5);
}
void TestAllocator() {
int n = 10;
Span<float> a;
Span<int> b;
Span<size_t> c;
dh::BulkAllocator ba;
ba.Allocate(0, &a, n, &b, n, &c, n);
// Should be no illegal memory accesses
dh::LaunchN(0, n, [=] __device__(size_t idx) { c[idx] = a[idx] + b[idx]; });
dh::safe_cuda(cudaDeviceSynchronize());
}
// Define the test in a function so we can use device lambda
TEST(BulkAllocator, Test) {
TestAllocator();
}
template <typename T, typename Comp = thrust::less<T>>
void TestUpperBoundImpl(const std::vector<T> &vec, T val_to_find,
const Comp &comp = Comp()) {

View File

@ -23,7 +23,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
d_data.resize(kRows);
thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f);
auto p_d_data = dh::Raw(d_data);
auto p_d_data = d_data.data().get();
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
@ -49,7 +49,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
d_data[i] = i * 2.0;
}
auto p_d_data = dh::Raw(d_data);
auto p_d_data = d_data.data().get();
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),

View File

@ -25,7 +25,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String(typestr);
auto p_d_data = dh::Raw(d_data);
auto p_d_data = d_data.data().get();
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};

View File

@ -24,37 +24,33 @@ namespace tree {
TEST(GpuHist, DeviceHistogram) {
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
dh::SaveCudaContext{
[&]() {
dh::safe_cuda(cudaSetDevice(0));
constexpr size_t kNBins = 128;
constexpr size_t kNNodes = 4;
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram;
histogram.Init(0, kNBins);
for (size_t i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i);
}
histogram.Reset();
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
dh::safe_cuda(cudaSetDevice(0));
constexpr size_t kNBins = 128;
constexpr size_t kNNodes = 4;
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram;
histogram.Init(0, kNBins);
for (size_t i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i);
}
histogram.Reset();
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
// Use allocated memory but do not erase nidx_map.
for (size_t i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i);
}
for (size_t i = 0; i < kNNodes; ++i) {
ASSERT_TRUE(histogram.HistogramExists(i));
}
// Use allocated memory but do not erase nidx_map.
for (size_t i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i);
}
for (size_t i = 0; i < kNNodes; ++i) {
ASSERT_TRUE(histogram.HistogramExists(i));
}
// Erase existing nidx_map.
for (size_t i = kNNodes; i < kNNodes * 2; ++i) {
histogram.AllocateHistogram(i);
}
for (size_t i = 0; i < kNNodes; ++i) {
ASSERT_FALSE(histogram.HistogramExists(i));
}
}
};
// Erase existing nidx_map.
for (size_t i = kNNodes; i < kNNodes * 2; ++i) {
histogram.AllocateHistogram(i);
}
for (size_t i = 0; i < kNNodes; ++i) {
ASSERT_FALSE(histogram.HistogramExists(i));
}
}
std::vector<GradientPairPrecise> GetHostHistGpair() {
@ -187,16 +183,14 @@ TEST(GpuHist, EvaluateSplits) {
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {{6.4f, 12.8f}};
maker.host_node_sum_gradients = {{6.4f, 12.8f}};
// Initialize GPUHistMakerDevice::cut
auto cmat = GetHostCutMatrix();
// Copy cut matrix to device.
page->Cuts() = cmat;
maker.ba.Allocate(0, &(maker.monotone_constraints), kNCols);
dh::CopyVectorToDeviceSpan(maker.monotone_constraints,
param.monotone_constraints);
maker.monotone_constraints = param.monotone_constraints;
// Initialize GPUHistMakerDevice::hist
maker.hist.Init(0, (max_bins - 1) * kNCols);