From 3ca64ffa02dd2f3608fcf11486bffbef438a75d8 Mon Sep 17 00:00:00 2001 From: PSEUDOTENSOR / Jonathan McKinney Date: Fri, 19 May 2017 19:16:24 -0700 Subject: [PATCH] [GPU-Plugin] Improved split finding performance. (#2325) --- plugin/updater_gpu/src/common.cuh | 12 + plugin/updater_gpu/src/find_split_sorting.cuh | 67 ++-- plugin/updater_gpu/src/gpu_hist_builder.cu | 305 +++++++++--------- plugin/updater_gpu/src/gpu_hist_builder.cuh | 15 +- 4 files changed, 206 insertions(+), 193 deletions(-) diff --git a/plugin/updater_gpu/src/common.cuh b/plugin/updater_gpu/src/common.cuh index ef8a8e504..2c26c26bc 100644 --- a/plugin/updater_gpu/src/common.cuh +++ b/plugin/updater_gpu/src/common.cuh @@ -168,5 +168,17 @@ inline std::vector col_sample(std::vector features, float colsample) { return features; } +struct GpairCallbackOp { + // Running prefix + gpu_gpair running_total; + // Constructor + __device__ GpairCallbackOp() : running_total(gpu_gpair()) {} + __device__ gpu_gpair operator()(gpu_gpair block_aggregate) { + gpu_gpair old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/find_split_sorting.cuh b/plugin/updater_gpu/src/find_split_sorting.cuh index 02fe1aa24..08d720984 100644 --- a/plugin/updater_gpu/src/find_split_sorting.cuh +++ b/plugin/updater_gpu/src/find_split_sorting.cuh @@ -2,11 +2,11 @@ * Copyright 2016 Rory mitchell */ #pragma once -#include #include +#include +#include "common.cuh" #include "device_helpers.cuh" #include "types.cuh" -#include "common.cuh" namespace xgboost { namespace tree { @@ -48,19 +48,8 @@ struct GpairTupleCallbackOp { } }; -struct GpairCallbackOp { - // Running prefix - gpu_gpair running_total; - // Constructor - __device__ GpairCallbackOp() : running_total(gpu_gpair()) {} - __device__ gpu_gpair operator()(gpu_gpair block_aggregate) { - gpu_gpair old_prefix = running_total; - running_total += block_aggregate; - return old_prefix; - } -}; - -template struct ReduceEnactorSorting { +template +struct ReduceEnactorSorting { typedef cub::BlockScan GpairScanT; struct _TempStorage { typename GpairScanT::TempStorage gpair_scan; @@ -82,13 +71,15 @@ template struct ReduceEnactorSorting { const int level; __device__ __forceinline__ - ReduceEnactorSorting(TempStorage &temp_storage, // NOLINT + ReduceEnactorSorting(TempStorage &temp_storage, // NOLINT gpu_gpair *d_block_node_sums, int *d_block_node_offsets, ItemIter item_iter, const int level) : temp_storage(temp_storage.Alias()), d_block_node_sums(d_block_node_sums), - d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), - callback_op(), level(level) {} + d_block_node_offsets(d_block_node_offsets), + item_iter(item_iter), + callback_op(), + level(level) {} __device__ __forceinline__ void LoadTile(const bst_uint &offset, const bst_uint &num_remaining) { @@ -102,7 +93,7 @@ template struct ReduceEnactorSorting { // Prevent overflow const int level_begin = (1 << level) - 1; node_id_adjusted = - max(static_cast(node_id) - level_begin, -1); // NOLINT + max(static_cast(node_id) - level_begin, -1); // NOLINT } } @@ -175,15 +166,18 @@ struct FindSplitEnactorSorting { const int level; __device__ __forceinline__ FindSplitEnactorSorting( - TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT + TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT int *d_block_node_offsets, const ItemIter item_iter, const Node *d_nodes, const GPUTrainingParam ¶m, Split *d_split_candidates_out, const int level) : temp_storage(temp_storage.Alias()), d_block_node_sums(d_block_node_sums), - d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), - d_nodes(d_nodes), d_split_candidates_out(d_split_candidates_out), - level(level), param(param) {} + d_block_node_offsets(d_block_node_offsets), + item_iter(item_iter), + d_nodes(d_nodes), + d_split_candidates_out(d_split_candidates_out), + level(level), + param(param) {} __device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted, const bst_uint &node_begin, @@ -254,9 +248,9 @@ struct FindSplitEnactorSorting { return fvalue != left_fvalue; } - __device__ __forceinline__ void - EvaluateSplits(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, - const bst_uint &offset, const bst_uint &num_remaining) { + __device__ __forceinline__ void EvaluateSplits( + const NodeIdT &node_id_adjusted, const bst_uint &node_begin, + const bst_uint &offset, const bst_uint &num_remaining) { bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining && node_id_adjusted >= 0 && node_id >= 0; @@ -289,10 +283,10 @@ struct FindSplitEnactorSorting { } } - __device__ __forceinline__ void - ProcessTile(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, - const bst_uint &offset, const bst_uint &num_remaining, - GpairCallbackOp &callback_op) { // NOLINT + __device__ __forceinline__ void ProcessTile( + const NodeIdT &node_id_adjusted, const bst_uint &node_begin, + const bst_uint &offset, const bst_uint &num_remaining, + GpairCallbackOp &callback_op) { // NOLINT LoadTile(node_id_adjusted, node_begin, offset, num_remaining); // Scan gpair @@ -304,8 +298,8 @@ struct FindSplitEnactorSorting { EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining); } - __device__ __forceinline__ void - WriteBestSplit(const NodeIdT &node_id_adjusted) { + __device__ __forceinline__ void WriteBestSplit( + const NodeIdT &node_id_adjusted) { if (threadIdx.x < 32) { bool active = threadIdx.x < N_WARPS; float warp_loss = @@ -370,7 +364,6 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( const Node *d_nodes, bst_uint num_items, const int num_features, const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, const GPUTrainingParam param, const int *d_feature_flags, const int level) { - if (num_items <= 0 || d_feature_flags[blockIdx.x] != 1) { return; } @@ -400,7 +393,7 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( .ProcessFeature(segment_begin, segment_end); } -void find_split_candidates_sorted(GPUData * data, const int level) { +void find_split_candidates_sorted(GPUData *data, const int level) { const int BLOCK_THREADS = 512; CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps."; @@ -410,9 +403,9 @@ void find_split_candidates_sorted(GPUData * data, const int level) { find_split_candidates_sorted_kernel< BLOCK_THREADS><<>>( data->items_iter, data->split_candidates.data(), data->nodes.data(), - data->fvalues.size(), data->n_features, - data->foffsets.data(), data->node_sums.data(), data->node_offsets.data(), - data->param, data->feature_flags.data(), level); + data->fvalues.size(), data->n_features, data->foffsets.data(), + data->node_sums.data(), data->node_offsets.data(), data->param, + data->feature_flags.data(), level); dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaDeviceSynchronize()); diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index 4e3390891..e70247a65 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -245,174 +245,183 @@ __global__ void find_split_kernel( } } } +template +__global__ void find_split_general_kernel( + const gpu_gpair* d_level_hist, int* d_feature_segments, int depth, + int n_features, int n_bins, Node* d_nodes, float* d_fidx_min_map, + float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, + bool* d_left_child_smallest, bool colsample, int* d_feature_flags) { + typedef cub::KeyValuePair ArgMaxT; + typedef cub::BlockScan + BlockScanT; + typedef cub::BlockReduce MaxReduceT; + typedef cub::BlockReduce SumReduceT; -void GPUHistBuilder::FindSplit(int depth) { - // Specialised based on max_bins - if (param.max_bin <= 256) { - this->FindSplit256(depth); - } else if (param.max_bin <= 1024) { - this->FindSplit1024(depth); - } else { - this->FindSplitLarge(depth); - } -} + union TempStorage { + typename BlockScanT::TempStorage scan; + typename MaxReduceT::TempStorage max_reduce; + typename SumReduceT::TempStorage sum_reduce; + }; -void GPUHistBuilder::FindSplit256(int depth) { - CHECK_LE(param.max_bin, 256); - const int BLOCK_THREADS = 256; - const int GRID_SIZE = n_nodes_level(depth); - bool colsample = - param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; - find_split_kernel<<>>( - hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col, - hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(), - gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample, - feature_flags.data()); + struct UninitializedSplit : cub::Uninitialized {}; + struct UninitializedGpair : cub::Uninitialized {}; - dh::safe_cuda(cudaDeviceSynchronize()); -} -void GPUHistBuilder::FindSplit1024(int depth) { - CHECK_LE(param.max_bin, 1024); - const int BLOCK_THREADS = 1024; - const int GRID_SIZE = n_nodes_level(depth); - bool colsample = - param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; - find_split_kernel<<>>( - hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col, - hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(), - gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample, - feature_flags.data()); + __shared__ UninitializedSplit uninitialized_split; + Split& split = uninitialized_split.Alias(); + __shared__ UninitializedGpair uninitialized_sum; + gpu_gpair& shared_sum = uninitialized_sum.Alias(); + __shared__ ArgMaxT block_max; + __shared__ TempStorage temp_storage; - dh::safe_cuda(cudaDeviceSynchronize()); -} -void GPUHistBuilder::FindSplitLarge(int depth) { - auto counting = thrust::make_counting_iterator(0); - auto d_gidx_feature_map = gidx_feature_map.data(); - int n_bins = hmat_.row_ptr.back(); - int n_features = hmat_.row_ptr.size() - 1; - - auto feature_boundary = [=] __device__(int idx_a, int idx_b) { - int gidx_a = idx_a % n_bins; - int gidx_b = idx_b % n_bins; - return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b]; - }; // NOLINT - - // Reduce node sums - { - size_t temp_storage_bytes; - cub::DeviceSegmentedReduce::Reduce( - nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(), - n_nodes_level(depth) * n_features, feature_segments.data(), - feature_segments.data() + 1, cub::Sum(), gpu_gpair()); - cub_mem.LazyAllocate(temp_storage_bytes); - cub::DeviceSegmentedReduce::Reduce( - cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, - hist.GetLevelPtr(depth), node_sums.data(), - n_nodes_level(depth) * n_features, feature_segments.data(), - feature_segments.data() + 1, cub::Sum(), gpu_gpair()); + if (threadIdx.x == 0) { + split = Split(); } - // Scan - thrust::exclusive_scan_by_key( - counting, counting + hist.LevelSize(depth), - thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(), - gpu_gpair(), feature_boundary); + __syncthreads(); - // Calculate gain - auto d_gain = gain.data(); - auto d_nodes = nodes.data(); - auto d_node_sums = node_sums.data(); - auto d_hist_scan = hist_scan.data(); - GPUTrainingParam gpu_param_alias = - gpu_param; // Must be local variable to be used in device lambda - bool colsample = - param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; - auto d_feature_flags = feature_flags.data(); + int node_idx = n_nodes(depth - 1) + blockIdx.x; - dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) { - int node_segment = idx / n_bins; - int node_idx = n_nodes(depth - 1) + node_segment; - gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients; - float parent_gain = d_nodes[node_idx].root_gain; - int gidx = idx % n_bins; - int findex = d_gidx_feature_map[gidx]; + for (int fidx = 0; fidx < n_features; fidx++) { + if (colsample && d_feature_flags[fidx] == 0) continue; - // colsample - if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) { - d_gain[idx] = 0; - } else { - gpu_gpair scan = d_hist_scan[idx]; - gpu_gpair sum = d_node_sums[node_segment * n_features + findex]; - gpu_gpair missing = parent_sum - sum; + int begin = d_feature_segments[blockIdx.x * n_features + fidx]; + int end = d_feature_segments[blockIdx.x * n_features + fidx + 1]; + int gidx = (begin - (blockIdx.x * n_bins)) + threadIdx.x; + bool thread_active = threadIdx.x < end - begin; + + gpu_gpair feature_sum = gpu_gpair(); + for (int reduce_begin = begin; reduce_begin < end; + reduce_begin += BLOCK_THREADS) { + // Scan histogram + gpu_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x] + : gpu_gpair(); + + feature_sum += + SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum()); + } + + if (threadIdx.x == 0) { + shared_sum = feature_sum; + } + // __syncthreads(); // no need to synch because below there is a Scan + + GpairCallbackOp prefix_op = GpairCallbackOp(); + for (int scan_begin = begin; scan_begin < end; + scan_begin += BLOCK_THREADS) { + gpu_gpair bin = + thread_active ? d_level_hist[scan_begin + threadIdx.x] : gpu_gpair(); + + BlockScanT(temp_storage.scan) + .ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + + // Calculate gain + gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients; + float parent_gain = d_nodes[node_idx].root_gain; + + gpu_gpair missing = parent_sum - shared_sum; bool missing_left; - d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain, - gpu_param_alias, missing_left); + float gain = thread_active + ? loss_chg_missing(bin, missing, parent_sum, parent_gain, + gpu_param, missing_left) + : -FLT_MAX; + __syncthreads(); + + // Find thread with best gain + ArgMaxT tuple(threadIdx.x, gain); + ArgMaxT best = + MaxReduceT(temp_storage.max_reduce).Reduce(tuple, cub::ArgMax()); + + if (threadIdx.x == 0) { + block_max = best; + } + + __syncthreads(); + + // Best thread updates split + if (threadIdx.x == block_max.key) { + float fvalue; + if (threadIdx.x == 0 && + begin == scan_begin) { // check at start of first tile + fvalue = d_fidx_min_map[fidx]; + } else { + fvalue = d_gidx_fvalue_map[gidx - 1]; + } + + gpu_gpair left = missing_left ? bin + missing : bin; + gpu_gpair right = parent_sum - left; + + split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param); + } + __syncthreads(); + } // end scan + } // end over features + + // Create node + if (threadIdx.x == 0) { + d_nodes[node_idx].split = split; + if (depth == 0) { + // split.Print(); } - }); - dh::safe_cuda(cudaDeviceSynchronize()); - // Find best gain - { - size_t temp_storage_bytes; - cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(), - argmax.data(), n_nodes_level(depth), - hist_node_segments.data(), - hist_node_segments.data() + 1); - cub_mem.LazyAllocate(temp_storage_bytes); - cub::DeviceSegmentedReduce::ArgMax( - cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(), - argmax.data(), n_nodes_level(depth), hist_node_segments.data(), - hist_node_segments.data() + 1); - } + d_nodes[left_child_nidx(node_idx)] = Node( + split.left_sum, + CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()), + CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess())); - auto d_argmax = argmax.data(); - auto d_gidx_fvalue_map = gidx_fvalue_map.data(); - auto d_fidx_min_map = fidx_min_map.data(); - auto d_left_child_smallest = left_child_smallest.data(); - - dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) { - int max_idx = n_bins * idx + d_argmax[idx].key; - int gidx = max_idx % n_bins; - int fidx = d_gidx_feature_map[gidx]; - int node_segment = max_idx / n_bins; - int node_idx = n_nodes(depth - 1) + node_segment; - gpu_gpair scan = d_hist_scan[max_idx]; - gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients; - float parent_gain = d_nodes[node_idx].root_gain; - gpu_gpair sum = d_node_sums[node_segment * n_features + fidx]; - gpu_gpair missing = parent_sum - sum; - - bool missing_left; - float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain, - gpu_param_alias, missing_left); - - float fvalue; - if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) { - fvalue = d_fidx_min_map[fidx]; - } else { - fvalue = d_gidx_fvalue_map[gidx - 1]; - } - gpu_gpair left = missing_left ? scan + missing : scan; - gpu_gpair right = parent_sum - left; - d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left, - right, gpu_param_alias); - - d_nodes[left_child_nidx(node_idx)] = - Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()), - CalcWeight(gpu_param_alias, left.grad(), left.hess())); - - d_nodes[right_child_nidx(node_idx)] = - Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()), - CalcWeight(gpu_param_alias, right.grad(), right.hess())); + d_nodes[right_child_nidx(node_idx)] = Node( + split.right_sum, + CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()), + CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess())); // Record smallest node - if (left.hess() <= right.hess()) { + if (split.left_sum.hess() <= split.right_sum.hess()) { d_left_child_smallest[node_idx] = true; } else { d_left_child_smallest[node_idx] = false; } - }); + } +} + +#define MIN_BLOCK_THREADS 32 +#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size + +void GPUHistBuilder::FindSplit(int depth) { + // Specialised based on max_bins + this->FindSplitSpecialize(depth); +} + +template <> +void GPUHistBuilder::FindSplitSpecialize(int depth) { + const int GRID_SIZE = n_nodes_level(depth); + bool colsample = + param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; + + find_split_general_kernel< + MAX_BLOCK_THREADS><<>>( + hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col, + hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(), + gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample, + feature_flags.data()); + + dh::safe_cuda(cudaDeviceSynchronize()); +} +template +void GPUHistBuilder::FindSplitSpecialize(int depth) { + if (param.max_bin <= BLOCK_THREADS) { + const int GRID_SIZE = n_nodes_level(depth); + bool colsample = + param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; + + find_split_general_kernel<<>>( + hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col, + hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(), + gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), + colsample, feature_flags.data()); + } else { + this->FindSplitSpecialize(depth); + } + dh::safe_cuda(cudaDeviceSynchronize()); } diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh index 7f6bc8f51..fb9a5068c 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cuh +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -62,17 +62,16 @@ class GPUHistBuilder { RegTree *p_tree); void BuildHist(int depth); void FindSplit(int depth); - void FindSplit256(int depth); - void FindSplit1024(int depth); - void FindSplitLarge(int depth); + template + void FindSplitSpecialize(int depth); void InitFirstNode(); void UpdatePosition(int depth); void UpdatePositionDense(int depth); void UpdatePositionSparse(int depth); void ColSampleTree(); void ColSampleLevel(); - bool UpdatePredictionCache(const DMatrix* data, - std::vector* p_out_preds); + bool UpdatePredictionCache(const DMatrix *data, + std::vector *p_out_preds); TrainParam param; GPUTrainingParam gpu_param; @@ -82,7 +81,7 @@ class GPUHistBuilder { bool initialised; bool is_dense; DeviceGMat device_matrix; - const DMatrix* p_last_fmat_; + const DMatrix *p_last_fmat_; dh::bulk_allocator ba; dh::CubMemory cub_mem; @@ -101,8 +100,8 @@ class GPUHistBuilder { dh::dvec device_gpair; dh::dvec nodes; dh::dvec feature_flags; - dh::dvec left_child_smallest; - dh::dvec prediction_cache; + dh::dvec left_child_smallest; + dh::dvec prediction_cache; bool prediction_cache_initialised; std::vector feature_set_tree;