[GPU Plugin] Fast histogram speed improvements. Updated benchmarks. (#2258)
This commit is contained in:
committed by
Tianqi Chen
parent
98ea461532
commit
6bf968efe6
@@ -62,6 +62,25 @@ __host__ __device__ inline int n_nodes(int depth) {
|
||||
// Number of nodes at this level of the tree
|
||||
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
|
||||
|
||||
// Whether a node is currently being processed at current depth
|
||||
__host__ __device__ inline bool is_active(int nidx, int depth) {
|
||||
return nidx >= n_nodes(depth - 1);
|
||||
}
|
||||
|
||||
__host__ __device__ inline int parent_nidx(int nidx) { return (nidx - 1) / 2; }
|
||||
|
||||
__host__ __device__ inline int left_child_nidx(int nidx) {
|
||||
return nidx * 2 + 1;
|
||||
}
|
||||
|
||||
__host__ __device__ inline int right_child_nidx(int nidx) {
|
||||
return nidx * 2 + 2;
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool is_left_child(int nidx) {
|
||||
return nidx % 2 == 1;
|
||||
}
|
||||
|
||||
enum NodeType {
|
||||
NODE = 0,
|
||||
LEAF = 1,
|
||||
@@ -96,7 +115,7 @@ inline void dense2sparse_tree(RegTree* p_tree,
|
||||
thrust::device_ptr<Node> nodes_begin,
|
||||
thrust::device_ptr<Node> nodes_end,
|
||||
const TrainParam& param) {
|
||||
RegTree & tree = *p_tree;
|
||||
RegTree& tree = *p_tree;
|
||||
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
|
||||
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
|
||||
flag_nodes(h_nodes, &node_flags, 0, NODE);
|
||||
|
||||
@@ -9,15 +9,11 @@
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cusparse_v2.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
// Uncomment to enable
|
||||
// #define DEVICE_TIMER
|
||||
@@ -43,20 +39,6 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
|
||||
|
||||
return code;
|
||||
}
|
||||
#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__)
|
||||
|
||||
inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status,
|
||||
const char *file, int line) {
|
||||
if (status != CUSPARSE_STATUS_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << "cusparse error: " << file << "(" << line << ")";
|
||||
std::string error_text;
|
||||
ss >> error_text;
|
||||
throw error_text;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
#define gpuErrchk(ans) \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||
@@ -153,40 +135,18 @@ struct DeviceTimer {
|
||||
};
|
||||
|
||||
struct Timer {
|
||||
volatile double start;
|
||||
typedef std::chrono::high_resolution_clock ClockT;
|
||||
|
||||
typedef std::chrono::high_resolution_clock::time_point TimePointT;
|
||||
TimePointT start;
|
||||
Timer() { reset(); }
|
||||
|
||||
double seconds_now() {
|
||||
#ifdef _WIN32
|
||||
static LARGE_INTEGER s_frequency;
|
||||
QueryPerformanceFrequency(&s_frequency);
|
||||
LARGE_INTEGER now;
|
||||
QueryPerformanceCounter(&now);
|
||||
return static_cast<double>(now.QuadPart) / s_frequency.QuadPart;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void reset() {
|
||||
#ifdef _WIN32
|
||||
_ReadWriteBarrier();
|
||||
start = seconds_now();
|
||||
#endif
|
||||
}
|
||||
double elapsed() {
|
||||
#ifdef _WIN32
|
||||
_ReadWriteBarrier();
|
||||
return seconds_now() - start;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
void reset() { start = ClockT::now(); }
|
||||
int64_t elapsed() const { return (ClockT::now() - start).count(); }
|
||||
void printElapsed(std::string label) {
|
||||
#ifdef TIMERS
|
||||
safe_cuda(cudaDeviceSynchronize());
|
||||
printf("%s:\t %1.4fs\n", label.c_str(), elapsed());
|
||||
#endif
|
||||
printf("%s:\t %lld\n", label.c_str(), elapsed());
|
||||
reset();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
/*!
|
||||
* Copyright 2017 Rory mitchell
|
||||
*/
|
||||
*/
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/count.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include "common.cuh"
|
||||
#include "device_helpers.cuh"
|
||||
#include "gpu_hist_builder.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) {
|
||||
CHECK_EQ(gidx.size(), gmat.index.size())
|
||||
<< "gidx must be externally allocated";
|
||||
@@ -61,15 +62,19 @@ __device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
|
||||
return d_hist[nidx * n_bins + gidx];
|
||||
}
|
||||
|
||||
GPUHistBuilder::GPUHistBuilder() {}
|
||||
GPUHistBuilder::GPUHistBuilder()
|
||||
: initialised(false),
|
||||
is_dense(false),
|
||||
p_last_fmat_(nullptr),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
GPUHistBuilder::~GPUHistBuilder() {}
|
||||
|
||||
void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
||||
CHECK(param.grow_policy != TrainParam::kLossGuide)
|
||||
<< "Loss guided growth policy not supported. Use CPU algorithm.";
|
||||
this->param = param;
|
||||
initialised = false;
|
||||
is_dense = false;
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
if (!param.silent) {
|
||||
@@ -77,117 +82,24 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ReductionOpT>
|
||||
struct ReduceBySegmentOp {
|
||||
ReductionOpT op;
|
||||
|
||||
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
|
||||
|
||||
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op)
|
||||
: op(op) {}
|
||||
|
||||
template <typename KeyValuePairT>
|
||||
__host__ __device__ __forceinline__ KeyValuePairT
|
||||
operator()(const KeyValuePairT& first, const KeyValuePairT& second) {
|
||||
KeyValuePairT retval;
|
||||
retval.key = second.key;
|
||||
retval.value =
|
||||
first.key != second.key ? second.value : op(first.value, second.value);
|
||||
return retval;
|
||||
}
|
||||
};
|
||||
|
||||
template <int ITEMS_PER_THREAD, int BLOCK_THREADS>
|
||||
__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx,
|
||||
NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins,
|
||||
int depth, int n) {
|
||||
typedef cub::KeyValuePair<int, gpu_gpair> OffsetValuePairT;
|
||||
typedef cub::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoadT;
|
||||
typedef cub::BlockRadixSort<int, BLOCK_THREADS, ITEMS_PER_THREAD, int>
|
||||
BlockRadixSortT;
|
||||
typedef cub::BlockDiscontinuity<int, BLOCK_THREADS> BlockDiscontinuityKeysT;
|
||||
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
|
||||
typedef cub::BlockScan<OffsetValuePairT, BLOCK_THREADS,
|
||||
cub::BLOCK_SCAN_WARP_SCANS>
|
||||
BlockScanT;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockRadixSortT::TempStorage sort;
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename BlockDiscontinuityKeysT::TempStorage disc;
|
||||
};
|
||||
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
int ridx[ITEMS_PER_THREAD];
|
||||
int gidx[ITEMS_PER_THREAD];
|
||||
|
||||
const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS;
|
||||
int block_offset = blockIdx.x * TILE_SIZE;
|
||||
|
||||
BlockLoadT(temp_storage.load)
|
||||
.Load(d_ridx + block_offset, ridx, n - block_offset, -1);
|
||||
BlockLoadT(temp_storage.load)
|
||||
.Load(d_gidx + block_offset, gidx, n - block_offset, -1);
|
||||
|
||||
int hist_idx[ITEMS_PER_THREAD];
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||
hist_idx[ITEM] =
|
||||
(d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM];
|
||||
} else {
|
||||
hist_idx[ITEM] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx);
|
||||
|
||||
OffsetValuePairT kv[ITEMS_PER_THREAD];
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
kv[ITEM].key = hist_idx[ITEM];
|
||||
if (ridx[ITEM] > -1) {
|
||||
kv[ITEM].value = d_gpair[ridx[ITEM]];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Scan
|
||||
BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT());
|
||||
|
||||
__syncthreads();
|
||||
int flags[ITEMS_PER_THREAD];
|
||||
BlockDiscontinuityKeysT(temp_storage.disc)
|
||||
.FlagTails(flags, hist_idx, cub::Inequality());
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
if (flags[ITEM]) {
|
||||
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad);
|
||||
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::BuildHist(int depth) {
|
||||
auto d_ridx = device_matrix.ridx.data();
|
||||
auto d_gidx = device_matrix.gidx.data();
|
||||
auto d_position = position.data();
|
||||
auto d_gpair = device_gpair.data();
|
||||
auto hist_builder = hist.GetBuilder();
|
||||
auto d_left_child_smallest = left_child_smallest.data();
|
||||
|
||||
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||
int ridx = d_ridx[idx];
|
||||
int pos = d_position[ridx];
|
||||
if (!is_active(pos, depth)) return;
|
||||
|
||||
// Only increment even nodes
|
||||
if (pos < 0 || pos % 2 == 1) return;
|
||||
// Only increment smallest node
|
||||
bool is_smallest =
|
||||
(d_left_child_smallest[parent_nidx(pos)] && is_left_child(pos)) ||
|
||||
(!d_left_child_smallest[parent_nidx(pos)] && !is_left_child(pos));
|
||||
if (!is_smallest && depth > 0) return;
|
||||
|
||||
int gidx = d_gidx[idx];
|
||||
gpu_gpair gpair = d_gpair[ridx];
|
||||
@@ -199,19 +111,181 @@ void GPUHistBuilder::BuildHist(int depth) {
|
||||
|
||||
// Subtraction trick
|
||||
int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins;
|
||||
if (n_sub_bins > 0) {
|
||||
if (depth > 0) {
|
||||
dh::launch_n(n_sub_bins, [=] __device__(int idx) {
|
||||
int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2);
|
||||
bool left_smallest = d_left_child_smallest[parent_nidx(nidx)];
|
||||
if (left_smallest) {
|
||||
nidx++; // If left is smallest switch to right child
|
||||
}
|
||||
|
||||
int gidx = idx % hist_builder.n_bins;
|
||||
gpu_gpair parent = hist_builder.Get(gidx, nidx / 2);
|
||||
gpu_gpair right = hist_builder.Get(gidx, nidx + 1);
|
||||
hist_builder.Add(parent - right, gidx, nidx);
|
||||
gpu_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||
gpu_gpair other = hist_builder.Get(gidx, other_nidx);
|
||||
hist_builder.Add(parent - other, gidx, nidx);
|
||||
});
|
||||
}
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void find_split_kernel(
|
||||
const gpu_gpair* d_level_hist, int* d_feature_segments, int depth,
|
||||
int n_features, int n_bins, Node* d_nodes, float* d_fidx_min_map,
|
||||
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
||||
bool* d_left_child_smallest, bool colsample, int* d_feature_flags) {
|
||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||
typedef cub::BlockScan<gpu_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||
BlockScanT;
|
||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||
typedef cub::BlockReduce<gpu_gpair, BLOCK_THREADS> SumReduceT;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
typename SumReduceT::TempStorage sum_reduce;
|
||||
};
|
||||
|
||||
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
||||
struct UninitializedGpair : cub::Uninitialized<gpu_gpair> {};
|
||||
|
||||
__shared__ UninitializedSplit uninitialized_split;
|
||||
Split& split = uninitialized_split.Alias();
|
||||
__shared__ ArgMaxT block_max;
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
split = Split();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int node_idx = n_nodes(depth - 1) + blockIdx.x;
|
||||
|
||||
for (int fidx = 0; fidx < n_features; fidx++) {
|
||||
if (colsample && d_feature_flags[fidx] == 0) continue;
|
||||
|
||||
int begin = d_feature_segments[blockIdx.x * n_features + fidx];
|
||||
int end = d_feature_segments[blockIdx.x * n_features + fidx + 1];
|
||||
int gidx = (begin - (blockIdx.x * n_bins)) + threadIdx.x;
|
||||
bool thread_active = threadIdx.x < end - begin;
|
||||
|
||||
// Scan histogram
|
||||
gpu_gpair bin =
|
||||
thread_active ? d_level_hist[begin + threadIdx.x] : gpu_gpair();
|
||||
|
||||
gpu_gpair feature_sum;
|
||||
BlockScanT(temp_storage.scan)
|
||||
.ExclusiveScan(bin, bin, gpu_gpair(), cub::Sum(), feature_sum);
|
||||
|
||||
// Calculate gain
|
||||
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||
float parent_gain = d_nodes[node_idx].root_gain;
|
||||
|
||||
gpu_gpair missing = parent_sum - feature_sum;
|
||||
|
||||
bool missing_left;
|
||||
float gain = thread_active
|
||||
? loss_chg_missing(bin, missing, parent_sum, parent_gain,
|
||||
gpu_param, missing_left)
|
||||
: -FLT_MAX;
|
||||
__syncthreads();
|
||||
|
||||
// Find thread with best gain
|
||||
ArgMaxT tuple(threadIdx.x, gain);
|
||||
ArgMaxT best = MaxReduceT(temp_storage.max_reduce)
|
||||
.Reduce(tuple, cub::ArgMax(), end - begin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
float fvalue;
|
||||
if (threadIdx.x == 0) {
|
||||
fvalue = d_fidx_min_map[fidx];
|
||||
} else {
|
||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
|
||||
gpu_gpair left = missing_left ? bin + missing : bin;
|
||||
gpu_gpair right = parent_sum - left;
|
||||
|
||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Create node
|
||||
if (threadIdx.x == 0) {
|
||||
d_nodes[node_idx].split = split;
|
||||
if (depth == 0) {
|
||||
// split.Print();
|
||||
}
|
||||
|
||||
d_nodes[left_child_nidx(node_idx)] = Node(
|
||||
split.left_sum,
|
||||
CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()),
|
||||
CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess()));
|
||||
|
||||
d_nodes[right_child_nidx(node_idx)] = Node(
|
||||
split.right_sum,
|
||||
CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()),
|
||||
CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess()));
|
||||
|
||||
// Record smallest node
|
||||
if (split.left_sum.hess() <= split.right_sum.hess()) {
|
||||
d_left_child_smallest[node_idx] = true;
|
||||
} else {
|
||||
d_left_child_smallest[node_idx] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::FindSplit(int depth) {
|
||||
// Specialised based on max_bins
|
||||
if (param.max_bin <= 256) {
|
||||
this->FindSplit256(depth);
|
||||
} else if (param.max_bin <= 1024) {
|
||||
this->FindSplit1024(depth);
|
||||
} else {
|
||||
this->FindSplitLarge(depth);
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::FindSplit256(int depth) {
|
||||
CHECK_LE(param.max_bin, 256);
|
||||
const int BLOCK_THREADS = 256;
|
||||
const int GRID_SIZE = n_nodes_level(depth);
|
||||
bool colsample =
|
||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
||||
feature_flags.data());
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
void GPUHistBuilder::FindSplit1024(int depth) {
|
||||
CHECK_LE(param.max_bin, 1024);
|
||||
const int BLOCK_THREADS = 1024;
|
||||
const int GRID_SIZE = n_nodes_level(depth);
|
||||
bool colsample =
|
||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
||||
feature_flags.data());
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
void GPUHistBuilder::FindSplitLarge(int depth) {
|
||||
auto counting = thrust::make_counting_iterator(0);
|
||||
auto d_gidx_feature_map = gidx_feature_map.data();
|
||||
int n_bins = hmat_.row_ptr.back();
|
||||
@@ -295,6 +369,8 @@ void GPUHistBuilder::FindSplit(int depth) {
|
||||
auto d_argmax = argmax.data();
|
||||
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||
auto d_fidx_min_map = fidx_min_map.data();
|
||||
auto d_left_child_smallest = left_child_smallest.data();
|
||||
|
||||
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
|
||||
int max_idx = n_bins * idx + d_argmax[idx].key;
|
||||
int gidx = max_idx % n_bins;
|
||||
@@ -317,64 +393,78 @@ void GPUHistBuilder::FindSplit(int depth) {
|
||||
} else {
|
||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
|
||||
gpu_gpair left = missing_left ? scan + missing : scan;
|
||||
gpu_gpair right = parent_sum - left;
|
||||
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
|
||||
right, gpu_param_alias);
|
||||
|
||||
int left_child_idx = n_nodes(depth) + idx * 2;
|
||||
int right_child_idx = n_nodes(depth) + idx * 2 + 1;
|
||||
d_nodes[left_child_idx] =
|
||||
d_nodes[left_child_nidx(node_idx)] =
|
||||
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
|
||||
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
|
||||
|
||||
d_nodes[right_child_idx] =
|
||||
d_nodes[right_child_nidx(node_idx)] =
|
||||
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
|
||||
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
|
||||
|
||||
// Record smallest node
|
||||
if (left.hess() <= right.hess()) {
|
||||
d_left_child_smallest[node_idx] = true;
|
||||
} else {
|
||||
d_left_child_smallest[node_idx] = false;
|
||||
}
|
||||
});
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void GPUHistBuilder::InitFirstNode() {
|
||||
// Build the root node on the CPU and copy to device
|
||||
gpu_gpair sum_gradients =
|
||||
thrust::reduce(device_gpair.tbegin(), device_gpair.tend(),
|
||||
gpu_gpair(0, 0), thrust::plus<gpu_gpair>());
|
||||
auto d_gpair = device_gpair.data();
|
||||
auto d_node_sums = node_sums.data();
|
||||
auto d_nodes = nodes.data();
|
||||
auto gpu_param_alias = gpu_param;
|
||||
|
||||
Node tmp =
|
||||
Node(sum_gradients,
|
||||
CalcGain(param, sum_gradients.grad(), sum_gradients.hess()),
|
||||
CalcWeight(param, sum_gradients.grad(), sum_gradients.hess()));
|
||||
size_t temp_storage_bytes;
|
||||
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, d_gpair, d_node_sums,
|
||||
device_gpair.size(), cub::Sum(), gpu_gpair());
|
||||
cub_mem.LazyAllocate(temp_storage_bytes);
|
||||
cub::DeviceReduce::Reduce(cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
|
||||
d_gpair, d_node_sums, device_gpair.size(),
|
||||
cub::Sum(), gpu_gpair());
|
||||
|
||||
thrust::copy_n(&tmp, 1, nodes.tbegin());
|
||||
dh::launch_n(1, [=] __device__(int idx) {
|
||||
gpu_gpair sum_gradients = d_node_sums[idx];
|
||||
d_nodes[idx] = Node(
|
||||
sum_gradients,
|
||||
CalcGain(gpu_param_alias, sum_gradients.grad(), sum_gradients.hess()),
|
||||
CalcWeight(gpu_param_alias, sum_gradients.grad(),
|
||||
sum_gradients.hess()));
|
||||
});
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePosition() {
|
||||
void GPUHistBuilder::UpdatePosition(int depth) {
|
||||
if (is_dense) {
|
||||
this->UpdatePositionDense();
|
||||
this->UpdatePositionDense(depth);
|
||||
} else {
|
||||
this->UpdatePositionSparse();
|
||||
this->UpdatePositionSparse(depth);
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePositionDense() {
|
||||
void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
auto d_position = position.data();
|
||||
Node* d_nodes = nodes.data();
|
||||
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||
auto d_gidx = device_matrix.gidx.data();
|
||||
int n_columns = info->num_col;
|
||||
|
||||
int gidx_size = device_matrix.gidx.size();
|
||||
|
||||
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||
NodeIdT pos = d_position[idx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Node node = d_nodes[pos];
|
||||
|
||||
if (node.IsLeaf()) {
|
||||
d_position[idx] = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -383,14 +473,16 @@ void GPUHistBuilder::UpdatePositionDense() {
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
if (fvalue <= node.split.fvalue) {
|
||||
d_position[idx] = pos * 2 + 1;
|
||||
d_position[idx] = left_child_nidx(pos);
|
||||
} else {
|
||||
d_position[idx] = pos * 2 + 2;
|
||||
d_position[idx] = right_child_nidx(pos);
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePositionSparse() {
|
||||
void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
||||
auto d_position = position.data();
|
||||
auto d_position_tmp = position_tmp.data();
|
||||
Node* d_nodes = nodes.data();
|
||||
@@ -402,14 +494,16 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
// Update missing direction
|
||||
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||
NodeIdT pos = d_position[idx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
d_position_tmp[idx] = pos;
|
||||
return;
|
||||
}
|
||||
|
||||
Node node = d_nodes[pos];
|
||||
|
||||
if (node.IsLeaf()) {
|
||||
d_position_tmp[idx] = -1;
|
||||
d_position_tmp[idx] = pos;
|
||||
return;
|
||||
} else if (node.split.missing_left) {
|
||||
d_position_tmp[idx] = pos * 2 + 1;
|
||||
} else {
|
||||
@@ -417,11 +511,13 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
// Update node based on fvalue where exists
|
||||
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||
int ridx = d_ridx[idx];
|
||||
NodeIdT pos = d_position[ridx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -438,23 +534,29 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
if (fvalue <= node.split.fvalue) {
|
||||
d_position_tmp[ridx] = pos * 2 + 1;
|
||||
d_position_tmp[ridx] = left_child_nidx(pos);
|
||||
} else {
|
||||
d_position_tmp[ridx] = pos * 2 + 2;
|
||||
d_position_tmp[ridx] = right_child_nidx(pos);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
position = position_tmp;
|
||||
}
|
||||
|
||||
void GPUHistBuilder::ColSampleTree() {
|
||||
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
|
||||
|
||||
feature_set_tree.resize(info->num_col);
|
||||
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
||||
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
|
||||
}
|
||||
|
||||
void GPUHistBuilder::ColSampleLevel() {
|
||||
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
|
||||
|
||||
feature_set_level.resize(feature_set_tree.size());
|
||||
feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel);
|
||||
std::vector<int> h_feature_flags(info->num_col, 0);
|
||||
@@ -491,18 +593,20 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins;
|
||||
|
||||
size_t free_memory = dh::available_memory();
|
||||
ba.allocate(
|
||||
&gidx_feature_map, n_bins, &hist_node_segments,
|
||||
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
|
||||
h_feature_segments.size(), &gain, level_max_bins, &position,
|
||||
gpair.size(), &position_tmp, gpair.size(), &nodes,
|
||||
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
|
||||
&fidx_min_map, hmat_.min_val.size(), &argmax,
|
||||
n_nodes_level(param.max_depth - 1), &node_sums,
|
||||
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
|
||||
level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx,
|
||||
gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist,
|
||||
n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features);
|
||||
ba.allocate(&gidx_feature_map, n_bins, &hist_node_segments,
|
||||
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
|
||||
h_feature_segments.size(), &gain, level_max_bins, &position,
|
||||
gpair.size(), &position_tmp, gpair.size(), &nodes,
|
||||
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
|
||||
&fidx_min_map, hmat_.min_val.size(), &argmax,
|
||||
n_nodes_level(param.max_depth - 1), &node_sums,
|
||||
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
|
||||
level_max_bins, &device_gpair, gpair.size(),
|
||||
&device_matrix.gidx, gmat_.index.size(), &device_matrix.ridx,
|
||||
gmat_.index.size(), &hist.hist,
|
||||
n_nodes(param.max_depth - 1) * n_bins, &feature_flags,
|
||||
n_features, &left_child_smallest, n_nodes(param.max_depth - 1),
|
||||
&prediction_cache, gpair.size());
|
||||
|
||||
if (!param.silent) {
|
||||
const int mb_size = 1048576;
|
||||
@@ -529,10 +633,14 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0,
|
||||
n_bins);
|
||||
|
||||
feature_flags.fill(1);
|
||||
|
||||
feature_segments = h_feature_segments;
|
||||
|
||||
hist.Init(n_bins);
|
||||
|
||||
prediction_cache.fill(0);
|
||||
|
||||
initialised = true;
|
||||
}
|
||||
nodes.fill(Node());
|
||||
@@ -540,6 +648,37 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
device_gpair = gpair;
|
||||
subsample_gpair(&device_gpair, param.subsample);
|
||||
hist.Reset();
|
||||
p_last_fmat_ = &fmat;
|
||||
}
|
||||
|
||||
bool GPUHistBuilder::UpdatePredictionCache(
|
||||
const DMatrix* data, std::vector<bst_float>* p_out_preds) {
|
||||
std::vector<bst_float>& out_preds = *p_out_preds;
|
||||
|
||||
if (nodes.empty() || !p_last_fmat_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
CHECK_EQ(prediction_cache.size(), out_preds.size());
|
||||
|
||||
if (!prediction_cache_initialised) {
|
||||
prediction_cache = out_preds;
|
||||
prediction_cache_initialised = true;
|
||||
}
|
||||
|
||||
auto d_nodes = nodes.data();
|
||||
auto d_position = position.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
float eps = param.learning_rate;
|
||||
|
||||
dh::launch_n(prediction_cache.size(), [=] __device__(int idx) {
|
||||
int pos = d_position[idx];
|
||||
d_prediction_cache[idx] += d_nodes[pos].weight * eps;
|
||||
});
|
||||
|
||||
thrust::copy(prediction_cache.tbegin(), prediction_cache.tend(),
|
||||
out_preds.data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
@@ -547,12 +686,11 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode();
|
||||
this->ColSampleTree();
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||
this->ColSampleLevel();
|
||||
this->BuildHist(depth);
|
||||
this->FindSplit(depth);
|
||||
this->UpdatePosition();
|
||||
this->UpdatePosition(depth);
|
||||
}
|
||||
dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
* Copyright 2016 Rory mitchell
|
||||
*/
|
||||
#pragma once
|
||||
#include <cusparse.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <cub/util_type.cuh> // Need key value pair definition
|
||||
@@ -63,12 +62,17 @@ class GPUHistBuilder {
|
||||
RegTree *p_tree);
|
||||
void BuildHist(int depth);
|
||||
void FindSplit(int depth);
|
||||
void FindSplit256(int depth);
|
||||
void FindSplit1024(int depth);
|
||||
void FindSplitLarge(int depth);
|
||||
void InitFirstNode();
|
||||
void UpdatePosition();
|
||||
void UpdatePositionDense();
|
||||
void UpdatePositionSparse();
|
||||
void UpdatePosition(int depth);
|
||||
void UpdatePositionDense(int depth);
|
||||
void UpdatePositionSparse(int depth);
|
||||
void ColSampleTree();
|
||||
void ColSampleLevel();
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* p_out_preds);
|
||||
|
||||
TrainParam param;
|
||||
GPUTrainingParam gpu_param;
|
||||
@@ -78,6 +82,7 @@ class GPUHistBuilder {
|
||||
bool initialised;
|
||||
bool is_dense;
|
||||
DeviceGMat device_matrix;
|
||||
const DMatrix* p_last_fmat_;
|
||||
|
||||
dh::bulk_allocator ba;
|
||||
dh::CubMemory cub_mem;
|
||||
@@ -96,6 +101,9 @@ class GPUHistBuilder {
|
||||
dh::dvec<gpu_gpair> device_gpair;
|
||||
dh::dvec<Node> nodes;
|
||||
dh::dvec<int> feature_flags;
|
||||
dh::dvec<bool> left_child_smallest;
|
||||
dh::dvec<bst_float> prediction_cache;
|
||||
bool prediction_cache_initialised;
|
||||
|
||||
std::vector<int> feature_set_tree;
|
||||
std::vector<int> feature_set_level;
|
||||
|
||||
@@ -122,7 +122,7 @@ struct Split {
|
||||
gpu_gpair right_sum;
|
||||
|
||||
__host__ __device__ Split()
|
||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0) {}
|
||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
||||
|
||||
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
||||
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
|
||||
|
||||
@@ -75,6 +75,11 @@ class GPUHistMaker : public TreeUpdater {
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) override {
|
||||
return builder.UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
|
||||
protected:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
|
||||
Reference in New Issue
Block a user