diff --git a/doc/gpu/index.md b/doc/gpu/index.md index da913f2f1..f8b375bb7 100644 --- a/doc/gpu/index.md +++ b/doc/gpu/index.md @@ -17,7 +17,7 @@ Specify the 'tree_method' parameter as one of the following algorithms. +==============+=================================================================================================================================================================================================================+ | gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | +--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Cannot be used with labels larger in magnitude than 2^16 due to it's histogram aggregation algorithm. | +| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. | +--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` @@ -44,17 +44,18 @@ Specify the 'tree_method' parameter as one of the following algorithms. +--------------------+------------+-----------+ | predictor | |tick| | |tick| | +--------------------+------------+-----------+ +| grow_policy | |cross| | |tick| | ++--------------------+------------+-----------+ -| ``` GPU accelerated prediction is enabled by default for the above mentioned 'tree_method' parameters but can be switched to CPU prediction by setting 'predictor':'cpu_predictor'. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting 'predictor':'gpu_predictor'. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0. -Multiple GPUs can be used with the grow_gpu_hist parameter using the n_gpus parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If gpu_id is specified as non-zero, the gpu device order is mod(gpu_id + i) % n_visible_devices for i=0 to n_gpus-1. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. For example, when n_features * n_bins * 2^depth divided by time of each round/iteration becomes comparable to the real PCI 16x bus bandwidth of order 4GB/s to 10GB/s, then AllReduce will dominant code speed and multiple GPUs become ineffective at increasing performance. Also, CPU overhead between GPU calls can limit usefulness of multiple GPUs. +Multiple GPUs can be used with the grow_gpu_hist parameter using the n_gpus parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If gpu_id is specified as non-zero, the gpu device order is mod(gpu_id + i) % n_visible_devices for i=0 to n_gpus-1. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. -This plugin currently works with the CLI version and python version. +This plugin currently works with the CLI, python and R - see installation guide for details. Python example: ```python @@ -83,7 +84,6 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a | exact | 1082.20 | +--------------+----------+ -| ``` [See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'gpu_exact' tree_method. @@ -91,6 +91,8 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a ## References [Mitchell R, Frank E. (2017) Accelerating the XGBoost algorithm using GPU computing. PeerJ Computer Science 3:e127 https://doi.org/10.7717/peerj-cs.127](https://peerj.com/articles/cs-127/) +[Nvidia Parallel Forall: Gradient Boosting, Decision Trees and XGBoost with CUDA](https://devblogs.nvidia.com/parallelforall/gradient-boosting-decision-trees-xgboost-cuda/) + ## Author Rory Mitchell Jonathan C. McKinney diff --git a/src/learner.cc b/src/learner.cc index 32e807137..117f73c92 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -111,7 +111,6 @@ struct LearnerTrainParam : public dmlc::Parameter { .add_enum("hist", 3) .add_enum("gpu_exact", 4) .add_enum("gpu_hist", 5) - .add_enum("gpu_hist_experimental", 6) .describe("Choice of tree construction method."); DMLC_DECLARE_FIELD(test_flag).set_default("").describe( "Internal test flag"); @@ -188,14 +187,6 @@ class LearnerImpl : public Learner { if (cfg_.count("predictor") == 0) { cfg_["predictor"] = "gpu_predictor"; } - } else if (tparam.tree_method == 6) { - this->AssertGPUSupport(); - if (cfg_.count("updater") == 0) { - cfg_["updater"] = "grow_gpu_hist_experimental,prune"; - } - if (cfg_.count("predictor") == 0) { - cfg_["predictor"] = "gpu_predictor"; - } } } @@ -468,7 +459,7 @@ class LearnerImpl : public Learner { // if not, initialize the column access. inline void LazyInitDMatrix(DMatrix* p_train) { if (tparam.tree_method == 3 || tparam.tree_method == 4 || - tparam.tree_method == 5 || tparam.tree_method == 6) { + tparam.tree_method == 5) { return; } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 2452dba55..32630903e 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -35,7 +35,6 @@ DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu); DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); -DMLC_REGISTRY_LINK_TAG(updater_gpu_hist_experimental); #endif } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 77b48b157..7154779b5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,8 +1,13 @@ /*! * Copyright 2017 XGBoost contributors */ +#include +#include +#include #include +#include #include +#include #include #include #include "../common/compressed_iterator.h" @@ -17,93 +22,110 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); -typedef bst_gpair_integer gpair_sum_t; +typedef bst_gpair_precise gpair_sum_t; -// Helper for explicit template specialisation -template -struct Int {}; +template +__device__ gpair_sum_t ReduceFeature(const gpair_sum_t* begin, + const gpair_sum_t* end, + temp_storage_t* temp_storage) { + __shared__ cub::Uninitialized uninitialized_sum; + gpair_sum_t& shared_sum = uninitialized_sum.Alias(); -struct DeviceGMat { - dh::dvec gidx_buffer; - common::CompressedIterator gidx; - dh::dvec row_ptr; - void Init(int device_idx, const common::GHistIndexMatrix& gmat, - bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, - bst_ulong row_end, int n_bins) { - dh::safe_cuda(cudaSetDevice(device_idx)); - CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated"; - CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1) - << "row_ptr must be externally allocated"; + gpair_sum_t local_sum = gpair_sum_t(); + for (auto itr = begin; itr < end; itr += BLOCK_THREADS) { + bool thread_active = itr + threadIdx.x < end; + // Scan histogram + gpair_sum_t bin = thread_active ? *(itr + threadIdx.x) : gpair_sum_t(); - common::CompressedBufferWriter cbw(n_bins); - std::vector host_buffer(gidx_buffer.size()); - cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin, - gmat.index.begin() + element_end); - gidx_buffer = host_buffer; - gidx = common::CompressedIterator(gidx_buffer.data(), n_bins); - - // row_ptr - dh::safe_cuda(cudaMemcpy(row_ptr.data(), gmat.row_ptr.data() + row_begin, - row_ptr.size() * sizeof(size_t), - cudaMemcpyHostToDevice)); - // normalise row_ptr - size_t start = gmat.row_ptr[row_begin]; - auto d_row_ptr = row_ptr.data(); - dh::launch_n(row_ptr.device_idx(), row_ptr.size(), - [=] __device__(size_t idx) { d_row_ptr[idx] -= start; }); - } -}; - -struct HistHelper { - gpair_sum_t* d_hist; - int n_bins; - __host__ __device__ HistHelper(gpair_sum_t* ptr, int n_bins) - : d_hist(ptr), n_bins(n_bins) {} - - __device__ void Add(bst_gpair gpair, int gidx, int nidx) const { - int hist_idx = nidx * n_bins + gidx; - - AtomicAddGpair(d_hist + hist_idx, gpair); - } - __device__ gpair_sum_t Get(int gidx, int nidx) const { - return d_hist[nidx * n_bins + gidx]; - } -}; - -struct DeviceHist { - int n_bins; - dh::dvec data; - - void Init(int n_bins_in) { - this->n_bins = n_bins_in; - CHECK(!data.empty()) << "DeviceHist must be externally allocated"; + local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum()); } - void Reset(int device_idx) { - cudaSetDevice(device_idx); - data.fill(gpair_sum_t()); + if (threadIdx.x == 0) { + shared_sum = local_sum; } + __syncthreads(); - HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); } + return shared_sum; +} - gpair_sum_t* GetLevelPtr(int depth) { - return data.data() + n_nodes(depth - 1) * n_bins; +template +__device__ void EvaluateFeature(int fidx, const gpair_sum_t* hist, + const int* feature_segments, float min_fvalue, + const float* gidx_fvalue_map, + DeviceSplitCandidate* best_split, + const DeviceNodeStats& node, + const GPUTrainingParam& param, + temp_storage_t* temp_storage) { + int gidx_begin = feature_segments[fidx]; + int gidx_end = feature_segments[fidx + 1]; + + gpair_sum_t feature_sum = ReduceFeature( + hist + gidx_begin, hist + gidx_end, temp_storage); + + auto prefix_op = SumCallbackOp(); + for (int scan_begin = gidx_begin; scan_begin < gidx_end; + scan_begin += BLOCK_THREADS) { + bool thread_active = scan_begin + threadIdx.x < gidx_end; + + gpair_sum_t bin = + thread_active ? hist[scan_begin + threadIdx.x] : gpair_sum_t(); + scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + + // Calculate gain + gpair_sum_t parent_sum = gpair_sum_t(node.sum_gradients); + + gpair_sum_t missing = parent_sum - feature_sum; + + bool missing_left = true; + const float null_gain = -FLT_MAX; + float gain = null_gain; + if (thread_active) { + gain = loss_chg_missing(bin, missing, parent_sum, node.root_gain, param, + missing_left); + } + + __syncthreads(); + + // Find thread with best gain + cub::KeyValuePair tuple(threadIdx.x, gain); + cub::KeyValuePair best = + max_reduce_t(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax()); + + __shared__ cub::KeyValuePair block_max; + if (threadIdx.x == 0) { + block_max = best; + } + + __syncthreads(); + + // Best thread updates split + if (threadIdx.x == block_max.key) { + int gidx = scan_begin + threadIdx.x; + float fvalue = + gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; + + gpair_sum_t left = missing_left ? bin + missing : bin; + gpair_sum_t right = parent_sum - left; + + best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, + left, right, param); + } + __syncthreads(); } - - int LevelSize(int depth) { return n_bins * n_nodes_level(depth); } -}; +} template -__global__ void find_split_kernel( - const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth, - uint64_t n_features, int n_bins, DeviceNodeStats* d_nodes, - int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map, - GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp, - bool colsample, int* d_feature_flags) { +__global__ void evaluate_split_kernel( + const gpair_sum_t* d_hist, int nidx, uint64_t n_features, + DeviceNodeStats nodes, const int* d_feature_segments, + const float* d_fidx_min_map, const float* d_gidx_fvalue_map, + GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split) { typedef cub::KeyValuePair ArgMaxT; typedef cub::BlockScan BlockScanT; typedef cub::BlockReduce MaxReduceT; + typedef cub::BlockReduce SumReduceT; union TempStorage { @@ -113,143 +135,367 @@ __global__ void find_split_kernel( }; __shared__ cub::Uninitialized uninitialized_split; - DeviceSplitCandidate& split = uninitialized_split.Alias(); - __shared__ cub::Uninitialized uninitialized_sum; - gpair_sum_t& shared_sum = uninitialized_sum.Alias(); - __shared__ ArgMaxT block_max; + DeviceSplitCandidate& best_split = uninitialized_split.Alias(); __shared__ TempStorage temp_storage; if (threadIdx.x == 0) { - split = DeviceSplitCandidate(); + best_split = DeviceSplitCandidate(); } __syncthreads(); - // below two are for accessing full-sized node list stored on each device - // always one block per node, BLOCK_THREADS threads per block - int level_node_idx = blockIdx.x + nodes_offset_device; - int node_idx = n_nodes(depth - 1) + level_node_idx; + auto fidx = blockIdx.x; + EvaluateFeature( + fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map, + &best_split, nodes, gpu_param, &temp_storage); - for (int fidx = 0; fidx < n_features; fidx++) { - if (colsample && d_feature_flags[fidx] == 0) continue; + __syncthreads(); - int begin = d_feature_segments[level_node_idx * n_features + fidx]; - int end = d_feature_segments[level_node_idx * n_features + fidx + 1]; - - gpair_sum_t feature_sum = gpair_sum_t(); - for (int reduce_begin = begin; reduce_begin < end; - reduce_begin += BLOCK_THREADS) { - bool thread_active = reduce_begin + threadIdx.x < end; - // Scan histogram - gpair_sum_t bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x] - : gpair_sum_t(); - - feature_sum += - SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum()); - } - - if (threadIdx.x == 0) { - shared_sum = feature_sum; - } - // __syncthreads(); // no need to synch because below there is a Scan - - auto prefix_op = SumCallbackOp(); - for (int scan_begin = begin; scan_begin < end; - scan_begin += BLOCK_THREADS) { - bool thread_active = scan_begin + threadIdx.x < end; - gpair_sum_t bin = thread_active ? d_level_hist[scan_begin + threadIdx.x] - : gpair_sum_t(); - - BlockScanT(temp_storage.scan) - .ExclusiveScan(bin, bin, cub::Sum(), prefix_op); - - // Calculate gain - gpair_sum_t parent_sum = gpair_sum_t(d_nodes[node_idx].sum_gradients); - float parent_gain = d_nodes[node_idx].root_gain; - - gpair_sum_t missing = parent_sum - shared_sum; - - bool missing_left; - float gain = thread_active - ? loss_chg_missing(bin, missing, parent_sum, parent_gain, - gpu_param, missing_left) - : -FLT_MAX; - __syncthreads(); - - // Find thread with best gain - ArgMaxT tuple(threadIdx.x, gain); - ArgMaxT best = - MaxReduceT(temp_storage.max_reduce).Reduce(tuple, cub::ArgMax()); - - if (threadIdx.x == 0) { - block_max = best; - } - - __syncthreads(); - - // Best thread updates split - if (threadIdx.x == block_max.key) { - float fvalue; - int gidx = (scan_begin - (level_node_idx * n_bins)) + threadIdx.x; - if (threadIdx.x == 0 && - begin == scan_begin) { // check at start of first tile - fvalue = d_fidx_min_map[fidx]; - } else { - fvalue = d_gidx_fvalue_map[gidx - 1]; - } - - gpair_sum_t left = missing_left ? bin + missing : bin; - gpair_sum_t right = parent_sum - left; - - split.Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, - left, right, gpu_param); - } - __syncthreads(); - } // end scan - } // end over features - - // Create node - if (threadIdx.x == 0 && split.IsValid()) { - d_nodes[node_idx].SetSplit(split); - - DeviceNodeStats& left_child = d_nodes[left_child_nidx(node_idx)]; - DeviceNodeStats& right_child = d_nodes[right_child_nidx(node_idx)]; - bool& left_child_smallest = d_left_child_smallest_temp[node_idx]; - left_child = - DeviceNodeStats(split.left_sum, left_child_nidx(node_idx), gpu_param); - - right_child = - DeviceNodeStats(split.right_sum, right_child_nidx(node_idx), gpu_param); - - // Record smallest node - if (split.left_sum.GetHess() <= split.right_sum.GetHess()) { - left_child_smallest = true; - } else { - left_child_smallest = false; - } + if (threadIdx.x == 0) { + // Record best loss + d_split[fidx] = best_split; } } + +// Find a gidx value for a given feature otherwise return -1 if not found +template +__device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data, + int fidx_begin, int fidx_end) { + bst_uint previous_middle = UINT32_MAX; + while (end != begin) { + auto middle = begin + (end - begin) / 2; + if (middle == previous_middle) { + break; + } + previous_middle = middle; + + auto gidx = data[middle]; + + if (gidx >= fidx_begin && gidx < fidx_end) { + return gidx; + } else if (gidx < fidx_begin) { + begin = middle; + } else { + end = middle; + } + } + // Value is missing + return -1; +} + +struct DeviceHistogram { + dh::bulk_allocator ba; + dh::dvec data; + int n_bins; + void Init(int device_idx, int max_nodes, int n_bins, bool silent) { + this->n_bins = n_bins; + ba.allocate(device_idx, silent, &data, size_t(max_nodes) * size_t(n_bins)); + } + + void Reset() { data.fill(gpair_sum_t()); } + gpair_sum_t* GetHistPtr(int nidx) { return data.data() + nidx * n_bins; } + + void PrintNidx(int nidx) const { + auto h_data = data.as_vector(); + std::cout << "nidx " << nidx << ":\n"; + for (int i = n_bins * nidx; i < n_bins * (nidx + 1); i++) { + std::cout << h_data[i] << " "; + } + std::cout << "\n"; + } +}; + +// Manage memory for a single GPU +struct DeviceShard { + struct Segment { + size_t begin; + size_t end; + + Segment() : begin(0), end(0) {} + + Segment(size_t begin, size_t end) : begin(begin), end(end) { + CHECK_GE(end, begin); + } + size_t Size() const { return end - begin; } + }; + + int device_idx; + int normalised_device_idx; // Device index counting from param.gpu_id + dh::bulk_allocator ba; + dh::dvec gidx_buffer; + dh::dvec gpair; + dh::dvec2 ridx; // Row index relative to this shard + dh::dvec2 position; + std::vector ridx_segments; + dh::dvec feature_segments; + dh::dvec gidx_fvalue_map; + dh::dvec min_fvalue; + std::vector node_sum_gradients; + common::CompressedIterator gidx; + int row_stride; + bst_uint row_begin_idx; // The row offset for this shard + bst_uint row_end_idx; + bst_uint n_rows; + int n_bins; + int null_gidx_value; + DeviceHistogram hist; + TrainParam param; + + int64_t* tmp_pinned; // Small amount of staging memory + + std::vector streams; + + dh::CubMemory temp_memory; + + DeviceShard(int device_idx, int normalised_device_idx, + const common::GHistIndexMatrix& gmat, bst_uint row_begin, + bst_uint row_end, int n_bins, TrainParam param) + : device_idx(device_idx), + normalised_device_idx(normalised_device_idx), + row_begin_idx(row_begin), + row_end_idx(row_end), + n_rows(row_end - row_begin), + n_bins(n_bins), + null_gidx_value(n_bins), + param(param) { + // Convert to ELLPACK matrix representation + int max_elements_row = 0; + for (auto i = row_begin; i < row_end; i++) { + max_elements_row = + (std::max)(max_elements_row, + static_cast(gmat.row_ptr[i + 1] - gmat.row_ptr[i])); + } + row_stride = max_elements_row; + std::vector ellpack_matrix(row_stride * n_rows, null_gidx_value); + + for (auto i = row_begin; i < row_end; i++) { + int row_count = 0; + for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { + ellpack_matrix[(i - row_begin) * row_stride + row_count] = + gmat.index[j]; + row_count++; + } + } + + // Allocate + int num_symbols = n_bins + 1; + size_t compressed_size_bytes = + common::CompressedBufferWriter::CalculateBufferSize( + ellpack_matrix.size(), num_symbols); + int max_nodes = + param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth); + ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, + &gpair, n_rows, &ridx, n_rows, &position, n_rows, + &feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map, + gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size()); + gidx_fvalue_map = gmat.cut->cut; + min_fvalue = gmat.cut->min_val; + feature_segments = gmat.cut->row_ptr; + + node_sum_gradients.resize(max_nodes); + ridx_segments.resize(max_nodes); + + // Compress gidx + common::CompressedBufferWriter cbw(num_symbols); + std::vector host_buffer(gidx_buffer.size()); + cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end()); + gidx_buffer = host_buffer; + gidx = + common::CompressedIterator(gidx_buffer.data(), num_symbols); + + common::CompressedIterator ci_host(host_buffer.data(), + num_symbols); + + // Init histogram + hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent); + + dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); + } + + ~DeviceShard() { + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + dh::safe_cuda(cudaFreeHost(tmp_pinned)); + } + + // Get vector of at least n initialised streams + std::vector& GetStreams(int n) { + if (n > streams.size()) { + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + + streams.clear(); + streams.resize(n); + + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamCreate(&stream)); + } + } + + return streams; + } + + // Reset values for each update iteration + void Reset(const std::vector& host_gpair) { + dh::safe_cuda(cudaSetDevice(device_idx)); + position.current_dvec().fill(0); + std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), + bst_gpair()); + + thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend()); + + std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); + ridx_segments.front() = Segment(0, ridx.size()); + this->gpair.copy(host_gpair.begin() + row_begin_idx, + host_gpair.begin() + row_end_idx); + subsample_gpair(&gpair, param.subsample, row_begin_idx); + hist.Reset(); + } + + void BuildHist(int nidx) { + auto segment = ridx_segments[nidx]; + auto d_node_hist = hist.GetHistPtr(nidx); + auto d_gidx = gidx; + auto d_ridx = ridx.current(); + auto d_gpair = gpair.data(); + auto row_stride = this->row_stride; + auto null_gidx_value = this->null_gidx_value; + auto n_elements = segment.Size() * row_stride; + + dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) { + int ridx = d_ridx[(idx / row_stride) + segment.begin]; + int gidx = d_gidx[ridx * row_stride + idx % row_stride]; + + if (gidx != null_gidx_value) { + AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]); + } + }); + } + void SubtractionTrick(int nidx_parent, int nidx_histogram, + int nidx_subtraction) { + auto d_node_hist_parent = hist.GetHistPtr(nidx_parent); + auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram); + auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction); + + dh::launch_n(device_idx, hist.n_bins, [=] __device__(size_t idx) { + d_node_hist_subtraction[idx] = + d_node_hist_parent[idx] - d_node_hist_histogram[idx]; + }); + } + + __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { + unsigned ballot = __ballot(val == left_nidx); + if (threadIdx.x % 32 == 0) { + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT + } + } + + void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx, + int split_gidx, bool default_dir_left, bool is_dense, + int fidx_begin, int fidx_end) { + dh::safe_cuda(cudaSetDevice(device_idx)); + temp_memory.LazyAllocate(sizeof(int64_t)); + auto d_left_count = temp_memory.Pointer(); + dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t))); + auto segment = ridx_segments[nidx]; + auto d_ridx = ridx.current(); + auto d_position = position.current(); + auto d_gidx = gidx; + auto row_stride = this->row_stride; + dh::launch_n<1, 512>( + device_idx, segment.Size(), [=] __device__(bst_uint idx) { + idx += segment.begin; + auto ridx = d_ridx[idx]; + auto row_begin = row_stride * ridx; + auto row_end = row_begin + row_stride; + auto gidx = -1; + if (is_dense) { + gidx = d_gidx[row_begin + fidx]; + } else { + gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, + fidx_end); + } + + int position; + if (gidx >= 0) { + // Feature is found + position = gidx <= split_gidx ? left_nidx : right_nidx; + } else { + // Feature is missing + position = default_dir_left ? left_nidx : right_nidx; + } + + CountLeft(d_left_count, position, left_nidx); + d_position[idx] = position; + }); + + dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count, sizeof(int64_t), + cudaMemcpyDeviceToHost)); + auto left_count = *tmp_pinned; + + SortPosition(segment, left_nidx, right_nidx); + // dh::safe_cuda(cudaStreamSynchronize(stream)); + ridx_segments[left_nidx] = + Segment(segment.begin, segment.begin + left_count); + ridx_segments[right_nidx] = + Segment(segment.begin + left_count, segment.end); + } + + void SortPosition(const Segment& segment, int left_nidx, int right_nidx) { + int min_bits = 0; + int max_bits = static_cast( + std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1))); + + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, position.current() + segment.begin, + position.other() + segment.begin, ridx.current() + segment.begin, + ridx.other() + segment.begin, segment.Size(), min_bits, max_bits); + + temp_memory.LazyAllocate(temp_storage_bytes); + + cub::DeviceRadixSort::SortPairs( + temp_memory.d_temp_storage, temp_memory.temp_storage_bytes, + position.current() + segment.begin, position.other() + segment.begin, + ridx.current() + segment.begin, ridx.other() + segment.begin, + segment.Size(), min_bits, max_bits); + dh::safe_cuda(cudaMemcpy( + position.current() + segment.begin, position.other() + segment.begin, + segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice)); + dh::safe_cuda(cudaMemcpy( + ridx.current() + segment.begin, ridx.other() + segment.begin, + segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice)); + } +}; + class GPUHistMaker : public TreeUpdater { public: - GPUHistMaker() - : initialised(false), - is_dense(false), - p_last_fmat_(nullptr), - prediction_cache_initialised(false) {} + struct ExpandEntry; + + GPUHistMaker() : initialised(false) {} ~GPUHistMaker() {} void Init( const std::vector>& args) override { param.InitAllowUnknown(args); - CHECK(param.max_depth < 16) << "Tree depth too large."; - CHECK(param.max_depth != 0) << "Tree depth cannot be 0."; - CHECK(param.grow_policy != TrainParam::kLossGuide) - << "Loss guided growth policy not supported. Use CPU algorithm."; - this->param = param; - CHECK(param.n_gpus != 0) << "Must have at least one device"; + n_devices = param.n_gpus; + + dh::check_compute_capability(); + + if (param.grow_policy == TrainParam::kLossGuide) { + qexpand_.reset(new ExpandQueue(loss_guide)); + } else { + qexpand_.reset(new ExpandQueue(depth_wise)); + } + + monitor.Init("updater_gpu_hist", param.debug_verbose); } void Update(const std::vector& gpair, DMatrix* dmat, const std::vector& trees) override { + monitor.Start("Update"); GradStats::CheckInfo(dmat->info()); // rescale learning rate according to size of trees float lr = param.learning_rate; @@ -263,705 +509,390 @@ class GPUHistMaker : public TreeUpdater { LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; } param.learning_rate = lr; + monitor.Stop("Update"); } - void InitData(const std::vector& gpair, DMatrix& fmat, // NOLINT - const RegTree& tree) { - common::Timer time1; - // set member num_rows and n_devices for rest of GPUHistBuilder members - info = &fmat.info(); - CHECK(info->num_row < std::numeric_limits::max()); - num_rows = static_cast(info->num_row); - n_devices = dh::n_devices(param.n_gpus, num_rows); + void InitDataOnce(DMatrix* dmat) { + info = &dmat->info(); + monitor.Start("Quantiles"); + hmat_.Init(dmat, param.max_bin); + gmat_.cut = &hmat_; + gmat_.Init(dmat); + monitor.Stop("Quantiles"); + n_bins = hmat_.row_ptr.back(); - if (!initialised) { - // Check gradients are within acceptable size range - CheckGradientMax(gpair); + int n_devices = dh::n_devices(param.n_gpus, info->num_row); - // Check compute capability is high enough - dh::check_compute_capability(); + bst_uint row_begin = 0; + bst_uint shard_size = + std::ceil(static_cast(info->num_row) / n_devices); - // reset static timers used across iterations - cpu_init_time = 0; - gpu_init_time = 0; - cpu_time.Reset(); - gpu_time = 0; - - // set dList member - dList.resize(n_devices); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices(); - dList[d_idx] = device_idx; - } - - // initialize nccl - reducer.Init(dList); - - is_dense = info->num_nonzero == info->num_col * info->num_row; - common::Timer time0; - hmat_.Init(&fmat, param.max_bin); - cpu_init_time += time0.ElapsedSeconds(); - if (param.debug_verbose) { // Only done once for each training session - LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init " - << time0.ElapsedSeconds() << " sec"; - fflush(stdout); - } - time0.Reset(); - - gmat_.cut = &hmat_; - cpu_init_time += time0.ElapsedSeconds(); - if (param.debug_verbose) { // Only done once for each training session - LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut " - << time0.ElapsedSeconds() << " sec"; - fflush(stdout); - } - time0.Reset(); - - gmat_.Init(&fmat); - cpu_init_time += time0.ElapsedSeconds(); - if (param.debug_verbose) { // Only done once for each training session - LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() " - << time0.ElapsedSeconds() << " sec"; - fflush(stdout); - } - time0.Reset(); - - if (param.debug_verbose) { // Only done once for each training session - LOG(CONSOLE) - << "[GPU Plug-in] CPU Time for hmat_.Init, gmat_.cut, gmat_.Init " - << cpu_init_time << " sec"; - fflush(stdout); - } - - int n_bins = static_cast(hmat_.row_ptr.back()); - int n_features = static_cast(hmat_.row_ptr.size() - 1); - - // deliniate data onto multiple gpus - device_row_segments.push_back(0); - device_element_segments.push_back(0); - bst_uint offset = 0; - bst_uint shard_size = static_cast( - std::ceil(static_cast(num_rows) / n_devices)); - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - offset += shard_size; - offset = std::min(offset, num_rows); - device_row_segments.push_back(offset); - device_element_segments.push_back(gmat_.row_ptr[offset]); - } - - // Build feature segments - std::vector h_feature_segments; - for (int node = 0; node < n_nodes_level(param.max_depth - 1); node++) { - for (int fidx = 0; fidx < n_features; fidx++) { - h_feature_segments.push_back(hmat_.row_ptr[fidx] + node * n_bins); - } - } - h_feature_segments.push_back(n_nodes_level(param.max_depth - 1) * n_bins); - - // Construct feature map - std::vector h_gidx_feature_map(n_bins); - for (int fidx = 0; fidx < n_features; fidx++) { - for (auto i = hmat_.row_ptr[fidx]; i < hmat_.row_ptr[fidx + 1]; i++) { - h_gidx_feature_map[i] = fidx; - } - } - - int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins; - - // allocate unique common data that reside on master device (NOTE: None - // currently) - // int master_device=dList[0]; - // ba.allocate(master_device, ); - - // allocate vectors across all devices - temp_memory.resize(n_devices); - hist_vec.resize(n_devices); - nodes.resize(n_devices); - left_child_smallest.resize(n_devices); - feature_flags.resize(n_devices); - fidx_min_map.resize(n_devices); - feature_segments.resize(n_devices); - prediction_cache.resize(n_devices); - position.resize(n_devices); - position_tmp.resize(n_devices); - device_matrix.resize(n_devices); - device_gpair.resize(n_devices); - gidx_feature_map.resize(n_devices); - gidx_fvalue_map.resize(n_devices); - - // num_rows_segment: for sharding rows onto gpus for splitting data - // num_elements_segment: for sharding rows (of elements) onto gpus for - // splitting data - // max_num_nodes_device: for sharding nodes onto gpus for split finding - // All other variables have full copy on gpu, with copy either being - // identical or just current portion (like for histogram) before - // AllReduce - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - bst_uint num_rows_segment = - device_row_segments[d_idx + 1] - device_row_segments[d_idx]; - bst_ulong num_elements_segment = - device_element_segments[d_idx + 1] - device_element_segments[d_idx]; - - // ensure allocation doesn't overflow - size_t hist_size = static_cast(n_nodes(param.max_depth - 1)) * - static_cast(n_bins); - size_t nodes_size = static_cast(n_nodes(param.max_depth)); - size_t hmat_size = static_cast(hmat_.min_val.size()); - size_t buffer_size = static_cast( - common::CompressedBufferWriter::CalculateBufferSize( - static_cast(num_elements_segment), - static_cast(n_bins))); - - ba.allocate( - device_idx, param.silent, &(hist_vec[d_idx].data), hist_size, - &nodes[d_idx], n_nodes(param.max_depth), - &left_child_smallest[d_idx], nodes_size, &feature_flags[d_idx], - n_features, // may change but same on all devices - &fidx_min_map[d_idx], - hmat_size, // constant and same on all devices - &feature_segments[d_idx], - h_feature_segments.size(), // constant and same on all devices - &prediction_cache[d_idx], num_rows_segment, &position[d_idx], - num_rows_segment, &position_tmp[d_idx], num_rows_segment, - &device_gpair[d_idx], num_rows_segment, - &device_matrix[d_idx].gidx_buffer, - buffer_size, // constant and same on all devices - &device_matrix[d_idx].row_ptr, num_rows_segment + 1, - &gidx_feature_map[d_idx], - n_bins, // constant and same on all devices - &gidx_fvalue_map[d_idx], - hmat_.cut.size()); // constant and same on all devices - - // Copy Host to Device (assumes comes after ba.allocate that sets - // device) - device_matrix[d_idx].Init( - device_idx, gmat_, device_element_segments[d_idx], - device_element_segments[d_idx + 1], device_row_segments[d_idx], - device_row_segments[d_idx + 1], n_bins); - gidx_feature_map[d_idx] = h_gidx_feature_map; - gidx_fvalue_map[d_idx] = hmat_.cut; - feature_segments[d_idx] = h_feature_segments; - fidx_min_map[d_idx] = hmat_.min_val; - - // Initialize, no copy - hist_vec[d_idx].Init(n_bins); // init host object - prediction_cache[d_idx].fill(0); // init device object (assumes comes - // after ba.allocate that sets device) - feature_flags[d_idx].fill( - 1); // init device object (assumes comes after - // ba.allocate that sets device) - } + std::vector dList(n_devices); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices(); + dList[d_idx] = device_idx; } - // copy or init to do every iteration - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); + reducer.Init(dList); - nodes[d_idx].fill(DeviceNodeStats()); - - position[d_idx].fill(0); - - device_gpair[d_idx].copy(gpair.begin() + device_row_segments[d_idx], - gpair.begin() + device_row_segments[d_idx + 1]); - - subsample_gpair(&device_gpair[d_idx], param.subsample, - device_row_segments[d_idx]); - - hist_vec[d_idx].Reset(device_idx); - - // left_child_smallest and left_child_smallest_temp don't need to be - // initialized + // Partition input matrix into row segments + std::vector row_segments; + shards.resize(n_devices); + row_segments.push_back(0); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + bst_uint row_end = + std::min(static_cast(row_begin + shard_size), info->num_row); + row_segments.push_back(row_end); + row_begin = row_end; } - dh::synchronize_n_devices(n_devices, dList); - - if (!initialised) { - gpu_init_time = time1.ElapsedSeconds() - cpu_init_time; - gpu_time = -cpu_init_time; - if (param.debug_verbose) { // Only done once for each training session - LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First " - "Call to InitData() " - << gpu_init_time << " sec"; - fflush(stdout); - } + // Create device shards + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id] = std::unique_ptr( + new DeviceShard(dList[cpu_thread_id], cpu_thread_id, gmat_, + row_segments[cpu_thread_id], + row_segments[cpu_thread_id + 1], n_bins, param)); } - p_last_fmat_ = &fmat; - initialised = true; } - void BuildHist(int depth) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - size_t begin = device_element_segments[d_idx]; - size_t end = device_element_segments[d_idx + 1]; - size_t row_begin = device_row_segments[d_idx]; - size_t row_end = device_row_segments[d_idx + 1]; + void InitData(const std::vector& gpair, DMatrix* dmat, + const RegTree& tree) { + monitor.Start("InitDataOnce"); + if (!initialised) { + this->InitDataOnce(dmat); + } + monitor.Stop("InitDataOnce"); - auto d_gidx = device_matrix[d_idx].gidx; - auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin(); - auto d_position = position[d_idx].data(); - auto d_gpair = device_gpair[d_idx].data(); - auto d_left_child_smallest = left_child_smallest[d_idx].data(); - auto hist_builder = hist_vec[d_idx].GetBuilder(); - dh::TransformLbs( - device_idx, &temp_memory[d_idx], end - begin, d_row_ptr, - row_end - row_begin, is_dense, - [=] __device__(size_t local_idx, int local_ridx) { - int nidx = d_position[local_ridx]; // OPTMARK: latency - if (!is_active(nidx, depth)) return; + column_sampler.Init(info->num_col, param); - // Only increment smallest node - bool is_smallest = (d_left_child_smallest[parent_nidx(nidx)] && - is_left_child(nidx)) || - (!d_left_child_smallest[parent_nidx(nidx)] && - !is_left_child(nidx)); - if (!is_smallest && depth > 0) return; + // Copy gpair & reset memory + monitor.Start("InitDataReset"); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id]->Reset(gpair); + } + monitor.Stop("InitDataReset"); + } - int gidx = d_gidx[local_idx]; - bst_gpair gpair = d_gpair[local_ridx]; - - hist_builder.Add(gpair, gidx, - nidx); // OPTMARK: This is slow, could use - // shared memory or cache results - // intead of writing to global - // memory every time in atomic way. - }); + void AllReduceHist(int nidx) { + for (auto& shard : shards) { + auto d_node_hist = shard->hist.GetHistPtr(nidx); + reducer.AllReduceSum( + shard->normalised_device_idx, + reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + n_bins * (sizeof(gpair_sum_t) / sizeof(gpair_sum_t::value_t))); } - dh::synchronize_n_devices(n_devices, dList); - - // time.printElapsed("Add Time"); - - // (in-place) reduce each element of histogram (for only current level) - // across multiple gpus - // TODO(JCM): use out of place with pre-allocated buffer, but then have to - // copy - // back on device - // fprintf(stderr,"sizeof(bst_gpair)/sizeof(float)=%d\n",sizeof(bst_gpair)/sizeof(float)); - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - reducer.AllReduceSum(device_idx, - reinterpret_cast( - hist_vec[d_idx].GetLevelPtr(depth)), - reinterpret_cast( - hist_vec[d_idx].GetLevelPtr(depth)), - hist_vec[d_idx].LevelSize(depth) * - sizeof(gpair_sum_t) / - sizeof(gpair_sum_t::value_t)); - } reducer.Synchronize(); + } - // time.printElapsed("Reduce-Add Time"); + void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) { + size_t left_node_max_elements = 0; + size_t right_node_max_elements = 0; + for (auto& shard : shards) { + left_node_max_elements = (std::max)( + left_node_max_elements, shard->ridx_segments[nidx_left].Size()); + right_node_max_elements = (std::max)( + right_node_max_elements, shard->ridx_segments[nidx_right].Size()); + } - // Subtraction trick (applied to all devices in same way -- to avoid doing - // on master and then Bcast) - if (depth > 0) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); + auto build_hist_nidx = nidx_left; + auto subtraction_trick_nidx = nidx_right; - auto hist_builder = hist_vec[d_idx].GetBuilder(); - auto d_left_child_smallest = left_child_smallest[d_idx].data(); - int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins; + if (right_node_max_elements < left_node_max_elements) { + build_hist_nidx = nidx_right; + subtraction_trick_nidx = nidx_left; + } - dh::launch_n(device_idx, n_sub_bins, [=] __device__(int idx) { - int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2); - bool left_smallest = d_left_child_smallest[parent_nidx(nidx)]; - if (left_smallest) { - nidx++; // If left is smallest switch to right child - } + for (auto& shard : shards) { + shard->BuildHist(build_hist_nidx); + } - int gidx = idx % hist_builder.n_bins; - gpair_sum_t parent = hist_builder.Get(gidx, parent_nidx(nidx)); - int other_nidx = left_smallest ? nidx - 1 : nidx + 1; - gpair_sum_t other = hist_builder.Get(gidx, other_nidx); - gpair_sum_t sub = parent - other; - hist_builder.Add( - bst_gpair(sub.GetGrad(), sub.GetHess()), gidx, - nidx); // OPTMARK: This is slow, could use shared - // memory or cache results intead of writing to - // global memory every time in atomic way. - }); + this->AllReduceHist(build_hist_nidx); + + for (auto& shard : shards) { + shard->SubtractionTrick(nidx_parent, build_hist_nidx, + subtraction_trick_nidx); + } + } + + // Returns best loss + std::vector EvaluateSplits( + const std::vector& nidx_set, RegTree* p_tree) { + auto columns = info->num_col; + std::vector best_splits(nidx_set.size()); + std::vector candidate_splits(nidx_set.size() * + columns); + // Use first device + auto& shard = shards.front(); + dh::safe_cuda(cudaSetDevice(shard->device_idx)); + shard->temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns * + nidx_set.size()); + auto d_split = shard->temp_memory.Pointer(); + + auto& streams = shard->GetStreams(static_cast(nidx_set.size())); + + // Use streams to process nodes concurrently + for (auto i = 0; i < nidx_set.size(); i++) { + auto nidx = nidx_set[i]; + DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param); + + const int BLOCK_THREADS = 256; + evaluate_split_kernel + <<>>( + shard->hist.GetHistPtr(nidx), nidx, info->num_col, node, + shard->feature_segments.data(), shard->min_fvalue.data(), + shard->gidx_fvalue_map.data(), GPUTrainingParam(param), + d_split + i * columns); + } + + dh::safe_cuda( + cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage, + sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), + cudaMemcpyDeviceToHost)); + + for (auto i = 0; i < nidx_set.size(); i++) { + auto nidx = nidx_set[i]; + DeviceSplitCandidate nidx_best; + for (auto fidx = 0; fidx < columns; fidx++) { + auto& candidate = candidate_splits[i * columns + fidx]; + if (column_sampler.ColumnUsed(candidate.findex, + p_tree->GetDepth(nidx))) { + nidx_best.Update(candidate_splits[i * columns + fidx], param); + } } - dh::synchronize_n_devices(n_devices, dList); + best_splits[i] = nidx_best; } - } -#define MIN_BLOCK_THREADS 128 -#define CHUNK_BLOCK_THREADS 128 -// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due -// to CUDA capability 35 and above requirement -// for Maximum number of threads per block -#define MAX_BLOCK_THREADS 512 - - void FindSplit(int depth) { - // Specialised based on max_bins - this->FindSplitSpecialize(depth, Int()); + return std::move(best_splits); } - template - void FindSplitSpecialize(int depth, Int) { - if (param.max_bin <= BLOCK_THREADS) { - LaunchFindSplit(depth); - } else { - this->FindSplitSpecialize(depth, - Int()); + void InitRoot(const std::vector& gpair, RegTree* p_tree) { + auto root_nidx = 0; + // Sum gradients + std::vector tmp_sums(shards.size()); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx)); + tmp_sums[cpu_thread_id] = + thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory), + shards[cpu_thread_id]->gpair.tbegin(), + shards[cpu_thread_id]->gpair.tend()); } + auto sum_gradient = + std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair()); + + // Generate root histogram + for (auto& shard : shards) { + shard->BuildHist(root_nidx); + } + + this->AllReduceHist(root_nidx); + + // Remember root stats + p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess(); + p_tree->stat(root_nidx).base_weight = CalcWeight(param, sum_gradient); + + // Store sum gradients + for (auto& shard : shards) { + shard->node_sum_gradients[root_nidx] = sum_gradient; + } + + // Generate first split + auto splits = this->EvaluateSplits({root_nidx}, p_tree); + qexpand_->push( + ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0)); } - void FindSplitSpecialize(int depth, Int) { - this->LaunchFindSplit(depth); - } + void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { + auto nidx = candidate.nid; + auto left_nidx = (*p_tree)[nidx].cleft(); + auto right_nidx = (*p_tree)[nidx].cright(); - template - void LaunchFindSplit(int depth) { - bool colsample = - param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; - - int num_nodes_device = n_nodes_level(depth); - const int GRID_SIZE = num_nodes_device; - - // all GPUs do same work - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - - int nodes_offset_device = 0; - find_split_kernel<<>>( - hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(), - depth, info->num_col, hmat_.row_ptr.back(), nodes[d_idx].data(), - nodes_offset_device, fidx_min_map[d_idx].data(), - gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), - left_child_smallest[d_idx].data(), colsample, - feature_flags[d_idx].data()); - } - - // NOTE: No need to syncrhonize with host as all above pure P2P ops or - // on-device ops - } - void InitFirstNode(const std::vector& gpair) { - // Perform asynchronous reduction on each gpu - std::vector device_sums(n_devices); -#pragma omp parallel for num_threads(n_devices) - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - auto begin = device_gpair[d_idx].tbegin(); - auto end = device_gpair[d_idx].tend(); - bst_gpair init = bst_gpair(); - auto binary_op = thrust::plus(); - device_sums[d_idx] = thrust::reduce(begin, end, init, binary_op); - } - - bst_gpair sum = bst_gpair(); - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - sum += device_sums[d_idx]; - } - - // Setup first node so all devices have same first node (here done same on - // all devices, or could have done one device and Bcast if worried about - // exact precision issues) - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - - auto d_nodes = nodes[d_idx].data(); - auto gpu_param = GPUTrainingParam(param); - - dh::launch_n(device_idx, 1, [=] __device__(int idx) { - bst_gpair sum_gradients = sum; - d_nodes[idx] = DeviceNodeStats(sum_gradients, 0, gpu_param); - }); - } - // synch all devices to host before moving on (No, can avoid because - // BuildHist calls another kernel in default stream) - // dh::synchronize_n_devices(n_devices, dList); - } - void UpdatePosition(int depth) { - if (is_dense) { - this->UpdatePositionDense(depth); - } else { - this->UpdatePositionSparse(depth); - } - } - void UpdatePositionDense(int depth) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - - auto d_position = position[d_idx].data(); - DeviceNodeStats* d_nodes = nodes[d_idx].data(); - auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); - auto d_gidx = device_matrix[d_idx].gidx; - auto n_columns = info->num_col; - size_t begin = device_row_segments[d_idx]; - size_t end = device_row_segments[d_idx + 1]; - - dh::launch_n(device_idx, end - begin, [=] __device__(size_t local_idx) { - int pos = d_position[local_idx]; - if (!is_active(pos, depth)) { - return; - } - DeviceNodeStats node = d_nodes[pos]; - - if (node.IsLeaf()) { - return; - } - - int gidx = d_gidx[local_idx * static_cast(n_columns) + - static_cast(node.fidx)]; - - float fvalue = d_gidx_fvalue_map[gidx]; - - if (fvalue <= node.fvalue) { - d_position[local_idx] = left_child_nidx(pos); - } else { - d_position[local_idx] = right_child_nidx(pos); - } - }); - } - dh::synchronize_n_devices(n_devices, dList); - // dh::safe_cuda(cudaDeviceSynchronize()); - } - - void UpdatePositionSparse(int depth) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - - auto d_position = position[d_idx].data(); - auto d_position_tmp = position_tmp[d_idx].data(); - DeviceNodeStats* d_nodes = nodes[d_idx].data(); - auto d_gidx_feature_map = gidx_feature_map[d_idx].data(); - auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); - auto d_gidx = device_matrix[d_idx].gidx; - auto d_row_ptr = device_matrix[d_idx].row_ptr.tbegin(); - - size_t row_begin = device_row_segments[d_idx]; - size_t row_end = device_row_segments[d_idx + 1]; - size_t element_begin = device_element_segments[d_idx]; - size_t element_end = device_element_segments[d_idx + 1]; - - // Update missing direction - dh::launch_n(device_idx, row_end - row_begin, - [=] __device__(int local_idx) { - int pos = d_position[local_idx]; - if (!is_active(pos, depth)) { - d_position_tmp[local_idx] = pos; - return; - } - - DeviceNodeStats node = d_nodes[pos]; - - if (node.IsLeaf()) { - d_position_tmp[local_idx] = pos; - return; - } else if (node.dir == LeftDir) { - d_position_tmp[local_idx] = pos * 2 + 1; - } else { - d_position_tmp[local_idx] = pos * 2 + 2; - } - }); - - // Update node based on fvalue where exists - // OPTMARK: This kernel is very inefficient for both compute and memory, - // dominated by memory dependency / access patterns - - dh::TransformLbs( - device_idx, &temp_memory[d_idx], element_end - element_begin, - d_row_ptr, row_end - row_begin, is_dense, - [=] __device__(size_t local_idx, int local_ridx) { - int pos = d_position[local_ridx]; - if (!is_active(pos, depth)) { - return; - } - - DeviceNodeStats node = d_nodes[pos]; - - if (node.IsLeaf()) { - return; - } - - int gidx = d_gidx[local_idx]; - int findex = - d_gidx_feature_map[gidx]; // OPTMARK: slowest global - // memory access, maybe setup - // position, gidx, etc. as - // combined structure? - - if (findex == node.fidx) { - float fvalue = d_gidx_fvalue_map[gidx]; - - if (fvalue <= node.fvalue) { - d_position_tmp[local_ridx] = left_child_nidx(pos); - } else { - d_position_tmp[local_ridx] = right_child_nidx(pos); - } - } - }); - position[d_idx] = position_tmp[d_idx]; - } - dh::synchronize_n_devices(n_devices, dList); - } - void ColSampleTree() { - if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return; - - feature_set_tree.resize(info->num_col); - std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); - feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree); - } - void ColSampleLevel() { - if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return; - - feature_set_level.resize(feature_set_tree.size()); - feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel); - std::vector h_feature_flags(info->num_col, 0); - for (auto fidx : feature_set_level) { - h_feature_flags[fidx] = 1; - } - - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - - feature_flags[d_idx] = h_feature_flags; - } - dh::synchronize_n_devices(n_devices, dList); - } - bool UpdatePredictionCache(const DMatrix* data, - std::vector* p_out_preds) override { - std::vector& out_preds = *p_out_preds; - - if (nodes.empty() || !p_last_fmat_ || data != p_last_fmat_) { - return false; - } - - if (!prediction_cache_initialised) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - size_t row_begin = device_row_segments[d_idx]; - size_t row_end = device_row_segments[d_idx + 1]; - - prediction_cache[d_idx].copy(out_preds.begin() + row_begin, - out_preds.begin() + row_end); + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + auto split_gidx = -1; + auto fidx = candidate.split.findex; + auto default_dir_left = candidate.split.dir == LeftDir; + auto fidx_begin = hmat_.row_ptr[fidx]; + auto fidx_end = hmat_.row_ptr[fidx + 1]; + for (auto i = fidx_begin; i < fidx_end; ++i) { + if (candidate.split.fvalue == hmat_.cut[i]) { + split_gidx = static_cast(i); } - prediction_cache_initialised = true; } - dh::synchronize_n_devices(n_devices, dList); - float eps = param.learning_rate; - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - size_t row_begin = device_row_segments[d_idx]; - size_t row_end = device_row_segments[d_idx + 1]; + auto is_dense = info->num_nonzero == info->num_row * info->num_col; - auto d_nodes = nodes[d_idx].data(); - auto d_position = position[d_idx].data(); - auto d_prediction_cache = prediction_cache[d_idx].data(); - - dh::launch_n(device_idx, prediction_cache[d_idx].size(), - [=] __device__(int local_idx) { - int pos = d_position[local_idx]; - d_prediction_cache[local_idx] += d_nodes[pos].weight * eps; - }); - - dh::safe_cuda( - cudaMemcpy(&out_preds[row_begin], prediction_cache[d_idx].data(), - prediction_cache[d_idx].size() * sizeof(bst_float), - cudaMemcpyDeviceToHost)); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id]->UpdatePosition(nidx, left_nidx, right_nidx, fidx, + split_gidx, default_dir_left, + is_dense, fidx_begin, fidx_end); } - dh::synchronize_n_devices(n_devices, dList); - - return true; } + + void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { + // Add new leaves + RegTree& tree = *p_tree; + tree.AddChilds(candidate.nid); + auto& parent = tree[candidate.nid]; + parent.set_split(candidate.split.findex, candidate.split.fvalue, + candidate.split.dir == LeftDir); + tree.stat(candidate.nid).loss_chg = candidate.split.loss_chg; + + // Configure left child + auto left_weight = CalcWeight(param, candidate.split.left_sum); + tree[parent.cleft()].set_leaf(left_weight * param.learning_rate, 0); + tree.stat(parent.cleft()).base_weight = left_weight; + tree.stat(parent.cleft()).sum_hess = candidate.split.left_sum.GetHess(); + + // Configure right child + auto right_weight = CalcWeight(param, candidate.split.right_sum); + tree[parent.cright()].set_leaf(right_weight * param.learning_rate, 0); + tree.stat(parent.cright()).base_weight = right_weight; + tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess(); + // Store sum gradients + for (auto& shard : shards) { + shard->node_sum_gradients[parent.cleft()] = candidate.split.left_sum; + shard->node_sum_gradients[parent.cright()] = candidate.split.right_sum; + } + this->UpdatePosition(candidate, p_tree); + } + void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, RegTree* p_tree) { - common::Timer time0; + // Temporarily store number of threads so we can change it back later + int nthread = omp_get_max_threads(); - this->InitData(gpair, *p_fmat, *p_tree); - this->InitFirstNode(gpair); - this->ColSampleTree(); + auto& tree = *p_tree; - for (int depth = 0; depth < param.max_depth; depth++) { - this->ColSampleLevel(); - this->BuildHist(depth); - this->FindSplit(depth); - this->UpdatePosition(depth); + monitor.Start("InitData"); + this->InitData(gpair, p_fmat, *p_tree); + monitor.Stop("InitData"); + monitor.Start("InitRoot"); + this->InitRoot(gpair, p_tree); + monitor.Stop("InitRoot"); + + auto timestamp = qexpand_->size(); + auto num_leaves = 1; + + while (!qexpand_->empty()) { + auto candidate = qexpand_->top(); + qexpand_->pop(); + if (!candidate.IsValid(param, num_leaves)) continue; + // std::cout << candidate; + monitor.Start("ApplySplit"); + this->ApplySplit(candidate, p_tree); + monitor.Stop("ApplySplit"); + num_leaves++; + + auto left_child_nidx = tree[candidate.nid].cleft(); + auto right_child_nidx = tree[candidate.nid].cright(); + + // Only create child entries if needed + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), + num_leaves)) { + monitor.Start("BuildHist"); + this->BuildHistLeftRight(candidate.nid, left_child_nidx, + right_child_nidx); + monitor.Stop("BuildHist"); + + monitor.Start("EvaluateSplits"); + auto splits = + this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree); + qexpand_->push(ExpandEntry(left_child_nidx, + tree.GetDepth(left_child_nidx), splits[0], + timestamp++)); + qexpand_->push(ExpandEntry(right_child_nidx, + tree.GetDepth(right_child_nidx), splits[1], + timestamp++)); + monitor.Stop("EvaluateSplits"); + } } - // done with multi-GPU, pass back result from master to tree on host - int master_device = dList[0]; - dh::safe_cuda(cudaSetDevice(master_device)); - dense2sparse_tree(p_tree, nodes[0], param); - - gpu_time += time0.ElapsedSeconds(); - - if (param.debug_verbose) { - LOG(CONSOLE) - << "[GPU Plug-in] Cumulative GPU Time excluding initial time " - << (gpu_time - gpu_init_time) << " sec"; - fflush(stdout); - } - - if (param.debug_verbose) { - LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time " - << cpu_time.ElapsedSeconds() << " sec"; - LOG(CONSOLE) - << "[GPU Plug-in] Cumulative CPU Time excluding initial time " - << (cpu_time.ElapsedSeconds() - cpu_init_time - gpu_time) << " sec"; - fflush(stdout); - } + // Reset omp num threads + omp_set_num_threads(nthread); } - protected: + struct ExpandEntry { + int nid; + int depth; + DeviceSplitCandidate split; + uint64_t timestamp; + ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split, + uint64_t timestamp) + : nid(nid), depth(depth), split(split), timestamp(timestamp) {} + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= rt_eps) return false; + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + return true; + } + + static bool ChildIsValid(const TrainParam& param, int depth, + int num_leaves) { + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } + }; + + inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.depth == rhs.depth) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.depth > rhs.depth; // favor small depth + } + } + inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.split.loss_chg == rhs.split.loss_chg) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg + } + } TrainParam param; - // std::unique_ptr builder; common::HistCutMatrix hmat_; common::GHistIndexMatrix gmat_; MetaInfo* info; bool initialised; - bool is_dense; - const DMatrix* p_last_fmat_; - bool prediction_cache_initialised; - - dh::bulk_allocator ba; - - std::vector feature_set_tree; - std::vector feature_set_level; - - bst_uint num_rows; int n_devices; + int n_bins; - // below vectors are for each devices used - std::vector dList; - std::vector device_row_segments; - std::vector device_element_segments; - - std::vector temp_memory; - std::vector hist_vec; - std::vector> nodes; - std::vector> left_child_smallest; - std::vector> feature_flags; - std::vector> fidx_min_map; - std::vector> feature_segments; - std::vector> prediction_cache; - std::vector> position; - std::vector> position_tmp; - std::vector device_matrix; - std::vector> device_gpair; - std::vector> gidx_feature_map; - std::vector> gidx_fvalue_map; - + std::vector> shards; + ColumnSampler column_sampler; + typedef std::priority_queue, + std::function> + ExpandQueue; + std::unique_ptr qexpand_; + common::Monitor monitor; dh::AllReducer reducer; - - double cpu_init_time; - double gpu_init_time; - common::Timer cpu_time; - double gpu_time; }; -XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") +XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, + "grow_gpu_hist") .describe("Grow tree with GPU.") .set_body([]() { return new GPUHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_gpu_hist_experimental.cu b/src/tree/updater_gpu_hist_experimental.cu deleted file mode 100644 index 67e71a3f1..000000000 --- a/src/tree/updater_gpu_hist_experimental.cu +++ /dev/null @@ -1,899 +0,0 @@ -/*! - * Copyright 2017 XGBoost contributors - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "../common/compressed_iterator.h" -#include "../common/device_helpers.cuh" -#include "../common/hist_util.h" -#include "../common/timer.h" -#include "param.h" -#include "updater_gpu_common.cuh" - -namespace xgboost { -namespace tree { - -DMLC_REGISTRY_FILE_TAG(updater_gpu_hist_experimental); - -typedef bst_gpair_precise gpair_sum_t; - -template -__device__ gpair_sum_t ReduceFeature(const gpair_sum_t* begin, - const gpair_sum_t* end, - temp_storage_t* temp_storage) { - __shared__ cub::Uninitialized uninitialized_sum; - gpair_sum_t& shared_sum = uninitialized_sum.Alias(); - - gpair_sum_t local_sum = gpair_sum_t(); - for (auto itr = begin; itr < end; itr += BLOCK_THREADS) { - bool thread_active = itr + threadIdx.x < end; - // Scan histogram - gpair_sum_t bin = thread_active ? *(itr + threadIdx.x) : gpair_sum_t(); - - local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum()); - } - - if (threadIdx.x == 0) { - shared_sum = local_sum; - } - __syncthreads(); - - return shared_sum; -} - -template -__device__ void EvaluateFeature(int fidx, const gpair_sum_t* hist, - const int* feature_segments, float min_fvalue, - const float* gidx_fvalue_map, - DeviceSplitCandidate* best_split, - const DeviceNodeStats& node, - const GPUTrainingParam& param, - temp_storage_t* temp_storage) { - int gidx_begin = feature_segments[fidx]; - int gidx_end = feature_segments[fidx + 1]; - - gpair_sum_t feature_sum = ReduceFeature( - hist + gidx_begin, hist + gidx_end, temp_storage); - - auto prefix_op = SumCallbackOp(); - for (int scan_begin = gidx_begin; scan_begin < gidx_end; - scan_begin += BLOCK_THREADS) { - bool thread_active = scan_begin + threadIdx.x < gidx_end; - - gpair_sum_t bin = - thread_active ? hist[scan_begin + threadIdx.x] : gpair_sum_t(); - scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); - - // Calculate gain - gpair_sum_t parent_sum = gpair_sum_t(node.sum_gradients); - - gpair_sum_t missing = parent_sum - feature_sum; - - bool missing_left = true; - const float null_gain = -FLT_MAX; - float gain = null_gain; - if (thread_active) { - gain = loss_chg_missing(bin, missing, parent_sum, node.root_gain, param, - missing_left); - } - - __syncthreads(); - - // Find thread with best gain - cub::KeyValuePair tuple(threadIdx.x, gain); - cub::KeyValuePair best = - max_reduce_t(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax()); - - __shared__ cub::KeyValuePair block_max; - if (threadIdx.x == 0) { - block_max = best; - } - - __syncthreads(); - - // Best thread updates split - if (threadIdx.x == block_max.key) { - int gidx = scan_begin + threadIdx.x; - float fvalue = - gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; - - gpair_sum_t left = missing_left ? bin + missing : bin; - gpair_sum_t right = parent_sum - left; - - best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, - left, right, param); - } - __syncthreads(); - } -} - -template -__global__ void evaluate_split_kernel( - const gpair_sum_t* d_hist, int nidx, uint64_t n_features, - DeviceNodeStats nodes, const int* d_feature_segments, - const float* d_fidx_min_map, const float* d_gidx_fvalue_map, - GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split) { - typedef cub::KeyValuePair ArgMaxT; - typedef cub::BlockScan - BlockScanT; - typedef cub::BlockReduce MaxReduceT; - - typedef cub::BlockReduce SumReduceT; - - union TempStorage { - typename BlockScanT::TempStorage scan; - typename MaxReduceT::TempStorage max_reduce; - typename SumReduceT::TempStorage sum_reduce; - }; - - __shared__ cub::Uninitialized uninitialized_split; - DeviceSplitCandidate& best_split = uninitialized_split.Alias(); - __shared__ TempStorage temp_storage; - - if (threadIdx.x == 0) { - best_split = DeviceSplitCandidate(); - } - - __syncthreads(); - - auto fidx = blockIdx.x; - EvaluateFeature( - fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map, - &best_split, nodes, gpu_param, &temp_storage); - - __syncthreads(); - - if (threadIdx.x == 0) { - // Record best loss - d_split[fidx] = best_split; - } -} - -// Find a gidx value for a given feature otherwise return -1 if not found -template -__device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data, - int fidx_begin, int fidx_end) { - bst_uint previous_middle = UINT32_MAX; - while (end != begin) { - auto middle = begin + (end - begin) / 2; - if (middle == previous_middle) { - break; - } - previous_middle = middle; - - auto gidx = data[middle]; - - if (gidx >= fidx_begin && gidx < fidx_end) { - return gidx; - } else if (gidx < fidx_begin) { - begin = middle; - } else { - end = middle; - } - } - // Value is missing - return -1; -} - -struct DeviceHistogram { - dh::bulk_allocator ba; - dh::dvec data; - int n_bins; - void Init(int device_idx, int max_nodes, int n_bins, bool silent) { - this->n_bins = n_bins; - ba.allocate(device_idx, silent, &data, size_t(max_nodes) * size_t(n_bins)); - } - - void Reset() { data.fill(gpair_sum_t()); } - gpair_sum_t* GetHistPtr(int nidx) { return data.data() + nidx * n_bins; } - - void PrintNidx(int nidx) const { - auto h_data = data.as_vector(); - std::cout << "nidx " << nidx << ":\n"; - for (int i = n_bins * nidx; i < n_bins * (nidx + 1); i++) { - std::cout << h_data[i] << " "; - } - std::cout << "\n"; - } -}; - -// Manage memory for a single GPU -struct DeviceShard { - struct Segment { - size_t begin; - size_t end; - - Segment() : begin(0), end(0) {} - - Segment(size_t begin, size_t end) : begin(begin), end(end) { - CHECK_GE(end, begin); - } - size_t Size() const { return end - begin; } - }; - - int device_idx; - int normalised_device_idx; // Device index counting from param.gpu_id - dh::bulk_allocator ba; - dh::dvec gidx_buffer; - dh::dvec gpair; - dh::dvec2 ridx; // Row index relative to this shard - dh::dvec2 position; - std::vector ridx_segments; - dh::dvec feature_segments; - dh::dvec gidx_fvalue_map; - dh::dvec min_fvalue; - std::vector node_sum_gradients; - common::CompressedIterator gidx; - int row_stride; - bst_uint row_begin_idx; // The row offset for this shard - bst_uint row_end_idx; - bst_uint n_rows; - int n_bins; - int null_gidx_value; - DeviceHistogram hist; - TrainParam param; - - int64_t* tmp_pinned; // Small amount of staging memory - - std::vector streams; - - dh::CubMemory temp_memory; - - DeviceShard(int device_idx, int normalised_device_idx, - const common::GHistIndexMatrix& gmat, bst_uint row_begin, - bst_uint row_end, int n_bins, TrainParam param) - : device_idx(device_idx), - normalised_device_idx(normalised_device_idx), - row_begin_idx(row_begin), - row_end_idx(row_end), - n_rows(row_end - row_begin), - n_bins(n_bins), - null_gidx_value(n_bins), - param(param) { - // Convert to ELLPACK matrix representation - int max_elements_row = 0; - for (auto i = row_begin; i < row_end; i++) { - max_elements_row = - (std::max)(max_elements_row, - static_cast(gmat.row_ptr[i + 1] - gmat.row_ptr[i])); - } - row_stride = max_elements_row; - std::vector ellpack_matrix(row_stride * n_rows, null_gidx_value); - - for (auto i = row_begin; i < row_end; i++) { - int row_count = 0; - for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { - ellpack_matrix[(i - row_begin) * row_stride + row_count] = - gmat.index[j]; - row_count++; - } - } - - // Allocate - int num_symbols = n_bins + 1; - size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize( - ellpack_matrix.size(), num_symbols); - int max_nodes = - param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth); - ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, - &gpair, n_rows, &ridx, n_rows, &position, n_rows, - &feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map, - gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size()); - gidx_fvalue_map = gmat.cut->cut; - min_fvalue = gmat.cut->min_val; - feature_segments = gmat.cut->row_ptr; - - node_sum_gradients.resize(max_nodes); - ridx_segments.resize(max_nodes); - - // Compress gidx - common::CompressedBufferWriter cbw(num_symbols); - std::vector host_buffer(gidx_buffer.size()); - cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end()); - gidx_buffer = host_buffer; - gidx = - common::CompressedIterator(gidx_buffer.data(), num_symbols); - - common::CompressedIterator ci_host(host_buffer.data(), - num_symbols); - - // Init histogram - hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent); - - dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); - } - - ~DeviceShard() { - for (auto& stream : streams) { - dh::safe_cuda(cudaStreamDestroy(stream)); - } - dh::safe_cuda(cudaFreeHost(tmp_pinned)); - } - - // Get vector of at least n initialised streams - std::vector& GetStreams(int n) { - if (n > streams.size()) { - for (auto& stream : streams) { - dh::safe_cuda(cudaStreamDestroy(stream)); - } - - streams.clear(); - streams.resize(n); - - for (auto& stream : streams) { - dh::safe_cuda(cudaStreamCreate(&stream)); - } - } - - return streams; - } - - // Reset values for each update iteration - void Reset(const std::vector& host_gpair) { - dh::safe_cuda(cudaSetDevice(device_idx)); - position.current_dvec().fill(0); - std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), - bst_gpair()); - - thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend()); - - std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); - ridx_segments.front() = Segment(0, ridx.size()); - this->gpair.copy(host_gpair.begin() + row_begin_idx, - host_gpair.begin() + row_end_idx); - subsample_gpair(&gpair, param.subsample, row_begin_idx); - hist.Reset(); - } - - void BuildHist(int nidx) { - auto segment = ridx_segments[nidx]; - auto d_node_hist = hist.GetHistPtr(nidx); - auto d_gidx = gidx; - auto d_ridx = ridx.current(); - auto d_gpair = gpair.data(); - auto row_stride = this->row_stride; - auto null_gidx_value = this->null_gidx_value; - auto n_elements = segment.Size() * row_stride; - - dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) { - int ridx = d_ridx[(idx / row_stride) + segment.begin]; - int gidx = d_gidx[ridx * row_stride + idx % row_stride]; - - if (gidx != null_gidx_value) { - AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]); - } - }); - } - void SubtractionTrick(int nidx_parent, int nidx_histogram, - int nidx_subtraction) { - auto d_node_hist_parent = hist.GetHistPtr(nidx_parent); - auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram); - auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction); - - dh::launch_n(device_idx, hist.n_bins, [=] __device__(size_t idx) { - d_node_hist_subtraction[idx] = - d_node_hist_parent[idx] - d_node_hist_histogram[idx]; - }); - } - - __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { - unsigned ballot = __ballot(val == left_nidx); - if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT - } - } - - void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx, - int split_gidx, bool default_dir_left, bool is_dense, - int fidx_begin, int fidx_end) { - dh::safe_cuda(cudaSetDevice(device_idx)); - temp_memory.LazyAllocate(sizeof(int64_t)); - auto d_left_count = temp_memory.Pointer(); - dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t))); - auto segment = ridx_segments[nidx]; - auto d_ridx = ridx.current(); - auto d_position = position.current(); - auto d_gidx = gidx; - auto row_stride = this->row_stride; - dh::launch_n<1, 512>( - device_idx, segment.Size(), [=] __device__(bst_uint idx) { - idx += segment.begin; - auto ridx = d_ridx[idx]; - auto row_begin = row_stride * ridx; - auto row_end = row_begin + row_stride; - auto gidx = -1; - if (is_dense) { - gidx = d_gidx[row_begin + fidx]; - } else { - gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, - fidx_end); - } - - int position; - if (gidx >= 0) { - // Feature is found - position = gidx <= split_gidx ? left_nidx : right_nidx; - } else { - // Feature is missing - position = default_dir_left ? left_nidx : right_nidx; - } - - CountLeft(d_left_count, position, left_nidx); - d_position[idx] = position; - }); - - dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count, sizeof(int64_t), - cudaMemcpyDeviceToHost)); - auto left_count = *tmp_pinned; - - SortPosition(segment, left_nidx, right_nidx); - // dh::safe_cuda(cudaStreamSynchronize(stream)); - ridx_segments[left_nidx] = - Segment(segment.begin, segment.begin + left_count); - ridx_segments[right_nidx] = - Segment(segment.begin + left_count, segment.end); - } - - void SortPosition(const Segment& segment, int left_nidx, int right_nidx) { - int min_bits = 0; - int max_bits = static_cast( - std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1))); - - size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairs( - nullptr, temp_storage_bytes, position.current() + segment.begin, - position.other() + segment.begin, ridx.current() + segment.begin, - ridx.other() + segment.begin, segment.Size(), min_bits, max_bits); - - temp_memory.LazyAllocate(temp_storage_bytes); - - cub::DeviceRadixSort::SortPairs( - temp_memory.d_temp_storage, temp_memory.temp_storage_bytes, - position.current() + segment.begin, position.other() + segment.begin, - ridx.current() + segment.begin, ridx.other() + segment.begin, - segment.Size(), min_bits, max_bits); - dh::safe_cuda(cudaMemcpy( - position.current() + segment.begin, position.other() + segment.begin, - segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice)); - dh::safe_cuda(cudaMemcpy( - ridx.current() + segment.begin, ridx.other() + segment.begin, - segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice)); - } -}; - -class GPUHistMakerExperimental : public TreeUpdater { - public: - struct ExpandEntry; - - GPUHistMakerExperimental() : initialised(false) {} - ~GPUHistMakerExperimental() {} - void Init( - const std::vector>& args) override { - param.InitAllowUnknown(args); - CHECK(param.n_gpus != 0) << "Must have at least one device"; - n_devices = param.n_gpus; - - dh::check_compute_capability(); - - if (param.grow_policy == TrainParam::kLossGuide) { - qexpand_.reset(new ExpandQueue(loss_guide)); - } else { - qexpand_.reset(new ExpandQueue(depth_wise)); - } - - monitor.Init("updater_gpu_hist_experimental", param.debug_verbose); - } - void Update(const std::vector& gpair, DMatrix* dmat, - const std::vector& trees) override { - monitor.Start("Update"); - GradStats::CheckInfo(dmat->info()); - // rescale learning rate according to size of trees - float lr = param.learning_rate; - param.learning_rate = lr / trees.size(); - // build tree - try { - for (size_t i = 0; i < trees.size(); ++i) { - this->UpdateTree(gpair, dmat, trees[i]); - } - } catch (const std::exception& e) { - LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; - } - param.learning_rate = lr; - monitor.Stop("Update"); - } - - void InitDataOnce(DMatrix* dmat) { - info = &dmat->info(); - monitor.Start("Quantiles"); - hmat_.Init(dmat, param.max_bin); - gmat_.cut = &hmat_; - gmat_.Init(dmat); - monitor.Stop("Quantiles"); - n_bins = hmat_.row_ptr.back(); - - int n_devices = dh::n_devices(param.n_gpus, info->num_row); - - bst_uint row_begin = 0; - bst_uint shard_size = - std::ceil(static_cast(info->num_row) / n_devices); - - std::vector dList(n_devices); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices(); - dList[d_idx] = device_idx; - } - - reducer.Init(dList); - - // Partition input matrix into row segments - std::vector row_segments; - shards.resize(n_devices); - row_segments.push_back(0); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - bst_uint row_end = - std::min(static_cast(row_begin + shard_size), info->num_row); - row_segments.push_back(row_end); - row_begin = row_end; - } - - // Create device shards - omp_set_num_threads(shards.size()); -#pragma omp parallel - { - auto cpu_thread_id = omp_get_thread_num(); - shards[cpu_thread_id] = std::unique_ptr( - new DeviceShard(dList[cpu_thread_id], cpu_thread_id, gmat_, - row_segments[cpu_thread_id], - row_segments[cpu_thread_id + 1], n_bins, param)); - } - - initialised = true; - } - - void InitData(const std::vector& gpair, DMatrix* dmat, - const RegTree& tree) { - monitor.Start("InitDataOnce"); - if (!initialised) { - this->InitDataOnce(dmat); - } - monitor.Stop("InitDataOnce"); - - column_sampler.Init(info->num_col, param); - - // Copy gpair & reset memory - monitor.Start("InitDataReset"); - omp_set_num_threads(shards.size()); -#pragma omp parallel - { - auto cpu_thread_id = omp_get_thread_num(); - shards[cpu_thread_id]->Reset(gpair); - } - monitor.Stop("InitDataReset"); - } - - void AllReduceHist(int nidx) { - for (auto& shard : shards) { - auto d_node_hist = shard->hist.GetHistPtr(nidx); - reducer.AllReduceSum( - shard->normalised_device_idx, - reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - n_bins * (sizeof(gpair_sum_t) / sizeof(gpair_sum_t::value_t))); - } - - reducer.Synchronize(); - } - - void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) { - size_t left_node_max_elements = 0; - size_t right_node_max_elements = 0; - for (auto& shard : shards) { - left_node_max_elements = (std::max)( - left_node_max_elements, shard->ridx_segments[nidx_left].Size()); - right_node_max_elements = (std::max)( - right_node_max_elements, shard->ridx_segments[nidx_right].Size()); - } - - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - if (right_node_max_elements < left_node_max_elements) { - build_hist_nidx = nidx_right; - subtraction_trick_nidx = nidx_left; - } - - for (auto& shard : shards) { - shard->BuildHist(build_hist_nidx); - } - - this->AllReduceHist(build_hist_nidx); - - for (auto& shard : shards) { - shard->SubtractionTrick(nidx_parent, build_hist_nidx, - subtraction_trick_nidx); - } - } - - // Returns best loss - std::vector EvaluateSplits( - const std::vector& nidx_set, RegTree* p_tree) { - auto columns = info->num_col; - std::vector best_splits(nidx_set.size()); - std::vector candidate_splits(nidx_set.size() * - columns); - // Use first device - auto& shard = shards.front(); - dh::safe_cuda(cudaSetDevice(shard->device_idx)); - shard->temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns * - nidx_set.size()); - auto d_split = shard->temp_memory.Pointer(); - - auto& streams = shard->GetStreams(static_cast(nidx_set.size())); - - // Use streams to process nodes concurrently - for (auto i = 0; i < nidx_set.size(); i++) { - auto nidx = nidx_set[i]; - DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param); - - const int BLOCK_THREADS = 256; - evaluate_split_kernel - <<>>( - shard->hist.GetHistPtr(nidx), nidx, info->num_col, node, - shard->feature_segments.data(), shard->min_fvalue.data(), - shard->gidx_fvalue_map.data(), GPUTrainingParam(param), - d_split + i * columns); - } - - dh::safe_cuda( - cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage, - sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), - cudaMemcpyDeviceToHost)); - - for (auto i = 0; i < nidx_set.size(); i++) { - auto nidx = nidx_set[i]; - DeviceSplitCandidate nidx_best; - for (auto fidx = 0; fidx < columns; fidx++) { - auto& candidate = candidate_splits[i * columns + fidx]; - if (column_sampler.ColumnUsed(candidate.findex, - p_tree->GetDepth(nidx))) { - nidx_best.Update(candidate_splits[i * columns + fidx], param); - } - } - best_splits[i] = nidx_best; - } - return std::move(best_splits); - } - - void InitRoot(const std::vector& gpair, RegTree* p_tree) { - auto root_nidx = 0; - // Sum gradients - std::vector tmp_sums(shards.size()); - omp_set_num_threads(shards.size()); -#pragma omp parallel - { - auto cpu_thread_id = omp_get_thread_num(); - dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx)); - tmp_sums[cpu_thread_id] = - thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory), - shards[cpu_thread_id]->gpair.tbegin(), - shards[cpu_thread_id]->gpair.tend()); - } - auto sum_gradient = - std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair()); - - // Generate root histogram - for (auto& shard : shards) { - shard->BuildHist(root_nidx); - } - - this->AllReduceHist(root_nidx); - - // Remember root stats - p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess(); - p_tree->stat(root_nidx).base_weight = CalcWeight(param, sum_gradient); - - // Store sum gradients - for (auto& shard : shards) { - shard->node_sum_gradients[root_nidx] = sum_gradient; - } - - // Generate first split - auto splits = this->EvaluateSplits({root_nidx}, p_tree); - qexpand_->push( - ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0)); - } - - void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { - auto nidx = candidate.nid; - auto left_nidx = (*p_tree)[nidx].cleft(); - auto right_nidx = (*p_tree)[nidx].cright(); - - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - auto split_gidx = -1; - auto fidx = candidate.split.findex; - auto default_dir_left = candidate.split.dir == LeftDir; - auto fidx_begin = hmat_.row_ptr[fidx]; - auto fidx_end = hmat_.row_ptr[fidx + 1]; - for (auto i = fidx_begin; i < fidx_end; ++i) { - if (candidate.split.fvalue == hmat_.cut[i]) { - split_gidx = static_cast(i); - } - } - - auto is_dense = info->num_nonzero == info->num_row * info->num_col; - - omp_set_num_threads(shards.size()); -#pragma omp parallel - { - auto cpu_thread_id = omp_get_thread_num(); - shards[cpu_thread_id]->UpdatePosition(nidx, left_nidx, right_nidx, fidx, - split_gidx, default_dir_left, - is_dense, fidx_begin, fidx_end); - } - } - - void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { - // Add new leaves - RegTree& tree = *p_tree; - tree.AddChilds(candidate.nid); - auto& parent = tree[candidate.nid]; - parent.set_split(candidate.split.findex, candidate.split.fvalue, - candidate.split.dir == LeftDir); - tree.stat(candidate.nid).loss_chg = candidate.split.loss_chg; - - // Configure left child - auto left_weight = CalcWeight(param, candidate.split.left_sum); - tree[parent.cleft()].set_leaf(left_weight * param.learning_rate, 0); - tree.stat(parent.cleft()).base_weight = left_weight; - tree.stat(parent.cleft()).sum_hess = candidate.split.left_sum.GetHess(); - - // Configure right child - auto right_weight = CalcWeight(param, candidate.split.right_sum); - tree[parent.cright()].set_leaf(right_weight * param.learning_rate, 0); - tree.stat(parent.cright()).base_weight = right_weight; - tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess(); - // Store sum gradients - for (auto& shard : shards) { - shard->node_sum_gradients[parent.cleft()] = candidate.split.left_sum; - shard->node_sum_gradients[parent.cright()] = candidate.split.right_sum; - } - this->UpdatePosition(candidate, p_tree); - } - - void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, - RegTree* p_tree) { - // Temporarily store number of threads so we can change it back later - int nthread = omp_get_max_threads(); - - auto& tree = *p_tree; - - monitor.Start("InitData"); - this->InitData(gpair, p_fmat, *p_tree); - monitor.Stop("InitData"); - monitor.Start("InitRoot"); - this->InitRoot(gpair, p_tree); - monitor.Stop("InitRoot"); - - auto timestamp = qexpand_->size(); - auto num_leaves = 1; - - while (!qexpand_->empty()) { - auto candidate = qexpand_->top(); - qexpand_->pop(); - if (!candidate.IsValid(param, num_leaves)) continue; - // std::cout << candidate; - monitor.Start("ApplySplit"); - this->ApplySplit(candidate, p_tree); - monitor.Stop("ApplySplit"); - num_leaves++; - - auto left_child_nidx = tree[candidate.nid].cleft(); - auto right_child_nidx = tree[candidate.nid].cright(); - - // Only create child entries if needed - if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate.nid, left_child_nidx, - right_child_nidx); - monitor.Stop("BuildHist"); - - monitor.Start("EvaluateSplits"); - auto splits = - this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree); - qexpand_->push(ExpandEntry(left_child_nidx, - tree.GetDepth(left_child_nidx), splits[0], - timestamp++)); - qexpand_->push(ExpandEntry(right_child_nidx, - tree.GetDepth(right_child_nidx), splits[1], - timestamp++)); - monitor.Stop("EvaluateSplits"); - } - } - - // Reset omp num threads - omp_set_num_threads(nthread); - } - - struct ExpandEntry { - int nid; - int depth; - DeviceSplitCandidate split; - uint64_t timestamp; - ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split, - uint64_t timestamp) - : nid(nid), depth(depth), split(split), timestamp(timestamp) {} - bool IsValid(const TrainParam& param, int num_leaves) const { - if (split.loss_chg <= rt_eps) return false; - if (param.max_depth > 0 && depth == param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; - return true; - } - - static bool ChildIsValid(const TrainParam& param, int depth, - int num_leaves) { - if (param.max_depth > 0 && depth == param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; - return true; - } - - friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { - os << "ExpandEntry: \n"; - os << "nidx: " << e.nid << "\n"; - os << "depth: " << e.depth << "\n"; - os << "loss: " << e.split.loss_chg << "\n"; - os << "left_sum: " << e.split.left_sum << "\n"; - os << "right_sum: " << e.split.right_sum << "\n"; - return os; - } - }; - - inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) { - if (lhs.depth == rhs.depth) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.depth > rhs.depth; // favor small depth - } - } - inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) { - if (lhs.split.loss_chg == rhs.split.loss_chg) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg - } - } - TrainParam param; - common::HistCutMatrix hmat_; - common::GHistIndexMatrix gmat_; - MetaInfo* info; - bool initialised; - int n_devices; - int n_bins; - - std::vector> shards; - ColumnSampler column_sampler; - typedef std::priority_queue, - std::function> - ExpandQueue; - std::unique_ptr qexpand_; - common::Monitor monitor; - dh::AllReducer reducer; -}; - -XGBOOST_REGISTER_TREE_UPDATER(GPUHistMakerExperimental, - "grow_gpu_hist_experimental") - .describe("Grow tree with GPU.") - .set_body([]() { return new GPUHistMakerExperimental(); }); -} // namespace tree -} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist_experimental.cu b/tests/cpp/tree/test_gpu_hist.cu similarity index 97% rename from tests/cpp/tree/test_gpu_hist_experimental.cu rename to tests/cpp/tree/test_gpu_hist.cu index 481a1b254..a6e34b0ce 100644 --- a/tests/cpp/tree/test_gpu_hist_experimental.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -8,7 +8,7 @@ #include "gtest/gtest.h" #include "../../../src/gbm/gbtree_model.h" -#include "../../../src/tree/updater_gpu_hist_experimental.cu" +#include "../../../src/tree/updater_gpu_hist.cu" namespace xgboost { namespace tree { diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index f7dbbb489..7d63e932d 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -117,14 +117,10 @@ def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, @attr('gpu') class TestGPU(unittest.TestCase): - def test_gpu_hist(self): - variable_param = {'max_depth': [2, 6, 11], 'max_bin': [2, 16, 1024], 'n_gpus': [1, -1]} - assert_updater_accuracy('gpu_hist', 'hist', variable_param, 0.02) - def test_gpu_exact(self): variable_param = {'max_depth': [2, 6, 15]} assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02) - def test_gpu_hist_experimental(self): + def test_gpu_hist(self): variable_param = {'n_gpus': [1, -1], 'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]} - assert_updater_accuracy('gpu_hist_experimental', 'hist', variable_param, 0.01) + assert_updater_accuracy('gpu_hist', 'hist', variable_param, 0.01)