From 530f01e21c005e28caebc1236067c23e20326b16 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 11 Jul 2017 22:36:39 +1200 Subject: [PATCH] [GPU-Plugin] Add load balancing search to gpu_hist. Add compressed iterator. (#2504) --- cmake/Utils.cmake | 3 +- plugin/updater_gpu/README.md | 4 +- plugin/updater_gpu/benchmark/benchmark.py | 3 +- plugin/updater_gpu/src/device_helpers.cuh | 226 +++++++++++------- plugin/updater_gpu/src/gpu_hist_builder.cu | 208 +++++++--------- plugin/updater_gpu/src/gpu_hist_builder.cuh | 20 +- .../test/cpp/test_device_helpers.cu | 28 +++ src/common/compressed_iterator.h | 199 +++++++++++++++ tests/cpp/common/test_compressed_iterator.cc | 54 +++++ 9 files changed, 523 insertions(+), 222 deletions(-) create mode 100644 plugin/updater_gpu/test/cpp/test_device_helpers.cu create mode 100644 src/common/compressed_iterator.h create mode 100644 tests/cpp/common/test_compressed_iterator.cc diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 57dd30de2..eca94c687 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -56,6 +56,7 @@ endfunction(set_default_configuration_release) function(format_gencode_flags flags out) foreach(ver ${flags}) - set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};" PARENT_SCOPE) + set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};") endforeach() + set(${out} "${${out}}" PARENT_SCOPE) endfunction(format_gencode_flags flags) \ No newline at end of file diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index 01ca4edfa..5664a7a22 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -144,8 +144,10 @@ $ make PLUGIN_UPDATER_GPU=ON GTEST_PATH=${CACHE_PREFIX} test ``` ## Changelog -##### 2017/6/26 +##### 2017/7/10 +* Memory performance improved 4x for gpu_hist +##### 2017/6/26 * Change API to use tree_method parameter * Increase required cmake version to 3.5 * Add compute arch 3.5 to default archs diff --git a/plugin/updater_gpu/benchmark/benchmark.py b/plugin/updater_gpu/benchmark/benchmark.py index f81232645..f29f2c756 100644 --- a/plugin/updater_gpu/benchmark/benchmark.py +++ b/plugin/updater_gpu/benchmark/benchmark.py @@ -15,7 +15,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm): param = {'objective': 'binary:logistic', 'max_depth': 6, - 'silent': 1, + 'silent': 0, 'n_gpus': 1, 'gpu_id': 0, 'eval_metric': 'auc'} @@ -26,6 +26,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm): xgb.train(param, dtrain, args.iterations) print ("Time: %s seconds" % (str(time.time() - tmp))) + param['silent'] = 1 param['tree_method'] = cpu_algorithm print("Training with '%s'" % param['tree_method']) tmp = time.time() diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index be30b8063..2950a8f51 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -2,11 +2,13 @@ * Copyright 2017 XGBoost contributors */ #pragma once +#include +#include #include #include #include +#include #include -#include "nccl.h" #include #include #include @@ -15,7 +17,7 @@ #include #include #include - +#include "nccl.h" // Uncomment to enable // #define DEVICE_TIMER @@ -121,87 +123,6 @@ inline int get_device_idx(int gpu_id) { * Timers */ -#define MAX_WARPS 32 // Maximum number of warps to time -#define MAX_SLOTS 10 -#define TIMER_BLOCKID 0 // Block to time -struct DeviceTimerGlobal { -#ifdef DEVICE_TIMER - - clock_t total_clocks[MAX_SLOTS][MAX_WARPS]; - int64_t count[MAX_SLOTS][MAX_WARPS]; - -#endif - - // Clear device memory. Call at start of kernel. - __device__ void Init() { -#ifdef DEVICE_TIMER - if (blockIdx.x == TIMER_BLOCKID && threadIdx.x < MAX_WARPS) { - for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) { - total_clocks[SLOT][threadIdx.x] = 0; - count[SLOT][threadIdx.x] = 0; - } - } -#endif - } - - void HostPrint() { -#ifdef DEVICE_TIMER - DeviceTimerGlobal h_timer; - safe_cuda( - cudaMemcpyFromSymbol(&h_timer, (*this), sizeof(DeviceTimerGlobal))); - - for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) { - if (h_timer.count[SLOT][0] == 0) { - continue; - } - - clock_t sum_clocks = 0; - int64_t sum_count = 0; - - for (int WARP = 0; WARP < MAX_WARPS; WARP++) { - if (h_timer.count[SLOT][WARP] == 0) { - continue; - } - - sum_clocks += h_timer.total_clocks[SLOT][WARP]; - sum_count += h_timer.count[SLOT][WARP]; - } - - printf("Slot %d: %d clocks per call, called %d times.\n", SLOT, - sum_clocks / sum_count, h_timer.count[SLOT][0]); - } -#endif - } -}; - -struct DeviceTimer { -#ifdef DEVICE_TIMER - clock_t start; - int slot; - DeviceTimerGlobal >imer; -#endif - -#ifdef DEVICE_TIMER - __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT - : GTimer(GTimer), - start(clock()), - slot(slot) {} -#else - __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT -#endif - - __device__ void End() { -#ifdef DEVICE_TIMER - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - if (blockIdx.x == TIMER_BLOCKID && lane_id == 0) { - GTimer.count[slot][warp_id] += 1; - GTimer.total_clocks[slot][warp_id] += clock() - start; - } -#endif - } -}; - struct Timer { typedef std::chrono::high_resolution_clock ClockT; @@ -549,23 +470,36 @@ struct CubMemory { void *d_temp_storage; size_t temp_storage_bytes; + // Thrust + typedef char value_type; + CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} ~CubMemory() { Free(); } void Free() { - if (d_temp_storage != NULL) { + if (this->IsAllocated()) { safe_cuda(cudaFree(d_temp_storage)); } } - void LazyAllocate(size_t n_bytes) { - if (n_bytes > temp_storage_bytes) { + void LazyAllocate(size_t num_bytes) { + if (num_bytes > temp_storage_bytes) { Free(); - safe_cuda(cudaMalloc(&d_temp_storage, n_bytes)); - temp_storage_bytes = n_bytes; + safe_cuda(cudaMalloc(&d_temp_storage, num_bytes)); + temp_storage_bytes = num_bytes; } } + // Thrust + char *allocate(std::ptrdiff_t num_bytes) { + LazyAllocate(num_bytes); + return reinterpret_cast(d_temp_storage); + } + + // Thrust + void deallocate(char *ptr, size_t n) { + // Do nothing + } bool IsAllocated() { return d_temp_storage != NULL; } }; @@ -591,7 +525,7 @@ void print(const thrust::device_vector &v, size_t max_items = 10) { std::cout << "\n"; } -template +template void print(const dvec &v, size_t max_items = 10) { std::vector h = v.as_vector(); for (int i = 0; i < std::min(max_items, h.size()); i++) { @@ -714,4 +648,118 @@ struct BernoulliRng { t1234.printElapsed(name); \ } while (0) +// Load balancing search + +template +class LauncherItr { + public: + int idx; + func_t f; + XGBOOST_DEVICE LauncherItr() : idx(0) {} + XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {} + XGBOOST_DEVICE LauncherItr &operator=(int output) { + f(idx, output); + return *this; + } +}; + +template + +/** + * \class DiscardLambdaItr + * + * \brief Thrust compatible iterator type - discards algorithm output and + * launches device lambda with the index of the output and the algorithm output as arguments. + * + * \author Rory + * \date 7/9/2017 + */ + +class DiscardLambdaItr { + public: + // Required iterator traits + typedef DiscardLambdaItr self_type; ///< My own type + typedef ptrdiff_t + difference_type; ///< Type to express the result of subtracting + /// one iterator from another + typedef LauncherItr + value_type; ///< The type of the element the iterator can point to + typedef value_type *pointer; ///< The type of a pointer to an element the + /// iterator can point to + typedef value_type reference; ///< The type of a reference to an element the + /// iterator can point to + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, thrust::random_access_traversal_tag, value_type, + reference>::type iterator_category; ///< The iterator category + private: + difference_type offset; + func_t f; + + public: + XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {} + XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f) + : offset(offset), f(f) {} + + XGBOOST_DEVICE self_type operator+(const int &b) const { + return DiscardLambdaItr(offset + b, f); + } + XGBOOST_DEVICE self_type operator++() { + offset++; + return *this; + } + XGBOOST_DEVICE self_type operator++(int) { + self_type retval = *this; + offset++; + return retval; + } + XGBOOST_DEVICE self_type &operator+=(const int &b) { + offset += b; + return *this; + } + XGBOOST_DEVICE reference operator*() const { + return LauncherItr(offset, f); + } + + XGBOOST_DEVICE reference operator[](int idx) { + self_type offset = (*this) + idx; + return *offset; + } +}; + +/** + * \fn template void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr segments, int num_segments, func_t f) + * + * \brief Load balancing search function. Reads a CSR type matrix description and allows a function + * to be executed on each element. Search 'modern GPU load balancing search for more + * information'. + * + * \author Rory + * \date 7/9/2017 + * + * \tparam segments_t Type of the segments t. + * \param device_idx Zero-based index of the device. + * \param [in,out] temp_memory Temporary memory allocator. + * \param count Number of elements. + * \param segments Device pointed to segments. + * \param num_segments Number of segments. + * \param f Lambda to be executed on matrix elements. + */ + +template +void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, + thrust::device_ptr segments, int num_segments, + func_t f) { + safe_cuda(cudaSetDevice(device_idx)); + auto counting = thrust::make_counting_iterator(0); + + auto f_wrapper = [=] __device__(int idx, int upper_bound) { + f(idx, upper_bound - 1); + }; + + DiscardLambdaItr itr(f_wrapper); + + thrust::upper_bound(thrust::cuda::par(*temp_memory), segments, + segments + num_segments, counting, counting + count, itr); +} + } // namespace dh diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index f45adecb8..ffb861f06 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -1,36 +1,45 @@ /*! * Copyright 2017 Rory mitchell */ -#include #include #include #include #include +#include #include #include #include #include #include "common.cuh" #include "device_helpers.cuh" +#include "dmlc/timer.h" #include "gpu_hist_builder.cuh" namespace xgboost { namespace tree { void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat, - bst_uint begin, bst_uint end) { + bst_uint element_begin, bst_uint element_end, + bst_uint row_begin, bst_uint row_end, int n_bins) { dh::safe_cuda(cudaSetDevice(device_idx)); - CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated"; - CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated"; + CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated"; + CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1) + << "row_ptr must be externally allocated"; - thrust::copy(gmat.index.data() + begin, gmat.index.data() + end, gidx.tbegin()); - thrust::device_vector row_ptr = gmat.row_ptr; + common::CompressedBufferWriter cbw(n_bins); + std::vector host_buffer(gidx_buffer.size()); + cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin, + gmat.index.begin() + element_end); + gidx_buffer = host_buffer; + gidx = common::CompressedIterator(gidx_buffer.data(), n_bins); - auto counting = thrust::make_counting_iterator(begin); - thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting, - counting + gidx.size(), ridx.tbegin()); - thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(), - [=] __device__(int val) { return val - 1; }); + // row_ptr + thrust::copy(gmat.row_ptr.data() + row_begin, + gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin()); + // normalise row_ptr + bst_uint start = gmat.row_ptr[row_begin]; + thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(), + [=] __device__(int val) { return val - start; }); } void DeviceHist::Init(int n_bins_in) { @@ -59,10 +68,10 @@ HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins) __device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const { int hist_idx = nidx * n_bins + gidx; atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below - // line lead to about 3X - // slowdown due to memory - // dependency and access - // pattern issues. + // line lead to about 3X + // slowdown due to memory + // dependency and access + // pattern issues. atomicAdd(&(d_hist[hist_idx].hess), gpair.hess); } @@ -170,7 +179,6 @@ void GPUHistBuilder::InitData(const std::vector& gpair, // process) } - CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column " "block. Try setting 'tree_method' " "parameter to 'exact'"; @@ -219,6 +227,7 @@ void GPUHistBuilder::InitData(const std::vector& gpair, // ba.allocate(master_device, ); // allocate vectors across all devices + temp_memory.resize(n_devices); hist_vec.resize(n_devices); nodes.resize(n_devices); nodes_temp.resize(n_devices); @@ -269,18 +278,21 @@ void GPUHistBuilder::InitData(const std::vector& gpair, h_feature_segments.size(), // constant and same on all devices &prediction_cache[d_idx], num_rows_segment, &position[d_idx], num_rows_segment, &position_tmp[d_idx], num_rows_segment, - &device_gpair[d_idx], num_rows_segment, &device_matrix[d_idx].gidx, - num_elements_segment, // constant and same on all devices - &device_matrix[d_idx].ridx, - num_elements_segment, // constant and same on all devices + &device_gpair[d_idx], num_rows_segment, + &device_matrix[d_idx].gidx_buffer, + common::CompressedBufferWriter::CalculateBufferSize( + num_elements_segment, + n_bins), // constant and same on all devices + &device_matrix[d_idx].row_ptr, num_rows_segment + 1, &gidx_feature_map[d_idx], n_bins, // constant and same on all devices &gidx_fvalue_map[d_idx], hmat_.cut.size()); // constant and same on all devices // Copy Host to Device (assumes comes after ba.allocate that sets device) - device_matrix[d_idx].Init(device_idx, gmat_, - device_element_segments[d_idx], - device_element_segments[d_idx + 1]); + device_matrix[d_idx].Init( + device_idx, gmat_, device_element_segments[d_idx], + device_element_segments[d_idx + 1], device_row_segments[d_idx], + device_row_segments[d_idx + 1], n_bins); gidx_feature_map[d_idx] = h_gidx_feature_map; gidx_fvalue_map[d_idx] = hmat_.cut; feature_segments[d_idx] = h_feature_segments; @@ -338,39 +350,41 @@ void GPUHistBuilder::BuildHist(int depth) { size_t begin = device_element_segments[d_idx]; size_t end = device_element_segments[d_idx + 1]; size_t row_begin = device_row_segments[d_idx]; + size_t row_end = device_row_segments[d_idx + 1]; - auto d_ridx = device_matrix[d_idx].ridx.data(); - auto d_gidx = device_matrix[d_idx].gidx.data(); + auto d_gidx = device_matrix[d_idx].gidx; + auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin(); auto d_position = position[d_idx].data(); auto d_gpair = device_gpair[d_idx].data(); auto d_left_child_smallest = left_child_smallest[d_idx].data(); auto hist_builder = hist_vec[d_idx].GetBuilder(); + dh::TransformLbs( + device_idx, &temp_memory[d_idx], end - begin, d_row_ptr, + row_end - row_begin, [=] __device__(int local_idx, int local_ridx) { + int nidx = d_position[local_ridx]; // OPTMARK: latency + if (!is_active(nidx, depth)) return; - dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) { - int ridx = d_ridx[local_idx]; // OPTMARK: latency - int nidx = d_position[ridx - row_begin]; // OPTMARK: latency - if (!is_active(nidx, depth)) return; + // Only increment smallest node + bool is_smallest = (d_left_child_smallest[parent_nidx(nidx)] && + is_left_child(nidx)) || + (!d_left_child_smallest[parent_nidx(nidx)] && + !is_left_child(nidx)); + if (!is_smallest && depth > 0) return; - // Only increment smallest node - bool is_smallest = - (d_left_child_smallest[parent_nidx(nidx)] && is_left_child(nidx)) || - (!d_left_child_smallest[parent_nidx(nidx)] && !is_left_child(nidx)); - if (!is_smallest && depth > 0) return; + int gidx = d_gidx[local_idx]; + bst_gpair gpair = d_gpair[local_ridx]; - int gidx = d_gidx[local_idx]; - bst_gpair gpair = d_gpair[ridx - row_begin]; - - hist_builder.Add(gpair, gidx, nidx); // OPTMARK: This is slow, could use - // shared memory or cache results - // intead of writing to global - // memory every time in atomic way. - }); + hist_builder.Add(gpair, gidx, + nidx); // OPTMARK: This is slow, could use + // shared memory or cache results + // intead of writing to global + // memory every time in atomic way. + }); } - // dh::safe_cuda(cudaDeviceSynchronize()); dh::synchronize_n_devices(n_devices, dList); -// time.printElapsed("Add Time"); + // time.printElapsed("Add Time"); // (in-place) reduce each element of histogram (for only current level) across // multiple gpus @@ -393,7 +407,7 @@ void GPUHistBuilder::BuildHist(int depth) { dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx]))); } -// if no NCCL, then presume only 1 GPU, then already correct + // if no NCCL, then presume only 1 GPU, then already correct // time.printElapsed("Reduce-Add Time"); @@ -572,15 +586,15 @@ __global__ void find_split_kernel( left_child_smallest = &d_left_child_smallest_temp[blockIdx.x]; } - *Nodeleft = Node( - split.left_sum, - CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess), - CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess)); + *Nodeleft = + Node(split.left_sum, + CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess), + CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess)); - *Noderight = Node( - split.right_sum, - CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess), - CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess)); + *Noderight = + Node(split.right_sum, + CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess), + CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess)); // Record smallest node if (split.left_sum.hess <= split.right_sum.hess) { @@ -650,9 +664,9 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(), nodes_child_temp[d_idx].data(), nodes_offset_device, - fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), - left_child_smallest_temp[d_idx].data(), colsample, - feature_flags[d_idx].data()); + fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), + GPUTrainingParam(param), left_child_smallest_temp[d_idx].data(), + colsample, feature_flags[d_idx].data()); } // nccl only on devices that did split @@ -747,7 +761,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, nodes_offset_device, fidx_min_map[d_idx].data(), - gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), + gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample, feature_flags[d_idx].data()); @@ -800,7 +814,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, nodes_offset_device, fidx_min_map[d_idx].data(), - gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), + gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample, feature_flags[d_idx].data()); } @@ -811,57 +825,23 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { } void GPUHistBuilder::InitFirstNode(const std::vector& gpair) { -#ifdef _WIN32 - // Visual studio complains about C:/Program Files (x86)/Microsoft Visual - // Studio 14.0/VC/bin/../../VC/INCLUDE\utility(445): error : static assertion - // failed with "tuple index out of bounds" - // and C:/Program Files (x86)/Microsoft Visual Studio - // 14.0/VC/bin/../../VC/INCLUDE\future(1888): error : no instance of function - // template "std::_Invoke_stored" matches the argument list - std::vector future_results(n_devices); + // Perform asynchronous reduction on each gpu + std::vector device_sums(n_devices); +#pragma omp parallel for num_threads(n_devices) for (int d_idx = 0; d_idx < n_devices; d_idx++) { int device_idx = dList[d_idx]; - + dh::safe_cuda(cudaSetDevice(device_idx)); auto begin = device_gpair[d_idx].tbegin(); auto end = device_gpair[d_idx].tend(); bst_gpair init = bst_gpair(); auto binary_op = thrust::plus(); - - dh::safe_cuda(cudaSetDevice(device_idx)); - future_results[d_idx] = thrust::reduce(begin, end, init, binary_op); + device_sums[d_idx] = thrust::reduce(begin, end, init, binary_op); } - // sum over devices on host (with blocking get()) bst_gpair sum = bst_gpair(); for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - sum += future_results[d_idx]; + sum += device_sums[d_idx]; } -#else - // asynch reduce per device - - std::vector> future_results(n_devices); - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - // std::async captures the algorithm parameters by value - // use std::launch::async to ensure the creation of a new thread - future_results[d_idx] = std::async(std::launch::async, [=] { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - auto begin = device_gpair[d_idx].tbegin(); - auto end = device_gpair[d_idx].tend(); - bst_gpair init = bst_gpair(); - auto binary_op = thrust::plus(); - return thrust::reduce(begin, end, init, binary_op); - }); - } - - // sum over devices on host (with blocking get()) - bst_gpair sum = bst_gpair(); - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - sum += future_results[d_idx].get(); - } -#endif // Setup first node so all devices have same first node (here done same on all // devices, or could have done one device and Bcast if worried about exact @@ -874,11 +854,10 @@ void GPUHistBuilder::InitFirstNode(const std::vector& gpair) { dh::launch_n(device_idx, 1, [=] __device__(int idx) { bst_gpair sum_gradients = sum; - d_nodes[idx] = Node( - sum_gradients, - CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess), - CalcWeight(gpu_param, sum_gradients.grad, - sum_gradients.hess)); + d_nodes[idx] = + Node(sum_gradients, + CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess), + CalcWeight(gpu_param, sum_gradients.grad, sum_gradients.hess)); }); } // synch all devices to host before moving on (No, can avoid because BuildHist @@ -901,7 +880,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) { auto d_position = position[d_idx].data(); Node* d_nodes = nodes[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); - auto d_gidx = device_matrix[d_idx].gidx.data(); + auto d_gidx = device_matrix[d_idx].gidx; int n_columns = info->num_col; size_t begin = device_row_segments[d_idx]; size_t end = device_row_segments[d_idx + 1]; @@ -941,8 +920,8 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) { Node* d_nodes = nodes[d_idx].data(); auto d_gidx_feature_map = gidx_feature_map[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); - auto d_gidx = device_matrix[d_idx].gidx.data(); - auto d_ridx = device_matrix[d_idx].ridx.data(); + auto d_gidx = device_matrix[d_idx].gidx; + auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin(); size_t row_begin = device_row_segments[d_idx]; size_t row_end = device_row_segments[d_idx + 1]; @@ -973,10 +952,11 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) { // Update node based on fvalue where exists // OPTMARK: This kernel is very inefficient for both compute and memory, // dominated by memory dependency / access patterns - dh::launch_n( - device_idx, element_end - element_begin, [=] __device__(int local_idx) { - int ridx = d_ridx[local_idx]; - int pos = d_position[ridx - row_begin]; + + dh::TransformLbs( + device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr, + row_end - row_begin, [=] __device__(int local_idx, int local_ridx) { + int pos = d_position[local_ridx]; if (!is_active(pos, depth)) { return; } @@ -997,9 +977,9 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) { float fvalue = d_gidx_fvalue_map[gidx]; if (fvalue <= node.split.fvalue) { - d_position_tmp[ridx - row_begin] = left_child_nidx(pos); + d_position_tmp[local_ridx] = left_child_nidx(pos); } else { - d_position_tmp[ridx - row_begin] = right_child_nidx(pos); + d_position_tmp[local_ridx] = right_child_nidx(pos); } } }); @@ -1026,10 +1006,6 @@ void GPUHistBuilder::ColSampleLevel() { h_feature_flags[fidx] = 1; } - // copy from Host to Device for all devices - // for(auto &f:feature_flags){ // this doesn't set device as should - // f = h_feature_flags; - // } for (int d_idx = 0; d_idx < n_devices; d_idx++) { int device_idx = dList[d_idx]; dh::safe_cuda(cudaSetDevice(device_idx)); diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh index a5cd57736..1264faef4 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cuh +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -8,26 +8,20 @@ #include #include "../../src/common/hist_util.h" #include "../../src/tree/param.h" +#include "../../src/common/compressed_iterator.h" #include "device_helpers.cuh" #include "types.cuh" - -#ifndef NCCL -#define NCCL 1 -#endif - -#if (NCCL) #include "nccl.h" -#endif namespace xgboost { - namespace tree { struct DeviceGMat { - dh::dvec gidx; - dh::dvec ridx; + dh::dvec gidx_buffer; + common::CompressedIterator gidx; + dh::dvec row_ptr; void Init(int device_idx, const common::GHistIndexMatrix &gmat, - bst_uint begin, bst_uint end); + bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins); }; struct HistBuilder { @@ -95,7 +89,6 @@ class GPUHistBuilder { dh::bulk_allocator ba; // dh::bulk_allocator ba; // can't be used // with NCCL - dh::CubMemory cub_mem; std::vector feature_set_tree; std::vector feature_set_level; @@ -108,6 +101,7 @@ class GPUHistBuilder { std::vector device_row_segments; std::vector device_element_segments; + std::vector temp_memory; std::vector hist_vec; std::vector> nodes; std::vector> nodes_temp; @@ -126,10 +120,8 @@ class GPUHistBuilder { std::vector> gidx_fvalue_map; std::vector streams; -#if (NCCL) std::vector comms; std::vector> find_split_comms; -#endif }; } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/test/cpp/test_device_helpers.cu b/plugin/updater_gpu/test/cpp/test_device_helpers.cu new file mode 100644 index 000000000..910f668bc --- /dev/null +++ b/plugin/updater_gpu/test/cpp/test_device_helpers.cu @@ -0,0 +1,28 @@ + +/*! + * Copyright 2017 XGBoost contributors + */ +#include +#include +#include "../../src/device_helpers.cuh" +#include "gtest/gtest.h" + +static const std::vector gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7}; +static const std::vector row_ptr = {0, 3, 6, 8, 10}; +static const std::vector lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3}; + +thrust::device_vector test_lbs() { + thrust::device_vector device_gidx = gidx; + thrust::device_vector device_row_ptr = row_ptr; + thrust::device_vector device_output_row(gidx.size(), 0); + auto d_output_row = device_output_row.data(); + dh::CubMemory temp_memory; + dh::TransformLbs( + 0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1, + [=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; }); + + dh::safe_cuda(cudaDeviceSynchronize()); + return device_output_row; +} + +TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); } diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h new file mode 100644 index 000000000..794c93398 --- /dev/null +++ b/src/common/compressed_iterator.h @@ -0,0 +1,199 @@ +/*! + * Copyright 2017 by Contributors + * \file compressed_iterator.h + */ +#pragma once +#include +#include +#include +#include "dmlc/logging.h" + +namespace xgboost { +namespace common { + +typedef unsigned char compressed_byte_t; + +namespace detail { +inline void SetBit(compressed_byte_t *byte, int bit_idx) { + *byte |= 1 << bit_idx; +} +template +inline T CheckBit(const T &byte, int bit_idx) { + return byte & (1 << bit_idx); +} +inline void ClearBit(compressed_byte_t *byte, int bit_idx) { + *byte &= ~(1 << bit_idx); +} +static const int padding = 4; // Assign padding so we can read slightly off + // the beginning of the array + +// The number of bits required to represent a given unsigned range +static int SymbolBits(int num_symbols) { + return std::ceil(std::log2(num_symbols)); +} +} // namespace detail + +/** + * \class CompressedBufferWriter + * + * \brief Writes bit compressed symbols to a memory buffer. Use + * CompressedIterator to read symbols back from buffer. Currently limited to a + * maximum symbol size of 28 bits. + * + * \author Rory + * \date 7/9/2017 + */ + +class CompressedBufferWriter { + private: + int symbol_bits_; + size_t offset_; + + public: + explicit CompressedBufferWriter(int num_symbols) : offset_(0) { + symbol_bits_ = detail::SymbolBits(num_symbols); + } + + /** + * \fn static size_t CompressedBufferWriter::CalculateBufferSize(int + * num_elements, int num_symbols) + * + * \brief Calculates number of bytes requiredm for a given number of elements + * and a symbol range. + * + * \author Rory + * \date 7/9/2017 + * + * \param num_elements Number of elements. + * \param num_symbols Max number of symbols (alphabet size) + * + * \return The calculated buffer size. + */ + + static size_t CalculateBufferSize(int num_elements, int num_symbols) { + const int bits_per_byte = 8; + int compressed_size = std::ceil( + static_cast(detail::SymbolBits(num_symbols) * num_elements) / + bits_per_byte); + return compressed_size + detail::padding; + } + + template + void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) { + const int bits_per_byte = 8; + + for (int i = 0; i < symbol_bits_; i++) { + size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte; + byte_idx += detail::padding; + int bit_idx = + ((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte; + + if (detail::CheckBit(symbol, i)) { + detail::SetBit(&buffer[byte_idx], bit_idx); + } else { + detail::ClearBit(&buffer[byte_idx], bit_idx); + } + } + } + template + void Write(compressed_byte_t *buffer, iter_t input_begin, iter_t input_end) { + uint64_t tmp = 0; + int stored_bits = 0; + const int max_stored_bits = 64 - symbol_bits_; + int buffer_position = detail::padding; + const int num_symbols = input_end - input_begin; + for (int i = 0; i < num_symbols; i++) { + typename std::iterator_traits::value_type symbol = input_begin[i]; + if (stored_bits > max_stored_bits) { + // Eject only full bytes + int tmp_bytes = stored_bits / 8; + for (int j = 0; j < tmp_bytes; j++) { + buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8); + buffer_position++; + } + stored_bits -= tmp_bytes * 8; + tmp &= (1 << stored_bits) - 1; + } + // Store symbol + tmp <<= symbol_bits_; + tmp |= symbol; + stored_bits += symbol_bits_; + } + + // Eject all bytes + int tmp_bytes = std::ceil(static_cast(stored_bits) / 8); + for (int j = 0; j < tmp_bytes; j++) { + int shift_bits = stored_bits - (j + 1) * 8; + if (shift_bits >= 0) { + buffer[buffer_position] = tmp >> shift_bits; + } else { + buffer[buffer_position] = tmp << std::abs(shift_bits); + } + buffer_position++; + } + } +}; + +template + +/** + * \class CompressedIterator + * + * \brief Read symbols from a bit compressed memory buffer. Usable on device and + * host. + * + * \author Rory + * \date 7/9/2017 + */ + +class CompressedIterator { + public: + typedef CompressedIterator self_type; ///< My own type + typedef ptrdiff_t + difference_type; ///< Type to express the result of subtracting + /// one iterator from another + typedef T value_type; ///< The type of the element the iterator can point to + typedef value_type *pointer; ///< The type of a pointer to an element the + /// iterator can point to + typedef value_type reference; ///< The type of a reference to an element the + /// iterator can point to + private: + compressed_byte_t *buffer_; + int symbol_bits_; + size_t offset_; + + public: + CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {} + CompressedIterator(compressed_byte_t *buffer, int num_symbols) + : buffer_(buffer), offset_(0) { + symbol_bits_ = detail::SymbolBits(num_symbols); + } + + XGBOOST_DEVICE reference operator*() const { + const int bits_per_byte = 8; + size_t start_bit_idx = ((offset_ + 1) * symbol_bits_ - 1); + size_t start_byte_idx = start_bit_idx / bits_per_byte; + start_byte_idx += detail::padding; + + // Read 5 bytes - the maximum we will need + uint64_t tmp = static_cast(buffer_[start_byte_idx - 4]) << 32 | + static_cast(buffer_[start_byte_idx - 3]) << 24 | + static_cast(buffer_[start_byte_idx - 2]) << 16 | + static_cast(buffer_[start_byte_idx - 1]) << 8 | + buffer_[start_byte_idx]; + int bit_shift = + (bits_per_byte - ((offset_ + 1) * symbol_bits_)) % bits_per_byte; + tmp >>= bit_shift; + // Mask off unneeded bits + uint64_t mask = (1 << symbol_bits_) - 1; + return static_cast(tmp & mask); + } + + XGBOOST_DEVICE reference operator[](int idx) const { + self_type offset = (*this); + offset.offset_ += idx; + return *offset; + } +}; +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc new file mode 100644 index 000000000..4ea1fc8e1 --- /dev/null +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -0,0 +1,54 @@ +#include "../../../src/common/compressed_iterator.h" +#include "gtest/gtest.h" + +namespace xgboost { +namespace common { +TEST(CompressedIterator, Test) { + ASSERT_TRUE(detail::SymbolBits(256) == 8); + ASSERT_TRUE(detail::SymbolBits(150) == 8); + std::vector test_cases = {3, 426, 21, 64, 256, 100000, INT32_MAX}; + int num_elements = 1000; + int repetitions = 1000; + srand(9); + + for (auto alphabet_size : test_cases) { + for (int i = 0; i < repetitions; i++) { + std::vector input(num_elements); + std::generate(input.begin(), input.end(), + [=]() { return rand() % alphabet_size; }); + CompressedBufferWriter cbw(alphabet_size); + + // Test write entire array + std::vector buffer( + CompressedBufferWriter::CalculateBufferSize(input.size(), + alphabet_size)); + + cbw.Write(buffer.data(), input.begin(), input.end()); + + CompressedIterator ci(buffer.data(), alphabet_size); + std::vector output(input.size()); + for (int i = 0; i < input.size(); i++) { + output[i] = ci[i]; + } + + ASSERT_TRUE(input == output); + + // Test write Symbol + std::vector buffer2( + CompressedBufferWriter::CalculateBufferSize(input.size(), + alphabet_size)); + for (int i = 0; i < input.size(); i++) { + cbw.WriteSymbol(buffer2.data(), input[i], i); + } + CompressedIterator ci2(buffer.data(), alphabet_size); + std::vector output2(input.size()); + for (int i = 0; i < input.size(); i++) { + output2[i] = ci2[i]; + } + ASSERT_TRUE(input == output2); + } + } +} + +} // namespace common +} // namespace xgboost \ No newline at end of file