diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index 2950a8f51..705cbcf43 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -132,9 +132,12 @@ struct Timer { void reset() { start = ClockT::now(); } int64_t elapsed() const { return (ClockT::now() - start).count(); } + double elapsedSeconds() const { + return elapsed() * ((double)ClockT::period::num / ClockT::period::den); + } void printElapsed(std::string label) { // synchronize_n_devices(n_devices, dList); - printf("%s:\t %lld\n", label.c_str(), elapsed()); + printf("%s:\t %fs\n", label.c_str(), elapsedSeconds()); reset(); } }; @@ -650,116 +653,124 @@ struct BernoulliRng { // Load balancing search -template -class LauncherItr { - public: - int idx; - func_t f; - XGBOOST_DEVICE LauncherItr() : idx(0) {} - XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {} - XGBOOST_DEVICE LauncherItr &operator=(int output) { - f(idx, output); - return *this; - } -}; +template +void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates, int num_tiles, + int tile_size, segments_t segments, offset_t num_rows, + offset_t num_elements) { + dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) { + offset_t diagonal = idx * tile_size; + coordinate_t tile_coordinate; + cub::CountingInputIterator nonzero_indices(0); -template + // Search the merge path + // Cast to signed integer as this function can have negatives + cub::MergePathSearch(static_cast(diagonal), segments + 1, + nonzero_indices, static_cast(num_rows), + static_cast(num_elements), tile_coordinate); + + // Output starting offset + d_tile_coordinates[idx] = tile_coordinate; + }); +} + +template +__global__ void LbsKernel(coordinate_t *d_coordinates, + segments_iter segment_end_offsets, func_t f, + offset_t num_segments) { + int tile = blockIdx.x; + coordinate_t tile_start_coord = d_coordinates[tile]; + coordinate_t tile_end_coord = d_coordinates[tile + 1]; + int64_t tile_num_rows = tile_end_coord.x - tile_start_coord.x; + int64_t tile_num_elements = tile_end_coord.y - tile_start_coord.y; + + cub::CountingInputIterator tile_element_indices(tile_start_coord.y); + coordinate_t thread_start_coord; + + typedef typename std::iterator_traits::value_type segment_t; + __shared__ struct { + segment_t tile_segment_end_offsets[TILE_SIZE + 1]; + segment_t output_segment[TILE_SIZE]; + } temp_storage; + + for (auto item : dh::block_stride_range(int(0), int(tile_num_rows + 1))) { + temp_storage.tile_segment_end_offsets[item] = + segment_end_offsets[min(tile_start_coord.x + item, num_segments - 1)]; + } + __syncthreads(); + + int64_t diag = threadIdx.x * ITEMS_PER_THREAD; + + // Cast to signed integer as this function can have negatives + cub::MergePathSearch(diag, // Diagonal + temp_storage.tile_segment_end_offsets, // List A + tile_element_indices, // List B + tile_num_rows, tile_num_elements, thread_start_coord); + + coordinate_t thread_current_coord = thread_start_coord; +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { + if (tile_element_indices[thread_current_coord.y] < + temp_storage.tile_segment_end_offsets[thread_current_coord.x]) { + temp_storage.output_segment[thread_current_coord.y] = + thread_current_coord.x + tile_start_coord.x; + ++thread_current_coord.y; + } else { + ++thread_current_coord.x; + } + } + __syncthreads(); + + for (auto item : dh::block_stride_range(int(0), int(tile_num_elements))) { + f(tile_start_coord.y + item, temp_storage.output_segment[item]); + } +} /** - * \class DiscardLambdaItr + * \fn template + * void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, + * segments_iter segments, offset_t num_segments, func_t f) * - * \brief Thrust compatible iterator type - discards algorithm output and - * launches device lambda with the index of the output and the algorithm output as arguments. - * - * \author Rory - * \date 7/9/2017 - */ - -class DiscardLambdaItr { - public: - // Required iterator traits - typedef DiscardLambdaItr self_type; ///< My own type - typedef ptrdiff_t - difference_type; ///< Type to express the result of subtracting - /// one iterator from another - typedef LauncherItr - value_type; ///< The type of the element the iterator can point to - typedef value_type *pointer; ///< The type of a pointer to an element the - /// iterator can point to - typedef value_type reference; ///< The type of a reference to an element the - /// iterator can point to - typedef typename thrust::detail::iterator_facade_category< - thrust::any_system_tag, thrust::random_access_traversal_tag, value_type, - reference>::type iterator_category; ///< The iterator category - private: - difference_type offset; - func_t f; - - public: - XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {} - XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f) - : offset(offset), f(f) {} - - XGBOOST_DEVICE self_type operator+(const int &b) const { - return DiscardLambdaItr(offset + b, f); - } - XGBOOST_DEVICE self_type operator++() { - offset++; - return *this; - } - XGBOOST_DEVICE self_type operator++(int) { - self_type retval = *this; - offset++; - return retval; - } - XGBOOST_DEVICE self_type &operator+=(const int &b) { - offset += b; - return *this; - } - XGBOOST_DEVICE reference operator*() const { - return LauncherItr(offset, f); - } - - XGBOOST_DEVICE reference operator[](int idx) { - self_type offset = (*this) + idx; - return *offset; - } -}; - -/** - * \fn template void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr segments, int num_segments, func_t f) - * - * \brief Load balancing search function. Reads a CSR type matrix description and allows a function - * to be executed on each element. Search 'modern GPU load balancing search for more - * information'. + * \brief Load balancing search function. Reads a CSR type matrix description + * and allows a function to be executed on each element. Search 'modern GPU load + * balancing search' for more information. * * \author Rory * \date 7/9/2017 * - * \tparam segments_t Type of the segments t. + * \tparam func_t Type of the function t. + * \tparam segments_iter Type of the segments iterator. + * \tparam offset_t Type of the offset. + * \tparam segments_t Type of the segments t. * \param device_idx Zero-based index of the device. * \param [in,out] temp_memory Temporary memory allocator. * \param count Number of elements. - * \param segments Device pointed to segments. + * \param segments Device pointer to segments. * \param num_segments Number of segments. * \param f Lambda to be executed on matrix elements. */ -template -void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, - thrust::device_ptr segments, int num_segments, - func_t f) { - safe_cuda(cudaSetDevice(device_idx)); - auto counting = thrust::make_counting_iterator(0); +template +void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, + segments_iter segments, offset_t num_segments, func_t f) { + typedef typename cub::CubVector::Type coordinate_t; + dh::safe_cuda(cudaSetDevice(device_idx)); + const int BLOCK_THREADS = 256; + const int ITEMS_PER_THREAD = 1; + const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; + int num_tiles = dh::div_round_up(count + num_segments, BLOCK_THREADS); - auto f_wrapper = [=] __device__(int idx, int upper_bound) { - f(idx, upper_bound - 1); - }; + temp_memory->LazyAllocate(sizeof(coordinate_t) * (num_tiles + 1)); + coordinate_t *tmp_tile_coordinates = + reinterpret_cast(temp_memory->d_temp_storage); - DiscardLambdaItr itr(f_wrapper); + FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles, BLOCK_THREADS, segments, + num_segments, count); - thrust::upper_bound(thrust::cuda::par(*temp_memory), segments, - segments + num_segments, counting, counting + count, itr); + LbsKernel + <<>>(tmp_tile_coordinates, segments + 1, f, + num_segments); } } // namespace dh diff --git a/plugin/updater_gpu/test/cpp/test_device_helpers.cu b/plugin/updater_gpu/test/cpp/test_device_helpers.cu index 910f668bc..6bd520b71 100644 --- a/plugin/updater_gpu/test/cpp/test_device_helpers.cu +++ b/plugin/updater_gpu/test/cpp/test_device_helpers.cu @@ -7,22 +7,72 @@ #include "../../src/device_helpers.cuh" #include "gtest/gtest.h" -static const std::vector gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7}; -static const std::vector row_ptr = {0, 3, 6, 8, 10}; -static const std::vector lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3}; +void CreateTestData(xgboost::bst_uint num_rows, int max_row_size, + thrust::host_vector *row_ptr, + thrust::host_vector *rows) { + row_ptr->resize(num_rows + 1); + int sum = 0; + for (int i = 0; i <= num_rows; i++) { + (*row_ptr)[i] = sum; + sum += rand() % max_row_size; // NOLINT -thrust::device_vector test_lbs() { - thrust::device_vector device_gidx = gidx; - thrust::device_vector device_row_ptr = row_ptr; - thrust::device_vector device_output_row(gidx.size(), 0); - auto d_output_row = device_output_row.data(); - dh::CubMemory temp_memory; - dh::TransformLbs( - 0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1, - [=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; }); - - dh::safe_cuda(cudaDeviceSynchronize()); - return device_output_row; + if (i < num_rows) { + for (int j = (*row_ptr)[i]; j < sum; j++) { + (*rows).push_back(i); + } + } + } } -TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); } +void SpeedTest() { + int num_rows = 1000000; + int max_row_size = 100; + dh::CubMemory temp_memory; + thrust::host_vector h_row_ptr; + thrust::host_vector h_rows; + CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows); + thrust::device_vector row_ptr = h_row_ptr; + thrust::device_vector output_row(h_rows.size()); + auto d_output_row = output_row.data(); + + dh::Timer t; + dh::TransformLbs( + 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, + [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); + + dh::safe_cuda(cudaDeviceSynchronize()); + double time = t.elapsedSeconds(); + const int mb_size = 1048576; + size_t size = (sizeof(int) * h_rows.size()) / mb_size; + printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time, + size / time); +} + +void TestLbs() { + srand(17); + dh::CubMemory temp_memory; + + std::vector test_rows = {4, 100, 1000}; + std::vector test_max_row_sizes = {4, 100, 1300}; + + for (auto num_rows : test_rows) { + for (auto max_row_size : test_max_row_sizes) { + thrust::host_vector h_row_ptr; + thrust::host_vector h_rows; + CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows); + thrust::device_vector row_ptr = h_row_ptr; + thrust::device_vector output_row(h_rows.size()); + auto d_output_row = output_row.data(); + + dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr), + row_ptr.size() - 1, + [=] __device__(size_t idx, size_t ridx) { + d_output_row[idx] = ridx; + }); + + dh::safe_cuda(cudaDeviceSynchronize()); + ASSERT_TRUE(h_rows == output_row); + } + } +} +TEST(cub_lbs, Test) { TestLbs(); }