diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 527724537..c1ff749d0 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -9,6 +9,7 @@ #include #include #include +#include #include "common.h" #include "span.h" @@ -299,9 +300,14 @@ namespace detail{ * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. */ template -struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator { +struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator { using super_t = thrust::device_malloc_allocator; using pointer = thrust::device_ptr; + template + struct rebind + { + typedef XGBDefaultDeviceAllocatorImpl other; + }; pointer allocate(size_t n) { pointer ptr = super_t::allocate(n); GlobalMemoryLogger().RegisterAllocation(ptr.get(), n); @@ -312,16 +318,56 @@ struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator { return super_t::deallocate(ptr, n); } }; + +/** + * \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction. + */ +template +struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator { + using pointer = thrust::device_ptr; + template + struct rebind + { + typedef XGBCachingDeviceAllocatorImpl other; + }; + cub::CachingDeviceAllocator& GetGlobalCachingAllocator () + { + // Configure allocator with maximum cached bin size of ~1GB and no limit on + // maximum cached bytes + static cub::CachingDeviceAllocator allocator(8,3,10); + return allocator; + } + pointer allocate(size_t n) { + T *ptr; + GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast(&ptr), + n * sizeof(T)); + pointer thrust_ptr = thrust::device_ptr(ptr); + GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n); + return thrust_ptr; + } + void deallocate(pointer ptr, size_t n) { + GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n); + GetGlobalCachingAllocator().DeviceFree(ptr.get()); + } + __host__ __device__ + void construct(T *) + { + // no-op + } +}; }; -// Declare xgboost allocator +// Declare xgboost allocators // Replacement of allocator with custom backend should occur here template -using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocator; +using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl; +template +using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl; /** \brief Specialisation of thrust device vector using custom allocator. */ template using device_vector = thrust::device_vector>; - +template +using caching_device_vector = thrust::device_vector>; /** * \brief A double buffer, useful for algorithms like sort. */ @@ -331,6 +377,14 @@ class DoubleBuffer { cub::DoubleBuffer buff; xgboost::common::Span a, b; DoubleBuffer() = default; + template + DoubleBuffer(VectorT *v1, VectorT *v2) { + a = xgboost::common::Span(v1->data().get(), v1->size()); + b = xgboost::common::Span(v2->data().get(), v2->size()); + buff.d_buffers[0] = v1->data().get(); + buff.d_buffers[1] = v2->data().get(); + buff.selector = 0; + } size_t Size() const { CHECK_EQ(a.size(), b.size()); @@ -362,6 +416,20 @@ void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span src) { cudaMemcpyDeviceToHost)); } +/** + * \brief Copies const device span to std::vector. + * + * \tparam T Generic type parameter. + * \param [in,out] dst Copy destination. + * \param src Copy source. Must be device memory. + */ +template +void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span src) { + CHECK_EQ(dst->size(), src.size()); + dh::safe_cuda(cudaMemcpyAsync(dst->data(), src.data(), dst->size() * sizeof(T), + cudaMemcpyDeviceToHost)); +} + /** * \brief Copies std::vector to device span. * @@ -1132,6 +1200,7 @@ class AllReducer { * safe) using the master thread. Uses naive reduce algorithm for local * threads, don't expect this to scale.*/ void HostMaxAllReduce(std::vector *p_data) { +#ifdef XGBOOST_USE_NCCL auto &data = *p_data; // Wait in case some other thread is accessing host_data #pragma omp barrier @@ -1162,6 +1231,7 @@ class AllReducer { for (auto i = 0ull; i < data.size(); i++) { data[i] = host_data[i]; } +#endif } }; diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu new file mode 100644 index 000000000..b509822b8 --- /dev/null +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -0,0 +1,146 @@ + +/*! + * Copyright 2017-2019 XGBoost contributors + */ +#include +#include +#include "../../common/device_helpers.cuh" +#include "row_partitioner.cuh" + +namespace xgboost { +namespace tree { + +struct IndicateLeftTransform { + RowPartitioner::TreePositionT left_nidx; + explicit IndicateLeftTransform(RowPartitioner::TreePositionT left_nidx) + : left_nidx(left_nidx) {} + __host__ __device__ __forceinline__ int operator()( + const RowPartitioner::TreePositionT& x) const { + return x == left_nidx ? 1 : 0; + } +}; + +void RowPartitioner::SortPosition(common::Span position, + common::Span position_out, + common::Span ridx, + common::Span ridx_out, + TreePositionT left_nidx, + TreePositionT right_nidx, + int64_t* d_left_count, cudaStream_t stream) { + auto d_position_out = position_out.data(); + auto d_position_in = position.data(); + auto d_ridx_out = ridx_out.data(); + auto d_ridx_in = ridx.data(); + auto write_results = [=] __device__(size_t idx, int ex_scan_result) { + int scatter_address; + if (d_position_in[idx] == left_nidx) { + scatter_address = ex_scan_result; + } else { + scatter_address = (idx - ex_scan_result) + *d_left_count; + } + d_position_out[scatter_address] = d_position_in[idx]; + d_ridx_out[scatter_address] = d_ridx_in[idx]; + }; // NOLINT + + IndicateLeftTransform conversion_op(left_nidx); + cub::TransformInputIterator + in_itr(d_position_in, conversion_op); + dh::DiscardLambdaItr out_itr(write_results); + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr, + position.size(), stream); + dh::caching_device_vector temp_storage(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes, + in_itr, out_itr, position.size(), stream); +} +RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) + : device_idx(device_idx) { + dh::safe_cuda(cudaSetDevice(device_idx)); + ridx_a.resize(num_rows); + ridx_b.resize(num_rows); + position_a.resize(num_rows); + position_b.resize(num_rows); + ridx = dh::DoubleBuffer{&ridx_a, &ridx_b}; + position = dh::DoubleBuffer{&position_a, &position_b}; + ridx_segments.emplace_back(Segment(0, num_rows)); + + thrust::sequence( + thrust::device_pointer_cast(ridx.CurrentSpan().data()), + thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size())); + thrust::fill( + thrust::device_pointer_cast(position.Current()), + thrust::device_pointer_cast(position.Current() + position.Size()), 0); + left_counts.resize(256); + thrust::fill(left_counts.begin(), left_counts.end(), 0); + streams.resize(2); + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamCreate(&stream)); + } +} +RowPartitioner::~RowPartitioner() { + dh::safe_cuda(cudaSetDevice(device_idx)); + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } +} + +common::Span RowPartitioner::GetRows( + TreePositionT nidx) { + auto segment = ridx_segments.at(nidx); + // Return empty span here as a valid result + // Will error if we try to construct a span from a pointer with size 0 + if (segment.Size() == 0) { + return common::Span(); + } + return ridx.CurrentSpan().subspan(segment.begin, segment.Size()); +} + +common::Span RowPartitioner::GetRows() { + return ridx.CurrentSpan(); +} + +common::Span +RowPartitioner::GetPosition() { + return position.CurrentSpan(); +} +std::vector RowPartitioner::GetRowsHost( + TreePositionT nidx) { + auto span = GetRows(nidx); + std::vector rows(span.size()); + dh::CopyDeviceSpanToVector(&rows, span); + return rows; +} + +std::vector RowPartitioner::GetPositionHost() { + auto span = GetPosition(); + std::vector position(span.size()); + dh::CopyDeviceSpanToVector(&position, span); + return position; +} + +void RowPartitioner::SortPositionAndCopy(const Segment& segment, + TreePositionT left_nidx, + TreePositionT right_nidx, + int64_t* d_left_count, + cudaStream_t stream) { + SortPosition( + common::Span(position.Current() + segment.begin, + segment.Size()), + common::Span(position.other() + segment.begin, + segment.Size()), + common::Span(ridx.Current() + segment.begin, segment.Size()), + common::Span(ridx.other() + segment.begin, segment.Size()), + left_nidx, right_nidx, d_left_count, stream); + // Copy back key/value + const auto d_position_current = position.Current() + segment.begin; + const auto d_position_other = position.other() + segment.begin; + const auto d_ridx_current = ridx.Current() + segment.begin; + const auto d_ridx_other = ridx.other() + segment.begin; + dh::LaunchN(device_idx, segment.Size(), stream, [=] __device__(size_t idx) { + d_position_current[idx] = d_position_other[idx]; + d_ridx_current[idx] = d_ridx_other[idx]; + }); +} +}; // namespace tree +}; // namespace xgboost diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh new file mode 100644 index 000000000..5ff45bd2b --- /dev/null +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -0,0 +1,186 @@ +/*! + * Copyright 2017-2019 XGBoost contributors + */ +#pragma once +#include "../../common/device_helpers.cuh" + +namespace xgboost { +namespace tree { + +/*! \brief Count how many rows are assigned to left node. */ +__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) { +#if __CUDACC_VER_MAJOR__ > 8 + int mask = __activemask(); + unsigned ballot = __ballot_sync(mask, increment); + int leader = __ffs(mask) - 1; + if (threadIdx.x % 32 == leader) { + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT + } +#else + unsigned ballot = __ballot(increment); + if (threadIdx.x % 32 == 0) { + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT + } +#endif +} + +/** \brief Class responsible for tracking subsets of rows as we add splits and + * partition training rows into different leaf nodes. */ +class RowPartitioner { + public: + using TreePositionT = int; + using RowIndexT = bst_uint; + struct Segment; + + private: + int device_idx; + /*! \brief Range of rows for each node. */ + std::vector ridx_segments; + dh::caching_device_vector ridx_a; + dh::caching_device_vector ridx_b; + dh::caching_device_vector position_a; + dh::caching_device_vector position_b; + dh::DoubleBuffer ridx; + dh::DoubleBuffer position; + dh::caching_device_vector + left_counts; // Useful to keep a bunch of zeroed memory for sort position + std::vector streams; + + public: + RowPartitioner(int device_idx, size_t num_rows); + ~RowPartitioner(); + RowPartitioner(const RowPartitioner&) = delete; + RowPartitioner& operator=(const RowPartitioner&) = delete; + + /** + * \brief Gets the row indices of training instances in a given node. + */ + common::Span GetRows(TreePositionT nidx); + + /** + * \brief Gets all training rows in the set. + */ + common::Span GetRows(); + + /** + * \brief Gets the tree position of all training instances. + */ + common::Span GetPosition(); + + /** + * \brief Convenience method for testing + */ + std::vector GetRowsHost(TreePositionT nidx); + + /** + * \brief Convenience method for testing + */ + std::vector GetPositionHost(); + + /** + * \brief Updates the tree position for set of training instances being split + * into left and right child nodes. Accepts a user-defined lambda specifying + * which branch each training instance should go down. + * + * \tparam UpdatePositionOpT + * \param nidx The index of the node being split. + * \param left_nidx The left child index. + * \param right_nidx The right child index. + * \param op Device lambda. Should provide the row index as an + * argument and return the new position for this training instance. + */ + template + void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx, + TreePositionT right_nidx, UpdatePositionOpT op) { + dh::safe_cuda(cudaSetDevice(device_idx)); + Segment segment = ridx_segments.at(nidx); + auto d_ridx = ridx.CurrentSpan(); + auto d_position = position.CurrentSpan(); + if (left_counts.size() <= nidx) { + left_counts.resize((nidx * 2) + 1); + thrust::fill(left_counts.begin(), left_counts.end(), 0); + } + int64_t* d_left_count = left_counts.data().get() + nidx; + // Launch 1 thread for each row + dh::LaunchN<1, 128>(device_idx, segment.Size(), [=] __device__(size_t idx) { + idx += segment.begin; + RowIndexT ridx = d_ridx[idx]; + // Missing value + TreePositionT new_position = op(ridx); + KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); + AtomicIncrement(d_left_count, new_position == left_nidx); + d_position[idx] = new_position; + }); + // Overlap device to host memory copy (left_count) with sort + int64_t left_count; + dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t), + cudaMemcpyDeviceToHost, streams[0])); + + SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, + streams[1]); + + dh::safe_cuda(cudaStreamSynchronize(streams[0])); + CHECK_LE(left_count, segment.Size()); + CHECK_GE(left_count, 0); + ridx_segments.resize(std::max(int(ridx_segments.size()), + std::max(left_nidx, right_nidx) + 1)); + ridx_segments[left_nidx] = + Segment(segment.begin, segment.begin + left_count); + ridx_segments[right_nidx] = + Segment(segment.begin + left_count, segment.end); + } + + /** + * \brief Finalise the position of all training instances after tree + * construction is complete. Does not update any other meta information in + * this data structure, so should only be used at the end of training. + * + * \param op Device lambda. Should provide the row index and current + * position as an argument and return the new position for this training + * instance. + */ + template + void FinalisePosition(FinalisePositionOpT op) { + auto d_position = position.Current(); + const auto d_ridx = ridx.Current(); + dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) { + auto position = d_position[idx]; + RowIndexT ridx = d_ridx[idx]; + d_position[idx] = op(ridx, position); + }); + } + + /** + * \brief Optimised routine for sorting key value pairs into left and right + * segments. Based on a single pass of exclusive scan, uses iterators to + * redirect inputs and outputs. + */ + void SortPosition(common::Span position, + common::Span position_out, + common::Span ridx, + common::Span ridx_out, TreePositionT left_nidx, + TreePositionT right_nidx, int64_t* d_left_count, + cudaStream_t stream = nullptr); + + /*! \brief Sort row indices according to position. */ + void SortPositionAndCopy(const Segment& segment, TreePositionT left_nidx, + TreePositionT right_nidx, int64_t* d_left_count, + cudaStream_t stream); + /** \brief Used to demarcate a contiguous set of row indices associated with + * some tree node. */ + struct Segment { + size_t begin; + size_t end; + + Segment() : begin{0}, end{0} {} + + Segment(size_t begin, size_t end) : begin(begin), end(end) { + CHECK_GE(end, begin); + } + size_t Size() const { return end - begin; } + }; +}; +}; // namespace tree +}; // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9750c6f8c..714e6258c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -25,6 +24,7 @@ #include "param.h" #include "updater_gpu_common.cuh" #include "constraints.cuh" +#include "gpu_hist/row_partitioner.cuh" namespace xgboost { namespace tree { @@ -515,10 +515,9 @@ __global__ void CompressBinEllpackKernel( template __global__ void SharedMemHistKernel(ELLPackMatrix matrix, - const bst_uint* d_ridx, + common::Span d_ridx, GradientSumT* d_node_hist, - const GradientPair* d_gpair, - size_t segment_begin, size_t n_elements, + const GradientPair* d_gpair, size_t n_elements, bool use_shared_memory_histograms) { extern __shared__ char smem[]; GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT @@ -527,7 +526,7 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix, __syncthreads(); } for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / matrix.row_stride + segment_begin]; + int ridx = d_ridx[idx / matrix.row_stride ]; int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; if (gidx != matrix.null_gidx_value) { @@ -549,86 +548,6 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix, } } -struct Segment { - size_t begin; - size_t end; - - Segment() : begin{0}, end{0} {} - - Segment(size_t begin, size_t end) : begin(begin), end(end) { - CHECK_GE(end, begin); - } - size_t Size() const { return end - begin; } -}; - -/** \brief Returns a one if the left node index is encountered, otherwise return - * zero. */ -struct IndicateLeftTransform { - int left_nidx; - explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {} - __host__ __device__ __forceinline__ int operator()(const int& x) const { - return x == left_nidx ? 1 : 0; - } -}; - -/** - * \brief Optimised routine for sorting key value pairs into left and right - * segments. Based on a single pass of exclusive scan, uses iterators to - * redirect inputs and outputs. - */ -inline void SortPosition(dh::CubMemory* temp_memory, common::Span position, - common::Span position_out, common::Span ridx, - common::Span ridx_out, int left_nidx, - int right_nidx, int64_t* d_left_count, - cudaStream_t stream = nullptr) { - auto d_position_out = position_out.data(); - auto d_position_in = position.data(); - auto d_ridx_out = ridx_out.data(); - auto d_ridx_in = ridx.data(); - auto write_results = [=] __device__(size_t idx, int ex_scan_result) { - int scatter_address; - if (d_position_in[idx] == left_nidx) { - scatter_address = ex_scan_result; - } else { - scatter_address = (idx - ex_scan_result) + *d_left_count; - } - d_position_out[scatter_address] = d_position_in[idx]; - d_ridx_out[scatter_address] = d_ridx_in[idx]; - }; // NOLINT - - IndicateLeftTransform conversion_op(left_nidx); - cub::TransformInputIterator in_itr( - d_position_in, conversion_op); - dh::DiscardLambdaItr out_itr(write_results); - size_t temp_storage_bytes = 0; - cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr, - position.size(), stream); - temp_memory->LazyAllocate(temp_storage_bytes); - cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage, - temp_memory->temp_storage_bytes, in_itr, - out_itr, position.size(), stream); -} - -/*! \brief Count how many rows are assigned to left node. */ -__forceinline__ __device__ void CountLeft(int64_t* d_count, int val, - int left_nidx) { -#if __CUDACC_VER_MAJOR__ > 8 - int mask = __activemask(); - unsigned ballot = __ballot_sync(mask, val == left_nidx); - int leader = __ffs(mask) - 1; - if (threadIdx.x % 32 == leader) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT - } -#else - unsigned ballot = __ballot(val == left_nidx); - if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT - } -#endif -} - // Instances of this type are created while creating the histogram bins for the // entire dataset across multiple sparse page batches. This keeps track of the number // of rows to process from a batch and the position from which to process on each device. @@ -671,8 +590,7 @@ struct DeviceShard { ELLPackMatrix ellpack_matrix; - /*! \brief Range of rows for each node. */ - std::vector ridx_segments; + std::unique_ptr row_partitioner; DeviceHistogram hist; /*! \brief row_ptr form HistCutMatrix. */ @@ -684,9 +602,6 @@ struct DeviceShard { /*! \brief global index of histogram, which is stored in ELLPack format. */ common::Span gidx_buffer; - /*! \brief Row indices relative to this shard, necessary for sorting rows. */ - dh::DoubleBuffer ridx; - dh::DoubleBuffer position; /*! \brief Gradient pair for each row. */ common::Span gpair; @@ -696,8 +611,8 @@ struct DeviceShard { /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; common::Span node_sum_gradients_d; - dh::device_vector - left_counts; // Useful to keep a bunch of zeroed memory for sort position + /*! \brief On-device feature set, only actually used on one of the devices */ + dh::device_vector feature_set_d; /*! The row offset for this shard. */ bst_uint row_begin_idx; bst_uint row_end_idx; @@ -783,24 +698,10 @@ struct DeviceShard { param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(device_id)); this->interaction_constraints.Reset(); - - thrust::fill( - thrust::device_pointer_cast(position.Current()), - thrust::device_pointer_cast(position.Current() + position.Size()), 0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); - if (left_counts.size() < 256) { - left_counts.resize(256); - } else { - dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0, - sizeof(int64_t) * left_counts.size())); - } - thrust::sequence( - thrust::device_pointer_cast(ridx.CurrentSpan().data()), - thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size())); + row_partitioner.reset(new RowPartitioner(device_id, n_rows)); - std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); - ridx_segments.front() = Segment(0, ridx.Size()); dh::safe_cuda(cudaMemcpyAsync( gpair.data(), dh_gpair->ConstDevicePointer(device_id), gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); @@ -892,12 +793,11 @@ struct DeviceShard { void BuildHist(int nidx) { hist.AllocateHistogram(nidx); - auto segment = ridx_segments[nidx]; auto d_node_hist = hist.GetNodeHistogram(nidx); - auto d_ridx = ridx.Current(); + auto d_ridx = row_partitioner->GetRows(nidx); auto d_gpair = gpair.data(); - auto n_elements = segment.Size() * ellpack_matrix.row_stride; + auto n_elements = d_ridx.size() * ellpack_matrix.row_stride; const size_t smem_size = use_shared_memory_histograms @@ -911,8 +811,8 @@ struct DeviceShard { return; } SharedMemHistKernel<<>>( - ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin, - n_elements, use_shared_memory_histograms); + ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, + use_shared_memory_histograms); } void SubtractionTrick(int nidx_parent, int nidx_histogram, @@ -936,21 +836,13 @@ struct DeviceShard { } void UpdatePosition(int nidx, RegTree::Node split_node) { - CHECK(!split_node.IsLeaf()) <<"Node must not be leaf"; - Segment segment = ridx_segments[nidx]; - bst_uint* d_ridx = ridx.Current(); - int* d_position = position.Current(); - if (left_counts.size() <= nidx) { - left_counts.resize((nidx * 2) + 1); - } - int64_t* d_left_count = left_counts.data().get() + nidx; - auto d_matrix = this->ellpack_matrix; - // Launch 1 thread for each row - dh::LaunchN<1, 128>( - device_id, segment.Size(), [=] __device__(bst_uint idx) { - idx += segment.begin; - bst_uint ridx = d_ridx[idx]; - bst_float element = d_matrix.GetElement(ridx, split_node.SplitIndex()); + auto d_matrix = ellpack_matrix; + + row_partitioner->UpdatePosition( + nidx, split_node.LeftChild(), split_node.RightChild(), + [=] __device__(bst_uint ridx) { + bst_float element = + d_matrix.GetElement(ridx, split_node.SplitIndex()); // Missing value int new_position = 0; if (isnan(element)) { @@ -962,49 +854,8 @@ struct DeviceShard { new_position = split_node.RightChild(); } } - CountLeft(d_left_count, new_position, split_node.LeftChild()); - d_position[idx] = new_position; + return new_position; }); - - // Overlap device to host memory copy (left_count) with sort - auto& streams = this->GetStreams(2); - auto tmp_pinned = pinned_memory.GetSpan(1); - dh::safe_cuda(cudaMemcpyAsync(tmp_pinned.data(), d_left_count, sizeof(int64_t), - cudaMemcpyDeviceToHost, streams[0])); - - SortPositionAndCopy(segment, split_node.LeftChild(), split_node.RightChild(), d_left_count, - streams[1]); - - dh::safe_cuda(cudaStreamSynchronize(streams[0])); - int64_t left_count = tmp_pinned[0]; - CHECK_LE(left_count, segment.Size()); - CHECK_GE(left_count, 0); - ridx_segments[split_node.LeftChild()] = - Segment(segment.begin, segment.begin + left_count); - ridx_segments[split_node.RightChild()] = - Segment(segment.begin + left_count, segment.end); - } - - /*! \brief Sort row indices according to position. */ - void SortPositionAndCopy(const Segment& segment, int left_nidx, - int right_nidx, int64_t* d_left_count, - cudaStream_t stream) { - SortPosition( - &temp_memory, - common::Span(position.Current() + segment.begin, segment.Size()), - common::Span(position.other() + segment.begin, segment.Size()), - common::Span(ridx.Current() + segment.begin, segment.Size()), - common::Span(ridx.other() + segment.begin, segment.Size()), - left_nidx, right_nidx, d_left_count, stream); - // Copy back key/value - const auto d_position_current = position.Current() + segment.begin; - const auto d_position_other = position.other() + segment.begin; - const auto d_ridx_current = ridx.Current() + segment.begin; - const auto d_ridx_other = ridx.other() + segment.begin; - dh::LaunchN(device_id, segment.Size(), stream, [=] __device__(size_t idx) { - d_position_current[idx] = d_position_other[idx]; - d_ridx_current[idx] = d_ridx_other[idx]; - }); } // After tree update is finished, update the position of all training @@ -1016,30 +867,27 @@ struct DeviceShard { dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - auto d_position = position.Current(); - const auto d_ridx = ridx.Current(); - auto d_matrix = this->ellpack_matrix; - dh::LaunchN(device_id, position.Size(), [=] __device__(size_t idx) { - auto position = d_position[idx]; - auto node = d_nodes[position]; - bst_uint ridx = d_ridx[idx]; + auto d_matrix = ellpack_matrix; + row_partitioner->FinalisePosition( + [=] __device__(bst_uint ridx, int position) { + auto node = d_nodes[position]; - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetElement(ridx, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetElement(ridx, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + if (element <= node.SplitCond()) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; } - } - node = d_nodes[position]; - } - d_position[idx] = position; - }); + return position; + }); } void UpdatePredictionCache(bst_float* out_preds_d) { @@ -1057,8 +905,8 @@ struct DeviceShard { cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(), sizeof(GradientPair) * node_sum_gradients.size(), cudaMemcpyHostToDevice)); - auto d_position = position.Current(); - auto d_ridx = ridx.Current(); + auto d_position = row_partitioner->GetPosition(); + auto d_ridx = row_partitioner->GetRows(); auto d_node_sum_gradients = node_sum_gradients_d.data(); auto d_prediction_cache = prediction_cache.data(); @@ -1096,13 +944,15 @@ struct DeviceShard { auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; - auto left_node_rows = ridx_segments[nidx_left].Size(); - auto right_node_rows = ridx_segments[nidx_right].Size(); + auto left_node_rows = row_partitioner->GetRows(nidx_left).size(); + auto right_node_rows = row_partitioner->GetRows(nidx_right).size(); // Decide whether to build the left histogram or right histogram // Find the largest number of training instances on any given Shard // Assume this will be the bottleneck and avoid building this node if // possible - std::vector max_reduce = {left_node_rows, right_node_rows}; + std::vector max_reduce; + max_reduce.push_back(left_node_rows); + max_reduce.push_back(right_node_rows); reducer->HostMaxAllReduce(&max_reduce); bool fewer_right = max_reduce[1] < max_reduce[0]; if (fewer_right) { @@ -1199,6 +1049,7 @@ struct DeviceShard { void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, RegTree* p_tree, dh::AllReducer* reducer) { auto& tree = *p_tree; + monitor.StartCuda("Reset"); this->Reset(gpair_all, p_fmat->Info().num_col_); monitor.StopCuda("Reset"); @@ -1206,7 +1057,6 @@ struct DeviceShard { monitor.StartCuda("InitRoot"); this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); - auto timestamp = qexpand->size(); auto num_leaves = 1; @@ -1269,8 +1119,6 @@ inline void DeviceShard::InitCompressedData( ba.Allocate(device_id, &gpair, n_rows, - &ridx, n_rows, - &position, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, &feature_segments, hmat.row_ptr.size(), @@ -1284,7 +1132,6 @@ inline void DeviceShard::InitCompressedData( dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints); node_sum_gradients.resize(max_nodes); - ridx_segments.resize(max_nodes); // allocate compressed bin data int num_symbols = n_bins + 1; @@ -1303,7 +1150,6 @@ inline void DeviceShard::InitCompressedData( gidx_fvalue_map, row_stride, common::CompressedIterator(gidx_buffer.data(), num_symbols), is_dense, null_gidx_value); - // check if we can use shared memory for building histograms // (assuming atleast we need 2 CTAs per SM to maintain decent latency // hiding) diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index c9cb1e61a..6f425a3c2 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -97,7 +97,8 @@ TEST(bulkAllocator, Test) { } // Test thread safe max reduction -TEST(AllReducer, HostMaxAllReduce) { +#if defined(XGBOOST_USE_NCCL) +TEST(AllReducer, MGPU_HostMaxAllReduce) { dh::AllReducer reducer; size_t num_threads = 50; std::vector> thread_data(num_threads); @@ -112,3 +113,4 @@ TEST(AllReducer, HostMaxAllReduce) { ASSERT_EQ(data.front(), num_threads - 1); } } +#endif diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu new file mode 100644 index 000000000..1106d8486 --- /dev/null +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -0,0 +1,125 @@ +#include +#include + +#include +#include +#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" +#include "../../helpers.h" + +namespace xgboost { +namespace tree { + +void TestSortPosition(const std::vector& position_in, int left_idx, + int right_idx) { + std::vector left_count = { + std::count(position_in.begin(), position_in.end(), left_idx)}; + thrust::device_vector d_left_count = left_count; + thrust::device_vector position = position_in; + thrust::device_vector position_out(position.size()); + + thrust::device_vector ridx(position.size()); + thrust::sequence(ridx.begin(), ridx.end()); + thrust::device_vector ridx_out(ridx.size()); + RowPartitioner rp(0,10); + rp.SortPosition( + common::Span(position.data().get(), position.size()), + common::Span(position_out.data().get(), position_out.size()), + common::Span(ridx.data().get(), ridx.size()), + common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, + right_idx, d_left_count.data().get(), nullptr); + thrust::host_vector position_result = position_out; + thrust::host_vector ridx_result = ridx_out; + + // Check position is sorted + EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end())); + // Check row indices are sorted inside left and right segment + EXPECT_TRUE( + std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0])); + EXPECT_TRUE( + std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end())); + + // Check key value pairs are the same + for (auto i = 0ull; i < ridx_result.size(); i++) { + EXPECT_EQ(position_result[i], position_in[ridx_result[i]]); + } +} +TEST(GpuHist, SortPosition) { + TestSortPosition({1, 2, 1, 2, 1}, 1, 2); + TestSortPosition({1, 1, 1, 1}, 1, 2); + TestSortPosition({2, 2, 2, 2}, 1, 2); + TestSortPosition({1, 2, 1, 2, 3}, 1, 2); +} + +void TestUpdatePosition() { + const int kNumRows = 10; + RowPartitioner rp(0, kNumRows); + auto rows = rp.GetRowsHost(0); + EXPECT_EQ(rows.size(), kNumRows); + for (auto i = 0ull; i < kNumRows; i++) { + EXPECT_EQ(rows[i], i); + } + // Send the first five training instances to the right node + // and the second 5 to the left node + rp.UpdatePosition(0, 1, 2, + [=] __device__(RowPartitioner::RowIndexT ridx) { + if (ridx > 4) { + return 1; + } + else { + return 2; + } + }); + rows = rp.GetRowsHost(1); + for (auto r : rows) { + EXPECT_GT(r, 4); + } + rows = rp.GetRowsHost(2); + for (auto r : rows) { + EXPECT_LT(r, 5); + } + + // Split the left node again + rp.UpdatePosition(1, 3, 4, [=]__device__(RowPartitioner::RowIndexT ridx) + { + if (ridx < 7) { + return 3 + ; + } + return 4; + }); + EXPECT_EQ(rp.GetRows(3).size(), 2); + EXPECT_EQ(rp.GetRows(4).size(), 3); + // Check position is as expected + EXPECT_EQ(rp.GetPositionHost(), std::vector({3,3,4,4,4,2,2,2,2,2})); +} + +TEST(RowPartitioner, Basic) { TestUpdatePosition(); } + +void TestFinalise() { + const int kNumRows = 10; + RowPartitioner rp(0, kNumRows); + rp.FinalisePosition([=]__device__(RowPartitioner::RowIndexT ridx, int position) + { + return 7; + }); + auto position = rp.GetPositionHost(); + for(auto p:position) + { + EXPECT_EQ(p, 7); + } +} +TEST(RowPartitioner, Finalise) { TestFinalise(); } + +void TestIncorrectRow() { + RowPartitioner rp(0, 1); + rp.UpdatePosition(0, 1, 2, [=]__device__ (RowPartitioner::RowIndexT ridx) + { + return 4; // This is not the left branch or the right branch + }); +} + +TEST(RowPartitioner, IncorrectRow) { + ASSERT_DEATH({ TestIncorrectRow(); },".*"); +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 0031ac6d2..3aebc1d0f 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -206,16 +206,10 @@ void TestBuildHist(bool use_shared_memory_histograms) { dh::safe_cuda(cudaMemcpy(h_gidx_buffer.data(), d_gidx_buffer_ptr, sizeof(common::CompressedByteT) * shard.gidx_buffer.size(), cudaMemcpyDeviceToHost)); - auto gidx = common::CompressedIterator(h_gidx_buffer.data(), - num_symbols); - shard.ridx_segments.resize(1); - shard.ridx_segments[0] = Segment(0, kNRows); + shard.row_partitioner.reset(new RowPartitioner(0, kNRows)); shard.hist.AllocateHistogram(0); dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair); - thrust::sequence( - thrust::device_pointer_cast(shard.ridx.Current()), - thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size())); shard.use_shared_memory_histograms = use_shared_memory_histograms; shard.BuildHist(0); @@ -358,138 +352,6 @@ TEST(GpuHist, EvaluateSplits) { ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps); } -TEST(GpuHist, ApplySplit) { - int constexpr kNId = 0; - int constexpr kNRows = 16; - int constexpr kNCols = 8; - - TrainParam param; - std::vector> args = {}; - param.InitAllowUnknown(args); - // Initialize shard - for (size_t i = 0; i < kNCols; ++i) { - param.monotone_constraints.emplace_back(0); - } - std::unique_ptr> shard{ - new DeviceShard(0, 0, 0, kNRows, param, kNCols, - kNCols)}; - - shard->ridx_segments.resize(3); // 3 nodes. - shard->node_sum_gradients.resize(3); - - shard->ridx_segments[0] = Segment(0, kNRows); - shard->ba.Allocate(0, &(shard->ridx), kNRows, - &(shard->position), kNRows); - shard->ellpack_matrix.row_stride = kNCols; - thrust::sequence( - thrust::device_pointer_cast(shard->ridx.Current()), - thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size())); - RegTree tree; - - DeviceSplitCandidate candidate; - candidate.Update(2, kLeftDir, - 0.59, 4, // fvalue has to be equal to one of the cut field - GradientPair(8.2, 2.8), GradientPair(6.3, 3.6), - GPUTrainingParam(param)); - ExpandEntry candidate_entry {0, 0, candidate, 0}; - candidate_entry.nid = kNId; - - // Used to get bin_id in update position. - common::HistCutMatrix cmat = GetHostCutMatrix(); - - MetaInfo info; - info.num_row_ = kNRows; - info.num_col_ = kNCols; - info.num_nonzero_ = kNRows * kNCols; // Dense - - // Initialize gidx - int n_bins = 24; - int row_stride = kNCols; - int num_symbols = n_bins + 1; - size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows, - num_symbols); - shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes, - &(shard->feature_segments), cmat.row_ptr.size(), - &(shard->min_fvalue), cmat.min_val.size(), - &(shard->gidx_fvalue_map), 24); - dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr); - dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut); - shard->ellpack_matrix.feature_segments = shard->feature_segments; - shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map; - dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val); - shard->ellpack_matrix.min_fvalue = shard->min_fvalue; - shard->ellpack_matrix.is_dense = true; - - common::CompressedBufferWriter wr(num_symbols); - // gidx 14 should go right, 12 goes left - std::vector h_gidx (kNRows * row_stride, 14); - h_gidx[4] = 12; - h_gidx[12] = 12; - std::vector h_gidx_compressed (compressed_size_bytes); - - wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end()); - dh::CopyVectorToDeviceSpan(shard->gidx_buffer, h_gidx_compressed); - - shard->ellpack_matrix.gidx_iter = common::CompressedIterator( - shard->gidx_buffer.data(), num_symbols); - - shard->ApplySplit(candidate_entry, &tree); - shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]); - - ASSERT_FALSE(tree[kNId].IsLeaf()); - - int left_nidx = tree[kNId].LeftChild(); - int right_nidx = tree[kNId].RightChild(); - - ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0); - ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2); - ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2); - ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16); -} - -void TestSortPosition(const std::vector& position_in, int left_idx, - int right_idx) { - std::vector left_count = { - std::count(position_in.begin(), position_in.end(), left_idx)}; - thrust::device_vector d_left_count = left_count; - thrust::device_vector position = position_in; - thrust::device_vector position_out(position.size()); - - thrust::device_vector ridx(position.size()); - thrust::sequence(ridx.begin(), ridx.end()); - thrust::device_vector ridx_out(ridx.size()); - dh::CubMemory tmp; - SortPosition( - &tmp, common::Span(position.data().get(), position.size()), - common::Span(position_out.data().get(), position_out.size()), - common::Span(ridx.data().get(), ridx.size()), - common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, - right_idx, d_left_count.data().get(), nullptr); - thrust::host_vector position_result = position_out; - thrust::host_vector ridx_result = ridx_out; - - // Check position is sorted - EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end())); - // Check row indices are sorted inside left and right segment - EXPECT_TRUE( - std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0])); - EXPECT_TRUE( - std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end())); - - // Check key value pairs are the same - for (auto i = 0ull; i < ridx_result.size(); i++) { - EXPECT_EQ(position_result[i], position_in[ridx_result[i]]); - } -} - -TEST(GpuHist, SortPosition) { - TestSortPosition({1, 2, 1, 2, 1}, 1, 2); - TestSortPosition({1, 1, 1, 1}, 1, 2); - TestSortPosition({2, 2, 2, 2}, 1, 2); - TestSortPosition({1, 2, 1, 2, 3}, 1, 2); -} - void TestHistogramIndexImpl(int n_gpus) { // Test if the compressed histogram index matches when using a sparse // dmatrix with and without using external memory