[GPU-Plugin] Improved split finding performance. (#2325)
This commit is contained in:
parent
29289d2302
commit
3ca64ffa02
@ -168,5 +168,17 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
|||||||
|
|
||||||
return features;
|
return features;
|
||||||
}
|
}
|
||||||
|
struct GpairCallbackOp {
|
||||||
|
// Running prefix
|
||||||
|
gpu_gpair running_total;
|
||||||
|
// Constructor
|
||||||
|
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
|
||||||
|
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
|
||||||
|
gpu_gpair old_prefix = running_total;
|
||||||
|
running_total += block_aggregate;
|
||||||
|
return old_prefix;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -2,11 +2,11 @@
|
|||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cub/cub.cuh>
|
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include "common.cuh"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "types.cuh"
|
#include "types.cuh"
|
||||||
#include "common.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -48,19 +48,8 @@ struct GpairTupleCallbackOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GpairCallbackOp {
|
template <int BLOCK_THREADS>
|
||||||
// Running prefix
|
struct ReduceEnactorSorting {
|
||||||
gpu_gpair running_total;
|
|
||||||
// Constructor
|
|
||||||
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
|
|
||||||
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
|
|
||||||
gpu_gpair old_prefix = running_total;
|
|
||||||
running_total += block_aggregate;
|
|
||||||
return old_prefix;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int BLOCK_THREADS> struct ReduceEnactorSorting {
|
|
||||||
typedef cub::BlockScan<ScanTuple, BLOCK_THREADS> GpairScanT;
|
typedef cub::BlockScan<ScanTuple, BLOCK_THREADS> GpairScanT;
|
||||||
struct _TempStorage {
|
struct _TempStorage {
|
||||||
typename GpairScanT::TempStorage gpair_scan;
|
typename GpairScanT::TempStorage gpair_scan;
|
||||||
@ -87,8 +76,10 @@ template <int BLOCK_THREADS> struct ReduceEnactorSorting {
|
|||||||
ItemIter item_iter, const int level)
|
ItemIter item_iter, const int level)
|
||||||
: temp_storage(temp_storage.Alias()),
|
: temp_storage(temp_storage.Alias()),
|
||||||
d_block_node_sums(d_block_node_sums),
|
d_block_node_sums(d_block_node_sums),
|
||||||
d_block_node_offsets(d_block_node_offsets), item_iter(item_iter),
|
d_block_node_offsets(d_block_node_offsets),
|
||||||
callback_op(), level(level) {}
|
item_iter(item_iter),
|
||||||
|
callback_op(),
|
||||||
|
level(level) {}
|
||||||
|
|
||||||
__device__ __forceinline__ void LoadTile(const bst_uint &offset,
|
__device__ __forceinline__ void LoadTile(const bst_uint &offset,
|
||||||
const bst_uint &num_remaining) {
|
const bst_uint &num_remaining) {
|
||||||
@ -181,9 +172,12 @@ struct FindSplitEnactorSorting {
|
|||||||
const int level)
|
const int level)
|
||||||
: temp_storage(temp_storage.Alias()),
|
: temp_storage(temp_storage.Alias()),
|
||||||
d_block_node_sums(d_block_node_sums),
|
d_block_node_sums(d_block_node_sums),
|
||||||
d_block_node_offsets(d_block_node_offsets), item_iter(item_iter),
|
d_block_node_offsets(d_block_node_offsets),
|
||||||
d_nodes(d_nodes), d_split_candidates_out(d_split_candidates_out),
|
item_iter(item_iter),
|
||||||
level(level), param(param) {}
|
d_nodes(d_nodes),
|
||||||
|
d_split_candidates_out(d_split_candidates_out),
|
||||||
|
level(level),
|
||||||
|
param(param) {}
|
||||||
|
|
||||||
__device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted,
|
__device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted,
|
||||||
const bst_uint &node_begin,
|
const bst_uint &node_begin,
|
||||||
@ -254,8 +248,8 @@ struct FindSplitEnactorSorting {
|
|||||||
return fvalue != left_fvalue;
|
return fvalue != left_fvalue;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void
|
__device__ __forceinline__ void EvaluateSplits(
|
||||||
EvaluateSplits(const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
|
const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
|
||||||
const bst_uint &offset, const bst_uint &num_remaining) {
|
const bst_uint &offset, const bst_uint &num_remaining) {
|
||||||
bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining &&
|
bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining &&
|
||||||
node_id_adjusted >= 0 && node_id >= 0;
|
node_id_adjusted >= 0 && node_id >= 0;
|
||||||
@ -289,8 +283,8 @@ struct FindSplitEnactorSorting {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void
|
__device__ __forceinline__ void ProcessTile(
|
||||||
ProcessTile(const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
|
const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
|
||||||
const bst_uint &offset, const bst_uint &num_remaining,
|
const bst_uint &offset, const bst_uint &num_remaining,
|
||||||
GpairCallbackOp &callback_op) { // NOLINT
|
GpairCallbackOp &callback_op) { // NOLINT
|
||||||
LoadTile(node_id_adjusted, node_begin, offset, num_remaining);
|
LoadTile(node_id_adjusted, node_begin, offset, num_remaining);
|
||||||
@ -304,8 +298,8 @@ struct FindSplitEnactorSorting {
|
|||||||
EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining);
|
EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void
|
__device__ __forceinline__ void WriteBestSplit(
|
||||||
WriteBestSplit(const NodeIdT &node_id_adjusted) {
|
const NodeIdT &node_id_adjusted) {
|
||||||
if (threadIdx.x < 32) {
|
if (threadIdx.x < 32) {
|
||||||
bool active = threadIdx.x < N_WARPS;
|
bool active = threadIdx.x < N_WARPS;
|
||||||
float warp_loss =
|
float warp_loss =
|
||||||
@ -370,7 +364,6 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
|
|||||||
const Node *d_nodes, bst_uint num_items, const int num_features,
|
const Node *d_nodes, bst_uint num_items, const int num_features,
|
||||||
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
|
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
|
||||||
const GPUTrainingParam param, const int *d_feature_flags, const int level) {
|
const GPUTrainingParam param, const int *d_feature_flags, const int level) {
|
||||||
|
|
||||||
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
|
if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -400,7 +393,7 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
|
|||||||
.ProcessFeature(segment_begin, segment_end);
|
.ProcessFeature(segment_begin, segment_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
void find_split_candidates_sorted(GPUData * data, const int level) {
|
void find_split_candidates_sorted(GPUData *data, const int level) {
|
||||||
const int BLOCK_THREADS = 512;
|
const int BLOCK_THREADS = 512;
|
||||||
|
|
||||||
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
|
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
|
||||||
@ -410,9 +403,9 @@ void find_split_candidates_sorted(GPUData * data, const int level) {
|
|||||||
find_split_candidates_sorted_kernel<
|
find_split_candidates_sorted_kernel<
|
||||||
BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>(
|
BLOCK_THREADS><<<grid_size, BLOCK_THREADS>>>(
|
||||||
data->items_iter, data->split_candidates.data(), data->nodes.data(),
|
data->items_iter, data->split_candidates.data(), data->nodes.data(),
|
||||||
data->fvalues.size(), data->n_features,
|
data->fvalues.size(), data->n_features, data->foffsets.data(),
|
||||||
data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(),
|
data->node_sums.data(), data->node_offsets.data(), data->param,
|
||||||
data->param, data->feature_flags.data(), level);
|
data->feature_flags.data(), level);
|
||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
dh::safe_cuda(cudaGetLastError());
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
|||||||
@ -245,174 +245,183 @@ __global__ void find_split_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
template <int BLOCK_THREADS>
|
||||||
|
__global__ void find_split_general_kernel(
|
||||||
|
const gpu_gpair* d_level_hist, int* d_feature_segments, int depth,
|
||||||
|
int n_features, int n_bins, Node* d_nodes, float* d_fidx_min_map,
|
||||||
|
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
||||||
|
bool* d_left_child_smallest, bool colsample, int* d_feature_flags) {
|
||||||
|
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||||
|
typedef cub::BlockScan<gpu_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||||
|
BlockScanT;
|
||||||
|
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||||
|
typedef cub::BlockReduce<gpu_gpair, BLOCK_THREADS> SumReduceT;
|
||||||
|
|
||||||
void GPUHistBuilder::FindSplit(int depth) {
|
union TempStorage {
|
||||||
// Specialised based on max_bins
|
typename BlockScanT::TempStorage scan;
|
||||||
if (param.max_bin <= 256) {
|
typename MaxReduceT::TempStorage max_reduce;
|
||||||
this->FindSplit256(depth);
|
typename SumReduceT::TempStorage sum_reduce;
|
||||||
} else if (param.max_bin <= 1024) {
|
};
|
||||||
this->FindSplit1024(depth);
|
|
||||||
} else {
|
|
||||||
this->FindSplitLarge(depth);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GPUHistBuilder::FindSplit256(int depth) {
|
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
||||||
CHECK_LE(param.max_bin, 256);
|
struct UninitializedGpair : cub::Uninitialized<gpu_gpair> {};
|
||||||
const int BLOCK_THREADS = 256;
|
|
||||||
const int GRID_SIZE = n_nodes_level(depth);
|
|
||||||
bool colsample =
|
|
||||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
|
||||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
|
||||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
|
||||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
|
||||||
feature_flags.data());
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
__shared__ UninitializedSplit uninitialized_split;
|
||||||
}
|
Split& split = uninitialized_split.Alias();
|
||||||
void GPUHistBuilder::FindSplit1024(int depth) {
|
__shared__ UninitializedGpair uninitialized_sum;
|
||||||
CHECK_LE(param.max_bin, 1024);
|
gpu_gpair& shared_sum = uninitialized_sum.Alias();
|
||||||
const int BLOCK_THREADS = 1024;
|
__shared__ ArgMaxT block_max;
|
||||||
const int GRID_SIZE = n_nodes_level(depth);
|
__shared__ TempStorage temp_storage;
|
||||||
bool colsample =
|
|
||||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
|
||||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
|
||||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
|
||||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
|
||||||
feature_flags.data());
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
if (threadIdx.x == 0) {
|
||||||
}
|
split = Split();
|
||||||
void GPUHistBuilder::FindSplitLarge(int depth) {
|
|
||||||
auto counting = thrust::make_counting_iterator(0);
|
|
||||||
auto d_gidx_feature_map = gidx_feature_map.data();
|
|
||||||
int n_bins = hmat_.row_ptr.back();
|
|
||||||
int n_features = hmat_.row_ptr.size() - 1;
|
|
||||||
|
|
||||||
auto feature_boundary = [=] __device__(int idx_a, int idx_b) {
|
|
||||||
int gidx_a = idx_a % n_bins;
|
|
||||||
int gidx_b = idx_b % n_bins;
|
|
||||||
return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b];
|
|
||||||
}; // NOLINT
|
|
||||||
|
|
||||||
// Reduce node sums
|
|
||||||
{
|
|
||||||
size_t temp_storage_bytes;
|
|
||||||
cub::DeviceSegmentedReduce::Reduce(
|
|
||||||
nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(),
|
|
||||||
n_nodes_level(depth) * n_features, feature_segments.data(),
|
|
||||||
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
|
|
||||||
cub_mem.LazyAllocate(temp_storage_bytes);
|
|
||||||
cub::DeviceSegmentedReduce::Reduce(
|
|
||||||
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
|
|
||||||
hist.GetLevelPtr(depth), node_sums.data(),
|
|
||||||
n_nodes_level(depth) * n_features, feature_segments.data(),
|
|
||||||
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan
|
__syncthreads();
|
||||||
thrust::exclusive_scan_by_key(
|
|
||||||
counting, counting + hist.LevelSize(depth),
|
int node_idx = n_nodes(depth - 1) + blockIdx.x;
|
||||||
thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(),
|
|
||||||
gpu_gpair(), feature_boundary);
|
for (int fidx = 0; fidx < n_features; fidx++) {
|
||||||
|
if (colsample && d_feature_flags[fidx] == 0) continue;
|
||||||
|
|
||||||
|
int begin = d_feature_segments[blockIdx.x * n_features + fidx];
|
||||||
|
int end = d_feature_segments[blockIdx.x * n_features + fidx + 1];
|
||||||
|
int gidx = (begin - (blockIdx.x * n_bins)) + threadIdx.x;
|
||||||
|
bool thread_active = threadIdx.x < end - begin;
|
||||||
|
|
||||||
|
gpu_gpair feature_sum = gpu_gpair();
|
||||||
|
for (int reduce_begin = begin; reduce_begin < end;
|
||||||
|
reduce_begin += BLOCK_THREADS) {
|
||||||
|
// Scan histogram
|
||||||
|
gpu_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
||||||
|
: gpu_gpair();
|
||||||
|
|
||||||
|
feature_sum +=
|
||||||
|
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
shared_sum = feature_sum;
|
||||||
|
}
|
||||||
|
// __syncthreads(); // no need to synch because below there is a Scan
|
||||||
|
|
||||||
|
GpairCallbackOp prefix_op = GpairCallbackOp();
|
||||||
|
for (int scan_begin = begin; scan_begin < end;
|
||||||
|
scan_begin += BLOCK_THREADS) {
|
||||||
|
gpu_gpair bin =
|
||||||
|
thread_active ? d_level_hist[scan_begin + threadIdx.x] : gpu_gpair();
|
||||||
|
|
||||||
|
BlockScanT(temp_storage.scan)
|
||||||
|
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||||
|
|
||||||
// Calculate gain
|
// Calculate gain
|
||||||
auto d_gain = gain.data();
|
|
||||||
auto d_nodes = nodes.data();
|
|
||||||
auto d_node_sums = node_sums.data();
|
|
||||||
auto d_hist_scan = hist_scan.data();
|
|
||||||
GPUTrainingParam gpu_param_alias =
|
|
||||||
gpu_param; // Must be local variable to be used in device lambda
|
|
||||||
bool colsample =
|
|
||||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
|
||||||
auto d_feature_flags = feature_flags.data();
|
|
||||||
|
|
||||||
dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) {
|
|
||||||
int node_segment = idx / n_bins;
|
|
||||||
int node_idx = n_nodes(depth - 1) + node_segment;
|
|
||||||
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||||
float parent_gain = d_nodes[node_idx].root_gain;
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
int gidx = idx % n_bins;
|
|
||||||
int findex = d_gidx_feature_map[gidx];
|
|
||||||
|
|
||||||
// colsample
|
gpu_gpair missing = parent_sum - shared_sum;
|
||||||
if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) {
|
|
||||||
d_gain[idx] = 0;
|
|
||||||
} else {
|
|
||||||
gpu_gpair scan = d_hist_scan[idx];
|
|
||||||
gpu_gpair sum = d_node_sums[node_segment * n_features + findex];
|
|
||||||
gpu_gpair missing = parent_sum - sum;
|
|
||||||
|
|
||||||
bool missing_left;
|
bool missing_left;
|
||||||
d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain,
|
float gain = thread_active
|
||||||
gpu_param_alias, missing_left);
|
? loss_chg_missing(bin, missing, parent_sum, parent_gain,
|
||||||
}
|
gpu_param, missing_left)
|
||||||
});
|
: -FLT_MAX;
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
__syncthreads();
|
||||||
|
|
||||||
// Find best gain
|
// Find thread with best gain
|
||||||
{
|
ArgMaxT tuple(threadIdx.x, gain);
|
||||||
size_t temp_storage_bytes;
|
ArgMaxT best =
|
||||||
cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(),
|
MaxReduceT(temp_storage.max_reduce).Reduce(tuple, cub::ArgMax());
|
||||||
argmax.data(), n_nodes_level(depth),
|
|
||||||
hist_node_segments.data(),
|
if (threadIdx.x == 0) {
|
||||||
hist_node_segments.data() + 1);
|
block_max = best;
|
||||||
cub_mem.LazyAllocate(temp_storage_bytes);
|
|
||||||
cub::DeviceSegmentedReduce::ArgMax(
|
|
||||||
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(),
|
|
||||||
argmax.data(), n_nodes_level(depth), hist_node_segments.data(),
|
|
||||||
hist_node_segments.data() + 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_argmax = argmax.data();
|
__syncthreads();
|
||||||
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
|
||||||
auto d_fidx_min_map = fidx_min_map.data();
|
|
||||||
auto d_left_child_smallest = left_child_smallest.data();
|
|
||||||
|
|
||||||
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
|
|
||||||
int max_idx = n_bins * idx + d_argmax[idx].key;
|
|
||||||
int gidx = max_idx % n_bins;
|
|
||||||
int fidx = d_gidx_feature_map[gidx];
|
|
||||||
int node_segment = max_idx / n_bins;
|
|
||||||
int node_idx = n_nodes(depth - 1) + node_segment;
|
|
||||||
gpu_gpair scan = d_hist_scan[max_idx];
|
|
||||||
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
|
||||||
float parent_gain = d_nodes[node_idx].root_gain;
|
|
||||||
gpu_gpair sum = d_node_sums[node_segment * n_features + fidx];
|
|
||||||
gpu_gpair missing = parent_sum - sum;
|
|
||||||
|
|
||||||
bool missing_left;
|
|
||||||
float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain,
|
|
||||||
gpu_param_alias, missing_left);
|
|
||||||
|
|
||||||
|
// Best thread updates split
|
||||||
|
if (threadIdx.x == block_max.key) {
|
||||||
float fvalue;
|
float fvalue;
|
||||||
if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) {
|
if (threadIdx.x == 0 &&
|
||||||
|
begin == scan_begin) { // check at start of first tile
|
||||||
fvalue = d_fidx_min_map[fidx];
|
fvalue = d_fidx_min_map[fidx];
|
||||||
} else {
|
} else {
|
||||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||||
}
|
}
|
||||||
gpu_gpair left = missing_left ? scan + missing : scan;
|
|
||||||
|
gpu_gpair left = missing_left ? bin + missing : bin;
|
||||||
gpu_gpair right = parent_sum - left;
|
gpu_gpair right = parent_sum - left;
|
||||||
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
|
|
||||||
right, gpu_param_alias);
|
|
||||||
|
|
||||||
d_nodes[left_child_nidx(node_idx)] =
|
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||||
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
|
}
|
||||||
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
|
__syncthreads();
|
||||||
|
} // end scan
|
||||||
|
} // end over features
|
||||||
|
|
||||||
d_nodes[right_child_nidx(node_idx)] =
|
// Create node
|
||||||
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
|
if (threadIdx.x == 0) {
|
||||||
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
|
d_nodes[node_idx].split = split;
|
||||||
|
if (depth == 0) {
|
||||||
|
// split.Print();
|
||||||
|
}
|
||||||
|
|
||||||
|
d_nodes[left_child_nidx(node_idx)] = Node(
|
||||||
|
split.left_sum,
|
||||||
|
CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()),
|
||||||
|
CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess()));
|
||||||
|
|
||||||
|
d_nodes[right_child_nidx(node_idx)] = Node(
|
||||||
|
split.right_sum,
|
||||||
|
CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()),
|
||||||
|
CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess()));
|
||||||
|
|
||||||
// Record smallest node
|
// Record smallest node
|
||||||
if (left.hess() <= right.hess()) {
|
if (split.left_sum.hess() <= split.right_sum.hess()) {
|
||||||
d_left_child_smallest[node_idx] = true;
|
d_left_child_smallest[node_idx] = true;
|
||||||
} else {
|
} else {
|
||||||
d_left_child_smallest[node_idx] = false;
|
d_left_child_smallest[node_idx] = false;
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MIN_BLOCK_THREADS 32
|
||||||
|
#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size
|
||||||
|
|
||||||
|
void GPUHistBuilder::FindSplit(int depth) {
|
||||||
|
// Specialised based on max_bins
|
||||||
|
this->FindSplitSpecialize<MIN_BLOCK_THREADS>(depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void GPUHistBuilder::FindSplitSpecialize<MAX_BLOCK_THREADS>(int depth) {
|
||||||
|
const int GRID_SIZE = n_nodes_level(depth);
|
||||||
|
bool colsample =
|
||||||
|
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||||
|
|
||||||
|
find_split_general_kernel<
|
||||||
|
MAX_BLOCK_THREADS><<<GRID_SIZE, MAX_BLOCK_THREADS>>>(
|
||||||
|
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||||
|
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||||
|
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
||||||
|
feature_flags.data());
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
template <int BLOCK_THREADS>
|
||||||
|
void GPUHistBuilder::FindSplitSpecialize(int depth) {
|
||||||
|
if (param.max_bin <= BLOCK_THREADS) {
|
||||||
|
const int GRID_SIZE = n_nodes_level(depth);
|
||||||
|
bool colsample =
|
||||||
|
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||||
|
|
||||||
|
find_split_general_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
|
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||||
|
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||||
|
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(),
|
||||||
|
colsample, feature_flags.data());
|
||||||
|
} else {
|
||||||
|
this->FindSplitSpecialize<BLOCK_THREADS + 32>(depth);
|
||||||
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -62,17 +62,16 @@ class GPUHistBuilder {
|
|||||||
RegTree *p_tree);
|
RegTree *p_tree);
|
||||||
void BuildHist(int depth);
|
void BuildHist(int depth);
|
||||||
void FindSplit(int depth);
|
void FindSplit(int depth);
|
||||||
void FindSplit256(int depth);
|
template <int BLOCK_THREADS>
|
||||||
void FindSplit1024(int depth);
|
void FindSplitSpecialize(int depth);
|
||||||
void FindSplitLarge(int depth);
|
|
||||||
void InitFirstNode();
|
void InitFirstNode();
|
||||||
void UpdatePosition(int depth);
|
void UpdatePosition(int depth);
|
||||||
void UpdatePositionDense(int depth);
|
void UpdatePositionDense(int depth);
|
||||||
void UpdatePositionSparse(int depth);
|
void UpdatePositionSparse(int depth);
|
||||||
void ColSampleTree();
|
void ColSampleTree();
|
||||||
void ColSampleLevel();
|
void ColSampleLevel();
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
bool UpdatePredictionCache(const DMatrix *data,
|
||||||
std::vector<bst_float>* p_out_preds);
|
std::vector<bst_float> *p_out_preds);
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
GPUTrainingParam gpu_param;
|
GPUTrainingParam gpu_param;
|
||||||
@ -82,7 +81,7 @@ class GPUHistBuilder {
|
|||||||
bool initialised;
|
bool initialised;
|
||||||
bool is_dense;
|
bool is_dense;
|
||||||
DeviceGMat device_matrix;
|
DeviceGMat device_matrix;
|
||||||
const DMatrix* p_last_fmat_;
|
const DMatrix *p_last_fmat_;
|
||||||
|
|
||||||
dh::bulk_allocator ba;
|
dh::bulk_allocator ba;
|
||||||
dh::CubMemory cub_mem;
|
dh::CubMemory cub_mem;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user