Purge device_helpers.cuh (#5534)
* Simplifications with caching_device_vector * Purge device helpers
This commit is contained in:
parent
a2f54963b6
commit
ca4e05660e
@ -85,19 +85,6 @@ inline int32_t CudaGetPointerDevice(void* ptr) {
|
||||
return device;
|
||||
}
|
||||
|
||||
inline void CudaCheckPointerDevice(void* ptr) {
|
||||
auto ptr_device = CudaGetPointerDevice(ptr);
|
||||
int cur_device = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&cur_device));
|
||||
CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device
|
||||
<< "current device: " << cur_device;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
|
||||
return raw_pointer_cast(v.data());
|
||||
}
|
||||
|
||||
inline size_t AvailableMemory(int device_idx) {
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
@ -552,161 +539,6 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Copies std::vector to device span.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param dst Copy destination. Must be device memory.
|
||||
* \param src Copy source.
|
||||
*/
|
||||
template <typename T>
|
||||
void CopyVectorToDeviceSpan(xgboost::common::Span<T> dst ,const std::vector<T>&src)
|
||||
{
|
||||
CHECK_EQ(dst.size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Device to device memory copy from src to dst. Spans must be the same size. Use subspan to
|
||||
* copy from a smaller array to a larger array.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param dst Copy destination. Must be device memory.
|
||||
* \param src Copy source. Must be device memory.
|
||||
*/
|
||||
template <typename T>
|
||||
void CopyDeviceSpan(xgboost::common::Span<T> dst,
|
||||
xgboost::common::Span<T> src) {
|
||||
CHECK_EQ(dst.size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst.data(), src.data(), dst.size() * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
/*! \brief Helper for allocating large block of memory. */
|
||||
class BulkAllocator {
|
||||
std::vector<char *> d_ptr_;
|
||||
std::vector<size_t> size_;
|
||||
int device_idx_{-1};
|
||||
|
||||
static const int kAlign = 256;
|
||||
|
||||
size_t AlignRoundUp(size_t n) const {
|
||||
n = (n + kAlign - 1) / kAlign;
|
||||
return n * kAlign;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size) {
|
||||
return AlignRoundUp(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
size_t GetSizeBytes(xgboost::common::Span<T> *first_vec, size_t first_size, Args... args) {
|
||||
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
|
||||
size_t first_size) {
|
||||
*first_vec = xgboost::common::Span<T>(reinterpret_cast<T *>(ptr), first_size);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AllocateSpan(int device_idx, char *ptr, xgboost::common::Span<T> *first_vec,
|
||||
size_t first_size, Args... args) {
|
||||
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += AlignRoundUp(first_size * sizeof(T));
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
}
|
||||
|
||||
char *AllocateDevice(int device_idx, size_t bytes) {
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
XGBDeviceAllocator<char> allocator;
|
||||
return allocator.allocate(bytes).get();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size) {
|
||||
return 2 * AlignRoundUp(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
size_t GetSizeBytes(DoubleBuffer<T> *first_vec, size_t first_size, Args... args) {
|
||||
return GetSizeBytes<T>(first_vec, first_size) + GetSizeBytes(args...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
|
||||
size_t first_size) {
|
||||
auto ptr1 = reinterpret_cast<T *>(ptr);
|
||||
auto ptr2 = ptr1 + first_size;
|
||||
first_vec->a = xgboost::common::Span<T>(ptr1, first_size);
|
||||
first_vec->b = xgboost::common::Span<T>(ptr2, first_size);
|
||||
first_vec->buff.d_buffers[0] = ptr1;
|
||||
first_vec->buff.d_buffers[1] = ptr2;
|
||||
first_vec->buff.selector = 0;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AllocateSpan(int device_idx, char *ptr, DoubleBuffer<T> *first_vec,
|
||||
size_t first_size, Args... args) {
|
||||
AllocateSpan<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += (AlignRoundUp(first_size * sizeof(T)) * 2);
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
}
|
||||
|
||||
public:
|
||||
BulkAllocator() = default;
|
||||
// prevent accidental copying, moving or assignment of this object
|
||||
BulkAllocator(const BulkAllocator&) = delete;
|
||||
BulkAllocator(BulkAllocator&&) = delete;
|
||||
void operator=(const BulkAllocator&) = delete;
|
||||
void operator=(BulkAllocator&&) = delete;
|
||||
|
||||
/*!
|
||||
* \brief Clear the bulk allocator.
|
||||
*
|
||||
* This frees the GPU memory managed by this allocator.
|
||||
*/
|
||||
void Clear() {
|
||||
if (d_ptr_.empty()) return;
|
||||
|
||||
safe_cuda(cudaSetDevice(device_idx_));
|
||||
size_t idx = 0;
|
||||
std::for_each(d_ptr_.begin(), d_ptr_.end(), [&](char *dptr) {
|
||||
XGBDeviceAllocator<char>().deallocate(thrust::device_ptr<char>(dptr), size_[idx++]);
|
||||
});
|
||||
d_ptr_.clear();
|
||||
size_.clear();
|
||||
}
|
||||
|
||||
~BulkAllocator() {
|
||||
Clear();
|
||||
}
|
||||
|
||||
// returns sum of bytes for all allocations
|
||||
size_t Size() {
|
||||
return std::accumulate(size_.begin(), size_.end(), static_cast<size_t>(0));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Allocate(int device_idx, Args... args) {
|
||||
if (device_idx_ == -1) {
|
||||
device_idx_ = device_idx;
|
||||
}
|
||||
else CHECK(device_idx_ == device_idx);
|
||||
size_t size = GetSizeBytes(args...);
|
||||
|
||||
char *ptr = AllocateDevice(device_idx, size);
|
||||
|
||||
AllocateSpan(device_idx, ptr, args...);
|
||||
|
||||
d_ptr_.push_back(ptr);
|
||||
size_.push_back(size);
|
||||
}
|
||||
};
|
||||
|
||||
// Keep track of pinned memory allocation
|
||||
struct PinnedMemory {
|
||||
void *temp_storage{nullptr};
|
||||
@ -787,196 +619,6 @@ struct CubMemory {
|
||||
* Utility functions
|
||||
*/
|
||||
|
||||
// Load balancing search
|
||||
|
||||
template <typename CoordinateT, typename SegmentT, typename OffsetT>
|
||||
void FindMergePartitions(int device_idx, CoordinateT *d_tile_coordinates,
|
||||
size_t num_tiles, int tile_size, SegmentT segments,
|
||||
OffsetT num_rows, OffsetT num_elements) {
|
||||
dh::LaunchN(device_idx, num_tiles + 1, [=] __device__(int idx) {
|
||||
OffsetT diagonal = idx * tile_size;
|
||||
CoordinateT tile_coordinate;
|
||||
cub::CountingInputIterator<OffsetT> nonzero_indices(0);
|
||||
|
||||
// Search the merge path
|
||||
// Cast to signed integer as this function can have negatives
|
||||
cub::MergePathSearch(static_cast<int64_t>(diagonal), segments + 1,
|
||||
nonzero_indices, static_cast<int64_t>(num_rows),
|
||||
static_cast<int64_t>(num_elements), tile_coordinate);
|
||||
|
||||
// Output starting offset
|
||||
d_tile_coordinates[idx] = tile_coordinate;
|
||||
});
|
||||
}
|
||||
|
||||
template <int TILE_SIZE, int ITEMS_PER_THREAD, int BLOCK_THREADS,
|
||||
typename OffsetT, typename CoordinateT, typename FunctionT,
|
||||
typename SegmentIterT>
|
||||
__global__ void LbsKernel(CoordinateT *d_coordinates,
|
||||
SegmentIterT segment_end_offsets, FunctionT f,
|
||||
OffsetT num_segments) {
|
||||
int tile = blockIdx.x;
|
||||
CoordinateT tile_start_coord = d_coordinates[tile];
|
||||
CoordinateT tile_end_coord = d_coordinates[tile + 1];
|
||||
int64_t tile_num_rows = tile_end_coord.x - tile_start_coord.x;
|
||||
int64_t tile_num_elements = tile_end_coord.y - tile_start_coord.y;
|
||||
|
||||
cub::CountingInputIterator<OffsetT> tile_element_indices(tile_start_coord.y);
|
||||
CoordinateT thread_start_coord;
|
||||
|
||||
using SegmentT = typename std::iterator_traits<SegmentIterT>::value_type;
|
||||
__shared__ struct {
|
||||
SegmentT tile_segment_end_offsets[TILE_SIZE + 1];
|
||||
SegmentT output_segment[TILE_SIZE];
|
||||
} temp_storage;
|
||||
|
||||
for (auto item : dh::BlockStrideRange(int(0), int(tile_num_rows + 1))) {
|
||||
temp_storage.tile_segment_end_offsets[item] =
|
||||
segment_end_offsets[min(static_cast<size_t>(tile_start_coord.x + item),
|
||||
static_cast<size_t>(num_segments - 1))];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int64_t diag = threadIdx.x * ITEMS_PER_THREAD;
|
||||
|
||||
// Cast to signed integer as this function can have negatives
|
||||
cub::MergePathSearch(diag, // Diagonal
|
||||
temp_storage.tile_segment_end_offsets, // List A
|
||||
tile_element_indices, // List B
|
||||
tile_num_rows, tile_num_elements, thread_start_coord);
|
||||
|
||||
CoordinateT thread_current_coord = thread_start_coord;
|
||||
#pragma unroll
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) {
|
||||
if (tile_element_indices[thread_current_coord.y] <
|
||||
temp_storage.tile_segment_end_offsets[thread_current_coord.x]) {
|
||||
temp_storage.output_segment[thread_current_coord.y] =
|
||||
thread_current_coord.x + tile_start_coord.x;
|
||||
++thread_current_coord.y;
|
||||
} else {
|
||||
++thread_current_coord.x;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto item : dh::BlockStrideRange(int(0), int(tile_num_elements))) {
|
||||
f(tile_start_coord.y + item, temp_storage.output_segment[item]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctionT, typename SegmentIterT, typename OffsetT>
|
||||
void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
|
||||
OffsetT count, SegmentIterT segments,
|
||||
OffsetT num_segments, FunctionT f) {
|
||||
using CoordinateT = typename cub::CubVector<OffsetT, 2>::Type;
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
const int BLOCK_THREADS = 256;
|
||||
const int ITEMS_PER_THREAD = 1;
|
||||
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
auto num_tiles = xgboost::common::DivRoundUp(count + num_segments, BLOCK_THREADS);
|
||||
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
|
||||
|
||||
temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1));
|
||||
CoordinateT *tmp_tile_coordinates =
|
||||
reinterpret_cast<CoordinateT *>(temp_memory->d_temp_storage);
|
||||
|
||||
FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles,
|
||||
BLOCK_THREADS, segments, num_segments, count);
|
||||
|
||||
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT>
|
||||
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, // NOLINT
|
||||
segments + 1, f, num_segments);
|
||||
}
|
||||
|
||||
template <typename FunctionT, typename OffsetT>
|
||||
void DenseTransformLbs(int device_idx, OffsetT count, OffsetT num_segments,
|
||||
FunctionT f) {
|
||||
CHECK(count % num_segments == 0) << "Data is not dense.";
|
||||
|
||||
LaunchN(device_idx, count, [=] __device__(OffsetT idx) {
|
||||
OffsetT segment = idx / (count / num_segments);
|
||||
f(idx, segment);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn template <typename FunctionT, typename SegmentIterT, typename OffsetT>
|
||||
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count,
|
||||
* SegmentIterT segments, OffsetT num_segments, bool is_dense, FunctionT f)
|
||||
*
|
||||
* \brief Load balancing search function. Reads a CSR type matrix description
|
||||
* and allows a function to be executed on each element. Search 'modern GPU load
|
||||
* balancing search' for more information.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \tparam SegmentIterT Type of the segments iterator.
|
||||
* \tparam OffsetT Type of the offset.
|
||||
* \param device_idx Zero-based index of the device.
|
||||
* \param [in,out] temp_memory Temporary memory allocator.
|
||||
* \param count Number of elements.
|
||||
* \param segments Device pointer to segments.
|
||||
* \param num_segments Number of segments.
|
||||
* \param is_dense True if this object is dense.
|
||||
* \param f Lambda to be executed on matrix elements.
|
||||
*/
|
||||
|
||||
template <typename FunctionT, typename SegmentIterT, typename OffsetT>
|
||||
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, OffsetT count,
|
||||
SegmentIterT segments, OffsetT num_segments, bool is_dense,
|
||||
FunctionT f) {
|
||||
if (is_dense) {
|
||||
DenseTransformLbs(device_idx, count, num_segments, f);
|
||||
} else {
|
||||
SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments,
|
||||
f);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to sort the pairs using cub's segmented RadixSortPairs
|
||||
* @param tmp_mem cub temporary memory info
|
||||
* @param keys keys double-buffer array
|
||||
* @param vals the values double-buffer array
|
||||
* @param nVals number of elements in the array
|
||||
* @param nSegs number of segments
|
||||
* @param offsets the segments
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
void SegmentedSort(dh::CubMemory *tmp_mem, dh::DoubleBuffer<T1> *keys,
|
||||
dh::DoubleBuffer<T2> *vals, int nVals, int nSegs,
|
||||
xgboost::common::Span<int> offsets, int start = 0,
|
||||
int end = sizeof(T1) * 8) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
NULL, tmpSize, keys->CubBuffer(), vals->CubBuffer(), nVals, nSegs,
|
||||
offsets.data(), offsets.data() + 1, start, end));
|
||||
tmp_mem->LazyAllocate(tmpSize);
|
||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
tmp_mem->d_temp_storage, tmpSize, keys->CubBuffer(), vals->CubBuffer(),
|
||||
nVals, nSegs, offsets.data(), offsets.data() + 1, start, end));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to perform device-wide sum-reduction
|
||||
* @param tmp_mem cub temporary memory info
|
||||
* @param in the input array to be reduced
|
||||
* @param out the output reduced value
|
||||
* @param nVals number of elements in the input array
|
||||
*/
|
||||
template <typename T>
|
||||
void SumReduction(dh::CubMemory* tmp_mem, xgboost::common::Span<T> in, xgboost::common::Span<T> out,
|
||||
int nVals) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(
|
||||
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
|
||||
tmp_mem->LazyAllocate(tmpSize);
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem->d_temp_storage, tmpSize,
|
||||
in.data(), out.data(), nVals));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to perform device-wide sum-reduction, returns to the
|
||||
* host
|
||||
@ -1004,79 +646,6 @@ typename std::iterator_traits<T>::value_type SumReduction(
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fill a given constant value across all elements in the buffer
|
||||
* @param out the buffer to be filled
|
||||
* @param len number of elements i the buffer
|
||||
* @param def default value to be filled
|
||||
*/
|
||||
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
|
||||
void FillConst(int device_idx, T *out, int len, T def) {
|
||||
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, len,
|
||||
[=] __device__(int i) { out[i] = def; });
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief gather elements
|
||||
* @param out1 output gathered array for the first buffer
|
||||
* @param in1 first input buffer
|
||||
* @param out2 output gathered array for the second buffer
|
||||
* @param in2 second input buffer
|
||||
* @param instId gather indices
|
||||
* @param nVals length of the buffers
|
||||
*/
|
||||
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
|
||||
void Gather(int device_idx, T1 *out1, const T1 *in1, T2 *out2, const T2 *in2,
|
||||
const int *instId, int nVals) {
|
||||
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, nVals,
|
||||
[=] __device__(int i) {
|
||||
int iid = instId[i];
|
||||
T1 v1 = in1[iid];
|
||||
T2 v2 = in2[iid];
|
||||
out1[i] = v1;
|
||||
out2[i] = v2;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief gather elements
|
||||
* @param out output gathered array
|
||||
* @param in input buffer
|
||||
* @param instId gather indices
|
||||
* @param nVals length of the buffers
|
||||
*/
|
||||
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
|
||||
void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
|
||||
dh::LaunchN<ItemsPerThread, BlkDim>(device_idx, nVals,
|
||||
[=] __device__(int i) {
|
||||
int iid = instId[i];
|
||||
out[i] = in[iid];
|
||||
});
|
||||
}
|
||||
|
||||
class SaveCudaContext {
|
||||
private:
|
||||
int saved_device_;
|
||||
|
||||
public:
|
||||
template <typename Functor>
|
||||
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
|
||||
// When compiled with CUDA but running on CPU only device,
|
||||
// cudaGetDevice will fail.
|
||||
try {
|
||||
safe_cuda(cudaGetDevice(&saved_device_));
|
||||
} catch (const dmlc::Error &except) {
|
||||
saved_device_ = -1;
|
||||
}
|
||||
func();
|
||||
}
|
||||
~SaveCudaContext() {
|
||||
if (saved_device_ != -1) {
|
||||
safe_cuda(cudaSetDevice(saved_device_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \class AllReducer
|
||||
*
|
||||
@ -1200,50 +769,12 @@ class AllReducer {
|
||||
return id;
|
||||
}
|
||||
#endif
|
||||
/** \brief Perform max all reduce operation on the host. This function first
|
||||
* reduces over omp threads then over nodes using rabit (which is not thread
|
||||
* safe) using the master thread. Uses naive reduce algorithm for local
|
||||
* threads, don't expect this to scale.*/
|
||||
void HostMaxAllReduce(std::vector<size_t> *p_data) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
auto &data = *p_data;
|
||||
// Wait in case some other thread is accessing host_data_
|
||||
#pragma omp barrier
|
||||
// Reset shared buffer
|
||||
#pragma omp single
|
||||
{
|
||||
host_data_.resize(data.size());
|
||||
std::fill(host_data_.begin(), host_data_.end(), size_t(0));
|
||||
}
|
||||
// Threads update shared array
|
||||
for (auto i = 0ull; i < data.size(); i++) {
|
||||
#pragma omp critical
|
||||
{ host_data_[i] = std::max(host_data_[i], data[i]); }
|
||||
}
|
||||
// Wait until all threads are finished
|
||||
#pragma omp barrier
|
||||
|
||||
// One thread performs all reduce across distributed nodes
|
||||
#pragma omp master
|
||||
{
|
||||
rabit::Allreduce<rabit::op::Max, size_t>(host_data_.data(),
|
||||
host_data_.size());
|
||||
}
|
||||
|
||||
#pragma omp barrier
|
||||
|
||||
// Threads can now read back all reduced values
|
||||
for (auto i = 0ull; i < data.size(); i++) {
|
||||
data[i] = host_data_[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
template <typename VectorT, typename T = typename VectorT::value_type,
|
||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||
xgboost::common::Span<T> ToSpan(
|
||||
device_vector<T>& vec,
|
||||
VectorT &vec,
|
||||
IndexT offset = 0,
|
||||
IndexT size = std::numeric_limits<size_t>::max()) {
|
||||
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
|
||||
@ -1467,6 +998,26 @@ class SegmentSorter {
|
||||
}
|
||||
};
|
||||
|
||||
// Atomic add function for gradients
|
||||
template <typename OutputGradientT, typename InputGradientT>
|
||||
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
const InputGradientT& gpair) {
|
||||
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
|
||||
|
||||
atomicAdd(dst_ptr,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetGrad()));
|
||||
atomicAdd(dst_ptr + 1,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename FunctionT>
|
||||
class LauncherItr {
|
||||
public:
|
||||
@ -1481,35 +1032,35 @@ public:
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
|
||||
* with the index of the output and the algorithm output as arguments.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam FunctionT Type of the function t.
|
||||
*/
|
||||
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
|
||||
* with the index of the output and the algorithm output as arguments.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam FunctionT Type of the function t.
|
||||
*/
|
||||
template <typename FunctionT>
|
||||
class DiscardLambdaItr {
|
||||
public:
|
||||
// Required iterator traits
|
||||
using self_type = DiscardLambdaItr; // NOLINT
|
||||
using difference_type = ptrdiff_t; // NOLINT
|
||||
using value_type = void; // NOLINT
|
||||
using pointer = value_type *; // NOLINT
|
||||
using reference = LauncherItr<FunctionT>; // NOLINT
|
||||
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
|
||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||
reference>::type; // NOLINT
|
||||
// Required iterator traits
|
||||
using self_type = DiscardLambdaItr; // NOLINT
|
||||
using difference_type = ptrdiff_t; // NOLINT
|
||||
using value_type = void; // NOLINT
|
||||
using pointer = value_type *; // NOLINT
|
||||
using reference = LauncherItr<FunctionT>; // NOLINT
|
||||
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
|
||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||
reference>::type; // NOLINT
|
||||
private:
|
||||
difference_type offset_;
|
||||
FunctionT f_;
|
||||
public:
|
||||
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
|
||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
|
||||
: offset_(offset), f_(f) {}
|
||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||
return DiscardLambdaItr(offset_ + b, f_);
|
||||
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
|
||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
|
||||
: offset_(offset), f_(f) {}
|
||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||
return DiscardLambdaItr(offset_ + b, f_);
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++() {
|
||||
offset_++;
|
||||
@ -1533,24 +1084,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Atomic add function for gradients
|
||||
template <typename OutputGradientT, typename InputGradientT>
|
||||
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
const InputGradientT& gpair) {
|
||||
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
|
||||
|
||||
atomicAdd(dst_ptr,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetGrad()));
|
||||
atomicAdd(dst_ptr + 1,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@ -86,14 +86,13 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
|
||||
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
|
||||
}
|
||||
ba_.Allocate(learner_param_->gpu_id, &data_, row_ptr_.back(), &gpair_,
|
||||
num_row_ * model_param.num_output_group);
|
||||
|
||||
data_.resize(row_ptr_.back());
|
||||
gpair_.resize(num_row_ * model_param.num_output_group);
|
||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
auto col = batch[fidx];
|
||||
auto seg = column_segments[fidx];
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
data_.subspan(row_ptr_[fidx]).data(),
|
||||
data_.data().get() + row_ptr_[fidx],
|
||||
col.data() + seg.first,
|
||||
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
|
||||
}
|
||||
@ -192,7 +191,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_;
|
||||
auto d_gpair = dh::ToSpan(gpair_);
|
||||
dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
@ -202,9 +201,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
||||
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
common::Span<xgboost::Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
@ -219,8 +218,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
|
||||
common::Span<Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
@ -236,7 +235,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair_.data(),
|
||||
gpair_.data().get(),
|
||||
host_gpair.data(),
|
||||
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
|
||||
}
|
||||
@ -247,10 +246,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
std::unique_ptr<FeatureSelector> selector_;
|
||||
common::Monitor monitor_;
|
||||
|
||||
dh::BulkAllocator ba_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
common::Span<xgboost::Entry> data_;
|
||||
common::Span<GradientPair> gpair_;
|
||||
dh::device_vector<xgboost::Entry> data_;
|
||||
dh::caching_device_vector<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t num_row_;
|
||||
};
|
||||
|
||||
@ -187,9 +187,10 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* pa
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
float subsample)
|
||||
: original_page_(page), batch_param_(batch_param), subsample_(subsample) {
|
||||
ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows);
|
||||
}
|
||||
: original_page_(page),
|
||||
batch_param_(batch_param),
|
||||
subsample_(subsample),
|
||||
sample_row_index_(n_rows) {}
|
||||
|
||||
GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) {
|
||||
@ -207,12 +208,12 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
|
||||
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||
|
||||
// Index the sample rows.
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
|
||||
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
|
||||
dh::tbegin(sample_row_index_));
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
|
||||
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
|
||||
sample_row_index_.begin());
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
||||
dh::tbegin(sample_row_index_),
|
||||
dh::tbegin(sample_row_index_),
|
||||
sample_row_index_.begin(),
|
||||
sample_row_index_.begin(),
|
||||
ClearEmptyRows());
|
||||
|
||||
// Create a new ELLPACK page with empty rows.
|
||||
@ -224,7 +225,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
|
||||
// Compact the ELLPACK pages into the single sample page.
|
||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
|
||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
|
||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
}
|
||||
|
||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||
@ -233,23 +234,23 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
|
||||
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
float subsample) : page_(page), subsample_(subsample) {
|
||||
ba_.Allocate(batch_param.gpu_id,
|
||||
&threshold_, n_rows + 1,
|
||||
&grad_sum_, n_rows);
|
||||
}
|
||||
float subsample)
|
||||
: page_(page),
|
||||
subsample_(subsample),
|
||||
threshold_(n_rows + 1, 0.0f),
|
||||
grad_sum_(n_rows, 0.0f) {}
|
||||
|
||||
GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) {
|
||||
size_t n_rows = dmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
gpair, threshold_, grad_sum_, n_rows * subsample_);
|
||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
// Perform Poisson sampling in place.
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
||||
thrust::counting_iterator<size_t>(0),
|
||||
dh::tbegin(gpair),
|
||||
PoissonSampling(threshold_,
|
||||
PoissonSampling(dh::ToSpan(threshold_),
|
||||
threshold_index,
|
||||
RandomWeight(common::GlobalRandom()())));
|
||||
return {n_rows, page_, gpair};
|
||||
@ -259,24 +260,25 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
||||
EllpackPageImpl* page,
|
||||
size_t n_rows,
|
||||
const BatchParam& batch_param,
|
||||
float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) {
|
||||
ba_.Allocate(batch_param.gpu_id,
|
||||
&threshold_, n_rows + 1,
|
||||
&grad_sum_, n_rows,
|
||||
&sample_row_index_, n_rows);
|
||||
}
|
||||
float subsample)
|
||||
: original_page_(page),
|
||||
batch_param_(batch_param),
|
||||
subsample_(subsample),
|
||||
threshold_(n_rows + 1, 0.0f),
|
||||
grad_sum_(n_rows, 0.0f),
|
||||
sample_row_index_(n_rows) {}
|
||||
|
||||
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) {
|
||||
size_t n_rows = dmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
gpair, threshold_, grad_sum_, n_rows * subsample_);
|
||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
// Perform Poisson sampling in place.
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
||||
thrust::counting_iterator<size_t>(0),
|
||||
dh::tbegin(gpair),
|
||||
PoissonSampling(threshold_,
|
||||
PoissonSampling(dh::ToSpan(threshold_),
|
||||
threshold_index,
|
||||
RandomWeight(common::GlobalRandom()())));
|
||||
|
||||
@ -288,12 +290,12 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
|
||||
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||
|
||||
// Index the sample rows.
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
|
||||
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
|
||||
dh::tbegin(sample_row_index_));
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
|
||||
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
|
||||
sample_row_index_.begin());
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
||||
dh::tbegin(sample_row_index_),
|
||||
dh::tbegin(sample_row_index_),
|
||||
sample_row_index_.begin(),
|
||||
sample_row_index_.begin(),
|
||||
ClearEmptyRows());
|
||||
|
||||
// Create a new ELLPACK page with empty rows.
|
||||
@ -305,7 +307,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
|
||||
// Compact the ELLPACK pages into the single sample page.
|
||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
|
||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
|
||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||
}
|
||||
|
||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||
@ -358,21 +360,21 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpai
|
||||
return sample;
|
||||
}
|
||||
|
||||
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
||||
common::Span<float> threshold,
|
||||
common::Span<float> grad_sum,
|
||||
size_t sample_rows) {
|
||||
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
||||
dh::tbegin(threshold),
|
||||
size_t GradientBasedSampler::CalculateThresholdIndex(
|
||||
common::Span<GradientPair> gpair, common::Span<float> threshold,
|
||||
common::Span<float> grad_sum, size_t sample_rows) {
|
||||
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold),
|
||||
std::numeric_limits<float>::max());
|
||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
||||
CombineGradientPair());
|
||||
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
|
||||
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum));
|
||||
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
|
||||
dh::tbegin(grad_sum));
|
||||
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
|
||||
thrust::counting_iterator<size_t>(0),
|
||||
dh::tbegin(grad_sum),
|
||||
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
|
||||
SampleRateDelta(threshold, gpair.size(), sample_rows));
|
||||
thrust::device_ptr<float> min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||
thrust::device_ptr<float> min =
|
||||
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
||||
}
|
||||
|
||||
|
||||
@ -73,13 +73,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
|
||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||
|
||||
private:
|
||||
dh::BulkAllocator ba_;
|
||||
EllpackPageImpl* original_page_;
|
||||
BatchParam batch_param_;
|
||||
float subsample_;
|
||||
std::unique_ptr<EllpackPageImpl> page_;
|
||||
dh::device_vector<GradientPair> gpair_{};
|
||||
common::Span<size_t> sample_row_index_;
|
||||
dh::caching_device_vector<size_t> sample_row_index_;
|
||||
};
|
||||
|
||||
/*! \brief Gradient-based sampling in in-memory mode.. */
|
||||
@ -94,9 +93,8 @@ class GradientBasedSampling : public SamplingStrategy {
|
||||
private:
|
||||
EllpackPageImpl* page_;
|
||||
float subsample_;
|
||||
dh::BulkAllocator ba_;
|
||||
common::Span<float> threshold_;
|
||||
common::Span<float> grad_sum_;
|
||||
dh::caching_device_vector<float> threshold_;
|
||||
dh::caching_device_vector<float> grad_sum_;
|
||||
};
|
||||
|
||||
/*! \brief Gradient-based sampling in external memory mode.. */
|
||||
@ -109,15 +107,14 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||
|
||||
private:
|
||||
dh::BulkAllocator ba_;
|
||||
EllpackPageImpl* original_page_;
|
||||
BatchParam batch_param_;
|
||||
float subsample_;
|
||||
common::Span<float> threshold_;
|
||||
common::Span<float> grad_sum_;
|
||||
dh::caching_device_vector<float> threshold_;
|
||||
dh::caching_device_vector<float> grad_sum_;
|
||||
std::unique_ptr<EllpackPageImpl> page_;
|
||||
dh::device_vector<GradientPair> gpair_;
|
||||
common::Span<size_t> sample_row_index_;
|
||||
dh::caching_device_vector<size_t> sample_row_index_;
|
||||
};
|
||||
|
||||
/*! \brief Draw a sample of rows from a DMatrix.
|
||||
|
||||
@ -408,25 +408,22 @@ struct GPUHistMakerDevice {
|
||||
EllpackPageImpl* page;
|
||||
BatchParam batch_param;
|
||||
|
||||
dh::BulkAllocator ba;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist{};
|
||||
|
||||
/*! \brief Gradient pair for each row. */
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
common::Span<int> monotone_constraints;
|
||||
common::Span<bst_float> prediction_cache;
|
||||
dh::caching_device_vector<int> monotone_constraints;
|
||||
dh::caching_device_vector<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
std::vector<GradientPair> host_node_sum_gradients;
|
||||
dh::caching_device_vector<GradientPair> node_sum_gradients;
|
||||
bst_uint n_rows;
|
||||
|
||||
TrainParam param;
|
||||
bool deterministic_histogram;
|
||||
bool prediction_cache_initialised;
|
||||
bool use_shared_memory_histograms {false};
|
||||
|
||||
GradientSumT histogram_rounding;
|
||||
@ -460,7 +457,6 @@ struct GPUHistMakerDevice {
|
||||
page(_page),
|
||||
n_rows(_n_rows),
|
||||
param(std::move(_param)),
|
||||
prediction_cache_initialised(false),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
deterministic_histogram{deterministic_histogram},
|
||||
@ -513,7 +509,7 @@ struct GPUHistMakerDevice {
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
std::fill(host_node_sum_gradients.begin(), host_node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
|
||||
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
|
||||
@ -541,44 +537,26 @@ struct GPUHistMakerDevice {
|
||||
// Work out cub temporary memory requirement
|
||||
GPUTrainingParam gpu_param(param);
|
||||
DeviceSplitCandidateReduceOp op(gpu_param);
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceSplitCandidate*dummy = nullptr;
|
||||
cub::DeviceReduce::Reduce(
|
||||
nullptr, temp_storage_bytes, dummy,
|
||||
dummy, num_columns, op,
|
||||
DeviceSplitCandidate());
|
||||
// size in terms of DeviceSplitCandidate
|
||||
size_t cub_memory_size =
|
||||
std::ceil(static_cast<double>(temp_storage_bytes) /
|
||||
sizeof(DeviceSplitCandidate));
|
||||
|
||||
// Allocate enough temporary memory
|
||||
// Result for each nidx
|
||||
// + intermediate result for each column
|
||||
// + cub reduce memory
|
||||
auto temp_span = temp_memory.GetSpan<DeviceSplitCandidate>(
|
||||
nidxs.size() + nidxs.size() * num_columns +cub_memory_size*nidxs.size());
|
||||
auto d_result_all = temp_span.subspan(0, nidxs.size());
|
||||
auto d_split_candidates_all =
|
||||
temp_span.subspan(d_result_all.size(), nidxs.size() * num_columns);
|
||||
auto d_cub_memory_all =
|
||||
temp_span.subspan(d_result_all.size() + d_split_candidates_all.size(),
|
||||
cub_memory_size * nidxs.size());
|
||||
dh::caching_device_vector<DeviceSplitCandidate> d_result_all(nidxs.size());
|
||||
dh::caching_device_vector<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
|
||||
|
||||
auto& streams = this->GetStreams(nidxs.size());
|
||||
for (auto i = 0ull; i < nidxs.size(); i++) {
|
||||
auto nidx = nidxs[i];
|
||||
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
|
||||
p_feature_set->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> d_sampled_features = p_feature_set->DeviceSpan();
|
||||
common::Span<bst_feature_t> d_sampled_features =
|
||||
p_feature_set->DeviceSpan();
|
||||
common::Span<bst_feature_t> d_feature_set =
|
||||
interaction_constraints.Query(d_sampled_features, nidx);
|
||||
auto d_split_candidates =
|
||||
d_split_candidates_all.subspan(i * num_columns, d_feature_set.size());
|
||||
common::Span<DeviceSplitCandidate> d_split_candidates(
|
||||
split_candidates_all.data().get() + i * num_columns,
|
||||
d_feature_set.size());
|
||||
|
||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||
DeviceNodeStats node(host_node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
common::Span<DeviceSplitCandidate> d_result(d_result_all.data().get() + i, 1);
|
||||
if (d_feature_set.empty()) {
|
||||
// Acting as a device side constructor for DeviceSplitCandidate.
|
||||
// DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this
|
||||
@ -596,19 +574,22 @@ struct GPUHistMakerDevice {
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>,
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, page->GetDeviceAccessor(device_id),
|
||||
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
||||
monotone_constraints);
|
||||
dh::ToSpan(monotone_constraints));
|
||||
|
||||
// Reduce over features to find best feature
|
||||
auto d_cub_memory =
|
||||
d_cub_memory_all.subspan(i * cub_memory_size, cub_memory_size);
|
||||
size_t cub_bytes = d_cub_memory.size() * sizeof(DeviceSplitCandidate);
|
||||
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(d_cub_memory.data()),
|
||||
size_t cub_bytes = 0;
|
||||
cub::DeviceReduce::Reduce(nullptr,
|
||||
cub_bytes, d_split_candidates.data(),
|
||||
d_result.data(), d_split_candidates.size(), op,
|
||||
DeviceSplitCandidate(), streams[i]);
|
||||
dh::caching_device_vector<char> cub_temp(cub_bytes);
|
||||
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
|
||||
cub_bytes, d_split_candidates.data(),
|
||||
d_result.data(), d_split_candidates.size(), op,
|
||||
DeviceSplitCandidate(), streams[i]);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data(),
|
||||
dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data().get(),
|
||||
sizeof(DeviceSplitCandidate) * d_result_all.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
|
||||
@ -718,23 +699,23 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data(), out_preds_d,
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
if (prediction_cache.size() != d_ridx.size()) {
|
||||
prediction_cache.resize(d_ridx.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d,
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
prediction_cache_initialised = true;
|
||||
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyAsync(node_sum_gradients.data().get(), host_node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * host_node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
auto d_node_sum_gradients = node_sum_gradients.data().get();
|
||||
auto d_prediction_cache = prediction_cache.data().get();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
@ -745,7 +726,7 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
out_preds_d, prediction_cache.data(),
|
||||
out_preds_d, prediction_cache.data().get(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
row_partitioner.reset();
|
||||
}
|
||||
@ -822,9 +803,9 @@ struct GPUHistMakerDevice {
|
||||
param, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
|
||||
&node_value_constraints[tree[candidate.nid].LeftChild()],
|
||||
&node_value_constraints[tree[candidate.nid].RightChild()]);
|
||||
node_sum_gradients[tree[candidate.nid].LeftChild()] =
|
||||
host_node_sum_gradients[tree[candidate.nid].LeftChild()] =
|
||||
candidate.split.left_sum;
|
||||
node_sum_gradients[tree[candidate.nid].RightChild()] =
|
||||
host_node_sum_gradients[tree[candidate.nid].RightChild()] =
|
||||
candidate.split.right_sum;
|
||||
|
||||
interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
|
||||
@ -839,22 +820,22 @@ struct GPUHistMakerDevice {
|
||||
thrust::cuda::par(alloc),
|
||||
thrust::device_ptr<GradientPair const>(gpair.data()),
|
||||
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()));
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.data(), &root_sum, sizeof(root_sum),
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data().get(), &root_sum, sizeof(root_sum),
|
||||
cudaMemcpyHostToDevice));
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()),
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
|
||||
reinterpret_cast<float*>(node_sum_gradients.data().get()),
|
||||
reinterpret_cast<float*>(node_sum_gradients.data().get()), 2);
|
||||
reducer->Synchronize();
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data(),
|
||||
node_sum_gradients_d.data(), sizeof(GradientPair),
|
||||
dh::safe_cuda(cudaMemcpyAsync(host_node_sum_gradients.data(),
|
||||
node_sum_gradients.data().get(), sizeof(GradientPair),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, reducer);
|
||||
|
||||
// Remember root stats
|
||||
p_tree->Stat(kRootNIdx).sum_hess = node_sum_gradients[kRootNIdx].GetHess();
|
||||
auto weight = CalcWeight(param, node_sum_gradients[kRootNIdx]);
|
||||
p_tree->Stat(kRootNIdx).sum_hess = host_node_sum_gradients[kRootNIdx].GetHess();
|
||||
auto weight = CalcWeight(param, host_node_sum_gradients[kRootNIdx]);
|
||||
p_tree->Stat(kRootNIdx).base_weight = weight;
|
||||
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
||||
|
||||
@ -927,15 +908,12 @@ struct GPUHistMakerDevice {
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
|
||||
bst_node_t max_nodes { param.MaxNodes() };
|
||||
ba.Allocate(device_id,
|
||||
&prediction_cache, n_rows,
|
||||
&node_sum_gradients_d, max_nodes,
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
|
||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
}
|
||||
host_node_sum_gradients.resize(param.MaxNodes());
|
||||
node_sum_gradients.resize(param.MaxNodes());
|
||||
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||
|
||||
@ -27,64 +27,13 @@ void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
}
|
||||
}
|
||||
|
||||
void TestLbs() {
|
||||
srand(17);
|
||||
dh::CubMemory temp_memory;
|
||||
|
||||
std::vector<int> test_rows = {4, 100, 1000};
|
||||
std::vector<int> test_max_row_sizes = {4, 100, 1300};
|
||||
|
||||
for (auto num_rows : test_rows) {
|
||||
for (auto max_row_size : test_max_row_sizes) {
|
||||
thrust::host_vector<int> h_row_ptr;
|
||||
thrust::host_vector<xgboost::bst_uint> h_rows;
|
||||
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
|
||||
thrust::device_vector<size_t> row_ptr = h_row_ptr;
|
||||
thrust::device_vector<int> output_row(h_rows.size());
|
||||
auto d_output_row = output_row.data();
|
||||
|
||||
dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::Raw(row_ptr),
|
||||
row_ptr.size() - 1, false,
|
||||
[=] __device__(size_t idx, size_t ridx) {
|
||||
d_output_row[idx] = ridx;
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
ASSERT_TRUE(h_rows == output_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CubLBS, Test) {
|
||||
TestLbs();
|
||||
}
|
||||
|
||||
TEST(SumReduce, Test) {
|
||||
thrust::device_vector<float> data(100, 1.0f);
|
||||
dh::CubMemory temp;
|
||||
auto sum = dh::SumReduction(&temp, dh::Raw(data), data.size());
|
||||
auto sum = dh::SumReduction(&temp, data.data().get(), data.size());
|
||||
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
||||
}
|
||||
|
||||
void TestAllocator() {
|
||||
int n = 10;
|
||||
Span<float> a;
|
||||
Span<int> b;
|
||||
Span<size_t> c;
|
||||
dh::BulkAllocator ba;
|
||||
ba.Allocate(0, &a, n, &b, n, &c, n);
|
||||
|
||||
// Should be no illegal memory accesses
|
||||
dh::LaunchN(0, n, [=] __device__(size_t idx) { c[idx] = a[idx] + b[idx]; });
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// Define the test in a function so we can use device lambda
|
||||
TEST(BulkAllocator, Test) {
|
||||
TestAllocator();
|
||||
}
|
||||
|
||||
template <typename T, typename Comp = thrust::less<T>>
|
||||
void TestUpperBoundImpl(const std::vector<T> &vec, T val_to_find,
|
||||
const Comp &comp = Comp()) {
|
||||
|
||||
@ -23,7 +23,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
|
||||
d_data.resize(kRows);
|
||||
thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f);
|
||||
|
||||
auto p_d_data = dh::Raw(d_data);
|
||||
auto p_d_data = d_data.data().get();
|
||||
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
|
||||
@ -49,7 +49,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
|
||||
d_data[i] = i * 2.0;
|
||||
}
|
||||
|
||||
auto p_d_data = dh::Raw(d_data);
|
||||
auto p_d_data = d_data.data().get();
|
||||
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
|
||||
|
||||
@ -25,7 +25,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
column["typestr"] = String(typestr);
|
||||
|
||||
auto p_d_data = dh::Raw(d_data);
|
||||
auto p_d_data = d_data.data().get();
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
|
||||
Json(Boolean(false))};
|
||||
|
||||
@ -24,37 +24,33 @@ namespace tree {
|
||||
|
||||
TEST(GpuHist, DeviceHistogram) {
|
||||
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
||||
dh::SaveCudaContext{
|
||||
[&]() {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
constexpr size_t kNBins = 128;
|
||||
constexpr size_t kNNodes = 4;
|
||||
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
|
||||
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram;
|
||||
histogram.Init(0, kNBins);
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
histogram.Reset();
|
||||
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
constexpr size_t kNBins = 128;
|
||||
constexpr size_t kNNodes = 4;
|
||||
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
|
||||
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram;
|
||||
histogram.Init(0, kNBins);
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
histogram.Reset();
|
||||
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
|
||||
|
||||
// Use allocated memory but do not erase nidx_map.
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_TRUE(histogram.HistogramExists(i));
|
||||
}
|
||||
// Use allocated memory but do not erase nidx_map.
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_TRUE(histogram.HistogramExists(i));
|
||||
}
|
||||
|
||||
// Erase existing nidx_map.
|
||||
for (size_t i = kNNodes; i < kNNodes * 2; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_FALSE(histogram.HistogramExists(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
// Erase existing nidx_map.
|
||||
for (size_t i = kNNodes; i < kNNodes * 2; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_FALSE(histogram.HistogramExists(i));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||
@ -187,16 +183,14 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
GPUHistMakerDevice<GradientPairPrecise>
|
||||
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.node_sum_gradients = {{6.4f, 12.8f}};
|
||||
maker.host_node_sum_gradients = {{6.4f, 12.8f}};
|
||||
|
||||
// Initialize GPUHistMakerDevice::cut
|
||||
auto cmat = GetHostCutMatrix();
|
||||
|
||||
// Copy cut matrix to device.
|
||||
page->Cuts() = cmat;
|
||||
maker.ba.Allocate(0, &(maker.monotone_constraints), kNCols);
|
||||
dh::CopyVectorToDeviceSpan(maker.monotone_constraints,
|
||||
param.monotone_constraints);
|
||||
maker.monotone_constraints = param.monotone_constraints;
|
||||
|
||||
// Initialize GPUHistMakerDevice::hist
|
||||
maker.hist.Init(0, (max_bins - 1) * kNCols);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user