[GPU-Plugin] Add load balancing search to gpu_hist. Add compressed iterator. (#2504)

This commit is contained in:
Rory Mitchell 2017-07-11 22:36:39 +12:00 committed by GitHub
parent 64c8f6fa6d
commit 530f01e21c
9 changed files with 523 additions and 222 deletions

View File

@ -56,6 +56,7 @@ endfunction(set_default_configuration_release)
function(format_gencode_flags flags out) function(format_gencode_flags flags out)
foreach(ver ${flags}) foreach(ver ${flags})
set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};" PARENT_SCOPE) set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};")
endforeach() endforeach()
set(${out} "${${out}}" PARENT_SCOPE)
endfunction(format_gencode_flags flags) endfunction(format_gencode_flags flags)

View File

@ -144,8 +144,10 @@ $ make PLUGIN_UPDATER_GPU=ON GTEST_PATH=${CACHE_PREFIX} test
``` ```
## Changelog ## Changelog
##### 2017/6/26 ##### 2017/7/10
* Memory performance improved 4x for gpu_hist
##### 2017/6/26
* Change API to use tree_method parameter * Change API to use tree_method parameter
* Increase required cmake version to 3.5 * Increase required cmake version to 3.5
* Add compute arch 3.5 to default archs * Add compute arch 3.5 to default archs

View File

@ -15,7 +15,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
param = {'objective': 'binary:logistic', param = {'objective': 'binary:logistic',
'max_depth': 6, 'max_depth': 6,
'silent': 1, 'silent': 0,
'n_gpus': 1, 'n_gpus': 1,
'gpu_id': 0, 'gpu_id': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
@ -26,6 +26,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
xgb.train(param, dtrain, args.iterations) xgb.train(param, dtrain, args.iterations)
print ("Time: %s seconds" % (str(time.time() - tmp))) print ("Time: %s seconds" % (str(time.time() - tmp)))
param['silent'] = 1
param['tree_method'] = cpu_algorithm param['tree_method'] = cpu_algorithm
print("Training with '%s'" % param['tree_method']) print("Training with '%s'" % param['tree_method'])
tmp = time.time() tmp = time.time()

View File

@ -2,11 +2,13 @@
* Copyright 2017 XGBoost contributors * Copyright 2017 XGBoost contributors
*/ */
#pragma once #pragma once
#include <dmlc/logging.h>
#include <thrust/binary_search.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include "nccl.h"
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
@ -15,7 +17,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "nccl.h"
// Uncomment to enable // Uncomment to enable
// #define DEVICE_TIMER // #define DEVICE_TIMER
@ -121,87 +123,6 @@ inline int get_device_idx(int gpu_id) {
* Timers * Timers
*/ */
#define MAX_WARPS 32 // Maximum number of warps to time
#define MAX_SLOTS 10
#define TIMER_BLOCKID 0 // Block to time
struct DeviceTimerGlobal {
#ifdef DEVICE_TIMER
clock_t total_clocks[MAX_SLOTS][MAX_WARPS];
int64_t count[MAX_SLOTS][MAX_WARPS];
#endif
// Clear device memory. Call at start of kernel.
__device__ void Init() {
#ifdef DEVICE_TIMER
if (blockIdx.x == TIMER_BLOCKID && threadIdx.x < MAX_WARPS) {
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
total_clocks[SLOT][threadIdx.x] = 0;
count[SLOT][threadIdx.x] = 0;
}
}
#endif
}
void HostPrint() {
#ifdef DEVICE_TIMER
DeviceTimerGlobal h_timer;
safe_cuda(
cudaMemcpyFromSymbol(&h_timer, (*this), sizeof(DeviceTimerGlobal)));
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
if (h_timer.count[SLOT][0] == 0) {
continue;
}
clock_t sum_clocks = 0;
int64_t sum_count = 0;
for (int WARP = 0; WARP < MAX_WARPS; WARP++) {
if (h_timer.count[SLOT][WARP] == 0) {
continue;
}
sum_clocks += h_timer.total_clocks[SLOT][WARP];
sum_count += h_timer.count[SLOT][WARP];
}
printf("Slot %d: %d clocks per call, called %d times.\n", SLOT,
sum_clocks / sum_count, h_timer.count[SLOT][0]);
}
#endif
}
};
struct DeviceTimer {
#ifdef DEVICE_TIMER
clock_t start;
int slot;
DeviceTimerGlobal &GTimer;
#endif
#ifdef DEVICE_TIMER
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) // NOLINT
: GTimer(GTimer),
start(clock()),
slot(slot) {}
#else
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) {} // NOLINT
#endif
__device__ void End() {
#ifdef DEVICE_TIMER
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
if (blockIdx.x == TIMER_BLOCKID && lane_id == 0) {
GTimer.count[slot][warp_id] += 1;
GTimer.total_clocks[slot][warp_id] += clock() - start;
}
#endif
}
};
struct Timer { struct Timer {
typedef std::chrono::high_resolution_clock ClockT; typedef std::chrono::high_resolution_clock ClockT;
@ -549,23 +470,36 @@ struct CubMemory {
void *d_temp_storage; void *d_temp_storage;
size_t temp_storage_bytes; size_t temp_storage_bytes;
// Thrust
typedef char value_type;
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
~CubMemory() { Free(); } ~CubMemory() { Free(); }
void Free() { void Free() {
if (d_temp_storage != NULL) { if (this->IsAllocated()) {
safe_cuda(cudaFree(d_temp_storage)); safe_cuda(cudaFree(d_temp_storage));
} }
} }
void LazyAllocate(size_t n_bytes) { void LazyAllocate(size_t num_bytes) {
if (n_bytes > temp_storage_bytes) { if (num_bytes > temp_storage_bytes) {
Free(); Free();
safe_cuda(cudaMalloc(&d_temp_storage, n_bytes)); safe_cuda(cudaMalloc(&d_temp_storage, num_bytes));
temp_storage_bytes = n_bytes; temp_storage_bytes = num_bytes;
} }
} }
// Thrust
char *allocate(std::ptrdiff_t num_bytes) {
LazyAllocate(num_bytes);
return reinterpret_cast<char *>(d_temp_storage);
}
// Thrust
void deallocate(char *ptr, size_t n) {
// Do nothing
}
bool IsAllocated() { return d_temp_storage != NULL; } bool IsAllocated() { return d_temp_storage != NULL; }
}; };
@ -591,7 +525,7 @@ void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
std::cout << "\n"; std::cout << "\n";
} }
template <typename T, memory_type MemoryT> template <typename T>
void print(const dvec<T> &v, size_t max_items = 10) { void print(const dvec<T> &v, size_t max_items = 10) {
std::vector<T> h = v.as_vector(); std::vector<T> h = v.as_vector();
for (int i = 0; i < std::min(max_items, h.size()); i++) { for (int i = 0; i < std::min(max_items, h.size()); i++) {
@ -714,4 +648,118 @@ struct BernoulliRng {
t1234.printElapsed(name); \ t1234.printElapsed(name); \
} while (0) } while (0)
// Load balancing search
template <typename func_t>
class LauncherItr {
public:
int idx;
func_t f;
XGBOOST_DEVICE LauncherItr() : idx(0) {}
XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {}
XGBOOST_DEVICE LauncherItr &operator=(int output) {
f(idx, output);
return *this;
}
};
template <typename func_t>
/**
* \class DiscardLambdaItr
*
* \brief Thrust compatible iterator type - discards algorithm output and
* launches device lambda with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*/
class DiscardLambdaItr {
public:
// Required iterator traits
typedef DiscardLambdaItr self_type; ///< My own type
typedef ptrdiff_t
difference_type; ///< Type to express the result of subtracting
/// one iterator from another
typedef LauncherItr<func_t>
value_type; ///< The type of the element the iterator can point to
typedef value_type *pointer; ///< The type of a pointer to an element the
/// iterator can point to
typedef value_type reference; ///< The type of a reference to an element the
/// iterator can point to
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type iterator_category; ///< The iterator category
private:
difference_type offset;
func_t f;
public:
XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f)
: offset(offset), f(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset + b, f);
}
XGBOOST_DEVICE self_type operator++() {
offset++;
return *this;
}
XGBOOST_DEVICE self_type operator++(int) {
self_type retval = *this;
offset++;
return retval;
}
XGBOOST_DEVICE self_type &operator+=(const int &b) {
offset += b;
return *this;
}
XGBOOST_DEVICE reference operator*() const {
return LauncherItr<func_t>(offset, f);
}
XGBOOST_DEVICE reference operator[](int idx) {
self_type offset = (*this) + idx;
return *offset;
}
};
/**
* \fn template <typename func_t, typename segments_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr<segments_t> segments, int num_segments, func_t f)
*
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
* to be executed on each element. Search 'modern GPU load balancing search for more
* information'.
*
* \author Rory
* \date 7/9/2017
*
* \tparam segments_t Type of the segments t.
* \param device_idx Zero-based index of the device.
* \param [in,out] temp_memory Temporary memory allocator.
* \param count Number of elements.
* \param segments Device pointed to segments.
* \param num_segments Number of segments.
* \param f Lambda to be executed on matrix elements.
*/
template <typename func_t, typename segments_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count,
thrust::device_ptr<segments_t> segments, int num_segments,
func_t f) {
safe_cuda(cudaSetDevice(device_idx));
auto counting = thrust::make_counting_iterator(0);
auto f_wrapper = [=] __device__(int idx, int upper_bound) {
f(idx, upper_bound - 1);
};
DiscardLambdaItr<decltype(f_wrapper)> itr(f_wrapper);
thrust::upper_bound(thrust::cuda::par(*temp_memory), segments,
segments + num_segments, counting, counting + count, itr);
}
} // namespace dh } // namespace dh

View File

@ -1,36 +1,45 @@
/*! /*!
* Copyright 2017 Rory mitchell * Copyright 2017 Rory mitchell
*/ */
#include <cub/cub.cuh>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <thrust/count.h> #include <thrust/count.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include <cub/cub.cuh>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <future> #include <future>
#include <numeric> #include <numeric>
#include "common.cuh" #include "common.cuh"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "dmlc/timer.h"
#include "gpu_hist_builder.cuh" #include "gpu_hist_builder.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat, void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
bst_uint begin, bst_uint end) { bst_uint element_begin, bst_uint element_end,
bst_uint row_begin, bst_uint row_end, int n_bins) {
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));
CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated"; CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated";
CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated"; CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1)
<< "row_ptr must be externally allocated";
thrust::copy(gmat.index.data() + begin, gmat.index.data() + end, gidx.tbegin()); common::CompressedBufferWriter cbw(n_bins);
thrust::device_vector<int> row_ptr = gmat.row_ptr; std::vector<common::compressed_byte_t> host_buffer(gidx_buffer.size());
cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin,
gmat.index.begin() + element_end);
gidx_buffer = host_buffer;
gidx = common::CompressedIterator<int>(gidx_buffer.data(), n_bins);
auto counting = thrust::make_counting_iterator(begin); // row_ptr
thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting, thrust::copy(gmat.row_ptr.data() + row_begin,
counting + gidx.size(), ridx.tbegin()); gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(), // normalise row_ptr
[=] __device__(int val) { return val - 1; }); bst_uint start = gmat.row_ptr[row_begin];
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
[=] __device__(int val) { return val - start; });
} }
void DeviceHist::Init(int n_bins_in) { void DeviceHist::Init(int n_bins_in) {
@ -59,10 +68,10 @@ HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins)
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const { __device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx; int hist_idx = nidx * n_bins + gidx;
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
// line lead to about 3X // line lead to about 3X
// slowdown due to memory // slowdown due to memory
// dependency and access // dependency and access
// pattern issues. // pattern issues.
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess); atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
} }
@ -170,7 +179,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// process) // process)
} }
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column " CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
"block. Try setting 'tree_method' " "block. Try setting 'tree_method' "
"parameter to 'exact'"; "parameter to 'exact'";
@ -219,6 +227,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// ba.allocate(master_device, ); // ba.allocate(master_device, );
// allocate vectors across all devices // allocate vectors across all devices
temp_memory.resize(n_devices);
hist_vec.resize(n_devices); hist_vec.resize(n_devices);
nodes.resize(n_devices); nodes.resize(n_devices);
nodes_temp.resize(n_devices); nodes_temp.resize(n_devices);
@ -269,18 +278,21 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
h_feature_segments.size(), // constant and same on all devices h_feature_segments.size(), // constant and same on all devices
&prediction_cache[d_idx], num_rows_segment, &position[d_idx], &prediction_cache[d_idx], num_rows_segment, &position[d_idx],
num_rows_segment, &position_tmp[d_idx], num_rows_segment, num_rows_segment, &position_tmp[d_idx], num_rows_segment,
&device_gpair[d_idx], num_rows_segment, &device_matrix[d_idx].gidx, &device_gpair[d_idx], num_rows_segment,
num_elements_segment, // constant and same on all devices &device_matrix[d_idx].gidx_buffer,
&device_matrix[d_idx].ridx, common::CompressedBufferWriter::CalculateBufferSize(
num_elements_segment, // constant and same on all devices num_elements_segment,
n_bins), // constant and same on all devices
&device_matrix[d_idx].row_ptr, num_rows_segment + 1,
&gidx_feature_map[d_idx], n_bins, // constant and same on all devices &gidx_feature_map[d_idx], n_bins, // constant and same on all devices
&gidx_fvalue_map[d_idx], &gidx_fvalue_map[d_idx],
hmat_.cut.size()); // constant and same on all devices hmat_.cut.size()); // constant and same on all devices
// Copy Host to Device (assumes comes after ba.allocate that sets device) // Copy Host to Device (assumes comes after ba.allocate that sets device)
device_matrix[d_idx].Init(device_idx, gmat_, device_matrix[d_idx].Init(
device_element_segments[d_idx], device_idx, gmat_, device_element_segments[d_idx],
device_element_segments[d_idx + 1]); device_element_segments[d_idx + 1], device_row_segments[d_idx],
device_row_segments[d_idx + 1], n_bins);
gidx_feature_map[d_idx] = h_gidx_feature_map; gidx_feature_map[d_idx] = h_gidx_feature_map;
gidx_fvalue_map[d_idx] = hmat_.cut; gidx_fvalue_map[d_idx] = hmat_.cut;
feature_segments[d_idx] = h_feature_segments; feature_segments[d_idx] = h_feature_segments;
@ -338,39 +350,41 @@ void GPUHistBuilder::BuildHist(int depth) {
size_t begin = device_element_segments[d_idx]; size_t begin = device_element_segments[d_idx];
size_t end = device_element_segments[d_idx + 1]; size_t end = device_element_segments[d_idx + 1];
size_t row_begin = device_row_segments[d_idx]; size_t row_begin = device_row_segments[d_idx];
size_t row_end = device_row_segments[d_idx + 1];
auto d_ridx = device_matrix[d_idx].ridx.data(); auto d_gidx = device_matrix[d_idx].gidx;
auto d_gidx = device_matrix[d_idx].gidx.data(); auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin();
auto d_position = position[d_idx].data(); auto d_position = position[d_idx].data();
auto d_gpair = device_gpair[d_idx].data(); auto d_gpair = device_gpair[d_idx].data();
auto d_left_child_smallest = left_child_smallest[d_idx].data(); auto d_left_child_smallest = left_child_smallest[d_idx].data();
auto hist_builder = hist_vec[d_idx].GetBuilder(); auto hist_builder = hist_vec[d_idx].GetBuilder();
dh::TransformLbs(
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
int nidx = d_position[local_ridx]; // OPTMARK: latency
if (!is_active(nidx, depth)) return;
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) { // Only increment smallest node
int ridx = d_ridx[local_idx]; // OPTMARK: latency bool is_smallest = (d_left_child_smallest[parent_nidx(nidx)] &&
int nidx = d_position[ridx - row_begin]; // OPTMARK: latency is_left_child(nidx)) ||
if (!is_active(nidx, depth)) return; (!d_left_child_smallest[parent_nidx(nidx)] &&
!is_left_child(nidx));
if (!is_smallest && depth > 0) return;
// Only increment smallest node int gidx = d_gidx[local_idx];
bool is_smallest = bst_gpair gpair = d_gpair[local_ridx];
(d_left_child_smallest[parent_nidx(nidx)] && is_left_child(nidx)) ||
(!d_left_child_smallest[parent_nidx(nidx)] && !is_left_child(nidx));
if (!is_smallest && depth > 0) return;
int gidx = d_gidx[local_idx]; hist_builder.Add(gpair, gidx,
bst_gpair gpair = d_gpair[ridx - row_begin]; nidx); // OPTMARK: This is slow, could use
// shared memory or cache results
hist_builder.Add(gpair, gidx, nidx); // OPTMARK: This is slow, could use // intead of writing to global
// shared memory or cache results // memory every time in atomic way.
// intead of writing to global });
// memory every time in atomic way.
});
} }
// dh::safe_cuda(cudaDeviceSynchronize());
dh::synchronize_n_devices(n_devices, dList); dh::synchronize_n_devices(n_devices, dList);
// time.printElapsed("Add Time"); // time.printElapsed("Add Time");
// (in-place) reduce each element of histogram (for only current level) across // (in-place) reduce each element of histogram (for only current level) across
// multiple gpus // multiple gpus
@ -393,7 +407,7 @@ void GPUHistBuilder::BuildHist(int depth) {
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx]))); dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
} }
// if no NCCL, then presume only 1 GPU, then already correct // if no NCCL, then presume only 1 GPU, then already correct
// time.printElapsed("Reduce-Add Time"); // time.printElapsed("Reduce-Add Time");
@ -572,15 +586,15 @@ __global__ void find_split_kernel(
left_child_smallest = &d_left_child_smallest_temp[blockIdx.x]; left_child_smallest = &d_left_child_smallest_temp[blockIdx.x];
} }
*Nodeleft = Node( *Nodeleft =
split.left_sum, Node(split.left_sum,
CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess), CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess),
CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess)); CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess));
*Noderight = Node( *Noderight =
split.right_sum, Node(split.right_sum,
CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess), CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess),
CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess)); CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess));
// Record smallest node // Record smallest node
if (split.left_sum.hess <= split.right_sum.hess) { if (split.left_sum.hess <= split.right_sum.hess) {
@ -650,9 +664,9 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
feature_segments[d_idx].data(), depth, (info->num_col), feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(), (hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
nodes_child_temp[d_idx].data(), nodes_offset_device, nodes_child_temp[d_idx].data(), nodes_offset_device,
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(),
left_child_smallest_temp[d_idx].data(), colsample, GPUTrainingParam(param), left_child_smallest_temp[d_idx].data(),
feature_flags[d_idx].data()); colsample, feature_flags[d_idx].data());
} }
// nccl only on devices that did split // nccl only on devices that did split
@ -747,7 +761,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
feature_segments[d_idx].data(), depth, (info->num_col), feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
nodes_offset_device, fidx_min_map[d_idx].data(), nodes_offset_device, fidx_min_map[d_idx].data(),
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
left_child_smallest[d_idx].data(), colsample, left_child_smallest[d_idx].data(), colsample,
feature_flags[d_idx].data()); feature_flags[d_idx].data());
@ -800,7 +814,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
feature_segments[d_idx].data(), depth, (info->num_col), feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
nodes_offset_device, fidx_min_map[d_idx].data(), nodes_offset_device, fidx_min_map[d_idx].data(),
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
left_child_smallest[d_idx].data(), colsample, left_child_smallest[d_idx].data(), colsample,
feature_flags[d_idx].data()); feature_flags[d_idx].data());
} }
@ -811,57 +825,23 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
} }
void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) { void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
#ifdef _WIN32 // Perform asynchronous reduction on each gpu
// Visual studio complains about C:/Program Files (x86)/Microsoft Visual std::vector<bst_gpair> device_sums(n_devices);
// Studio 14.0/VC/bin/../../VC/INCLUDE\utility(445): error : static assertion #pragma omp parallel for num_threads(n_devices)
// failed with "tuple index out of bounds"
// and C:/Program Files (x86)/Microsoft Visual Studio
// 14.0/VC/bin/../../VC/INCLUDE\future(1888): error : no instance of function
// template "std::_Invoke_stored" matches the argument list
std::vector<bst_gpair> future_results(n_devices);
for (int d_idx = 0; d_idx < n_devices; d_idx++) { for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx]; int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
auto begin = device_gpair[d_idx].tbegin(); auto begin = device_gpair[d_idx].tbegin();
auto end = device_gpair[d_idx].tend(); auto end = device_gpair[d_idx].tend();
bst_gpair init = bst_gpair(); bst_gpair init = bst_gpair();
auto binary_op = thrust::plus<bst_gpair>(); auto binary_op = thrust::plus<bst_gpair>();
device_sums[d_idx] = thrust::reduce(begin, end, init, binary_op);
dh::safe_cuda(cudaSetDevice(device_idx));
future_results[d_idx] = thrust::reduce(begin, end, init, binary_op);
} }
// sum over devices on host (with blocking get())
bst_gpair sum = bst_gpair(); bst_gpair sum = bst_gpair();
for (int d_idx = 0; d_idx < n_devices; d_idx++) { for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx]; sum += device_sums[d_idx];
sum += future_results[d_idx];
} }
#else
// asynch reduce per device
std::vector<std::future<bst_gpair>> future_results(n_devices);
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
// std::async captures the algorithm parameters by value
// use std::launch::async to ensure the creation of a new thread
future_results[d_idx] = std::async(std::launch::async, [=] {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
auto begin = device_gpair[d_idx].tbegin();
auto end = device_gpair[d_idx].tend();
bst_gpair init = bst_gpair();
auto binary_op = thrust::plus<bst_gpair>();
return thrust::reduce(begin, end, init, binary_op);
});
}
// sum over devices on host (with blocking get())
bst_gpair sum = bst_gpair();
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
sum += future_results[d_idx].get();
}
#endif
// Setup first node so all devices have same first node (here done same on all // Setup first node so all devices have same first node (here done same on all
// devices, or could have done one device and Bcast if worried about exact // devices, or could have done one device and Bcast if worried about exact
@ -874,11 +854,10 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
dh::launch_n(device_idx, 1, [=] __device__(int idx) { dh::launch_n(device_idx, 1, [=] __device__(int idx) {
bst_gpair sum_gradients = sum; bst_gpair sum_gradients = sum;
d_nodes[idx] = Node( d_nodes[idx] =
sum_gradients, Node(sum_gradients,
CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess), CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess),
CalcWeight(gpu_param, sum_gradients.grad, CalcWeight(gpu_param, sum_gradients.grad, sum_gradients.hess));
sum_gradients.hess));
}); });
} }
// synch all devices to host before moving on (No, can avoid because BuildHist // synch all devices to host before moving on (No, can avoid because BuildHist
@ -901,7 +880,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
auto d_position = position[d_idx].data(); auto d_position = position[d_idx].data();
Node* d_nodes = nodes[d_idx].data(); Node* d_nodes = nodes[d_idx].data();
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
auto d_gidx = device_matrix[d_idx].gidx.data(); auto d_gidx = device_matrix[d_idx].gidx;
int n_columns = info->num_col; int n_columns = info->num_col;
size_t begin = device_row_segments[d_idx]; size_t begin = device_row_segments[d_idx];
size_t end = device_row_segments[d_idx + 1]; size_t end = device_row_segments[d_idx + 1];
@ -941,8 +920,8 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
Node* d_nodes = nodes[d_idx].data(); Node* d_nodes = nodes[d_idx].data();
auto d_gidx_feature_map = gidx_feature_map[d_idx].data(); auto d_gidx_feature_map = gidx_feature_map[d_idx].data();
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
auto d_gidx = device_matrix[d_idx].gidx.data(); auto d_gidx = device_matrix[d_idx].gidx;
auto d_ridx = device_matrix[d_idx].ridx.data(); auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin();
size_t row_begin = device_row_segments[d_idx]; size_t row_begin = device_row_segments[d_idx];
size_t row_end = device_row_segments[d_idx + 1]; size_t row_end = device_row_segments[d_idx + 1];
@ -973,10 +952,11 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
// Update node based on fvalue where exists // Update node based on fvalue where exists
// OPTMARK: This kernel is very inefficient for both compute and memory, // OPTMARK: This kernel is very inefficient for both compute and memory,
// dominated by memory dependency / access patterns // dominated by memory dependency / access patterns
dh::launch_n(
device_idx, element_end - element_begin, [=] __device__(int local_idx) { dh::TransformLbs(
int ridx = d_ridx[local_idx]; device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
int pos = d_position[ridx - row_begin]; row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
int pos = d_position[local_ridx];
if (!is_active(pos, depth)) { if (!is_active(pos, depth)) {
return; return;
} }
@ -997,9 +977,9 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
float fvalue = d_gidx_fvalue_map[gidx]; float fvalue = d_gidx_fvalue_map[gidx];
if (fvalue <= node.split.fvalue) { if (fvalue <= node.split.fvalue) {
d_position_tmp[ridx - row_begin] = left_child_nidx(pos); d_position_tmp[local_ridx] = left_child_nidx(pos);
} else { } else {
d_position_tmp[ridx - row_begin] = right_child_nidx(pos); d_position_tmp[local_ridx] = right_child_nidx(pos);
} }
} }
}); });
@ -1026,10 +1006,6 @@ void GPUHistBuilder::ColSampleLevel() {
h_feature_flags[fidx] = 1; h_feature_flags[fidx] = 1;
} }
// copy from Host to Device for all devices
// for(auto &f:feature_flags){ // this doesn't set device as should
// f = h_feature_flags;
// }
for (int d_idx = 0; d_idx < n_devices; d_idx++) { for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx]; int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));

View File

@ -8,26 +8,20 @@
#include <vector> #include <vector>
#include "../../src/common/hist_util.h" #include "../../src/common/hist_util.h"
#include "../../src/tree/param.h" #include "../../src/tree/param.h"
#include "../../src/common/compressed_iterator.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "types.cuh" #include "types.cuh"
#ifndef NCCL
#define NCCL 1
#endif
#if (NCCL)
#include "nccl.h" #include "nccl.h"
#endif
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
struct DeviceGMat { struct DeviceGMat {
dh::dvec<int> gidx; dh::dvec<common::compressed_byte_t> gidx_buffer;
dh::dvec<int> ridx; common::CompressedIterator<int > gidx;
dh::dvec<int> row_ptr;
void Init(int device_idx, const common::GHistIndexMatrix &gmat, void Init(int device_idx, const common::GHistIndexMatrix &gmat,
bst_uint begin, bst_uint end); bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins);
}; };
struct HistBuilder { struct HistBuilder {
@ -95,7 +89,6 @@ class GPUHistBuilder {
dh::bulk_allocator<dh::memory_type::DEVICE> ba; dh::bulk_allocator<dh::memory_type::DEVICE> ba;
// dh::bulk_allocator<dh::memory_type::DEVICE_MANAGED> ba; // can't be used // dh::bulk_allocator<dh::memory_type::DEVICE_MANAGED> ba; // can't be used
// with NCCL // with NCCL
dh::CubMemory cub_mem;
std::vector<int> feature_set_tree; std::vector<int> feature_set_tree;
std::vector<int> feature_set_level; std::vector<int> feature_set_level;
@ -108,6 +101,7 @@ class GPUHistBuilder {
std::vector<int> device_row_segments; std::vector<int> device_row_segments;
std::vector<int> device_element_segments; std::vector<int> device_element_segments;
std::vector<dh::CubMemory> temp_memory;
std::vector<DeviceHist> hist_vec; std::vector<DeviceHist> hist_vec;
std::vector<dh::dvec<Node>> nodes; std::vector<dh::dvec<Node>> nodes;
std::vector<dh::dvec<Node>> nodes_temp; std::vector<dh::dvec<Node>> nodes_temp;
@ -126,10 +120,8 @@ class GPUHistBuilder {
std::vector<dh::dvec<float>> gidx_fvalue_map; std::vector<dh::dvec<float>> gidx_fvalue_map;
std::vector<cudaStream_t *> streams; std::vector<cudaStream_t *> streams;
#if (NCCL)
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
std::vector<std::vector<ncclComm_t>> find_split_comms; std::vector<std::vector<ncclComm_t>> find_split_comms;
#endif
}; };
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,28 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include "../../src/device_helpers.cuh"
#include "gtest/gtest.h"
static const std::vector<int> gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7};
static const std::vector<int> row_ptr = {0, 3, 6, 8, 10};
static const std::vector<int> lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3};
thrust::device_vector<int> test_lbs() {
thrust::device_vector<int> device_gidx = gidx;
thrust::device_vector<int> device_row_ptr = row_ptr;
thrust::device_vector<int> device_output_row(gidx.size(), 0);
auto d_output_row = device_output_row.data();
dh::CubMemory temp_memory;
dh::TransformLbs(
0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1,
[=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; });
dh::safe_cuda(cudaDeviceSynchronize());
return device_output_row;
}
TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); }

View File

@ -0,0 +1,199 @@
/*!
* Copyright 2017 by Contributors
* \file compressed_iterator.h
*/
#pragma once
#include <xgboost/base.h>
#include <cmath>
#include <cstddef>
#include "dmlc/logging.h"
namespace xgboost {
namespace common {
typedef unsigned char compressed_byte_t;
namespace detail {
inline void SetBit(compressed_byte_t *byte, int bit_idx) {
*byte |= 1 << bit_idx;
}
template <typename T>
inline T CheckBit(const T &byte, int bit_idx) {
return byte & (1 << bit_idx);
}
inline void ClearBit(compressed_byte_t *byte, int bit_idx) {
*byte &= ~(1 << bit_idx);
}
static const int padding = 4; // Assign padding so we can read slightly off
// the beginning of the array
// The number of bits required to represent a given unsigned range
static int SymbolBits(int num_symbols) {
return std::ceil(std::log2(num_symbols));
}
} // namespace detail
/**
* \class CompressedBufferWriter
*
* \brief Writes bit compressed symbols to a memory buffer. Use
* CompressedIterator to read symbols back from buffer. Currently limited to a
* maximum symbol size of 28 bits.
*
* \author Rory
* \date 7/9/2017
*/
class CompressedBufferWriter {
private:
int symbol_bits_;
size_t offset_;
public:
explicit CompressedBufferWriter(int num_symbols) : offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}
/**
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
* num_elements, int num_symbols)
*
* \brief Calculates number of bytes requiredm for a given number of elements
* and a symbol range.
*
* \author Rory
* \date 7/9/2017
*
* \param num_elements Number of elements.
* \param num_symbols Max number of symbols (alphabet size)
*
* \return The calculated buffer size.
*/
static size_t CalculateBufferSize(int num_elements, int num_symbols) {
const int bits_per_byte = 8;
int compressed_size = std::ceil(
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
bits_per_byte);
return compressed_size + detail::padding;
}
template <typename T>
void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) {
const int bits_per_byte = 8;
for (int i = 0; i < symbol_bits_; i++) {
size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte;
byte_idx += detail::padding;
int bit_idx =
((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte;
if (detail::CheckBit(symbol, i)) {
detail::SetBit(&buffer[byte_idx], bit_idx);
} else {
detail::ClearBit(&buffer[byte_idx], bit_idx);
}
}
}
template <typename iter_t>
void Write(compressed_byte_t *buffer, iter_t input_begin, iter_t input_end) {
uint64_t tmp = 0;
int stored_bits = 0;
const int max_stored_bits = 64 - symbol_bits_;
int buffer_position = detail::padding;
const int num_symbols = input_end - input_begin;
for (int i = 0; i < num_symbols; i++) {
typename std::iterator_traits<iter_t>::value_type symbol = input_begin[i];
if (stored_bits > max_stored_bits) {
// Eject only full bytes
int tmp_bytes = stored_bits / 8;
for (int j = 0; j < tmp_bytes; j++) {
buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8);
buffer_position++;
}
stored_bits -= tmp_bytes * 8;
tmp &= (1 << stored_bits) - 1;
}
// Store symbol
tmp <<= symbol_bits_;
tmp |= symbol;
stored_bits += symbol_bits_;
}
// Eject all bytes
int tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
for (int j = 0; j < tmp_bytes; j++) {
int shift_bits = stored_bits - (j + 1) * 8;
if (shift_bits >= 0) {
buffer[buffer_position] = tmp >> shift_bits;
} else {
buffer[buffer_position] = tmp << std::abs(shift_bits);
}
buffer_position++;
}
}
};
template <typename T>
/**
* \class CompressedIterator
*
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
* host.
*
* \author Rory
* \date 7/9/2017
*/
class CompressedIterator {
public:
typedef CompressedIterator<T> self_type; ///< My own type
typedef ptrdiff_t
difference_type; ///< Type to express the result of subtracting
/// one iterator from another
typedef T value_type; ///< The type of the element the iterator can point to
typedef value_type *pointer; ///< The type of a pointer to an element the
/// iterator can point to
typedef value_type reference; ///< The type of a reference to an element the
/// iterator can point to
private:
compressed_byte_t *buffer_;
int symbol_bits_;
size_t offset_;
public:
CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {}
CompressedIterator(compressed_byte_t *buffer, int num_symbols)
: buffer_(buffer), offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}
XGBOOST_DEVICE reference operator*() const {
const int bits_per_byte = 8;
size_t start_bit_idx = ((offset_ + 1) * symbol_bits_ - 1);
size_t start_byte_idx = start_bit_idx / bits_per_byte;
start_byte_idx += detail::padding;
// Read 5 bytes - the maximum we will need
uint64_t tmp = static_cast<uint64_t>(buffer_[start_byte_idx - 4]) << 32 |
static_cast<uint64_t>(buffer_[start_byte_idx - 3]) << 24 |
static_cast<uint64_t>(buffer_[start_byte_idx - 2]) << 16 |
static_cast<uint64_t>(buffer_[start_byte_idx - 1]) << 8 |
buffer_[start_byte_idx];
int bit_shift =
(bits_per_byte - ((offset_ + 1) * symbol_bits_)) % bits_per_byte;
tmp >>= bit_shift;
// Mask off unneeded bits
uint64_t mask = (1 << symbol_bits_) - 1;
return static_cast<T>(tmp & mask);
}
XGBOOST_DEVICE reference operator[](int idx) const {
self_type offset = (*this);
offset.offset_ += idx;
return *offset;
}
};
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,54 @@
#include "../../../src/common/compressed_iterator.h"
#include "gtest/gtest.h"
namespace xgboost {
namespace common {
TEST(CompressedIterator, Test) {
ASSERT_TRUE(detail::SymbolBits(256) == 8);
ASSERT_TRUE(detail::SymbolBits(150) == 8);
std::vector<int> test_cases = {3, 426, 21, 64, 256, 100000, INT32_MAX};
int num_elements = 1000;
int repetitions = 1000;
srand(9);
for (auto alphabet_size : test_cases) {
for (int i = 0; i < repetitions; i++) {
std::vector<int> input(num_elements);
std::generate(input.begin(), input.end(),
[=]() { return rand() % alphabet_size; });
CompressedBufferWriter cbw(alphabet_size);
// Test write entire array
std::vector<unsigned char> buffer(
CompressedBufferWriter::CalculateBufferSize(input.size(),
alphabet_size));
cbw.Write(buffer.data(), input.begin(), input.end());
CompressedIterator<int> ci(buffer.data(), alphabet_size);
std::vector<int> output(input.size());
for (int i = 0; i < input.size(); i++) {
output[i] = ci[i];
}
ASSERT_TRUE(input == output);
// Test write Symbol
std::vector<unsigned char> buffer2(
CompressedBufferWriter::CalculateBufferSize(input.size(),
alphabet_size));
for (int i = 0; i < input.size(); i++) {
cbw.WriteSymbol(buffer2.data(), input[i], i);
}
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
std::vector<int> output2(input.size());
for (int i = 0; i < input.size(); i++) {
output2[i] = ci2[i];
}
ASSERT_TRUE(input == output2);
}
}
}
} // namespace common
} // namespace xgboost