Refactor out row partitioning logic from gpu_hist, introduce caching device vectors (#4554)

This commit is contained in:
Rory Mitchell 2019-06-20 18:24:09 +12:00 committed by GitHub
parent 0c50f8417a
commit 221e163185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 582 additions and 345 deletions

View File

@ -9,6 +9,7 @@
#include <thrust/system_error.h>
#include <xgboost/logging.h>
#include <rabit/rabit.h>
#include <cub/util_allocator.cuh>
#include "common.h"
#include "span.h"
@ -299,9 +300,14 @@ namespace detail{
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
*/
template <class T>
struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator<T> {
struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
using super_t = thrust::device_malloc_allocator<T>;
using pointer = thrust::device_ptr<T>;
template<typename U>
struct rebind
{
typedef XGBDefaultDeviceAllocatorImpl<U> other;
};
pointer allocate(size_t n) {
pointer ptr = super_t::allocate(n);
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n);
@ -312,16 +318,56 @@ struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator<T> {
return super_t::deallocate(ptr, n);
}
};
/**
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction.
*/
template <class T>
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
using pointer = thrust::device_ptr<T>;
template<typename U>
struct rebind
{
typedef XGBCachingDeviceAllocatorImpl<U> other;
};
cub::CachingDeviceAllocator& GetGlobalCachingAllocator ()
{
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
static cub::CachingDeviceAllocator allocator(8,3,10);
return allocator;
}
pointer allocate(size_t n) {
T *ptr;
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
pointer thrust_ptr = thrust::device_ptr<T>(ptr);
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n);
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) {
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n);
GetGlobalCachingAllocator().DeviceFree(ptr.get());
}
__host__ __device__
void construct(T *)
{
// no-op
}
};
};
// Declare xgboost allocator
// Declare xgboost allocators
// Replacement of allocator with custom backend should occur here
template <typename T>
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocator<T>;
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
template <typename T>
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
/** \brief Specialisation of thrust device vector using custom allocator. */
template <typename T>
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>;
template <typename T>
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>;
/**
* \brief A double buffer, useful for algorithms like sort.
*/
@ -331,6 +377,14 @@ class DoubleBuffer {
cub::DoubleBuffer<T> buff;
xgboost::common::Span<T> a, b;
DoubleBuffer() = default;
template <typename VectorT>
DoubleBuffer(VectorT *v1, VectorT *v2) {
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
buff.d_buffers[0] = v1->data().get();
buff.d_buffers[1] = v2->data().get();
buff.selector = 0;
}
size_t Size() const {
CHECK_EQ(a.size(), b.size());
@ -362,6 +416,20 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<T> src) {
cudaMemcpyDeviceToHost));
}
/**
* \brief Copies const device span to std::vector.
*
* \tparam T Generic type parameter.
* \param [in,out] dst Copy destination.
* \param src Copy source. Must be device memory.
*/
template <typename T>
void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T> src) {
CHECK_EQ(dst->size(), src.size());
dh::safe_cuda(cudaMemcpyAsync(dst->data(), src.data(), dst->size() * sizeof(T),
cudaMemcpyDeviceToHost));
}
/**
* \brief Copies std::vector to device span.
*
@ -1132,6 +1200,7 @@ class AllReducer {
* safe) using the master thread. Uses naive reduce algorithm for local
* threads, don't expect this to scale.*/
void HostMaxAllReduce(std::vector<size_t> *p_data) {
#ifdef XGBOOST_USE_NCCL
auto &data = *p_data;
// Wait in case some other thread is accessing host_data
#pragma omp barrier
@ -1162,6 +1231,7 @@ class AllReducer {
for (auto i = 0ull; i < data.size(); i++) {
data[i] = host_data[i];
}
#endif
}
};

View File

@ -0,0 +1,146 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*/
#include <thrust/sequence.h>
#include <vector>
#include "../../common/device_helpers.cuh"
#include "row_partitioner.cuh"
namespace xgboost {
namespace tree {
struct IndicateLeftTransform {
RowPartitioner::TreePositionT left_nidx;
explicit IndicateLeftTransform(RowPartitioner::TreePositionT left_nidx)
: left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()(
const RowPartitioner::TreePositionT& x) const {
return x == left_nidx ? 1 : 0;
}
};
void RowPartitioner::SortPosition(common::Span<TreePositionT> position,
common::Span<TreePositionT> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
TreePositionT left_nidx,
TreePositionT right_nidx,
int64_t* d_left_count, cudaStream_t stream) {
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
} else {
scatter_address = (idx - ex_scan_result) + *d_left_count;
}
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
IndicateLeftTransform conversion_op(left_nidx);
cub::TransformInputIterator<TreePositionT, IndicateLeftTransform,
TreePositionT*>
in_itr(d_position_in, conversion_op);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
in_itr, out_itr, position.size(), stream);
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx(device_idx) {
dh::safe_cuda(cudaSetDevice(device_idx));
ridx_a.resize(num_rows);
ridx_b.resize(num_rows);
position_a.resize(num_rows);
position_b.resize(num_rows);
ridx = dh::DoubleBuffer<RowIndexT>{&ridx_a, &ridx_b};
position = dh::DoubleBuffer<TreePositionT>{&position_a, &position_b};
ridx_segments.emplace_back(Segment(0, num_rows));
thrust::sequence(
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
thrust::fill(
thrust::device_pointer_cast(position.Current()),
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
left_counts.resize(256);
thrust::fill(left_counts.begin(), left_counts.end(), 0);
streams.resize(2);
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
}
RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx));
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
TreePositionT nidx) {
auto segment = ridx_segments.at(nidx);
// Return empty span here as a valid result
// Will error if we try to construct a span from a pointer with size 0
if (segment.Size() == 0) {
return common::Span<const RowPartitioner::RowIndexT>();
}
return ridx.CurrentSpan().subspan(segment.begin, segment.Size());
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx.CurrentSpan();
}
common::Span<const RowPartitioner::TreePositionT>
RowPartitioner::GetPosition() {
return position.CurrentSpan();
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
TreePositionT nidx) {
auto span = GetRows(nidx);
std::vector<RowIndexT> rows(span.size());
dh::CopyDeviceSpanToVector(&rows, span);
return rows;
}
std::vector<RowPartitioner::TreePositionT> RowPartitioner::GetPositionHost() {
auto span = GetPosition();
std::vector<TreePositionT> position(span.size());
dh::CopyDeviceSpanToVector(&position, span);
return position;
}
void RowPartitioner::SortPositionAndCopy(const Segment& segment,
TreePositionT left_nidx,
TreePositionT right_nidx,
int64_t* d_left_count,
cudaStream_t stream) {
SortPosition(
common::Span<TreePositionT>(position.Current() + segment.begin,
segment.Size()),
common::Span<TreePositionT>(position.other() + segment.begin,
segment.Size()),
common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()),
common::Span<RowIndexT>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position.Current() + segment.begin;
const auto d_position_other = position.other() + segment.begin;
const auto d_ridx_current = ridx.Current() + segment.begin;
const auto d_ridx_other = ridx.other() + segment.begin;
dh::LaunchN(device_idx, segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});
}
}; // namespace tree
}; // namespace xgboost

View File

@ -0,0 +1,186 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*/
#pragma once
#include "../../common/device_helpers.cuh"
namespace xgboost {
namespace tree {
/*! \brief Count how many rows are assigned to left node. */
__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) {
#if __CUDACC_VER_MAJOR__ > 8
int mask = __activemask();
unsigned ballot = __ballot_sync(mask, increment);
int leader = __ffs(mask) - 1;
if (threadIdx.x % 32 == leader) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#else
unsigned ballot = __ballot(increment);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#endif
}
/** \brief Class responsible for tracking subsets of rows as we add splits and
* partition training rows into different leaf nodes. */
class RowPartitioner {
public:
using TreePositionT = int;
using RowIndexT = bst_uint;
struct Segment;
private:
int device_idx;
/*! \brief Range of rows for each node. */
std::vector<Segment> ridx_segments;
dh::caching_device_vector<RowIndexT> ridx_a;
dh::caching_device_vector<RowIndexT> ridx_b;
dh::caching_device_vector<TreePositionT> position_a;
dh::caching_device_vector<TreePositionT> position_b;
dh::DoubleBuffer<RowIndexT> ridx;
dh::DoubleBuffer<TreePositionT> position;
dh::caching_device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams;
public:
RowPartitioner(int device_idx, size_t num_rows);
~RowPartitioner();
RowPartitioner(const RowPartitioner&) = delete;
RowPartitioner& operator=(const RowPartitioner&) = delete;
/**
* \brief Gets the row indices of training instances in a given node.
*/
common::Span<const RowIndexT> GetRows(TreePositionT nidx);
/**
* \brief Gets all training rows in the set.
*/
common::Span<const RowIndexT> GetRows();
/**
* \brief Gets the tree position of all training instances.
*/
common::Span<const TreePositionT> GetPosition();
/**
* \brief Convenience method for testing
*/
std::vector<RowIndexT> GetRowsHost(TreePositionT nidx);
/**
* \brief Convenience method for testing
*/
std::vector<TreePositionT> GetPositionHost();
/**
* \brief Updates the tree position for set of training instances being split
* into left and right child nodes. Accepts a user-defined lambda specifying
* which branch each training instance should go down.
*
* \tparam UpdatePositionOpT
* \param nidx The index of the node being split.
* \param left_nidx The left child index.
* \param right_nidx The right child index.
* \param op Device lambda. Should provide the row index as an
* argument and return the new position for this training instance.
*/
template <typename UpdatePositionOpT>
void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx,
TreePositionT right_nidx, UpdatePositionOpT op) {
dh::safe_cuda(cudaSetDevice(device_idx));
Segment segment = ridx_segments.at(nidx);
auto d_ridx = ridx.CurrentSpan();
auto d_position = position.CurrentSpan();
if (left_counts.size() <= nidx) {
left_counts.resize((nidx * 2) + 1);
thrust::fill(left_counts.begin(), left_counts.end(), 0);
}
int64_t* d_left_count = left_counts.data().get() + nidx;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(device_idx, segment.Size(), [=] __device__(size_t idx) {
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
// Missing value
TreePositionT new_position = op(ridx);
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;
});
// Overlap device to host memory copy (left_count) with sort
int64_t left_count;
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams[0]));
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
streams[1]);
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);
ridx_segments.resize(std::max(int(ridx_segments.size()),
std::max(left_nidx, right_nidx) + 1));
ridx_segments[left_nidx] =
Segment(segment.begin, segment.begin + left_count);
ridx_segments[right_nidx] =
Segment(segment.begin + left_count, segment.end);
}
/**
* \brief Finalise the position of all training instances after tree
* construction is complete. Does not update any other meta information in
* this data structure, so should only be used at the end of training.
*
* \param op Device lambda. Should provide the row index and current
* position as an argument and return the new position for this training
* instance.
*/
template <typename FinalisePositionOpT>
void FinalisePosition(FinalisePositionOpT op) {
auto d_position = position.Current();
const auto d_ridx = ridx.Current();
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
d_position[idx] = op(ridx, position);
});
}
/**
* \brief Optimised routine for sorting key value pairs into left and right
* segments. Based on a single pass of exclusive scan, uses iterators to
* redirect inputs and outputs.
*/
void SortPosition(common::Span<TreePositionT> position,
common::Span<TreePositionT> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out, TreePositionT left_nidx,
TreePositionT right_nidx, int64_t* d_left_count,
cudaStream_t stream = nullptr);
/*! \brief Sort row indices according to position. */
void SortPositionAndCopy(const Segment& segment, TreePositionT left_nidx,
TreePositionT right_nidx, int64_t* d_left_count,
cudaStream_t stream);
/** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */
struct Segment {
size_t begin;
size_t end;
Segment() : begin{0}, end{0} {}
Segment(size_t begin, size_t end) : begin(begin), end(end) {
CHECK_GE(end, begin);
}
size_t Size() const { return end - begin; }
};
};
}; // namespace tree
}; // namespace xgboost

View File

@ -6,7 +6,6 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <xgboost/tree_updater.h>
#include <algorithm>
#include <cmath>
@ -25,6 +24,7 @@
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
#include "gpu_hist/row_partitioner.cuh"
namespace xgboost {
namespace tree {
@ -515,10 +515,9 @@ __global__ void CompressBinEllpackKernel(
template <typename GradientSumT>
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
const bst_uint* d_ridx,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* d_node_hist,
const GradientPair* d_gpair,
size_t segment_begin, size_t n_elements,
const GradientPair* d_gpair, size_t n_elements,
bool use_shared_memory_histograms) {
extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
@ -527,7 +526,7 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
__syncthreads();
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
int ridx = d_ridx[idx / matrix.row_stride ];
int gidx =
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
if (gidx != matrix.null_gidx_value) {
@ -549,86 +548,6 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
}
}
struct Segment {
size_t begin;
size_t end;
Segment() : begin{0}, end{0} {}
Segment(size_t begin, size_t end) : begin(begin), end(end) {
CHECK_GE(end, begin);
}
size_t Size() const { return end - begin; }
};
/** \brief Returns a one if the left node index is encountered, otherwise return
* zero. */
struct IndicateLeftTransform {
int left_nidx;
explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()(const int& x) const {
return x == left_nidx ? 1 : 0;
}
};
/**
* \brief Optimised routine for sorting key value pairs into left and right
* segments. Based on a single pass of exclusive scan, uses iterators to
* redirect inputs and outputs.
*/
inline void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
common::Span<int> position_out, common::Span<bst_uint> ridx,
common::Span<bst_uint> ridx_out, int left_nidx,
int right_nidx, int64_t* d_left_count,
cudaStream_t stream = nullptr) {
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
} else {
scatter_address = (idx - ex_scan_result) + *d_left_count;
}
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
IndicateLeftTransform conversion_op(left_nidx);
cub::TransformInputIterator<int, IndicateLeftTransform, int*> in_itr(
d_position_in, conversion_op);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
temp_memory->LazyAllocate(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage,
temp_memory->temp_storage_bytes, in_itr,
out_itr, position.size(), stream);
}
/*! \brief Count how many rows are assigned to left node. */
__forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
int left_nidx) {
#if __CUDACC_VER_MAJOR__ > 8
int mask = __activemask();
unsigned ballot = __ballot_sync(mask, val == left_nidx);
int leader = __ffs(mask) - 1;
if (threadIdx.x % 32 == leader) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#else
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
#endif
}
// Instances of this type are created while creating the histogram bins for the
// entire dataset across multiple sparse page batches. This keeps track of the number
// of rows to process from a batch and the position from which to process on each device.
@ -671,8 +590,7 @@ struct DeviceShard {
ELLPackMatrix ellpack_matrix;
/*! \brief Range of rows for each node. */
std::vector<Segment> ridx_segments;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist;
/*! \brief row_ptr form HistCutMatrix. */
@ -684,9 +602,6 @@ struct DeviceShard {
/*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer;
/*! \brief Row indices relative to this shard, necessary for sorting rows. */
dh::DoubleBuffer<bst_uint> ridx;
dh::DoubleBuffer<int> position;
/*! \brief Gradient pair for each row. */
common::Span<GradientPair> gpair;
@ -696,8 +611,8 @@ struct DeviceShard {
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
dh::device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
/*! \brief On-device feature set, only actually used on one of the devices */
dh::device_vector<int> feature_set_d;
/*! The row offset for this shard. */
bst_uint row_begin_idx;
bst_uint row_end_idx;
@ -783,24 +698,10 @@ struct DeviceShard {
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
this->interaction_constraints.Reset();
thrust::fill(
thrust::device_pointer_cast(position.Current()),
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
GradientPair());
if (left_counts.size() < 256) {
left_counts.resize(256);
} else {
dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0,
sizeof(int64_t) * left_counts.size()));
}
thrust::sequence(
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
ridx_segments.front() = Segment(0, ridx.Size());
dh::safe_cuda(cudaMemcpyAsync(
gpair.data(), dh_gpair->ConstDevicePointer(device_id),
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
@ -892,12 +793,11 @@ struct DeviceShard {
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
auto segment = ridx_segments[nidx];
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = ridx.Current();
auto d_ridx = row_partitioner->GetRows(nidx);
auto d_gpair = gpair.data();
auto n_elements = segment.Size() * ellpack_matrix.row_stride;
auto n_elements = d_ridx.size() * ellpack_matrix.row_stride;
const size_t smem_size =
use_shared_memory_histograms
@ -911,8 +811,8 @@ struct DeviceShard {
return;
}
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin,
n_elements, use_shared_memory_histograms);
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
use_shared_memory_histograms);
}
void SubtractionTrick(int nidx_parent, int nidx_histogram,
@ -936,21 +836,13 @@ struct DeviceShard {
}
void UpdatePosition(int nidx, RegTree::Node split_node) {
CHECK(!split_node.IsLeaf()) <<"Node must not be leaf";
Segment segment = ridx_segments[nidx];
bst_uint* d_ridx = ridx.Current();
int* d_position = position.Current();
if (left_counts.size() <= nidx) {
left_counts.resize((nidx * 2) + 1);
}
int64_t* d_left_count = left_counts.data().get() + nidx;
auto d_matrix = this->ellpack_matrix;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(
device_id, segment.Size(), [=] __device__(bst_uint idx) {
idx += segment.begin;
bst_uint ridx = d_ridx[idx];
bst_float element = d_matrix.GetElement(ridx, split_node.SplitIndex());
auto d_matrix = ellpack_matrix;
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(bst_uint ridx) {
bst_float element =
d_matrix.GetElement(ridx, split_node.SplitIndex());
// Missing value
int new_position = 0;
if (isnan(element)) {
@ -962,49 +854,8 @@ struct DeviceShard {
new_position = split_node.RightChild();
}
}
CountLeft(d_left_count, new_position, split_node.LeftChild());
d_position[idx] = new_position;
return new_position;
});
// Overlap device to host memory copy (left_count) with sort
auto& streams = this->GetStreams(2);
auto tmp_pinned = pinned_memory.GetSpan<int64_t>(1);
dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams[0]));
SortPositionAndCopy(segment, split_node.LeftChild(), split_node.RightChild(), d_left_count,
streams[1]);
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
int64_t left_count = tmp_pinned[0];
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);
ridx_segments[split_node.LeftChild()] =
Segment(segment.begin, segment.begin + left_count);
ridx_segments[split_node.RightChild()] =
Segment(segment.begin + left_count, segment.end);
}
/*! \brief Sort row indices according to position. */
void SortPositionAndCopy(const Segment& segment, int left_nidx,
int right_nidx, int64_t* d_left_count,
cudaStream_t stream) {
SortPosition(
&temp_memory,
common::Span<int>(position.Current() + segment.begin, segment.Size()),
common::Span<int>(position.other() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position.Current() + segment.begin;
const auto d_position_other = position.other() + segment.begin;
const auto d_ridx_current = ridx.Current() + segment.begin;
const auto d_ridx_other = ridx.other() + segment.begin;
dh::LaunchN(device_id, segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});
}
// After tree update is finished, update the position of all training
@ -1016,30 +867,27 @@ struct DeviceShard {
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));
auto d_position = position.Current();
const auto d_ridx = ridx.Current();
auto d_matrix = this->ellpack_matrix;
dh::LaunchN(device_id, position.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
auto node = d_nodes[position];
bst_uint ridx = d_ridx[idx];
auto d_matrix = ellpack_matrix;
row_partitioner->FinalisePosition(
[=] __device__(bst_uint ridx, int position) {
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(ridx, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
if (element <= node.SplitCond()) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
}
node = d_nodes[position];
}
d_position[idx] = position;
});
return position;
});
}
void UpdatePredictionCache(bst_float* out_preds_d) {
@ -1057,8 +905,8 @@ struct DeviceShard {
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
sizeof(GradientPair) * node_sum_gradients.size(),
cudaMemcpyHostToDevice));
auto d_position = position.Current();
auto d_ridx = ridx.Current();
auto d_position = row_partitioner->GetPosition();
auto d_ridx = row_partitioner->GetRows();
auto d_node_sum_gradients = node_sum_gradients_d.data();
auto d_prediction_cache = prediction_cache.data();
@ -1096,13 +944,15 @@ struct DeviceShard {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
auto left_node_rows = ridx_segments[nidx_left].Size();
auto right_node_rows = ridx_segments[nidx_right].Size();
auto left_node_rows = row_partitioner->GetRows(nidx_left).size();
auto right_node_rows = row_partitioner->GetRows(nidx_right).size();
// Decide whether to build the left histogram or right histogram
// Find the largest number of training instances on any given Shard
// Assume this will be the bottleneck and avoid building this node if
// possible
std::vector<size_t> max_reduce = {left_node_rows, right_node_rows};
std::vector<size_t> max_reduce;
max_reduce.push_back(left_node_rows);
max_reduce.push_back(right_node_rows);
reducer->HostMaxAllReduce(&max_reduce);
bool fewer_right = max_reduce[1] < max_reduce[0];
if (fewer_right) {
@ -1199,6 +1049,7 @@ struct DeviceShard {
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
monitor.StartCuda("Reset");
this->Reset(gpair_all, p_fmat->Info().num_col_);
monitor.StopCuda("Reset");
@ -1206,7 +1057,6 @@ struct DeviceShard {
monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size();
auto num_leaves = 1;
@ -1269,8 +1119,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
ba.Allocate(device_id,
&gpair, n_rows,
&ridx, n_rows,
&position, n_rows,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
&feature_segments, hmat.row_ptr.size(),
@ -1284,7 +1132,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
node_sum_gradients.resize(max_nodes);
ridx_segments.resize(max_nodes);
// allocate compressed bin data
int num_symbols = n_bins + 1;
@ -1303,7 +1150,6 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
gidx_fvalue_map, row_stride,
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
is_dense, null_gidx_value);
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
// hiding)

View File

@ -97,7 +97,8 @@ TEST(bulkAllocator, Test) {
}
// Test thread safe max reduction
TEST(AllReducer, HostMaxAllReduce) {
#if defined(XGBOOST_USE_NCCL)
TEST(AllReducer, MGPU_HostMaxAllReduce) {
dh::AllReducer reducer;
size_t num_threads = 50;
std::vector<std::vector<size_t>> thread_data(num_threads);
@ -112,3 +113,4 @@ TEST(AllReducer, HostMaxAllReduce) {
ASSERT_EQ(data.front(), num_threads - 1);
}
}
#endif

View File

@ -0,0 +1,125 @@
#include <gtest/gtest.h>
#include <vector>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../helpers.h"
namespace xgboost {
namespace tree {
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
std::vector<int64_t> left_count = {
std::count(position_in.begin(), position_in.end(), left_idx)};
thrust::device_vector<int64_t> d_left_count = left_count;
thrust::device_vector<int> position = position_in;
thrust::device_vector<int> position_out(position.size());
thrust::device_vector<RowPartitioner::RowIndexT> ridx(position.size());
thrust::sequence(ridx.begin(), ridx.end());
thrust::device_vector<RowPartitioner::RowIndexT> ridx_out(ridx.size());
RowPartitioner rp(0,10);
rp.SortPosition(
common::Span<int>(position.data().get(), position.size()),
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<RowPartitioner::RowIndexT>(ridx.data().get(), ridx.size()),
common::Span<RowPartitioner::RowIndexT>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, d_left_count.data().get(), nullptr);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;
// Check position is sorted
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
}
}
TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
TestSortPosition({1, 1, 1, 1}, 1, 2);
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
void TestUpdatePosition() {
const int kNumRows = 10;
RowPartitioner rp(0, kNumRows);
auto rows = rp.GetRowsHost(0);
EXPECT_EQ(rows.size(), kNumRows);
for (auto i = 0ull; i < kNumRows; i++) {
EXPECT_EQ(rows[i], i);
}
// Send the first five training instances to the right node
// and the second 5 to the left node
rp.UpdatePosition(0, 1, 2,
[=] __device__(RowPartitioner::RowIndexT ridx) {
if (ridx > 4) {
return 1;
}
else {
return 2;
}
});
rows = rp.GetRowsHost(1);
for (auto r : rows) {
EXPECT_GT(r, 4);
}
rows = rp.GetRowsHost(2);
for (auto r : rows) {
EXPECT_LT(r, 5);
}
// Split the left node again
rp.UpdatePosition(1, 3, 4, [=]__device__(RowPartitioner::RowIndexT ridx)
{
if (ridx < 7) {
return 3
;
}
return 4;
});
EXPECT_EQ(rp.GetRows(3).size(), 2);
EXPECT_EQ(rp.GetRows(4).size(), 3);
// Check position is as expected
EXPECT_EQ(rp.GetPositionHost(), std::vector<RowPartitioner::TreePositionT>({3,3,4,4,4,2,2,2,2,2}));
}
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
void TestFinalise() {
const int kNumRows = 10;
RowPartitioner rp(0, kNumRows);
rp.FinalisePosition([=]__device__(RowPartitioner::RowIndexT ridx, int position)
{
return 7;
});
auto position = rp.GetPositionHost();
for(auto p:position)
{
EXPECT_EQ(p, 7);
}
}
TEST(RowPartitioner, Finalise) { TestFinalise(); }
void TestIncorrectRow() {
RowPartitioner rp(0, 1);
rp.UpdatePosition(0, 1, 2, [=]__device__ (RowPartitioner::RowIndexT ridx)
{
return 4; // This is not the left branch or the right branch
});
}
TEST(RowPartitioner, IncorrectRow) {
ASSERT_DEATH({ TestIncorrectRow(); },".*");
}
} // namespace tree
} // namespace xgboost

View File

@ -206,16 +206,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr,
sizeof(common::CompressedByteT) * shard.gidx_buffer.size(),
cudaMemcpyDeviceToHost));
auto gidx = common::CompressedIterator<uint32_t>(h_gidx_buffer.data(),
num_symbols);
shard.ridx_segments.resize(1);
shard.ridx_segments[0] = Segment(0, kNRows);
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
shard.hist.AllocateHistogram(0);
dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair);
thrust::sequence(
thrust::device_pointer_cast(shard.ridx.Current()),
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
shard.use_shared_memory_histograms = use_shared_memory_histograms;
shard.BuildHist(0);
@ -358,138 +352,6 @@ TEST(GpuHist, EvaluateSplits) {
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
}
TEST(GpuHist, ApplySplit) {
int constexpr kNId = 0;
int constexpr kNRows = 16;
int constexpr kNCols = 8;
TrainParam param;
std::vector<std::pair<std::string, std::string>> args = {};
param.InitAllowUnknown(args);
// Initialize shard
for (size_t i = 0; i < kNCols; ++i) {
param.monotone_constraints.emplace_back(0);
}
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols,
kNCols)};
shard->ridx_segments.resize(3); // 3 nodes.
shard->node_sum_gradients.resize(3);
shard->ridx_segments[0] = Segment(0, kNRows);
shard->ba.Allocate(0, &(shard->ridx), kNRows,
&(shard->position), kNRows);
shard->ellpack_matrix.row_stride = kNCols;
thrust::sequence(
thrust::device_pointer_cast(shard->ridx.Current()),
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
RegTree tree;
DeviceSplitCandidate candidate;
candidate.Update(2, kLeftDir,
0.59, 4, // fvalue has to be equal to one of the cut field
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
GPUTrainingParam(param));
ExpandEntry candidate_entry {0, 0, candidate, 0};
candidate_entry.nid = kNId;
// Used to get bin_id in update position.
common::HistCutMatrix cmat = GetHostCutMatrix();
MetaInfo info;
info.num_row_ = kNRows;
info.num_col_ = kNCols;
info.num_nonzero_ = kNRows * kNCols; // Dense
// Initialize gidx
int n_bins = 24;
int row_stride = kNCols;
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows,
num_symbols);
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes,
&(shard->feature_segments), cmat.row_ptr.size(),
&(shard->min_fvalue), cmat.min_val.size(),
&(shard->gidx_fvalue_map), 24);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
shard->ellpack_matrix.feature_segments = shard->feature_segments;
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
shard->ellpack_matrix.is_dense = true;
common::CompressedBufferWriter wr(num_symbols);
// gidx 14 should go right, 12 goes left
std::vector<int> h_gidx (kNRows * row_stride, 14);
h_gidx[4] = 12;
h_gidx[12] = 12;
std::vector<common::CompressedByteT> h_gidx_compressed (compressed_size_bytes);
wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end());
dh::CopyVectorToDeviceSpan(shard->gidx_buffer, h_gidx_compressed);
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
shard->gidx_buffer.data(), num_symbols);
shard->ApplySplit(candidate_entry, &tree);
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);
ASSERT_FALSE(tree[kNId].IsLeaf());
int left_nidx = tree[kNId].LeftChild();
int right_nidx = tree[kNId].RightChild();
ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0);
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2);
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2);
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
}
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
std::vector<int64_t> left_count = {
std::count(position_in.begin(), position_in.end(), left_idx)};
thrust::device_vector<int64_t> d_left_count = left_count;
thrust::device_vector<int> position = position_in;
thrust::device_vector<int> position_out(position.size());
thrust::device_vector<bst_uint> ridx(position.size());
thrust::sequence(ridx.begin(), ridx.end());
thrust::device_vector<bst_uint> ridx_out(ridx.size());
dh::CubMemory tmp;
SortPosition(
&tmp, common::Span<int>(position.data().get(), position.size()),
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, d_left_count.data().get(), nullptr);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;
// Check position is sorted
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
}
}
TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
TestSortPosition({1, 1, 1, 1}, 1, 2);
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
void TestHistogramIndexImpl(int n_gpus) {
// Test if the compressed histogram index matches when using a sparse
// dmatrix with and without using external memory