[GPU-Plugin] Add load balancing search to gpu_hist. Add compressed iterator. (#2504)
This commit is contained in:
parent
64c8f6fa6d
commit
530f01e21c
@ -56,6 +56,7 @@ endfunction(set_default_configuration_release)
|
|||||||
|
|
||||||
function(format_gencode_flags flags out)
|
function(format_gencode_flags flags out)
|
||||||
foreach(ver ${flags})
|
foreach(ver ${flags})
|
||||||
set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};" PARENT_SCOPE)
|
set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};")
|
||||||
endforeach()
|
endforeach()
|
||||||
|
set(${out} "${${out}}" PARENT_SCOPE)
|
||||||
endfunction(format_gencode_flags flags)
|
endfunction(format_gencode_flags flags)
|
||||||
@ -144,8 +144,10 @@ $ make PLUGIN_UPDATER_GPU=ON GTEST_PATH=${CACHE_PREFIX} test
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
##### 2017/6/26
|
##### 2017/7/10
|
||||||
|
* Memory performance improved 4x for gpu_hist
|
||||||
|
|
||||||
|
##### 2017/6/26
|
||||||
* Change API to use tree_method parameter
|
* Change API to use tree_method parameter
|
||||||
* Increase required cmake version to 3.5
|
* Increase required cmake version to 3.5
|
||||||
* Add compute arch 3.5 to default archs
|
* Add compute arch 3.5 to default archs
|
||||||
|
|||||||
@ -15,7 +15,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
|||||||
|
|
||||||
param = {'objective': 'binary:logistic',
|
param = {'objective': 'binary:logistic',
|
||||||
'max_depth': 6,
|
'max_depth': 6,
|
||||||
'silent': 1,
|
'silent': 0,
|
||||||
'n_gpus': 1,
|
'n_gpus': 1,
|
||||||
'gpu_id': 0,
|
'gpu_id': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
@ -26,6 +26,7 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
|||||||
xgb.train(param, dtrain, args.iterations)
|
xgb.train(param, dtrain, args.iterations)
|
||||||
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
param['silent'] = 1
|
||||||
param['tree_method'] = cpu_algorithm
|
param['tree_method'] = cpu_algorithm
|
||||||
print("Training with '%s'" % param['tree_method'])
|
print("Training with '%s'" % param['tree_method'])
|
||||||
tmp = time.time()
|
tmp = time.time()
|
||||||
|
|||||||
@ -2,11 +2,13 @@
|
|||||||
* Copyright 2017 XGBoost contributors
|
* Copyright 2017 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <dmlc/logging.h>
|
||||||
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <thrust/random.h>
|
#include <thrust/random.h>
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
|
#include <thrust/system/cuda/execution_policy.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
#include "nccl.h"
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@ -15,7 +17,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "nccl.h"
|
||||||
|
|
||||||
// Uncomment to enable
|
// Uncomment to enable
|
||||||
// #define DEVICE_TIMER
|
// #define DEVICE_TIMER
|
||||||
@ -121,87 +123,6 @@ inline int get_device_idx(int gpu_id) {
|
|||||||
* Timers
|
* Timers
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#define MAX_WARPS 32 // Maximum number of warps to time
|
|
||||||
#define MAX_SLOTS 10
|
|
||||||
#define TIMER_BLOCKID 0 // Block to time
|
|
||||||
struct DeviceTimerGlobal {
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
|
|
||||||
clock_t total_clocks[MAX_SLOTS][MAX_WARPS];
|
|
||||||
int64_t count[MAX_SLOTS][MAX_WARPS];
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Clear device memory. Call at start of kernel.
|
|
||||||
__device__ void Init() {
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
if (blockIdx.x == TIMER_BLOCKID && threadIdx.x < MAX_WARPS) {
|
|
||||||
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
|
|
||||||
total_clocks[SLOT][threadIdx.x] = 0;
|
|
||||||
count[SLOT][threadIdx.x] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void HostPrint() {
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
DeviceTimerGlobal h_timer;
|
|
||||||
safe_cuda(
|
|
||||||
cudaMemcpyFromSymbol(&h_timer, (*this), sizeof(DeviceTimerGlobal)));
|
|
||||||
|
|
||||||
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
|
|
||||||
if (h_timer.count[SLOT][0] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
clock_t sum_clocks = 0;
|
|
||||||
int64_t sum_count = 0;
|
|
||||||
|
|
||||||
for (int WARP = 0; WARP < MAX_WARPS; WARP++) {
|
|
||||||
if (h_timer.count[SLOT][WARP] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
sum_clocks += h_timer.total_clocks[SLOT][WARP];
|
|
||||||
sum_count += h_timer.count[SLOT][WARP];
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("Slot %d: %d clocks per call, called %d times.\n", SLOT,
|
|
||||||
sum_clocks / sum_count, h_timer.count[SLOT][0]);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DeviceTimer {
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
clock_t start;
|
|
||||||
int slot;
|
|
||||||
DeviceTimerGlobal >imer;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT
|
|
||||||
: GTimer(GTimer),
|
|
||||||
start(clock()),
|
|
||||||
slot(slot) {}
|
|
||||||
#else
|
|
||||||
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT
|
|
||||||
#endif
|
|
||||||
|
|
||||||
__device__ void End() {
|
|
||||||
#ifdef DEVICE_TIMER
|
|
||||||
int warp_id = threadIdx.x / 32;
|
|
||||||
int lane_id = threadIdx.x % 32;
|
|
||||||
if (blockIdx.x == TIMER_BLOCKID && lane_id == 0) {
|
|
||||||
GTimer.count[slot][warp_id] += 1;
|
|
||||||
GTimer.total_clocks[slot][warp_id] += clock() - start;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Timer {
|
struct Timer {
|
||||||
typedef std::chrono::high_resolution_clock ClockT;
|
typedef std::chrono::high_resolution_clock ClockT;
|
||||||
|
|
||||||
@ -549,23 +470,36 @@ struct CubMemory {
|
|||||||
void *d_temp_storage;
|
void *d_temp_storage;
|
||||||
size_t temp_storage_bytes;
|
size_t temp_storage_bytes;
|
||||||
|
|
||||||
|
// Thrust
|
||||||
|
typedef char value_type;
|
||||||
|
|
||||||
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
|
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
|
||||||
|
|
||||||
~CubMemory() { Free(); }
|
~CubMemory() { Free(); }
|
||||||
|
|
||||||
void Free() {
|
void Free() {
|
||||||
if (d_temp_storage != NULL) {
|
if (this->IsAllocated()) {
|
||||||
safe_cuda(cudaFree(d_temp_storage));
|
safe_cuda(cudaFree(d_temp_storage));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LazyAllocate(size_t n_bytes) {
|
void LazyAllocate(size_t num_bytes) {
|
||||||
if (n_bytes > temp_storage_bytes) {
|
if (num_bytes > temp_storage_bytes) {
|
||||||
Free();
|
Free();
|
||||||
safe_cuda(cudaMalloc(&d_temp_storage, n_bytes));
|
safe_cuda(cudaMalloc(&d_temp_storage, num_bytes));
|
||||||
temp_storage_bytes = n_bytes;
|
temp_storage_bytes = num_bytes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Thrust
|
||||||
|
char *allocate(std::ptrdiff_t num_bytes) {
|
||||||
|
LazyAllocate(num_bytes);
|
||||||
|
return reinterpret_cast<char *>(d_temp_storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thrust
|
||||||
|
void deallocate(char *ptr, size_t n) {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
|
||||||
bool IsAllocated() { return d_temp_storage != NULL; }
|
bool IsAllocated() { return d_temp_storage != NULL; }
|
||||||
};
|
};
|
||||||
@ -591,7 +525,7 @@ void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
|||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, memory_type MemoryT>
|
template <typename T>
|
||||||
void print(const dvec<T> &v, size_t max_items = 10) {
|
void print(const dvec<T> &v, size_t max_items = 10) {
|
||||||
std::vector<T> h = v.as_vector();
|
std::vector<T> h = v.as_vector();
|
||||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||||
@ -714,4 +648,118 @@ struct BernoulliRng {
|
|||||||
t1234.printElapsed(name); \
|
t1234.printElapsed(name); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
// Load balancing search
|
||||||
|
|
||||||
|
template <typename func_t>
|
||||||
|
class LauncherItr {
|
||||||
|
public:
|
||||||
|
int idx;
|
||||||
|
func_t f;
|
||||||
|
XGBOOST_DEVICE LauncherItr() : idx(0) {}
|
||||||
|
XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {}
|
||||||
|
XGBOOST_DEVICE LauncherItr &operator=(int output) {
|
||||||
|
f(idx, output);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename func_t>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \class DiscardLambdaItr
|
||||||
|
*
|
||||||
|
* \brief Thrust compatible iterator type - discards algorithm output and
|
||||||
|
* launches device lambda with the index of the output and the algorithm output as arguments.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*/
|
||||||
|
|
||||||
|
class DiscardLambdaItr {
|
||||||
|
public:
|
||||||
|
// Required iterator traits
|
||||||
|
typedef DiscardLambdaItr self_type; ///< My own type
|
||||||
|
typedef ptrdiff_t
|
||||||
|
difference_type; ///< Type to express the result of subtracting
|
||||||
|
/// one iterator from another
|
||||||
|
typedef LauncherItr<func_t>
|
||||||
|
value_type; ///< The type of the element the iterator can point to
|
||||||
|
typedef value_type *pointer; ///< The type of a pointer to an element the
|
||||||
|
/// iterator can point to
|
||||||
|
typedef value_type reference; ///< The type of a reference to an element the
|
||||||
|
/// iterator can point to
|
||||||
|
typedef typename thrust::detail::iterator_facade_category<
|
||||||
|
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||||
|
reference>::type iterator_category; ///< The iterator category
|
||||||
|
private:
|
||||||
|
difference_type offset;
|
||||||
|
func_t f;
|
||||||
|
|
||||||
|
public:
|
||||||
|
XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {}
|
||||||
|
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f)
|
||||||
|
: offset(offset), f(f) {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||||
|
return DiscardLambdaItr(offset + b, f);
|
||||||
|
}
|
||||||
|
XGBOOST_DEVICE self_type operator++() {
|
||||||
|
offset++;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
XGBOOST_DEVICE self_type operator++(int) {
|
||||||
|
self_type retval = *this;
|
||||||
|
offset++;
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
XGBOOST_DEVICE self_type &operator+=(const int &b) {
|
||||||
|
offset += b;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
XGBOOST_DEVICE reference operator*() const {
|
||||||
|
return LauncherItr<func_t>(offset, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE reference operator[](int idx) {
|
||||||
|
self_type offset = (*this) + idx;
|
||||||
|
return *offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn template <typename func_t, typename segments_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr<segments_t> segments, int num_segments, func_t f)
|
||||||
|
*
|
||||||
|
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
|
||||||
|
* to be executed on each element. Search 'modern GPU load balancing search for more
|
||||||
|
* information'.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*
|
||||||
|
* \tparam segments_t Type of the segments t.
|
||||||
|
* \param device_idx Zero-based index of the device.
|
||||||
|
* \param [in,out] temp_memory Temporary memory allocator.
|
||||||
|
* \param count Number of elements.
|
||||||
|
* \param segments Device pointed to segments.
|
||||||
|
* \param num_segments Number of segments.
|
||||||
|
* \param f Lambda to be executed on matrix elements.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename func_t, typename segments_t>
|
||||||
|
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count,
|
||||||
|
thrust::device_ptr<segments_t> segments, int num_segments,
|
||||||
|
func_t f) {
|
||||||
|
safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
auto counting = thrust::make_counting_iterator(0);
|
||||||
|
|
||||||
|
auto f_wrapper = [=] __device__(int idx, int upper_bound) {
|
||||||
|
f(idx, upper_bound - 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
DiscardLambdaItr<decltype(f_wrapper)> itr(f_wrapper);
|
||||||
|
|
||||||
|
thrust::upper_bound(thrust::cuda::par(*temp_memory), segments,
|
||||||
|
segments + num_segments, counting, counting + count, itr);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -1,36 +1,45 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017 Rory mitchell
|
* Copyright 2017 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#include <cub/cub.cuh>
|
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/count.h>
|
#include <thrust/count.h>
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
#include <thrust/sort.h>
|
#include <thrust/sort.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
|
#include "dmlc/timer.h"
|
||||||
#include "gpu_hist_builder.cuh"
|
#include "gpu_hist_builder.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
||||||
bst_uint begin, bst_uint end) {
|
bst_uint element_begin, bst_uint element_end,
|
||||||
|
bst_uint row_begin, bst_uint row_end, int n_bins) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated";
|
CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated";
|
||||||
CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated";
|
CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1)
|
||||||
|
<< "row_ptr must be externally allocated";
|
||||||
|
|
||||||
thrust::copy(gmat.index.data() + begin, gmat.index.data() + end, gidx.tbegin());
|
common::CompressedBufferWriter cbw(n_bins);
|
||||||
thrust::device_vector<int> row_ptr = gmat.row_ptr;
|
std::vector<common::compressed_byte_t> host_buffer(gidx_buffer.size());
|
||||||
|
cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin,
|
||||||
|
gmat.index.begin() + element_end);
|
||||||
|
gidx_buffer = host_buffer;
|
||||||
|
gidx = common::CompressedIterator<int>(gidx_buffer.data(), n_bins);
|
||||||
|
|
||||||
auto counting = thrust::make_counting_iterator(begin);
|
// row_ptr
|
||||||
thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting,
|
thrust::copy(gmat.row_ptr.data() + row_begin,
|
||||||
counting + gidx.size(), ridx.tbegin());
|
gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
|
||||||
thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(),
|
// normalise row_ptr
|
||||||
[=] __device__(int val) { return val - 1; });
|
bst_uint start = gmat.row_ptr[row_begin];
|
||||||
|
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
|
||||||
|
[=] __device__(int val) { return val - start; });
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeviceHist::Init(int n_bins_in) {
|
void DeviceHist::Init(int n_bins_in) {
|
||||||
@ -59,10 +68,10 @@ HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins)
|
|||||||
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
|
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
|
||||||
int hist_idx = nidx * n_bins + gidx;
|
int hist_idx = nidx * n_bins + gidx;
|
||||||
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
||||||
// line lead to about 3X
|
// line lead to about 3X
|
||||||
// slowdown due to memory
|
// slowdown due to memory
|
||||||
// dependency and access
|
// dependency and access
|
||||||
// pattern issues.
|
// pattern issues.
|
||||||
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,7 +179,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
// process)
|
// process)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
||||||
"block. Try setting 'tree_method' "
|
"block. Try setting 'tree_method' "
|
||||||
"parameter to 'exact'";
|
"parameter to 'exact'";
|
||||||
@ -219,6 +227,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
// ba.allocate(master_device, );
|
// ba.allocate(master_device, );
|
||||||
|
|
||||||
// allocate vectors across all devices
|
// allocate vectors across all devices
|
||||||
|
temp_memory.resize(n_devices);
|
||||||
hist_vec.resize(n_devices);
|
hist_vec.resize(n_devices);
|
||||||
nodes.resize(n_devices);
|
nodes.resize(n_devices);
|
||||||
nodes_temp.resize(n_devices);
|
nodes_temp.resize(n_devices);
|
||||||
@ -269,18 +278,21 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
h_feature_segments.size(), // constant and same on all devices
|
h_feature_segments.size(), // constant and same on all devices
|
||||||
&prediction_cache[d_idx], num_rows_segment, &position[d_idx],
|
&prediction_cache[d_idx], num_rows_segment, &position[d_idx],
|
||||||
num_rows_segment, &position_tmp[d_idx], num_rows_segment,
|
num_rows_segment, &position_tmp[d_idx], num_rows_segment,
|
||||||
&device_gpair[d_idx], num_rows_segment, &device_matrix[d_idx].gidx,
|
&device_gpair[d_idx], num_rows_segment,
|
||||||
num_elements_segment, // constant and same on all devices
|
&device_matrix[d_idx].gidx_buffer,
|
||||||
&device_matrix[d_idx].ridx,
|
common::CompressedBufferWriter::CalculateBufferSize(
|
||||||
num_elements_segment, // constant and same on all devices
|
num_elements_segment,
|
||||||
|
n_bins), // constant and same on all devices
|
||||||
|
&device_matrix[d_idx].row_ptr, num_rows_segment + 1,
|
||||||
&gidx_feature_map[d_idx], n_bins, // constant and same on all devices
|
&gidx_feature_map[d_idx], n_bins, // constant and same on all devices
|
||||||
&gidx_fvalue_map[d_idx],
|
&gidx_fvalue_map[d_idx],
|
||||||
hmat_.cut.size()); // constant and same on all devices
|
hmat_.cut.size()); // constant and same on all devices
|
||||||
|
|
||||||
// Copy Host to Device (assumes comes after ba.allocate that sets device)
|
// Copy Host to Device (assumes comes after ba.allocate that sets device)
|
||||||
device_matrix[d_idx].Init(device_idx, gmat_,
|
device_matrix[d_idx].Init(
|
||||||
device_element_segments[d_idx],
|
device_idx, gmat_, device_element_segments[d_idx],
|
||||||
device_element_segments[d_idx + 1]);
|
device_element_segments[d_idx + 1], device_row_segments[d_idx],
|
||||||
|
device_row_segments[d_idx + 1], n_bins);
|
||||||
gidx_feature_map[d_idx] = h_gidx_feature_map;
|
gidx_feature_map[d_idx] = h_gidx_feature_map;
|
||||||
gidx_fvalue_map[d_idx] = hmat_.cut;
|
gidx_fvalue_map[d_idx] = hmat_.cut;
|
||||||
feature_segments[d_idx] = h_feature_segments;
|
feature_segments[d_idx] = h_feature_segments;
|
||||||
@ -338,39 +350,41 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
size_t begin = device_element_segments[d_idx];
|
size_t begin = device_element_segments[d_idx];
|
||||||
size_t end = device_element_segments[d_idx + 1];
|
size_t end = device_element_segments[d_idx + 1];
|
||||||
size_t row_begin = device_row_segments[d_idx];
|
size_t row_begin = device_row_segments[d_idx];
|
||||||
|
size_t row_end = device_row_segments[d_idx + 1];
|
||||||
|
|
||||||
auto d_ridx = device_matrix[d_idx].ridx.data();
|
auto d_gidx = device_matrix[d_idx].gidx;
|
||||||
auto d_gidx = device_matrix[d_idx].gidx.data();
|
auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin();
|
||||||
auto d_position = position[d_idx].data();
|
auto d_position = position[d_idx].data();
|
||||||
auto d_gpair = device_gpair[d_idx].data();
|
auto d_gpair = device_gpair[d_idx].data();
|
||||||
auto d_left_child_smallest = left_child_smallest[d_idx].data();
|
auto d_left_child_smallest = left_child_smallest[d_idx].data();
|
||||||
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
||||||
|
dh::TransformLbs(
|
||||||
|
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
|
||||||
|
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||||
|
int nidx = d_position[local_ridx]; // OPTMARK: latency
|
||||||
|
if (!is_active(nidx, depth)) return;
|
||||||
|
|
||||||
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
|
// Only increment smallest node
|
||||||
int ridx = d_ridx[local_idx]; // OPTMARK: latency
|
bool is_smallest = (d_left_child_smallest[parent_nidx(nidx)] &&
|
||||||
int nidx = d_position[ridx - row_begin]; // OPTMARK: latency
|
is_left_child(nidx)) ||
|
||||||
if (!is_active(nidx, depth)) return;
|
(!d_left_child_smallest[parent_nidx(nidx)] &&
|
||||||
|
!is_left_child(nidx));
|
||||||
|
if (!is_smallest && depth > 0) return;
|
||||||
|
|
||||||
// Only increment smallest node
|
int gidx = d_gidx[local_idx];
|
||||||
bool is_smallest =
|
bst_gpair gpair = d_gpair[local_ridx];
|
||||||
(d_left_child_smallest[parent_nidx(nidx)] && is_left_child(nidx)) ||
|
|
||||||
(!d_left_child_smallest[parent_nidx(nidx)] && !is_left_child(nidx));
|
|
||||||
if (!is_smallest && depth > 0) return;
|
|
||||||
|
|
||||||
int gidx = d_gidx[local_idx];
|
hist_builder.Add(gpair, gidx,
|
||||||
bst_gpair gpair = d_gpair[ridx - row_begin];
|
nidx); // OPTMARK: This is slow, could use
|
||||||
|
// shared memory or cache results
|
||||||
hist_builder.Add(gpair, gidx, nidx); // OPTMARK: This is slow, could use
|
// intead of writing to global
|
||||||
// shared memory or cache results
|
// memory every time in atomic way.
|
||||||
// intead of writing to global
|
});
|
||||||
// memory every time in atomic way.
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dh::safe_cuda(cudaDeviceSynchronize());
|
|
||||||
dh::synchronize_n_devices(n_devices, dList);
|
dh::synchronize_n_devices(n_devices, dList);
|
||||||
|
|
||||||
// time.printElapsed("Add Time");
|
// time.printElapsed("Add Time");
|
||||||
|
|
||||||
// (in-place) reduce each element of histogram (for only current level) across
|
// (in-place) reduce each element of histogram (for only current level) across
|
||||||
// multiple gpus
|
// multiple gpus
|
||||||
@ -393,7 +407,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
||||||
}
|
}
|
||||||
// if no NCCL, then presume only 1 GPU, then already correct
|
// if no NCCL, then presume only 1 GPU, then already correct
|
||||||
|
|
||||||
// time.printElapsed("Reduce-Add Time");
|
// time.printElapsed("Reduce-Add Time");
|
||||||
|
|
||||||
@ -572,15 +586,15 @@ __global__ void find_split_kernel(
|
|||||||
left_child_smallest = &d_left_child_smallest_temp[blockIdx.x];
|
left_child_smallest = &d_left_child_smallest_temp[blockIdx.x];
|
||||||
}
|
}
|
||||||
|
|
||||||
*Nodeleft = Node(
|
*Nodeleft =
|
||||||
split.left_sum,
|
Node(split.left_sum,
|
||||||
CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess),
|
CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess),
|
||||||
CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess));
|
CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess));
|
||||||
|
|
||||||
*Noderight = Node(
|
*Noderight =
|
||||||
split.right_sum,
|
Node(split.right_sum,
|
||||||
CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess),
|
CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess),
|
||||||
CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess));
|
CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess));
|
||||||
|
|
||||||
// Record smallest node
|
// Record smallest node
|
||||||
if (split.left_sum.hess <= split.right_sum.hess) {
|
if (split.left_sum.hess <= split.right_sum.hess) {
|
||||||
@ -650,9 +664,9 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
||||||
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
||||||
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(),
|
||||||
left_child_smallest_temp[d_idx].data(), colsample,
|
GPUTrainingParam(param), left_child_smallest_temp[d_idx].data(),
|
||||||
feature_flags[d_idx].data());
|
colsample, feature_flags[d_idx].data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// nccl only on devices that did split
|
// nccl only on devices that did split
|
||||||
@ -747,7 +761,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
left_child_smallest[d_idx].data(), colsample,
|
left_child_smallest[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
|
|
||||||
@ -800,7 +814,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
left_child_smallest[d_idx].data(), colsample,
|
left_child_smallest[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
}
|
}
|
||||||
@ -811,57 +825,23 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
||||||
#ifdef _WIN32
|
// Perform asynchronous reduction on each gpu
|
||||||
// Visual studio complains about C:/Program Files (x86)/Microsoft Visual
|
std::vector<bst_gpair> device_sums(n_devices);
|
||||||
// Studio 14.0/VC/bin/../../VC/INCLUDE\utility(445): error : static assertion
|
#pragma omp parallel for num_threads(n_devices)
|
||||||
// failed with "tuple index out of bounds"
|
|
||||||
// and C:/Program Files (x86)/Microsoft Visual Studio
|
|
||||||
// 14.0/VC/bin/../../VC/INCLUDE\future(1888): error : no instance of function
|
|
||||||
// template "std::_Invoke_stored" matches the argument list
|
|
||||||
std::vector<bst_gpair> future_results(n_devices);
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
auto begin = device_gpair[d_idx].tbegin();
|
auto begin = device_gpair[d_idx].tbegin();
|
||||||
auto end = device_gpair[d_idx].tend();
|
auto end = device_gpair[d_idx].tend();
|
||||||
bst_gpair init = bst_gpair();
|
bst_gpair init = bst_gpair();
|
||||||
auto binary_op = thrust::plus<bst_gpair>();
|
auto binary_op = thrust::plus<bst_gpair>();
|
||||||
|
device_sums[d_idx] = thrust::reduce(begin, end, init, binary_op);
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
|
||||||
future_results[d_idx] = thrust::reduce(begin, end, init, binary_op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum over devices on host (with blocking get())
|
|
||||||
bst_gpair sum = bst_gpair();
|
bst_gpair sum = bst_gpair();
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
sum += device_sums[d_idx];
|
||||||
sum += future_results[d_idx];
|
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
// asynch reduce per device
|
|
||||||
|
|
||||||
std::vector<std::future<bst_gpair>> future_results(n_devices);
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
|
||||||
// std::async captures the algorithm parameters by value
|
|
||||||
// use std::launch::async to ensure the creation of a new thread
|
|
||||||
future_results[d_idx] = std::async(std::launch::async, [=] {
|
|
||||||
int device_idx = dList[d_idx];
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
|
||||||
auto begin = device_gpair[d_idx].tbegin();
|
|
||||||
auto end = device_gpair[d_idx].tend();
|
|
||||||
bst_gpair init = bst_gpair();
|
|
||||||
auto binary_op = thrust::plus<bst_gpair>();
|
|
||||||
return thrust::reduce(begin, end, init, binary_op);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// sum over devices on host (with blocking get())
|
|
||||||
bst_gpair sum = bst_gpair();
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
|
||||||
int device_idx = dList[d_idx];
|
|
||||||
sum += future_results[d_idx].get();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Setup first node so all devices have same first node (here done same on all
|
// Setup first node so all devices have same first node (here done same on all
|
||||||
// devices, or could have done one device and Bcast if worried about exact
|
// devices, or could have done one device and Bcast if worried about exact
|
||||||
@ -874,11 +854,10 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
|||||||
|
|
||||||
dh::launch_n(device_idx, 1, [=] __device__(int idx) {
|
dh::launch_n(device_idx, 1, [=] __device__(int idx) {
|
||||||
bst_gpair sum_gradients = sum;
|
bst_gpair sum_gradients = sum;
|
||||||
d_nodes[idx] = Node(
|
d_nodes[idx] =
|
||||||
sum_gradients,
|
Node(sum_gradients,
|
||||||
CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess),
|
CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess),
|
||||||
CalcWeight(gpu_param, sum_gradients.grad,
|
CalcWeight(gpu_param, sum_gradients.grad, sum_gradients.hess));
|
||||||
sum_gradients.hess));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// synch all devices to host before moving on (No, can avoid because BuildHist
|
// synch all devices to host before moving on (No, can avoid because BuildHist
|
||||||
@ -901,7 +880,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
|||||||
auto d_position = position[d_idx].data();
|
auto d_position = position[d_idx].data();
|
||||||
Node* d_nodes = nodes[d_idx].data();
|
Node* d_nodes = nodes[d_idx].data();
|
||||||
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
|
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
|
||||||
auto d_gidx = device_matrix[d_idx].gidx.data();
|
auto d_gidx = device_matrix[d_idx].gidx;
|
||||||
int n_columns = info->num_col;
|
int n_columns = info->num_col;
|
||||||
size_t begin = device_row_segments[d_idx];
|
size_t begin = device_row_segments[d_idx];
|
||||||
size_t end = device_row_segments[d_idx + 1];
|
size_t end = device_row_segments[d_idx + 1];
|
||||||
@ -941,8 +920,8 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
Node* d_nodes = nodes[d_idx].data();
|
Node* d_nodes = nodes[d_idx].data();
|
||||||
auto d_gidx_feature_map = gidx_feature_map[d_idx].data();
|
auto d_gidx_feature_map = gidx_feature_map[d_idx].data();
|
||||||
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
|
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
|
||||||
auto d_gidx = device_matrix[d_idx].gidx.data();
|
auto d_gidx = device_matrix[d_idx].gidx;
|
||||||
auto d_ridx = device_matrix[d_idx].ridx.data();
|
auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin();
|
||||||
|
|
||||||
size_t row_begin = device_row_segments[d_idx];
|
size_t row_begin = device_row_segments[d_idx];
|
||||||
size_t row_end = device_row_segments[d_idx + 1];
|
size_t row_end = device_row_segments[d_idx + 1];
|
||||||
@ -973,10 +952,11 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
// Update node based on fvalue where exists
|
// Update node based on fvalue where exists
|
||||||
// OPTMARK: This kernel is very inefficient for both compute and memory,
|
// OPTMARK: This kernel is very inefficient for both compute and memory,
|
||||||
// dominated by memory dependency / access patterns
|
// dominated by memory dependency / access patterns
|
||||||
dh::launch_n(
|
|
||||||
device_idx, element_end - element_begin, [=] __device__(int local_idx) {
|
dh::TransformLbs(
|
||||||
int ridx = d_ridx[local_idx];
|
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
|
||||||
int pos = d_position[ridx - row_begin];
|
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||||
|
int pos = d_position[local_ridx];
|
||||||
if (!is_active(pos, depth)) {
|
if (!is_active(pos, depth)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -997,9 +977,9 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
float fvalue = d_gidx_fvalue_map[gidx];
|
float fvalue = d_gidx_fvalue_map[gidx];
|
||||||
|
|
||||||
if (fvalue <= node.split.fvalue) {
|
if (fvalue <= node.split.fvalue) {
|
||||||
d_position_tmp[ridx - row_begin] = left_child_nidx(pos);
|
d_position_tmp[local_ridx] = left_child_nidx(pos);
|
||||||
} else {
|
} else {
|
||||||
d_position_tmp[ridx - row_begin] = right_child_nidx(pos);
|
d_position_tmp[local_ridx] = right_child_nidx(pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -1026,10 +1006,6 @@ void GPUHistBuilder::ColSampleLevel() {
|
|||||||
h_feature_flags[fidx] = 1;
|
h_feature_flags[fidx] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy from Host to Device for all devices
|
|
||||||
// for(auto &f:feature_flags){ // this doesn't set device as should
|
|
||||||
// f = h_feature_flags;
|
|
||||||
// }
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
|||||||
@ -8,26 +8,20 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../src/common/hist_util.h"
|
#include "../../src/common/hist_util.h"
|
||||||
#include "../../src/tree/param.h"
|
#include "../../src/tree/param.h"
|
||||||
|
#include "../../src/common/compressed_iterator.h"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "types.cuh"
|
#include "types.cuh"
|
||||||
|
|
||||||
#ifndef NCCL
|
|
||||||
#define NCCL 1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
#include "nccl.h"
|
#include "nccl.h"
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
struct DeviceGMat {
|
struct DeviceGMat {
|
||||||
dh::dvec<int> gidx;
|
dh::dvec<common::compressed_byte_t> gidx_buffer;
|
||||||
dh::dvec<int> ridx;
|
common::CompressedIterator<int > gidx;
|
||||||
|
dh::dvec<int> row_ptr;
|
||||||
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
|
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
|
||||||
bst_uint begin, bst_uint end);
|
bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HistBuilder {
|
struct HistBuilder {
|
||||||
@ -95,7 +89,6 @@ class GPUHistBuilder {
|
|||||||
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
||||||
// dh::bulk_allocator<dh::memory_type::DEVICE_MANAGED> ba; // can't be used
|
// dh::bulk_allocator<dh::memory_type::DEVICE_MANAGED> ba; // can't be used
|
||||||
// with NCCL
|
// with NCCL
|
||||||
dh::CubMemory cub_mem;
|
|
||||||
|
|
||||||
std::vector<int> feature_set_tree;
|
std::vector<int> feature_set_tree;
|
||||||
std::vector<int> feature_set_level;
|
std::vector<int> feature_set_level;
|
||||||
@ -108,6 +101,7 @@ class GPUHistBuilder {
|
|||||||
std::vector<int> device_row_segments;
|
std::vector<int> device_row_segments;
|
||||||
std::vector<int> device_element_segments;
|
std::vector<int> device_element_segments;
|
||||||
|
|
||||||
|
std::vector<dh::CubMemory> temp_memory;
|
||||||
std::vector<DeviceHist> hist_vec;
|
std::vector<DeviceHist> hist_vec;
|
||||||
std::vector<dh::dvec<Node>> nodes;
|
std::vector<dh::dvec<Node>> nodes;
|
||||||
std::vector<dh::dvec<Node>> nodes_temp;
|
std::vector<dh::dvec<Node>> nodes_temp;
|
||||||
@ -126,10 +120,8 @@ class GPUHistBuilder {
|
|||||||
std::vector<dh::dvec<float>> gidx_fvalue_map;
|
std::vector<dh::dvec<float>> gidx_fvalue_map;
|
||||||
|
|
||||||
std::vector<cudaStream_t *> streams;
|
std::vector<cudaStream_t *> streams;
|
||||||
#if (NCCL)
|
|
||||||
std::vector<ncclComm_t> comms;
|
std::vector<ncclComm_t> comms;
|
||||||
std::vector<std::vector<ncclComm_t>> find_split_comms;
|
std::vector<std::vector<ncclComm_t>> find_split_comms;
|
||||||
#endif
|
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
28
plugin/updater_gpu/test/cpp/test_device_helpers.cu
Normal file
28
plugin/updater_gpu/test/cpp/test_device_helpers.cu
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
/*!
|
||||||
|
* Copyright 2017 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include "../../src/device_helpers.cuh"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
static const std::vector<int> gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7};
|
||||||
|
static const std::vector<int> row_ptr = {0, 3, 6, 8, 10};
|
||||||
|
static const std::vector<int> lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3};
|
||||||
|
|
||||||
|
thrust::device_vector<int> test_lbs() {
|
||||||
|
thrust::device_vector<int> device_gidx = gidx;
|
||||||
|
thrust::device_vector<int> device_row_ptr = row_ptr;
|
||||||
|
thrust::device_vector<int> device_output_row(gidx.size(), 0);
|
||||||
|
auto d_output_row = device_output_row.data();
|
||||||
|
dh::CubMemory temp_memory;
|
||||||
|
dh::TransformLbs(
|
||||||
|
0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1,
|
||||||
|
[=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; });
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
return device_output_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); }
|
||||||
199
src/common/compressed_iterator.h
Normal file
199
src/common/compressed_iterator.h
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017 by Contributors
|
||||||
|
* \file compressed_iterator.h
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstddef>
|
||||||
|
#include "dmlc/logging.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
typedef unsigned char compressed_byte_t;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
inline void SetBit(compressed_byte_t *byte, int bit_idx) {
|
||||||
|
*byte |= 1 << bit_idx;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
inline T CheckBit(const T &byte, int bit_idx) {
|
||||||
|
return byte & (1 << bit_idx);
|
||||||
|
}
|
||||||
|
inline void ClearBit(compressed_byte_t *byte, int bit_idx) {
|
||||||
|
*byte &= ~(1 << bit_idx);
|
||||||
|
}
|
||||||
|
static const int padding = 4; // Assign padding so we can read slightly off
|
||||||
|
// the beginning of the array
|
||||||
|
|
||||||
|
// The number of bits required to represent a given unsigned range
|
||||||
|
static int SymbolBits(int num_symbols) {
|
||||||
|
return std::ceil(std::log2(num_symbols));
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \class CompressedBufferWriter
|
||||||
|
*
|
||||||
|
* \brief Writes bit compressed symbols to a memory buffer. Use
|
||||||
|
* CompressedIterator to read symbols back from buffer. Currently limited to a
|
||||||
|
* maximum symbol size of 28 bits.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*/
|
||||||
|
|
||||||
|
class CompressedBufferWriter {
|
||||||
|
private:
|
||||||
|
int symbol_bits_;
|
||||||
|
size_t offset_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CompressedBufferWriter(int num_symbols) : offset_(0) {
|
||||||
|
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
|
||||||
|
* num_elements, int num_symbols)
|
||||||
|
*
|
||||||
|
* \brief Calculates number of bytes requiredm for a given number of elements
|
||||||
|
* and a symbol range.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*
|
||||||
|
* \param num_elements Number of elements.
|
||||||
|
* \param num_symbols Max number of symbols (alphabet size)
|
||||||
|
*
|
||||||
|
* \return The calculated buffer size.
|
||||||
|
*/
|
||||||
|
|
||||||
|
static size_t CalculateBufferSize(int num_elements, int num_symbols) {
|
||||||
|
const int bits_per_byte = 8;
|
||||||
|
int compressed_size = std::ceil(
|
||||||
|
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
|
||||||
|
bits_per_byte);
|
||||||
|
return compressed_size + detail::padding;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void WriteSymbol(compressed_byte_t *buffer, T symbol, size_t offset) {
|
||||||
|
const int bits_per_byte = 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < symbol_bits_; i++) {
|
||||||
|
size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte;
|
||||||
|
byte_idx += detail::padding;
|
||||||
|
int bit_idx =
|
||||||
|
((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte;
|
||||||
|
|
||||||
|
if (detail::CheckBit(symbol, i)) {
|
||||||
|
detail::SetBit(&buffer[byte_idx], bit_idx);
|
||||||
|
} else {
|
||||||
|
detail::ClearBit(&buffer[byte_idx], bit_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename iter_t>
|
||||||
|
void Write(compressed_byte_t *buffer, iter_t input_begin, iter_t input_end) {
|
||||||
|
uint64_t tmp = 0;
|
||||||
|
int stored_bits = 0;
|
||||||
|
const int max_stored_bits = 64 - symbol_bits_;
|
||||||
|
int buffer_position = detail::padding;
|
||||||
|
const int num_symbols = input_end - input_begin;
|
||||||
|
for (int i = 0; i < num_symbols; i++) {
|
||||||
|
typename std::iterator_traits<iter_t>::value_type symbol = input_begin[i];
|
||||||
|
if (stored_bits > max_stored_bits) {
|
||||||
|
// Eject only full bytes
|
||||||
|
int tmp_bytes = stored_bits / 8;
|
||||||
|
for (int j = 0; j < tmp_bytes; j++) {
|
||||||
|
buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8);
|
||||||
|
buffer_position++;
|
||||||
|
}
|
||||||
|
stored_bits -= tmp_bytes * 8;
|
||||||
|
tmp &= (1 << stored_bits) - 1;
|
||||||
|
}
|
||||||
|
// Store symbol
|
||||||
|
tmp <<= symbol_bits_;
|
||||||
|
tmp |= symbol;
|
||||||
|
stored_bits += symbol_bits_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eject all bytes
|
||||||
|
int tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
|
||||||
|
for (int j = 0; j < tmp_bytes; j++) {
|
||||||
|
int shift_bits = stored_bits - (j + 1) * 8;
|
||||||
|
if (shift_bits >= 0) {
|
||||||
|
buffer[buffer_position] = tmp >> shift_bits;
|
||||||
|
} else {
|
||||||
|
buffer[buffer_position] = tmp << std::abs(shift_bits);
|
||||||
|
}
|
||||||
|
buffer_position++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \class CompressedIterator
|
||||||
|
*
|
||||||
|
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
|
||||||
|
* host.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*/
|
||||||
|
|
||||||
|
class CompressedIterator {
|
||||||
|
public:
|
||||||
|
typedef CompressedIterator<T> self_type; ///< My own type
|
||||||
|
typedef ptrdiff_t
|
||||||
|
difference_type; ///< Type to express the result of subtracting
|
||||||
|
/// one iterator from another
|
||||||
|
typedef T value_type; ///< The type of the element the iterator can point to
|
||||||
|
typedef value_type *pointer; ///< The type of a pointer to an element the
|
||||||
|
/// iterator can point to
|
||||||
|
typedef value_type reference; ///< The type of a reference to an element the
|
||||||
|
/// iterator can point to
|
||||||
|
private:
|
||||||
|
compressed_byte_t *buffer_;
|
||||||
|
int symbol_bits_;
|
||||||
|
size_t offset_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {}
|
||||||
|
CompressedIterator(compressed_byte_t *buffer, int num_symbols)
|
||||||
|
: buffer_(buffer), offset_(0) {
|
||||||
|
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE reference operator*() const {
|
||||||
|
const int bits_per_byte = 8;
|
||||||
|
size_t start_bit_idx = ((offset_ + 1) * symbol_bits_ - 1);
|
||||||
|
size_t start_byte_idx = start_bit_idx / bits_per_byte;
|
||||||
|
start_byte_idx += detail::padding;
|
||||||
|
|
||||||
|
// Read 5 bytes - the maximum we will need
|
||||||
|
uint64_t tmp = static_cast<uint64_t>(buffer_[start_byte_idx - 4]) << 32 |
|
||||||
|
static_cast<uint64_t>(buffer_[start_byte_idx - 3]) << 24 |
|
||||||
|
static_cast<uint64_t>(buffer_[start_byte_idx - 2]) << 16 |
|
||||||
|
static_cast<uint64_t>(buffer_[start_byte_idx - 1]) << 8 |
|
||||||
|
buffer_[start_byte_idx];
|
||||||
|
int bit_shift =
|
||||||
|
(bits_per_byte - ((offset_ + 1) * symbol_bits_)) % bits_per_byte;
|
||||||
|
tmp >>= bit_shift;
|
||||||
|
// Mask off unneeded bits
|
||||||
|
uint64_t mask = (1 << symbol_bits_) - 1;
|
||||||
|
return static_cast<T>(tmp & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE reference operator[](int idx) const {
|
||||||
|
self_type offset = (*this);
|
||||||
|
offset.offset_ += idx;
|
||||||
|
return *offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
54
tests/cpp/common/test_compressed_iterator.cc
Normal file
54
tests/cpp/common/test_compressed_iterator.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#include "../../../src/common/compressed_iterator.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST(CompressedIterator, Test) {
|
||||||
|
ASSERT_TRUE(detail::SymbolBits(256) == 8);
|
||||||
|
ASSERT_TRUE(detail::SymbolBits(150) == 8);
|
||||||
|
std::vector<int> test_cases = {3, 426, 21, 64, 256, 100000, INT32_MAX};
|
||||||
|
int num_elements = 1000;
|
||||||
|
int repetitions = 1000;
|
||||||
|
srand(9);
|
||||||
|
|
||||||
|
for (auto alphabet_size : test_cases) {
|
||||||
|
for (int i = 0; i < repetitions; i++) {
|
||||||
|
std::vector<int> input(num_elements);
|
||||||
|
std::generate(input.begin(), input.end(),
|
||||||
|
[=]() { return rand() % alphabet_size; });
|
||||||
|
CompressedBufferWriter cbw(alphabet_size);
|
||||||
|
|
||||||
|
// Test write entire array
|
||||||
|
std::vector<unsigned char> buffer(
|
||||||
|
CompressedBufferWriter::CalculateBufferSize(input.size(),
|
||||||
|
alphabet_size));
|
||||||
|
|
||||||
|
cbw.Write(buffer.data(), input.begin(), input.end());
|
||||||
|
|
||||||
|
CompressedIterator<int> ci(buffer.data(), alphabet_size);
|
||||||
|
std::vector<int> output(input.size());
|
||||||
|
for (int i = 0; i < input.size(); i++) {
|
||||||
|
output[i] = ci[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(input == output);
|
||||||
|
|
||||||
|
// Test write Symbol
|
||||||
|
std::vector<unsigned char> buffer2(
|
||||||
|
CompressedBufferWriter::CalculateBufferSize(input.size(),
|
||||||
|
alphabet_size));
|
||||||
|
for (int i = 0; i < input.size(); i++) {
|
||||||
|
cbw.WriteSymbol(buffer2.data(), input[i], i);
|
||||||
|
}
|
||||||
|
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
|
||||||
|
std::vector<int> output2(input.size());
|
||||||
|
for (int i = 0; i < input.size(); i++) {
|
||||||
|
output2[i] = ci2[i];
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(input == output2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user