[GPU-Plugin] Improved load balancing search (#2521)
This commit is contained in:
parent
33ee7d1615
commit
c85bf9859e
@ -132,9 +132,12 @@ struct Timer {
|
||||
|
||||
void reset() { start = ClockT::now(); }
|
||||
int64_t elapsed() const { return (ClockT::now() - start).count(); }
|
||||
double elapsedSeconds() const {
|
||||
return elapsed() * ((double)ClockT::period::num / ClockT::period::den);
|
||||
}
|
||||
void printElapsed(std::string label) {
|
||||
// synchronize_n_devices(n_devices, dList);
|
||||
printf("%s:\t %lld\n", label.c_str(), elapsed());
|
||||
printf("%s:\t %fs\n", label.c_str(), elapsedSeconds());
|
||||
reset();
|
||||
}
|
||||
};
|
||||
@ -650,116 +653,124 @@ struct BernoulliRng {
|
||||
|
||||
// Load balancing search
|
||||
|
||||
template <typename func_t>
|
||||
class LauncherItr {
|
||||
public:
|
||||
int idx;
|
||||
func_t f;
|
||||
XGBOOST_DEVICE LauncherItr() : idx(0) {}
|
||||
XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {}
|
||||
XGBOOST_DEVICE LauncherItr &operator=(int output) {
|
||||
f(idx, output);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
template <typename coordinate_t, typename segments_t, typename offset_t>
|
||||
void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates, int num_tiles,
|
||||
int tile_size, segments_t segments, offset_t num_rows,
|
||||
offset_t num_elements) {
|
||||
dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) {
|
||||
offset_t diagonal = idx * tile_size;
|
||||
coordinate_t tile_coordinate;
|
||||
cub::CountingInputIterator<offset_t> nonzero_indices(0);
|
||||
|
||||
template <typename func_t>
|
||||
// Search the merge path
|
||||
// Cast to signed integer as this function can have negatives
|
||||
cub::MergePathSearch(static_cast<int64_t>(diagonal), segments + 1,
|
||||
nonzero_indices, static_cast<int64_t>(num_rows),
|
||||
static_cast<int64_t>(num_elements), tile_coordinate);
|
||||
|
||||
// Output starting offset
|
||||
d_tile_coordinates[idx] = tile_coordinate;
|
||||
});
|
||||
}
|
||||
|
||||
template <int TILE_SIZE, int ITEMS_PER_THREAD, int BLOCK_THREADS,
|
||||
typename offset_t, typename coordinate_t, typename func_t,
|
||||
typename segments_iter>
|
||||
__global__ void LbsKernel(coordinate_t *d_coordinates,
|
||||
segments_iter segment_end_offsets, func_t f,
|
||||
offset_t num_segments) {
|
||||
int tile = blockIdx.x;
|
||||
coordinate_t tile_start_coord = d_coordinates[tile];
|
||||
coordinate_t tile_end_coord = d_coordinates[tile + 1];
|
||||
int64_t tile_num_rows = tile_end_coord.x - tile_start_coord.x;
|
||||
int64_t tile_num_elements = tile_end_coord.y - tile_start_coord.y;
|
||||
|
||||
cub::CountingInputIterator<offset_t> tile_element_indices(tile_start_coord.y);
|
||||
coordinate_t thread_start_coord;
|
||||
|
||||
typedef typename std::iterator_traits<segments_iter>::value_type segment_t;
|
||||
__shared__ struct {
|
||||
segment_t tile_segment_end_offsets[TILE_SIZE + 1];
|
||||
segment_t output_segment[TILE_SIZE];
|
||||
} temp_storage;
|
||||
|
||||
for (auto item : dh::block_stride_range(int(0), int(tile_num_rows + 1))) {
|
||||
temp_storage.tile_segment_end_offsets[item] =
|
||||
segment_end_offsets[min(tile_start_coord.x + item, num_segments - 1)];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int64_t diag = threadIdx.x * ITEMS_PER_THREAD;
|
||||
|
||||
// Cast to signed integer as this function can have negatives
|
||||
cub::MergePathSearch(diag, // Diagonal
|
||||
temp_storage.tile_segment_end_offsets, // List A
|
||||
tile_element_indices, // List B
|
||||
tile_num_rows, tile_num_elements, thread_start_coord);
|
||||
|
||||
coordinate_t thread_current_coord = thread_start_coord;
|
||||
#pragma unroll
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) {
|
||||
if (tile_element_indices[thread_current_coord.y] <
|
||||
temp_storage.tile_segment_end_offsets[thread_current_coord.x]) {
|
||||
temp_storage.output_segment[thread_current_coord.y] =
|
||||
thread_current_coord.x + tile_start_coord.x;
|
||||
++thread_current_coord.y;
|
||||
} else {
|
||||
++thread_current_coord.x;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto item : dh::block_stride_range(int(0), int(tile_num_elements))) {
|
||||
f(tile_start_coord.y + item, temp_storage.output_segment[item]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \class DiscardLambdaItr
|
||||
* \fn template <typename func_t, typename segments_iter, typename offset_t>
|
||||
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
||||
* segments_iter segments, offset_t num_segments, func_t f)
|
||||
*
|
||||
* \brief Thrust compatible iterator type - discards algorithm output and
|
||||
* launches device lambda with the index of the output and the algorithm output as arguments.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*/
|
||||
|
||||
class DiscardLambdaItr {
|
||||
public:
|
||||
// Required iterator traits
|
||||
typedef DiscardLambdaItr self_type; ///< My own type
|
||||
typedef ptrdiff_t
|
||||
difference_type; ///< Type to express the result of subtracting
|
||||
/// one iterator from another
|
||||
typedef LauncherItr<func_t>
|
||||
value_type; ///< The type of the element the iterator can point to
|
||||
typedef value_type *pointer; ///< The type of a pointer to an element the
|
||||
/// iterator can point to
|
||||
typedef value_type reference; ///< The type of a reference to an element the
|
||||
/// iterator can point to
|
||||
typedef typename thrust::detail::iterator_facade_category<
|
||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||
reference>::type iterator_category; ///< The iterator category
|
||||
private:
|
||||
difference_type offset;
|
||||
func_t f;
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {}
|
||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f)
|
||||
: offset(offset), f(f) {}
|
||||
|
||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||
return DiscardLambdaItr(offset + b, f);
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++() {
|
||||
offset++;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++(int) {
|
||||
self_type retval = *this;
|
||||
offset++;
|
||||
return retval;
|
||||
}
|
||||
XGBOOST_DEVICE self_type &operator+=(const int &b) {
|
||||
offset += b;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE reference operator*() const {
|
||||
return LauncherItr<func_t>(offset, f);
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE reference operator[](int idx) {
|
||||
self_type offset = (*this) + idx;
|
||||
return *offset;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \fn template <typename func_t, typename segments_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr<segments_t> segments, int num_segments, func_t f)
|
||||
*
|
||||
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
|
||||
* to be executed on each element. Search 'modern GPU load balancing search for more
|
||||
* information'.
|
||||
* \brief Load balancing search function. Reads a CSR type matrix description
|
||||
* and allows a function to be executed on each element. Search 'modern GPU load
|
||||
* balancing search' for more information.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam func_t Type of the function t.
|
||||
* \tparam segments_iter Type of the segments iterator.
|
||||
* \tparam offset_t Type of the offset.
|
||||
* \tparam segments_t Type of the segments t.
|
||||
* \param device_idx Zero-based index of the device.
|
||||
* \param [in,out] temp_memory Temporary memory allocator.
|
||||
* \param count Number of elements.
|
||||
* \param segments Device pointed to segments.
|
||||
* \param segments Device pointer to segments.
|
||||
* \param num_segments Number of segments.
|
||||
* \param f Lambda to be executed on matrix elements.
|
||||
*/
|
||||
|
||||
template <typename func_t, typename segments_t>
|
||||
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count,
|
||||
thrust::device_ptr<segments_t> segments, int num_segments,
|
||||
func_t f) {
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
auto counting = thrust::make_counting_iterator(0);
|
||||
template <typename func_t, typename segments_iter, typename offset_t>
|
||||
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
||||
segments_iter segments, offset_t num_segments, func_t f) {
|
||||
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
const int BLOCK_THREADS = 256;
|
||||
const int ITEMS_PER_THREAD = 1;
|
||||
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
int num_tiles = dh::div_round_up(count + num_segments, BLOCK_THREADS);
|
||||
|
||||
auto f_wrapper = [=] __device__(int idx, int upper_bound) {
|
||||
f(idx, upper_bound - 1);
|
||||
};
|
||||
temp_memory->LazyAllocate(sizeof(coordinate_t) * (num_tiles + 1));
|
||||
coordinate_t *tmp_tile_coordinates =
|
||||
reinterpret_cast<coordinate_t *>(temp_memory->d_temp_storage);
|
||||
|
||||
DiscardLambdaItr<decltype(f_wrapper)> itr(f_wrapper);
|
||||
FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles, BLOCK_THREADS, segments,
|
||||
num_segments, count);
|
||||
|
||||
thrust::upper_bound(thrust::cuda::par(*temp_memory), segments,
|
||||
segments + num_segments, counting, counting + count, itr);
|
||||
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t>
|
||||
<<<num_tiles, BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
|
||||
num_segments);
|
||||
}
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@ -7,22 +7,72 @@
|
||||
#include "../../src/device_helpers.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
static const std::vector<int> gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7};
|
||||
static const std::vector<int> row_ptr = {0, 3, 6, 8, 10};
|
||||
static const std::vector<int> lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3};
|
||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
thrust::host_vector<int> *row_ptr,
|
||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
||||
row_ptr->resize(num_rows + 1);
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= num_rows; i++) {
|
||||
(*row_ptr)[i] = sum;
|
||||
sum += rand() % max_row_size; // NOLINT
|
||||
|
||||
thrust::device_vector<int> test_lbs() {
|
||||
thrust::device_vector<int> device_gidx = gidx;
|
||||
thrust::device_vector<int> device_row_ptr = row_ptr;
|
||||
thrust::device_vector<int> device_output_row(gidx.size(), 0);
|
||||
auto d_output_row = device_output_row.data();
|
||||
dh::CubMemory temp_memory;
|
||||
dh::TransformLbs(
|
||||
0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1,
|
||||
[=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; });
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
return device_output_row;
|
||||
if (i < num_rows) {
|
||||
for (int j = (*row_ptr)[i]; j < sum; j++) {
|
||||
(*rows).push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); }
|
||||
void SpeedTest() {
|
||||
int num_rows = 1000000;
|
||||
int max_row_size = 100;
|
||||
dh::CubMemory temp_memory;
|
||||
thrust::host_vector<int> h_row_ptr;
|
||||
thrust::host_vector<xgboost::bst_uint> h_rows;
|
||||
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
|
||||
thrust::device_vector<int> row_ptr = h_row_ptr;
|
||||
thrust::device_vector<int> output_row(h_rows.size());
|
||||
auto d_output_row = output_row.data();
|
||||
|
||||
dh::Timer t;
|
||||
dh::TransformLbs(
|
||||
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1,
|
||||
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
double time = t.elapsedSeconds();
|
||||
const int mb_size = 1048576;
|
||||
size_t size = (sizeof(int) * h_rows.size()) / mb_size;
|
||||
printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time,
|
||||
size / time);
|
||||
}
|
||||
|
||||
void TestLbs() {
|
||||
srand(17);
|
||||
dh::CubMemory temp_memory;
|
||||
|
||||
std::vector<int> test_rows = {4, 100, 1000};
|
||||
std::vector<int> test_max_row_sizes = {4, 100, 1300};
|
||||
|
||||
for (auto num_rows : test_rows) {
|
||||
for (auto max_row_size : test_max_row_sizes) {
|
||||
thrust::host_vector<int> h_row_ptr;
|
||||
thrust::host_vector<xgboost::bst_uint> h_rows;
|
||||
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
|
||||
thrust::device_vector<size_t> row_ptr = h_row_ptr;
|
||||
thrust::device_vector<int> output_row(h_rows.size());
|
||||
auto d_output_row = output_row.data();
|
||||
|
||||
dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr),
|
||||
row_ptr.size() - 1,
|
||||
[=] __device__(size_t idx, size_t ridx) {
|
||||
d_output_row[idx] = ridx;
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
ASSERT_TRUE(h_rows == output_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST(cub_lbs, Test) { TestLbs(); }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user