Refactor out row partitioning logic from gpu_hist, introduce caching device vectors (#4554)
This commit is contained in:
parent
0c50f8417a
commit
221e163185
@ -9,6 +9,7 @@
|
||||
#include <thrust/system_error.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cub/util_allocator.cuh>
|
||||
|
||||
#include "common.h"
|
||||
#include "span.h"
|
||||
@ -299,9 +300,14 @@ namespace detail{
|
||||
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
|
||||
*/
|
||||
template <class T>
|
||||
struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator<T> {
|
||||
struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
using super_t = thrust::device_malloc_allocator<T>;
|
||||
using pointer = thrust::device_ptr<T>;
|
||||
template<typename U>
|
||||
struct rebind
|
||||
{
|
||||
typedef XGBDefaultDeviceAllocatorImpl<U> other;
|
||||
};
|
||||
pointer allocate(size_t n) {
|
||||
pointer ptr = super_t::allocate(n);
|
||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n);
|
||||
@ -312,16 +318,56 @@ struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator<T> {
|
||||
return super_t::deallocate(ptr, n);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction.
|
||||
*/
|
||||
template <class T>
|
||||
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
using pointer = thrust::device_ptr<T>;
|
||||
template<typename U>
|
||||
struct rebind
|
||||
{
|
||||
typedef XGBCachingDeviceAllocatorImpl<U> other;
|
||||
};
|
||||
cub::CachingDeviceAllocator& GetGlobalCachingAllocator ()
|
||||
{
|
||||
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
||||
// maximum cached bytes
|
||||
static cub::CachingDeviceAllocator allocator(8,3,10);
|
||||
return allocator;
|
||||
}
|
||||
pointer allocate(size_t n) {
|
||||
T *ptr;
|
||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||
n * sizeof(T));
|
||||
pointer thrust_ptr = thrust::device_ptr<T>(ptr);
|
||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n);
|
||||
return thrust_ptr;
|
||||
}
|
||||
void deallocate(pointer ptr, size_t n) {
|
||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n);
|
||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||
}
|
||||
__host__ __device__
|
||||
void construct(T *)
|
||||
{
|
||||
// no-op
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Declare xgboost allocator
|
||||
// Declare xgboost allocators
|
||||
// Replacement of allocator with custom backend should occur here
|
||||
template <typename T>
|
||||
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocator<T>;
|
||||
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
|
||||
template <typename T>
|
||||
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
|
||||
/** \brief Specialisation of thrust device vector using custom allocator. */
|
||||
template <typename T>
|
||||
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>;
|
||||
|
||||
template <typename T>
|
||||
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>;
|
||||
/**
|
||||
* \brief A double buffer, useful for algorithms like sort.
|
||||
*/
|
||||
@ -331,6 +377,14 @@ class DoubleBuffer {
|
||||
cub::DoubleBuffer<T> buff;
|
||||
xgboost::common::Span<T> a, b;
|
||||
DoubleBuffer() = default;
|
||||
template <typename VectorT>
|
||||
DoubleBuffer(VectorT *v1, VectorT *v2) {
|
||||
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
|
||||
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
|
||||
buff.d_buffers[0] = v1->data().get();
|
||||
buff.d_buffers[1] = v2->data().get();
|
||||
buff.selector = 0;
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
CHECK_EQ(a.size(), b.size());
|
||||
@ -362,6 +416,20 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<T> src) {
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Copies const device span to std::vector.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \param [in,out] dst Copy destination.
|
||||
* \param src Copy source. Must be device memory.
|
||||
*/
|
||||
template <typename T>
|
||||
void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T> src) {
|
||||
CHECK_EQ(dst->size(), src.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(dst->data(), src.data(), dst->size() * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Copies std::vector to device span.
|
||||
*
|
||||
@ -1132,6 +1200,7 @@ class AllReducer {
|
||||
* safe) using the master thread. Uses naive reduce algorithm for local
|
||||
* threads, don't expect this to scale.*/
|
||||
void HostMaxAllReduce(std::vector<size_t> *p_data) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
auto &data = *p_data;
|
||||
// Wait in case some other thread is accessing host_data
|
||||
#pragma omp barrier
|
||||
@ -1162,6 +1231,7 @@ class AllReducer {
|
||||
for (auto i = 0ull; i < data.size(); i++) {
|
||||
data[i] = host_data[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
146
src/tree/gpu_hist/row_partitioner.cu
Normal file
146
src/tree/gpu_hist/row_partitioner.cu
Normal file
@ -0,0 +1,146 @@
|
||||
|
||||
/*!
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/sequence.h>
|
||||
#include <vector>
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "row_partitioner.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
struct IndicateLeftTransform {
|
||||
RowPartitioner::TreePositionT left_nidx;
|
||||
explicit IndicateLeftTransform(RowPartitioner::TreePositionT left_nidx)
|
||||
: left_nidx(left_nidx) {}
|
||||
__host__ __device__ __forceinline__ int operator()(
|
||||
const RowPartitioner::TreePositionT& x) const {
|
||||
return x == left_nidx ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
void RowPartitioner::SortPosition(common::Span<TreePositionT> position,
|
||||
common::Span<TreePositionT> position_out,
|
||||
common::Span<RowIndexT> ridx,
|
||||
common::Span<RowIndexT> ridx_out,
|
||||
TreePositionT left_nidx,
|
||||
TreePositionT right_nidx,
|
||||
int64_t* d_left_count, cudaStream_t stream) {
|
||||
auto d_position_out = position_out.data();
|
||||
auto d_position_in = position.data();
|
||||
auto d_ridx_out = ridx_out.data();
|
||||
auto d_ridx_in = ridx.data();
|
||||
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
|
||||
int scatter_address;
|
||||
if (d_position_in[idx] == left_nidx) {
|
||||
scatter_address = ex_scan_result;
|
||||
} else {
|
||||
scatter_address = (idx - ex_scan_result) + *d_left_count;
|
||||
}
|
||||
d_position_out[scatter_address] = d_position_in[idx];
|
||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
||||
}; // NOLINT
|
||||
|
||||
IndicateLeftTransform conversion_op(left_nidx);
|
||||
cub::TransformInputIterator<TreePositionT, IndicateLeftTransform,
|
||||
TreePositionT*>
|
||||
in_itr(d_position_in, conversion_op);
|
||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
||||
position.size(), stream);
|
||||
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
|
||||
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
|
||||
in_itr, out_itr, position.size(), stream);
|
||||
}
|
||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||
: device_idx(device_idx) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
ridx_a.resize(num_rows);
|
||||
ridx_b.resize(num_rows);
|
||||
position_a.resize(num_rows);
|
||||
position_b.resize(num_rows);
|
||||
ridx = dh::DoubleBuffer<RowIndexT>{&ridx_a, &ridx_b};
|
||||
position = dh::DoubleBuffer<TreePositionT>{&position_a, &position_b};
|
||||
ridx_segments.emplace_back(Segment(0, num_rows));
|
||||
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(position.Current()),
|
||||
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
|
||||
left_counts.resize(256);
|
||||
thrust::fill(left_counts.begin(), left_counts.end(), 0);
|
||||
streams.resize(2);
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamCreate(&stream));
|
||||
}
|
||||
}
|
||||
RowPartitioner::~RowPartitioner() {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
}
|
||||
}
|
||||
|
||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
|
||||
TreePositionT nidx) {
|
||||
auto segment = ridx_segments.at(nidx);
|
||||
// Return empty span here as a valid result
|
||||
// Will error if we try to construct a span from a pointer with size 0
|
||||
if (segment.Size() == 0) {
|
||||
return common::Span<const RowPartitioner::RowIndexT>();
|
||||
}
|
||||
return ridx.CurrentSpan().subspan(segment.begin, segment.Size());
|
||||
}
|
||||
|
||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
|
||||
return ridx.CurrentSpan();
|
||||
}
|
||||
|
||||
common::Span<const RowPartitioner::TreePositionT>
|
||||
RowPartitioner::GetPosition() {
|
||||
return position.CurrentSpan();
|
||||
}
|
||||
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
|
||||
TreePositionT nidx) {
|
||||
auto span = GetRows(nidx);
|
||||
std::vector<RowIndexT> rows(span.size());
|
||||
dh::CopyDeviceSpanToVector(&rows, span);
|
||||
return rows;
|
||||
}
|
||||
|
||||
std::vector<RowPartitioner::TreePositionT> RowPartitioner::GetPositionHost() {
|
||||
auto span = GetPosition();
|
||||
std::vector<TreePositionT> position(span.size());
|
||||
dh::CopyDeviceSpanToVector(&position, span);
|
||||
return position;
|
||||
}
|
||||
|
||||
void RowPartitioner::SortPositionAndCopy(const Segment& segment,
|
||||
TreePositionT left_nidx,
|
||||
TreePositionT right_nidx,
|
||||
int64_t* d_left_count,
|
||||
cudaStream_t stream) {
|
||||
SortPosition(
|
||||
common::Span<TreePositionT>(position.Current() + segment.begin,
|
||||
segment.Size()),
|
||||
common::Span<TreePositionT>(position.other() + segment.begin,
|
||||
segment.Size()),
|
||||
common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()),
|
||||
common::Span<RowIndexT>(ridx.other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, d_left_count, stream);
|
||||
// Copy back key/value
|
||||
const auto d_position_current = position.Current() + segment.begin;
|
||||
const auto d_position_other = position.other() + segment.begin;
|
||||
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||
dh::LaunchN(device_idx, segment.Size(), stream, [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
});
|
||||
}
|
||||
}; // namespace tree
|
||||
}; // namespace xgboost
|
||||
186
src/tree/gpu_hist/row_partitioner.cuh
Normal file
186
src/tree/gpu_hist/row_partitioner.cuh
Normal file
@ -0,0 +1,186 @@
|
||||
/*!
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
/*! \brief Count how many rows are assigned to left node. */
|
||||
__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) {
|
||||
#if __CUDACC_VER_MAJOR__ > 8
|
||||
int mask = __activemask();
|
||||
unsigned ballot = __ballot_sync(mask, increment);
|
||||
int leader = __ffs(mask) - 1;
|
||||
if (threadIdx.x % 32 == leader) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
#else
|
||||
unsigned ballot = __ballot(increment);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/** \brief Class responsible for tracking subsets of rows as we add splits and
|
||||
* partition training rows into different leaf nodes. */
|
||||
class RowPartitioner {
|
||||
public:
|
||||
using TreePositionT = int;
|
||||
using RowIndexT = bst_uint;
|
||||
struct Segment;
|
||||
|
||||
private:
|
||||
int device_idx;
|
||||
/*! \brief Range of rows for each node. */
|
||||
std::vector<Segment> ridx_segments;
|
||||
dh::caching_device_vector<RowIndexT> ridx_a;
|
||||
dh::caching_device_vector<RowIndexT> ridx_b;
|
||||
dh::caching_device_vector<TreePositionT> position_a;
|
||||
dh::caching_device_vector<TreePositionT> position_b;
|
||||
dh::DoubleBuffer<RowIndexT> ridx;
|
||||
dh::DoubleBuffer<TreePositionT> position;
|
||||
dh::caching_device_vector<int64_t>
|
||||
left_counts; // Useful to keep a bunch of zeroed memory for sort position
|
||||
std::vector<cudaStream_t> streams;
|
||||
|
||||
public:
|
||||
RowPartitioner(int device_idx, size_t num_rows);
|
||||
~RowPartitioner();
|
||||
RowPartitioner(const RowPartitioner&) = delete;
|
||||
RowPartitioner& operator=(const RowPartitioner&) = delete;
|
||||
|
||||
/**
|
||||
* \brief Gets the row indices of training instances in a given node.
|
||||
*/
|
||||
common::Span<const RowIndexT> GetRows(TreePositionT nidx);
|
||||
|
||||
/**
|
||||
* \brief Gets all training rows in the set.
|
||||
*/
|
||||
common::Span<const RowIndexT> GetRows();
|
||||
|
||||
/**
|
||||
* \brief Gets the tree position of all training instances.
|
||||
*/
|
||||
common::Span<const TreePositionT> GetPosition();
|
||||
|
||||
/**
|
||||
* \brief Convenience method for testing
|
||||
*/
|
||||
std::vector<RowIndexT> GetRowsHost(TreePositionT nidx);
|
||||
|
||||
/**
|
||||
* \brief Convenience method for testing
|
||||
*/
|
||||
std::vector<TreePositionT> GetPositionHost();
|
||||
|
||||
/**
|
||||
* \brief Updates the tree position for set of training instances being split
|
||||
* into left and right child nodes. Accepts a user-defined lambda specifying
|
||||
* which branch each training instance should go down.
|
||||
*
|
||||
* \tparam UpdatePositionOpT
|
||||
* \param nidx The index of the node being split.
|
||||
* \param left_nidx The left child index.
|
||||
* \param right_nidx The right child index.
|
||||
* \param op Device lambda. Should provide the row index as an
|
||||
* argument and return the new position for this training instance.
|
||||
*/
|
||||
template <typename UpdatePositionOpT>
|
||||
void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, UpdatePositionOpT op) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
Segment segment = ridx_segments.at(nidx);
|
||||
auto d_ridx = ridx.CurrentSpan();
|
||||
auto d_position = position.CurrentSpan();
|
||||
if (left_counts.size() <= nidx) {
|
||||
left_counts.resize((nidx * 2) + 1);
|
||||
thrust::fill(left_counts.begin(), left_counts.end(), 0);
|
||||
}
|
||||
int64_t* d_left_count = left_counts.data().get() + nidx;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(device_idx, segment.Size(), [=] __device__(size_t idx) {
|
||||
idx += segment.begin;
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
// Missing value
|
||||
TreePositionT new_position = op(ridx);
|
||||
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
|
||||
AtomicIncrement(d_left_count, new_position == left_nidx);
|
||||
d_position[idx] = new_position;
|
||||
});
|
||||
// Overlap device to host memory copy (left_count) with sort
|
||||
int64_t left_count;
|
||||
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, streams[0]));
|
||||
|
||||
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
|
||||
streams[1]);
|
||||
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
|
||||
CHECK_LE(left_count, segment.Size());
|
||||
CHECK_GE(left_count, 0);
|
||||
ridx_segments.resize(std::max(int(ridx_segments.size()),
|
||||
std::max(left_nidx, right_nidx) + 1));
|
||||
ridx_segments[left_nidx] =
|
||||
Segment(segment.begin, segment.begin + left_count);
|
||||
ridx_segments[right_nidx] =
|
||||
Segment(segment.begin + left_count, segment.end);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Finalise the position of all training instances after tree
|
||||
* construction is complete. Does not update any other meta information in
|
||||
* this data structure, so should only be used at the end of training.
|
||||
*
|
||||
* \param op Device lambda. Should provide the row index and current
|
||||
* position as an argument and return the new position for this training
|
||||
* instance.
|
||||
*/
|
||||
template <typename FinalisePositionOpT>
|
||||
void FinalisePosition(FinalisePositionOpT op) {
|
||||
auto d_position = position.Current();
|
||||
const auto d_ridx = ridx.Current();
|
||||
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
|
||||
auto position = d_position[idx];
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
d_position[idx] = op(ridx, position);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Optimised routine for sorting key value pairs into left and right
|
||||
* segments. Based on a single pass of exclusive scan, uses iterators to
|
||||
* redirect inputs and outputs.
|
||||
*/
|
||||
void SortPosition(common::Span<TreePositionT> position,
|
||||
common::Span<TreePositionT> position_out,
|
||||
common::Span<RowIndexT> ridx,
|
||||
common::Span<RowIndexT> ridx_out, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPositionAndCopy(const Segment& segment, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream);
|
||||
/** \brief Used to demarcate a contiguous set of row indices associated with
|
||||
* some tree node. */
|
||||
struct Segment {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
|
||||
Segment() : begin{0}, end{0} {}
|
||||
|
||||
Segment(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
CHECK_GE(end, begin);
|
||||
}
|
||||
size_t Size() const { return end - begin; }
|
||||
};
|
||||
};
|
||||
}; // namespace tree
|
||||
}; // namespace xgboost
|
||||
@ -6,7 +6,6 @@
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@ -25,6 +24,7 @@
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "constraints.cuh"
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -515,10 +515,9 @@ __global__ void CompressBinEllpackKernel(
|
||||
|
||||
template <typename GradientSumT>
|
||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||
const bst_uint* d_ridx,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientSumT* d_node_hist,
|
||||
const GradientPair* d_gpair,
|
||||
size_t segment_begin, size_t n_elements,
|
||||
const GradientPair* d_gpair, size_t n_elements,
|
||||
bool use_shared_memory_histograms) {
|
||||
extern __shared__ char smem[];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
@ -527,7 +526,7 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||
__syncthreads();
|
||||
}
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
|
||||
int ridx = d_ridx[idx / matrix.row_stride ];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
if (gidx != matrix.null_gidx_value) {
|
||||
@ -549,86 +548,6 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||
}
|
||||
}
|
||||
|
||||
struct Segment {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
|
||||
Segment() : begin{0}, end{0} {}
|
||||
|
||||
Segment(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
CHECK_GE(end, begin);
|
||||
}
|
||||
size_t Size() const { return end - begin; }
|
||||
};
|
||||
|
||||
/** \brief Returns a one if the left node index is encountered, otherwise return
|
||||
* zero. */
|
||||
struct IndicateLeftTransform {
|
||||
int left_nidx;
|
||||
explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {}
|
||||
__host__ __device__ __forceinline__ int operator()(const int& x) const {
|
||||
return x == left_nidx ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Optimised routine for sorting key value pairs into left and right
|
||||
* segments. Based on a single pass of exclusive scan, uses iterators to
|
||||
* redirect inputs and outputs.
|
||||
*/
|
||||
inline void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
|
||||
common::Span<int> position_out, common::Span<bst_uint> ridx,
|
||||
common::Span<bst_uint> ridx_out, int left_nidx,
|
||||
int right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream = nullptr) {
|
||||
auto d_position_out = position_out.data();
|
||||
auto d_position_in = position.data();
|
||||
auto d_ridx_out = ridx_out.data();
|
||||
auto d_ridx_in = ridx.data();
|
||||
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
|
||||
int scatter_address;
|
||||
if (d_position_in[idx] == left_nidx) {
|
||||
scatter_address = ex_scan_result;
|
||||
} else {
|
||||
scatter_address = (idx - ex_scan_result) + *d_left_count;
|
||||
}
|
||||
d_position_out[scatter_address] = d_position_in[idx];
|
||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
||||
}; // NOLINT
|
||||
|
||||
IndicateLeftTransform conversion_op(left_nidx);
|
||||
cub::TransformInputIterator<int, IndicateLeftTransform, int*> in_itr(
|
||||
d_position_in, conversion_op);
|
||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
||||
position.size(), stream);
|
||||
temp_memory->LazyAllocate(temp_storage_bytes);
|
||||
cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage,
|
||||
temp_memory->temp_storage_bytes, in_itr,
|
||||
out_itr, position.size(), stream);
|
||||
}
|
||||
|
||||
/*! \brief Count how many rows are assigned to left node. */
|
||||
__forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
|
||||
int left_nidx) {
|
||||
#if __CUDACC_VER_MAJOR__ > 8
|
||||
int mask = __activemask();
|
||||
unsigned ballot = __ballot_sync(mask, val == left_nidx);
|
||||
int leader = __ffs(mask) - 1;
|
||||
if (threadIdx.x % 32 == leader) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
#else
|
||||
unsigned ballot = __ballot(val == left_nidx);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Instances of this type are created while creating the histogram bins for the
|
||||
// entire dataset across multiple sparse page batches. This keeps track of the number
|
||||
// of rows to process from a batch and the position from which to process on each device.
|
||||
@ -671,8 +590,7 @@ struct DeviceShard {
|
||||
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
|
||||
/*! \brief Range of rows for each node. */
|
||||
std::vector<Segment> ridx_segments;
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief row_ptr form HistCutMatrix. */
|
||||
@ -684,9 +602,6 @@ struct DeviceShard {
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
common::Span<common::CompressedByteT> gidx_buffer;
|
||||
|
||||
/*! \brief Row indices relative to this shard, necessary for sorting rows. */
|
||||
dh::DoubleBuffer<bst_uint> ridx;
|
||||
dh::DoubleBuffer<int> position;
|
||||
/*! \brief Gradient pair for each row. */
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
@ -696,8 +611,8 @@ struct DeviceShard {
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
dh::device_vector<int64_t>
|
||||
left_counts; // Useful to keep a bunch of zeroed memory for sort position
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
dh::device_vector<int> feature_set_d;
|
||||
/*! The row offset for this shard. */
|
||||
bst_uint row_begin_idx;
|
||||
bst_uint row_end_idx;
|
||||
@ -783,24 +698,10 @@ struct DeviceShard {
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(position.Current()),
|
||||
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
if (left_counts.size() < 256) {
|
||||
left_counts.resize(256);
|
||||
} else {
|
||||
dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0,
|
||||
sizeof(int64_t) * left_counts.size()));
|
||||
}
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
|
||||
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.Size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair.data(), dh_gpair->ConstDevicePointer(device_id),
|
||||
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
|
||||
@ -892,12 +793,11 @@ struct DeviceShard {
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
auto segment = ridx_segments[nidx];
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
auto d_gpair = gpair.data();
|
||||
|
||||
auto n_elements = segment.Size() * ellpack_matrix.row_stride;
|
||||
auto n_elements = d_ridx.size() * ellpack_matrix.row_stride;
|
||||
|
||||
const size_t smem_size =
|
||||
use_shared_memory_histograms
|
||||
@ -911,8 +811,8 @@ struct DeviceShard {
|
||||
return;
|
||||
}
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin,
|
||||
n_elements, use_shared_memory_histograms);
|
||||
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
||||
use_shared_memory_histograms);
|
||||
}
|
||||
|
||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||
@ -936,21 +836,13 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree::Node split_node) {
|
||||
CHECK(!split_node.IsLeaf()) <<"Node must not be leaf";
|
||||
Segment segment = ridx_segments[nidx];
|
||||
bst_uint* d_ridx = ridx.Current();
|
||||
int* d_position = position.Current();
|
||||
if (left_counts.size() <= nidx) {
|
||||
left_counts.resize((nidx * 2) + 1);
|
||||
}
|
||||
int64_t* d_left_count = left_counts.data().get() + nidx;
|
||||
auto d_matrix = this->ellpack_matrix;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(
|
||||
device_id, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
idx += segment.begin;
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
bst_float element = d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
auto d_matrix = ellpack_matrix;
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(bst_uint ridx) {
|
||||
bst_float element =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
if (isnan(element)) {
|
||||
@ -962,49 +854,8 @@ struct DeviceShard {
|
||||
new_position = split_node.RightChild();
|
||||
}
|
||||
}
|
||||
CountLeft(d_left_count, new_position, split_node.LeftChild());
|
||||
d_position[idx] = new_position;
|
||||
return new_position;
|
||||
});
|
||||
|
||||
// Overlap device to host memory copy (left_count) with sort
|
||||
auto& streams = this->GetStreams(2);
|
||||
auto tmp_pinned = pinned_memory.GetSpan<int64_t>(1);
|
||||
dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, streams[0]));
|
||||
|
||||
SortPositionAndCopy(segment, split_node.LeftChild(), split_node.RightChild(), d_left_count,
|
||||
streams[1]);
|
||||
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
|
||||
int64_t left_count = tmp_pinned[0];
|
||||
CHECK_LE(left_count, segment.Size());
|
||||
CHECK_GE(left_count, 0);
|
||||
ridx_segments[split_node.LeftChild()] =
|
||||
Segment(segment.begin, segment.begin + left_count);
|
||||
ridx_segments[split_node.RightChild()] =
|
||||
Segment(segment.begin + left_count, segment.end);
|
||||
}
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPositionAndCopy(const Segment& segment, int left_nidx,
|
||||
int right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream) {
|
||||
SortPosition(
|
||||
&temp_memory,
|
||||
common::Span<int>(position.Current() + segment.begin, segment.Size()),
|
||||
common::Span<int>(position.other() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, d_left_count, stream);
|
||||
// Copy back key/value
|
||||
const auto d_position_current = position.Current() + segment.begin;
|
||||
const auto d_position_other = position.other() + segment.begin;
|
||||
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||
dh::LaunchN(device_id, segment.Size(), stream, [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
});
|
||||
}
|
||||
|
||||
// After tree update is finished, update the position of all training
|
||||
@ -1016,30 +867,27 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
||||
d_nodes.size() * sizeof(RegTree::Node),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
const auto d_ridx = ridx.Current();
|
||||
auto d_matrix = this->ellpack_matrix;
|
||||
dh::LaunchN(device_id, position.Size(), [=] __device__(size_t idx) {
|
||||
auto position = d_position[idx];
|
||||
auto node = d_nodes[position];
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
auto d_matrix = ellpack_matrix;
|
||||
row_partitioner->FinalisePosition(
|
||||
[=] __device__(bst_uint ridx, int position) {
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
} else {
|
||||
if (element <= node.SplitCond()) {
|
||||
position = node.LeftChild();
|
||||
} else {
|
||||
position = node.RightChild();
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
}
|
||||
node = d_nodes[position];
|
||||
}
|
||||
d_position[idx] = position;
|
||||
});
|
||||
return position;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
@ -1057,8 +905,8 @@ struct DeviceShard {
|
||||
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_ridx = row_partitioner->GetRows();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
|
||||
@ -1096,13 +944,15 @@ struct DeviceShard {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
auto left_node_rows = ridx_segments[nidx_left].Size();
|
||||
auto right_node_rows = ridx_segments[nidx_right].Size();
|
||||
auto left_node_rows = row_partitioner->GetRows(nidx_left).size();
|
||||
auto right_node_rows = row_partitioner->GetRows(nidx_right).size();
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Find the largest number of training instances on any given Shard
|
||||
// Assume this will be the bottleneck and avoid building this node if
|
||||
// possible
|
||||
std::vector<size_t> max_reduce = {left_node_rows, right_node_rows};
|
||||
std::vector<size_t> max_reduce;
|
||||
max_reduce.push_back(left_node_rows);
|
||||
max_reduce.push_back(right_node_rows);
|
||||
reducer->HostMaxAllReduce(&max_reduce);
|
||||
bool fewer_right = max_reduce[1] < max_reduce[0];
|
||||
if (fewer_right) {
|
||||
@ -1199,6 +1049,7 @@ struct DeviceShard {
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor.StartCuda("Reset");
|
||||
this->Reset(gpair_all, p_fmat->Info().num_col_);
|
||||
monitor.StopCuda("Reset");
|
||||
@ -1206,7 +1057,6 @@ struct DeviceShard {
|
||||
monitor.StartCuda("InitRoot");
|
||||
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
|
||||
monitor.StopCuda("InitRoot");
|
||||
|
||||
auto timestamp = qexpand->size();
|
||||
auto num_leaves = 1;
|
||||
|
||||
@ -1269,8 +1119,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
|
||||
ba.Allocate(device_id,
|
||||
&gpair, n_rows,
|
||||
&ridx, n_rows,
|
||||
&position, n_rows,
|
||||
&prediction_cache, n_rows,
|
||||
&node_sum_gradients_d, max_nodes,
|
||||
&feature_segments, hmat.row_ptr.size(),
|
||||
@ -1284,7 +1132,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
@ -1303,7 +1150,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
gidx_fvalue_map, row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
||||
is_dense, null_gidx_value);
|
||||
|
||||
// check if we can use shared memory for building histograms
|
||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||
// hiding)
|
||||
|
||||
@ -97,7 +97,8 @@ TEST(bulkAllocator, Test) {
|
||||
}
|
||||
|
||||
// Test thread safe max reduction
|
||||
TEST(AllReducer, HostMaxAllReduce) {
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
TEST(AllReducer, MGPU_HostMaxAllReduce) {
|
||||
dh::AllReducer reducer;
|
||||
size_t num_threads = 50;
|
||||
std::vector<std::vector<size_t>> thread_data(num_threads);
|
||||
@ -112,3 +113,4 @@ TEST(AllReducer, HostMaxAllReduce) {
|
||||
ASSERT_EQ(data.front(), num_threads - 1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
125
tests/cpp/tree/gpu_hist/test_row_partitioner.cu
Normal file
125
tests/cpp/tree/gpu_hist/test_row_partitioner.cu
Normal file
@ -0,0 +1,125 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||
#include "../../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
std::vector<int64_t> left_count = {
|
||||
std::count(position_in.begin(), position_in.end(), left_idx)};
|
||||
thrust::device_vector<int64_t> d_left_count = left_count;
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
thrust::device_vector<RowPartitioner::RowIndexT> ridx(position.size());
|
||||
thrust::sequence(ridx.begin(), ridx.end());
|
||||
thrust::device_vector<RowPartitioner::RowIndexT> ridx_out(ridx.size());
|
||||
RowPartitioner rp(0,10);
|
||||
rp.SortPosition(
|
||||
common::Span<int>(position.data().get(), position.size()),
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<RowPartitioner::RowIndexT>(ridx.data().get(), ridx.size()),
|
||||
common::Span<RowPartitioner::RowIndexT>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, d_left_count.data().get(), nullptr);
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
// Check position is sorted
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
|
||||
}
|
||||
}
|
||||
TEST(GpuHist, SortPosition) {
|
||||
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
|
||||
TestSortPosition({1, 1, 1, 1}, 1, 2);
|
||||
TestSortPosition({2, 2, 2, 2}, 1, 2);
|
||||
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
|
||||
}
|
||||
|
||||
void TestUpdatePosition() {
|
||||
const int kNumRows = 10;
|
||||
RowPartitioner rp(0, kNumRows);
|
||||
auto rows = rp.GetRowsHost(0);
|
||||
EXPECT_EQ(rows.size(), kNumRows);
|
||||
for (auto i = 0ull; i < kNumRows; i++) {
|
||||
EXPECT_EQ(rows[i], i);
|
||||
}
|
||||
// Send the first five training instances to the right node
|
||||
// and the second 5 to the left node
|
||||
rp.UpdatePosition(0, 1, 2,
|
||||
[=] __device__(RowPartitioner::RowIndexT ridx) {
|
||||
if (ridx > 4) {
|
||||
return 1;
|
||||
}
|
||||
else {
|
||||
return 2;
|
||||
}
|
||||
});
|
||||
rows = rp.GetRowsHost(1);
|
||||
for (auto r : rows) {
|
||||
EXPECT_GT(r, 4);
|
||||
}
|
||||
rows = rp.GetRowsHost(2);
|
||||
for (auto r : rows) {
|
||||
EXPECT_LT(r, 5);
|
||||
}
|
||||
|
||||
// Split the left node again
|
||||
rp.UpdatePosition(1, 3, 4, [=]__device__(RowPartitioner::RowIndexT ridx)
|
||||
{
|
||||
if (ridx < 7) {
|
||||
return 3
|
||||
;
|
||||
}
|
||||
return 4;
|
||||
});
|
||||
EXPECT_EQ(rp.GetRows(3).size(), 2);
|
||||
EXPECT_EQ(rp.GetRows(4).size(), 3);
|
||||
// Check position is as expected
|
||||
EXPECT_EQ(rp.GetPositionHost(), std::vector<RowPartitioner::TreePositionT>({3,3,4,4,4,2,2,2,2,2}));
|
||||
}
|
||||
|
||||
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
|
||||
|
||||
void TestFinalise() {
|
||||
const int kNumRows = 10;
|
||||
RowPartitioner rp(0, kNumRows);
|
||||
rp.FinalisePosition([=]__device__(RowPartitioner::RowIndexT ridx, int position)
|
||||
{
|
||||
return 7;
|
||||
});
|
||||
auto position = rp.GetPositionHost();
|
||||
for(auto p:position)
|
||||
{
|
||||
EXPECT_EQ(p, 7);
|
||||
}
|
||||
}
|
||||
TEST(RowPartitioner, Finalise) { TestFinalise(); }
|
||||
|
||||
void TestIncorrectRow() {
|
||||
RowPartitioner rp(0, 1);
|
||||
rp.UpdatePosition(0, 1, 2, [=]__device__ (RowPartitioner::RowIndexT ridx)
|
||||
{
|
||||
return 4; // This is not the left branch or the right branch
|
||||
});
|
||||
}
|
||||
|
||||
TEST(RowPartitioner, IncorrectRow) {
|
||||
ASSERT_DEATH({ TestIncorrectRow(); },".*");
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@ -206,16 +206,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
|
||||
sizeof(common::CompressedByteT) * shard.gidx_buffer.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
auto gidx = common::CompressedIterator<uint32_t>(h_gidx_buffer.data(),
|
||||
num_symbols);
|
||||
|
||||
shard.ridx_segments.resize(1);
|
||||
shard.ridx_segments[0] = Segment(0, kNRows);
|
||||
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
shard.hist.AllocateHistogram(0);
|
||||
dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair);
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(shard.ridx.Current()),
|
||||
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
||||
|
||||
shard.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||
shard.BuildHist(0);
|
||||
@ -358,138 +352,6 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
int constexpr kNId = 0;
|
||||
int constexpr kNRows = 16;
|
||||
int constexpr kNCols = 8;
|
||||
|
||||
TrainParam param;
|
||||
std::vector<std::pair<std::string, std::string>> args = {};
|
||||
param.InitAllowUnknown(args);
|
||||
// Initialize shard
|
||||
for (size_t i = 0; i < kNCols; ++i) {
|
||||
param.monotone_constraints.emplace_back(0);
|
||||
}
|
||||
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
||||
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols,
|
||||
kNCols)};
|
||||
|
||||
shard->ridx_segments.resize(3); // 3 nodes.
|
||||
shard->node_sum_gradients.resize(3);
|
||||
|
||||
shard->ridx_segments[0] = Segment(0, kNRows);
|
||||
shard->ba.Allocate(0, &(shard->ridx), kNRows,
|
||||
&(shard->position), kNRows);
|
||||
shard->ellpack_matrix.row_stride = kNCols;
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(shard->ridx.Current()),
|
||||
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
|
||||
RegTree tree;
|
||||
|
||||
DeviceSplitCandidate candidate;
|
||||
candidate.Update(2, kLeftDir,
|
||||
0.59, 4, // fvalue has to be equal to one of the cut field
|
||||
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
|
||||
GPUTrainingParam(param));
|
||||
ExpandEntry candidate_entry {0, 0, candidate, 0};
|
||||
candidate_entry.nid = kNId;
|
||||
|
||||
// Used to get bin_id in update position.
|
||||
common::HistCutMatrix cmat = GetHostCutMatrix();
|
||||
|
||||
MetaInfo info;
|
||||
info.num_row_ = kNRows;
|
||||
info.num_col_ = kNCols;
|
||||
info.num_nonzero_ = kNRows * kNCols; // Dense
|
||||
|
||||
// Initialize gidx
|
||||
int n_bins = 24;
|
||||
int row_stride = kNCols;
|
||||
int num_symbols = n_bins + 1;
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows,
|
||||
num_symbols);
|
||||
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes,
|
||||
&(shard->feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->min_fvalue), cmat.min_val.size(),
|
||||
&(shard->gidx_fvalue_map), 24);
|
||||
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
|
||||
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
|
||||
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
||||
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
|
||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
||||
shard->ellpack_matrix.is_dense = true;
|
||||
|
||||
common::CompressedBufferWriter wr(num_symbols);
|
||||
// gidx 14 should go right, 12 goes left
|
||||
std::vector<int> h_gidx (kNRows * row_stride, 14);
|
||||
h_gidx[4] = 12;
|
||||
h_gidx[12] = 12;
|
||||
std::vector<common::CompressedByteT> h_gidx_compressed (compressed_size_bytes);
|
||||
|
||||
wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end());
|
||||
dh::CopyVectorToDeviceSpan(shard->gidx_buffer, h_gidx_compressed);
|
||||
|
||||
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
|
||||
shard->gidx_buffer.data(), num_symbols);
|
||||
|
||||
shard->ApplySplit(candidate_entry, &tree);
|
||||
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);
|
||||
|
||||
ASSERT_FALSE(tree[kNId].IsLeaf());
|
||||
|
||||
int left_nidx = tree[kNId].LeftChild();
|
||||
int right_nidx = tree[kNId].RightChild();
|
||||
|
||||
ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0);
|
||||
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2);
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2);
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
|
||||
}
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
std::vector<int64_t> left_count = {
|
||||
std::count(position_in.begin(), position_in.end(), left_idx)};
|
||||
thrust::device_vector<int64_t> d_left_count = left_count;
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
thrust::device_vector<bst_uint> ridx(position.size());
|
||||
thrust::sequence(ridx.begin(), ridx.end());
|
||||
thrust::device_vector<bst_uint> ridx_out(ridx.size());
|
||||
dh::CubMemory tmp;
|
||||
SortPosition(
|
||||
&tmp, common::Span<int>(position.data().get(), position.size()),
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, d_left_count.data().get(), nullptr);
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
// Check position is sorted
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, SortPosition) {
|
||||
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
|
||||
TestSortPosition({1, 1, 1, 1}, 1, 2);
|
||||
TestSortPosition({2, 2, 2, 2}, 1, 2);
|
||||
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
|
||||
}
|
||||
|
||||
void TestHistogramIndexImpl(int n_gpus) {
|
||||
// Test if the compressed histogram index matches when using a sparse
|
||||
// dmatrix with and without using external memory
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user