diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index e554386a4..94fadc469 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -37,30 +37,34 @@ Supported parameters .. |tick| unicode:: U+2714 .. |cross| unicode:: U+2718 -+--------------------------+---------------+--------------+ -| parameter | ``gpu_exact`` | ``gpu_hist`` | -+==========================+===============+==============+ -| ``subsample`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``colsample_bytree`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``colsample_bylevel`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``max_bin`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``gpu_id`` | |tick| | |tick| | -+--------------------------+---------------+--------------+ -| ``n_gpus`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``predictor`` | |tick| | |tick| | -+--------------------------+---------------+--------------+ -| ``grow_policy`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ -| ``monotone_constraints`` | |cross| | |tick| | -+--------------------------+---------------+--------------+ ++--------------------------------+---------------+--------------+ +| parameter | ``gpu_exact`` | ``gpu_hist`` | ++================================+===============+==============+ +| ``subsample`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``colsample_bytree`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``colsample_bylevel`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``max_bin`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``gpu_id`` | |tick| | |tick| | ++--------------------------------+---------------+--------------+ +| ``n_gpus`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``predictor`` | |tick| | |tick| | ++--------------------------------+---------------+--------------+ +| ``grow_policy`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``monotone_constraints`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ +| ``single_precision_histogram`` | |cross| | |tick| | ++--------------------------------+---------------+--------------+ GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. +The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures. + The device ordinal can be selected using the ``gpu_id`` parameter, which defaults to 0. Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the selected gpu devices will be from ``gpu_id`` to ``gpu_id+n_gpus``, please note that ``gpu_id+n_gpus`` must be less than or equal to the number of available GPUs on your system. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. @@ -121,6 +125,52 @@ For multi-gpu support, objective functions also honor the ``n_gpus`` parameter, which, by default is set to 1. To disable running objectives on GPU, just set ``n_gpus`` to 0. +Metric functions +=================== +Following table shows current support status for evaluation metrics on the GPU. + +.. |tick| unicode:: U+2714 +.. |cross| unicode:: U+2718 + ++-----------------+-------------+ +| Metric | GPU Support | ++=================+=============+ +| rmse | |tick| | ++-----------------+-------------+ +| mae | |tick| | ++-----------------+-------------+ +| logloss | |tick| | ++-----------------+-------------+ +| error | |tick| | ++-----------------+-------------+ +| merror | |cross| | ++-----------------+-------------+ +| mlogloss | |cross| | ++-----------------+-------------+ +| auc | |cross| | ++-----------------+-------------+ +| aucpr | |cross| | ++-----------------+-------------+ +| ndcg | |cross| | ++-----------------+-------------+ +| map | |cross| | ++-----------------+-------------+ +| poisson-nloglik | |tick| | ++-----------------+-------------+ +| gamma-nloglik | |tick| | ++-----------------+-------------+ +| cox-nloglik | |cross| | ++-----------------+-------------+ +| gamma-deviance | |tick| | ++-----------------+-------------+ +| tweedie-nloglik | |tick| | ++-----------------+-------------+ + +As for objective functions, metrics honor the ``n_gpus`` parameter, +which, by default is set to 1. To disable running metrics on GPU, just set +``n_gpus`` to 0. + + Benchmarks ========== You can run benchmarks on synthetic data for binary classification: @@ -152,12 +202,15 @@ References `Nvidia Parallel Forall: Gradient Boosting, Decision Trees and XGBoost with CUDA `_ -Authors +Contributors ======= -* Rory Mitchell +Many thanks to the following contributors (alphabetical order): +* Andrey Adinets +* Jiaming Yuan * Jonathan C. McKinney +* Philip Cho +* Rory Mitchell * Shankara Rao Thejaswi Nanditale * Vinay Deshpande -* ... and the rest of the H2O.ai and NVIDIA team. Please report bugs to the user forum https://discuss.xgboost.ai/. diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index e0d3e41ed..2d5393f3e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -944,6 +944,32 @@ class AllReducer { #endif } + /** + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param communication_group_idx Zero-based index of the communication group. + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. + */ + + void AllReduceSum(int communication_group_idx, const float *sendbuff, + float *recvbuff, int count) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx))); + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, + comms.at(communication_group_idx), + streams.at(communication_group_idx))); + if(communication_group_idx == 0) + { + allreduce_bytes_ += count * sizeof(float); + allreduce_calls_ += 1; + } +#endif + } + /** * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. * diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 7daf7fe0d..0f2f89b95 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -116,19 +116,19 @@ struct GPUSketcher { n_rows_(row_end - row_begin), param_(std::move(param)) { } - void Init(const SparsePage& row_batch, const MetaInfo& info) { + void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { num_cols_ = info.num_col_; has_weights_ = info.weights_.Size() > 0; // find the batch size - if (param_.gpu_batch_nrows == 0) { + if (gpu_batch_nrows == 0) { // By default, use no more than 1/16th of GPU memory gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry)); - } else if (param_.gpu_batch_nrows == -1) { + } else if (gpu_batch_nrows == -1) { gpu_batch_nrows_ = n_rows_; } else { - gpu_batch_nrows_ = param_.gpu_batch_nrows; + gpu_batch_nrows_ = gpu_batch_nrows; } if (gpu_batch_nrows_ > n_rows_) { gpu_batch_nrows_ = n_rows_; @@ -346,7 +346,8 @@ struct GPUSketcher { } }; - void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) { + void Sketch(const SparsePage& batch, const MetaInfo& info, + HistCutMatrix* hmat, int gpu_batch_nrows) { // create device shards shards_.resize(dist_.Devices().Size()); dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { @@ -358,10 +359,11 @@ struct GPUSketcher { }); // compute sketches for each shard - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->Init(batch, info); - shard->Sketch(batch, info); - }); + dh::ExecuteIndexShards(&shards_, + [&](int idx, std::unique_ptr& shard) { + shard->Init(batch, info, gpu_batch_nrows); + shard->Sketch(batch, info); + }); // merge the sketches from all shards // TODO(canonizer): do it in a tree-like reduction @@ -390,9 +392,9 @@ struct GPUSketcher { void DeviceSketch (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat) { + const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows) { GPUSketcher sketcher(param, info.num_row_); - sketcher.Sketch(batch, info, hmat); + sketcher.Sketch(batch, info, hmat, gpu_batch_nrows); } } // namespace common diff --git a/src/common/hist_util.h b/src/common/hist_util.h index ad83dd6c8..1c112641d 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -72,7 +72,7 @@ struct HistCutMatrix { /*! \brief Builds the cut matrix on the GPU */ void DeviceSketch (const SparsePage& batch, const MetaInfo& info, - const tree::TrainParam& param, HistCutMatrix* hmat); + const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows); /*! * \brief A single row in global histogram index. diff --git a/src/tree/param.h b/src/tree/param.h index 3662193e5..8f607647e 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -77,8 +77,6 @@ struct TrainParam : public dmlc::Parameter { int gpu_id; // number of GPUs to use int n_gpus; - // number of rows in a single GPU batch - int gpu_batch_nrows; // the criteria to use for ranking splits std::string split_evaluator; @@ -205,11 +203,6 @@ struct TrainParam : public dmlc::Parameter { .set_lower_bound(-1) .set_default(1) .describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs"); - DMLC_DECLARE_FIELD(gpu_batch_nrows) - .set_lower_bound(-1) - .set_default(0) - .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; " - "-1 to use all rows assignted to a GPU, and 0 to auto-deduce"); DMLC_DECLARE_FIELD(split_evaluator) .set_default("elastic_net,monotonic,interaction") .describe("The criteria to use for ranking splits"); diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 297b40e39..94b52e971 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -36,50 +36,16 @@ XGBOOST_DEVICE __forceinline__ double atomicAdd(double* address, double val) { namespace xgboost { namespace tree { -// Atomic add function for double precision gradients -__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest, - const GradientPair& gpair) { - auto dst_ptr = reinterpret_cast(dest); - - atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); - atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); -} -// used by shared-memory atomics code -__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest, - const GradientPairPrecise& gpair) { - auto dst_ptr = reinterpret_cast(dest); - - atomicAdd(dst_ptr, gpair.GetGrad()); - atomicAdd(dst_ptr + 1, gpair.GetHess()); -} - -// For integer gradients -__device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest, - const GradientPair& gpair) { - auto dst_ptr = reinterpret_cast(dest); // NOLINT - GradientPairInteger tmp(gpair.GetGrad(), gpair.GetHess()); - auto src_ptr = reinterpret_cast(&tmp); +// Atomic add function for gradients +template +DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, + const InputGradientT& gpair) { + auto dst_ptr = reinterpret_cast(dest); atomicAdd(dst_ptr, - static_cast(*src_ptr)); // NOLINT + static_cast(gpair.GetGrad())); atomicAdd(dst_ptr + 1, - static_cast(*(src_ptr + 1))); // NOLINT -} - -/** - * \brief Check maximum gradient value is below 2^16. This is to prevent - * overflow when using integer gradient summation. - */ - -inline void CheckGradientMax(const std::vector& gpair) { - auto* ptr = reinterpret_cast(gpair.data()); - float abs_max = - std::accumulate(ptr, ptr + (gpair.size() * 2), 0.f, - [=](float a, float b) { - return std::max(abs(a), abs(b)); }); - - CHECK_LT(abs_max, std::pow(2.0f, 16.0f)) - << "Labels are too large for this algorithm. Rescale to less than 2^16."; + static_cast(gpair.GetHess())); } struct GPUTrainingParam { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2c11269f4..e09e3e921 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -30,7 +30,25 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); -using GradientPairSumT = GradientPairPrecise; +// training parameters specific to this algorithm +struct GPUHistMakerTrainParam + : public dmlc::Parameter { + bool single_precision_histogram; + // number of rows in a single GPU batch + int gpu_batch_nrows; + // declare parameters + DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + DMLC_DECLARE_FIELD(gpu_batch_nrows) + .set_lower_bound(-1) + .set_default(0) + .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; " + "-1 to use all rows assignted to a GPU, and 0 to auto-deduce"); + } +}; + +DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); /*! * \brief @@ -42,20 +60,20 @@ using GradientPairSumT = GradientPairPrecise; * \param end * \param temp_storage Shared memory for intermediate result. */ -template -__device__ GradientPairSumT ReduceFeature(common::Span feature_histogram, +template +__device__ GradientSumT ReduceFeature(common::Span feature_histogram, TempStorageT* temp_storage) { - __shared__ cub::Uninitialized uninitialized_sum; - GradientPairSumT& shared_sum = uninitialized_sum.Alias(); + __shared__ cub::Uninitialized uninitialized_sum; + GradientSumT& shared_sum = uninitialized_sum.Alias(); - GradientPairSumT local_sum = GradientPairSumT(); + GradientSumT local_sum = GradientSumT(); // For loop sums features into one block size auto begin = feature_histogram.data(); auto end = begin + feature_histogram.size(); for (auto itr = begin; itr < end; itr += BLOCK_THREADS) { bool thread_active = itr + threadIdx.x < end; // Scan histogram - GradientPairSumT bin = thread_active ? *(itr + threadIdx.x) : GradientPairSumT(); + GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT(); local_sum += bin; } local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum()); @@ -69,10 +87,10 @@ __device__ GradientPairSumT ReduceFeature(common::Span f /*! \brief Find the thread with best gain. */ template + typename MaxReduceT, typename TempStorageT, typename GradientSumT> __device__ void EvaluateFeature( int fidx, - common::Span node_histogram, + common::Span node_histogram, common::Span feature_segments, // cut.row_ptr float min_fvalue, // cut.min_value common::Span gidx_fvalue_map, // cut.cut @@ -86,22 +104,22 @@ __device__ void EvaluateFeature( uint32_t gidx_end = feature_segments[fidx + 1]; // end bin for i^th feature // Sum histogram bins for current feature - GradientPairSumT const feature_sum = ReduceFeature( + GradientSumT const feature_sum = ReduceFeature( node_histogram.subspan(gidx_begin, gidx_end - gidx_begin), temp_storage); - GradientPairSumT const parent_sum = GradientPairSumT(node.sum_gradients); - GradientPairSumT const missing = parent_sum - feature_sum; + GradientSumT const parent_sum = GradientSumT(node.sum_gradients); + GradientSumT const missing = parent_sum - feature_sum; float const null_gain = -std::numeric_limits::infinity(); - SumCallbackOp prefix_op = - SumCallbackOp(); + SumCallbackOp prefix_op = + SumCallbackOp(); for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = (scan_begin + threadIdx.x) < gidx_end; // Gradient value for current bin. - GradientPairSumT bin = - thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientPairSumT(); + GradientSumT bin = + thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientSumT(); scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); // Whether the gradient of missing values is put to the left side. @@ -117,7 +135,7 @@ __device__ void EvaluateFeature( // Find thread with best gain cub::KeyValuePair tuple(threadIdx.x, gain); cub::KeyValuePair best = - max_ReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax()); + MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax()); __shared__ cub::KeyValuePair block_max; if (threadIdx.x == 0) { @@ -131,8 +149,8 @@ __device__ void EvaluateFeature( int gidx = scan_begin + threadIdx.x; float fvalue = gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; - GradientPairSumT left = missing_left ? bin + missing : bin; - GradientPairSumT right = parent_sum - left; + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = parent_sum - left; best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, GradientPair(left), @@ -143,9 +161,9 @@ __device__ void EvaluateFeature( } } -template +template __global__ void EvaluateSplitKernel( - common::Span + common::Span node_histogram, // histogram for gradients common::Span feature_set, // Selected features DeviceNodeStats node, @@ -160,10 +178,10 @@ __global__ void EvaluateSplitKernel( // KeyValuePair here used as threadIdx.x -> gain_value typedef cub::KeyValuePair ArgMaxT; typedef cub::BlockScan< - GradientPairSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT; + GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT; typedef cub::BlockReduce MaxReduceT; - typedef cub::BlockReduce SumReduceT; + typedef cub::BlockReduce SumReduceT; union TempStorage { typename BlockScanT::TempStorage scan; @@ -233,10 +251,11 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data, * \author Rory * \date 28/07/2018 */ +template struct DeviceHistogram { /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map; - thrust::device_vector data; + thrust::device_vector data; const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size int n_bins; int device_id_; @@ -264,7 +283,7 @@ struct DeviceHistogram { std::pair old_entry = *nidx_map.begin(); nidx_map.erase(old_entry.first); dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0, - n_bins * sizeof(GradientPairSumT))); + n_bins * sizeof(GradientSumT))); nidx_map[nidx] = old_entry.second; } else { // Append new node histogram @@ -280,11 +299,11 @@ struct DeviceHistogram { * \param nidx Tree node index. * \return hist pointer. */ - common::Span GetNodeHistogram(int nidx) { + common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); auto ptr = data.data().get() + nidx_map[nidx]; - return common::Span( - reinterpret_cast(ptr), n_bins); + return common::Span( + reinterpret_cast(ptr), n_bins); } }; @@ -341,18 +360,17 @@ __global__ void compress_bin_ellpack_k( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } -__global__ void sharedMemHistKernel(size_t row_stride, - const bst_uint* d_ridx, +template +__global__ void SharedMemHistKernel(size_t row_stride, const bst_uint* d_ridx, common::CompressedIterator d_gidx, int null_gidx_value, - GradientPairSumT* d_node_hist, + GradientSumT* d_node_hist, const GradientPair* d_gpair, - size_t segment_begin, - size_t n_elements) { + size_t segment_begin, size_t n_elements) { extern __shared__ char smem[]; - GradientPairSumT* smem_arr = reinterpret_cast(smem); // NOLINT + GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT for (auto i : dh::BlockStrideRange(0, null_gidx_value)) { - smem_arr[i] = GradientPairSumT(); + smem_arr[i] = GradientSumT(); } __syncthreads(); for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { @@ -427,15 +445,18 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span position, out_itr, position.size()); } +template struct DeviceShard; +template struct GPUHistBuilderBase { public: - virtual void Build(DeviceShard* shard, int idx) = 0; + virtual void Build(DeviceShard* shard, int idx) = 0; virtual ~GPUHistBuilderBase() = default; }; // Manage memory for a single GPU +template struct DeviceShard { int device_id_; dh::BulkAllocator ba; @@ -452,7 +473,7 @@ struct DeviceShard { /*! \brief Range of rows for each node. */ std::vector ridx_segments; - DeviceHistogram hist; + DeviceHistogram hist; /*! \brief global index of histogram, which is stored in ELLPack format. */ dh::DVec gidx_buffer; @@ -489,7 +510,7 @@ struct DeviceShard { dh::CubMemory temp_memory; - std::unique_ptr hist_builder; + std::unique_ptr> hist_builder; // TODO(canonizer): do add support multi-batch DMatrix here DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end, @@ -567,7 +588,7 @@ struct DeviceShard { // One block for each feature int constexpr BLOCK_THREADS = 256; - EvaluateSplitKernel + EvaluateSplitKernel <<>>( hist.GetNodeHistogram(nidx), feature_set.DeviceSpan(device_id_), node, cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(), @@ -719,8 +740,9 @@ struct DeviceShard { } }; -struct SharedMemHistBuilder : public GPUHistBuilderBase { - void Build(DeviceShard* shard, int nidx) override { +template +struct SharedMemHistBuilder : public GPUHistBuilderBase { + void Build(DeviceShard* shard, int nidx) override { auto segment = shard->ridx_segments[nidx]; auto segment_begin = segment.begin; auto d_node_hist = shard->hist.GetNodeHistogram(nidx); @@ -731,7 +753,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { int null_gidx_value = shard->null_gidx_value; auto n_elements = segment.Size() * shard->row_stride; - const size_t smem_size = sizeof(GradientPairSumT) * shard->null_gidx_value; + const size_t smem_size = sizeof(GradientSumT) * shard->null_gidx_value; const int items_per_thread = 8; const int block_threads = 256; const int grid_size = @@ -741,14 +763,15 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { return; } dh::safe_cuda(cudaSetDevice(shard->device_id_)); - sharedMemHistKernel<<>> + SharedMemHistKernel<<>> (shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair, segment_begin, n_elements); } }; -struct GlobalMemHistBuilder : public GPUHistBuilderBase { - void Build(DeviceShard* shard, int nidx) override { +template +struct GlobalMemHistBuilder : public GPUHistBuilderBase { + void Build(DeviceShard* shard, int nidx) override { Segment segment = shard->ridx_segments[nidx]; auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); common::CompressedIterator d_gidx = shard->gidx; @@ -771,7 +794,8 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { } }; -inline void DeviceShard::InitCompressedData( +template +inline void DeviceShard::InitCompressedData( const common::HistCutMatrix& hmat, const SparsePage& row_batch) { n_bins = hmat.row_ptr.back(); null_gidx_value = hmat.row_ptr.back(); @@ -820,19 +844,21 @@ inline void DeviceShard::InitCompressedData( // check if we can use shared memory for building histograms // (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding) - auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value; + auto histogram_size = sizeof(GradientSumT) * null_gidx_value; auto max_smem = dh::MaxSharedMemory(device_id_); if (histogram_size <= max_smem) { - hist_builder.reset(new SharedMemHistBuilder); + hist_builder.reset(new SharedMemHistBuilder); } else { - hist_builder.reset(new GlobalMemHistBuilder); + hist_builder.reset(new GlobalMemHistBuilder); } // Init histogram hist.Init(device_id_, hmat.row_ptr.back()); } -inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { + +template +inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { int num_symbols = n_bins + 1; // bin and compress entries in batches of rows size_t gpu_batch_nrows = @@ -882,14 +908,17 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { entries_d.shrink_to_fit(); } -class GPUHistMaker : public TreeUpdater { + +template +class GPUHistMakerSpecialised{ public: struct ExpandEntry; - GPUHistMaker() : initialised_(false), p_last_fmat_(nullptr) {} + GPUHistMakerSpecialised() : initialised_(false), p_last_fmat_(nullptr) {} void Init( - const std::vector>& args) override { + const std::vector>& args) { param_.InitAllowUnknown(args); + hist_maker_param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; n_devices_ = param_.n_gpus; dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus)); @@ -906,7 +935,7 @@ class GPUHistMaker : public TreeUpdater { } void Update(HostDeviceVector* gpair, DMatrix* dmat, - const std::vector& trees) override { + const std::vector& trees) { monitor_.Start("Update", dist_.Devices()); GradStats::CheckInfo(dmat->Info()); // rescale learning rate according to size of trees @@ -943,23 +972,24 @@ class GPUHistMaker : public TreeUpdater { const SparsePage& batch = *batch_iter; // Create device shards shards_.resize(n_devices); - dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr>& shard) { size_t start = dist_.ShardStart(info_->num_row_, i); size_t size = dist_.ShardSize(info_->num_row_, i); - shard = std::unique_ptr - (new DeviceShard(dist_.Devices().DeviceId(i), + shard = std::unique_ptr> + (new DeviceShard(dist_.Devices().DeviceId(i), start, start + size, param_)); shard->InitRowPtrs(batch); }); // Find the cuts. monitor_.Start("Quantiles", dist_.Devices()); - common::DeviceSketch(batch, *info_, param_, &hmat_); + common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows); n_bins_ = hmat_.row_ptr.back(); monitor_.Stop("Quantiles", dist_.Devices()); monitor_.Start("BinningCompression", dist_.Devices()); - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, + std::unique_ptr>& shard) { shard->InitCompressedData(hmat_, batch); }); monitor_.Stop("BinningCompression", dist_.Devices()); @@ -983,9 +1013,11 @@ class GPUHistMaker : public TreeUpdater { monitor_.Start("InitDataReset", dist_.Devices()); gpair->Reshard(dist_); - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->Reset(gpair); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->Reset(gpair); + }); monitor_.Stop("InitDataReset", dist_.Devices()); } @@ -993,14 +1025,17 @@ class GPUHistMaker : public TreeUpdater { if (shards_.size() == 1) return; monitor_.Start("AllReduce"); - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); - reducer_.AllReduceSum( - dist_.Devices().Index(shard->device_id_), - reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT))); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); + reducer_.AllReduceSum( + dist_.Devices().Index(shard->device_id_), + reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + n_bins_ * (sizeof(GradientSumT) / + sizeof(typename GradientSumT::ValueT))); + }); monitor_.Stop("AllReduce"); } @@ -1026,9 +1061,11 @@ class GPUHistMaker : public TreeUpdater { } // Build histogram for node with the smallest number of training examples - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->BuildHist(build_hist_nidx); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->BuildHist(build_hist_nidx); + }); this->AllReduceHist(build_hist_nidx); @@ -1041,15 +1078,19 @@ class GPUHistMaker : public TreeUpdater { if (do_subtraction_trick) { // Calculate other histogram using subtraction trick - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->SubtractionTrick(nidx_parent, build_hist_nidx, - subtraction_trick_nidx); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->SubtractionTrick(nidx_parent, build_hist_nidx, + subtraction_trick_nidx); + }); } else { // Calculate other histogram manually - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->BuildHist(subtraction_trick_nidx); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->BuildHist(subtraction_trick_nidx); + }); this->AllReduceHist(subtraction_trick_nidx); } @@ -1066,19 +1107,22 @@ class GPUHistMaker : public TreeUpdater { // Sum gradients std::vector tmp_sums(shards_.size()); - dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - dh::safe_cuda(cudaSetDevice(shard->device_id_)); - tmp_sums[i] = - dh::SumReduction(shard->temp_memory, shard->gpair.Data(), - shard->gpair.Size()); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int i, std::unique_ptr>& shard) { + dh::safe_cuda(cudaSetDevice(shard->device_id_)); + tmp_sums[i] = dh::SumReduction( + shard->temp_memory, shard->gpair.Data(), shard->gpair.Size()); + }); GradientPair sum_gradient = std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair()); // Generate root histogram - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->BuildHist(root_nidx); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->BuildHist(root_nidx); + }); this->AllReduceHist(root_nidx); @@ -1122,11 +1166,13 @@ class GPUHistMaker : public TreeUpdater { } auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, - split_gidx, default_dir_left, - is_dense, fidx_begin, fidx_end); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx, + default_dir_left, is_dense, fidx_begin, + fidx_end); + }); } void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { @@ -1223,15 +1269,17 @@ class GPUHistMaker : public TreeUpdater { } bool UpdatePredictionCache( - const DMatrix* data, HostDeviceVector* p_out_preds) override { + const DMatrix* data, HostDeviceVector* p_out_preds) { monitor_.Start("UpdatePredictionCache", dist_.Devices()); if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) return false; p_out_preds->Reshard(dist_.Devices()); - dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { - shard->UpdatePredictionCache( - p_out_preds->DevicePointer(shard->device_id_)); - }); + dh::ExecuteIndexShards( + &shards_, + [&](int idx, std::unique_ptr>& shard) { + shard->UpdatePredictionCache( + p_out_preds->DevicePointer(shard->device_id_)); + }); monitor_.Stop("UpdatePredictionCache", dist_.Devices()); return true; } @@ -1286,6 +1334,7 @@ class GPUHistMaker : public TreeUpdater { } } TrainParam param_; + GPUHistMakerTrainParam hist_maker_param_; common::HistCutMatrix hmat_; common::GHistIndexMatrix gmat_; MetaInfo* info_; @@ -1293,7 +1342,7 @@ class GPUHistMaker : public TreeUpdater { int n_devices_; int n_bins_; - std::vector> shards_; + std::vector>> shards_; common::ColumnSampler column_sampler_; using ExpandQueue = std::priority_queue, std::function>; @@ -1308,6 +1357,46 @@ class GPUHistMaker : public TreeUpdater { GPUDistribution dist_; }; +class GPUHistMaker : public TreeUpdater { + public: + void Init( + const std::vector>& args) override { + hist_maker_param_.InitAllowUnknown(args); + float_maker_.reset(); + double_maker_.reset(); + if (hist_maker_param_.single_precision_histogram) { + float_maker_.reset(new GPUHistMakerSpecialised()); + float_maker_->Init(args); + } else { + double_maker_.reset(new GPUHistMakerSpecialised()); + double_maker_->Init(args); + } + } + + void Update(HostDeviceVector* gpair, DMatrix* dmat, + const std::vector& trees) override { + if (hist_maker_param_.single_precision_histogram) { + float_maker_->Update(gpair, dmat, trees); + } else { + double_maker_->Update(gpair, dmat, trees); + } + } + + bool UpdatePredictionCache( + const DMatrix* data, HostDeviceVector* p_out_preds) override { + if (hist_maker_param_.single_precision_histogram) { + return float_maker_->UpdatePredictionCache(data, p_out_preds); + } else { + return double_maker_->UpdatePredictionCache(data, p_out_preds); + } + } + + private: + GPUHistMakerTrainParam hist_maker_param_; + std::unique_ptr> float_maker_; + std::unique_ptr> double_maker_; +}; + XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") .set_body([]() { return new GPUHistMaker(); }); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index f7bce5580..7c4fbd745 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -30,7 +30,7 @@ void TestDeviceSketch(const GPUSet& devices) { p.gpu_id = 0; p.n_gpus = devices.Size(); // ensure that the exact quantiles are found - p.gpu_batch_nrows = nrows * 10; + int gpu_batch_nrows = nrows * 10; // find quantiles on the CPU HistCutMatrix hmat_cpu; @@ -39,7 +39,7 @@ void TestDeviceSketch(const GPUSet& devices) { // find the cuts on the GPU const SparsePage& batch = *(*dmat)->GetRowBatches().begin(); HistCutMatrix hmat_gpu; - DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu); + DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows); // compare the cuts double eps = 1e-2; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index cd4096fca..4f8e7d6ff 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -17,7 +17,8 @@ namespace xgboost { namespace tree { -void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, +template +void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, bst_float sparsity=0) { auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3); const SparsePage& batch = *(*dmat)->GetRowBatches().begin(); @@ -48,7 +49,7 @@ TEST(GpuHist, BuildGidxDense) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols); std::vector h_gidx_buffer; @@ -87,7 +88,7 @@ TEST(GpuHist, BuildGidxSparse) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols, 0.9f); std::vector h_gidx_buffer; @@ -122,7 +123,8 @@ std::vector GetHostHistGpair() { return hist_gpair; } -void TestBuildHist(GPUHistBuilderBase& builder) { +template +void TestBuildHist(GPUHistBuilderBase& builder) { int const n_rows = 16, n_cols = 8; TrainParam param; @@ -130,7 +132,7 @@ void TestBuildHist(GPUHistBuilderBase& builder) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols); @@ -166,13 +168,14 @@ void TestBuildHist(GPUHistBuilderBase& builder) { shard.ridx.CurrentDVec().tend()); builder.Build(&shard, 0); - DeviceHistogram d_hist = shard.hist; + DeviceHistogram d_hist = shard.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair - thrust::host_vector h_result (d_hist.data.size()/2); - size_t data_size = sizeof(GradientPairSumT) / ( - sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)); + thrust::host_vector h_result (d_hist.data.size()/2); + size_t data_size = + sizeof(GradientSumT) / + (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)); data_size *= d_hist.data.size(); dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), data_size, cudaMemcpyDeviceToHost)); @@ -186,13 +189,17 @@ void TestBuildHist(GPUHistBuilderBase& builder) { } TEST(GpuHist, BuildHistGlobalMem) { - GlobalMemHistBuilder builder; - TestBuildHist(builder); + GlobalMemHistBuilder double_builder; + TestBuildHist(double_builder); + GlobalMemHistBuilder float_builder; + TestBuildHist(float_builder); } TEST(GpuHist, BuildHistSharedMem) { - SharedMemHistBuilder builder; - TestBuildHist(builder); + SharedMemHistBuilder double_builder; + TestBuildHist(double_builder); + SharedMemHistBuilder float_builder; + TestBuildHist(float_builder); } common::HistCutMatrix GetHostCutMatrix () { @@ -236,7 +243,7 @@ TEST(GpuHist, EvaluateSplits) { int max_bins = 4; // Initialize DeviceShard - std::unique_ptr shard {new DeviceShard(0, 0, n_rows, param)}; + std::unique_ptr> shard {new DeviceShard(0, 0, n_rows, param)}; // Initialize DeviceShard::node_sum_gradients shard->node_sum_gradients = {{6.4f, 12.8f}}; @@ -244,7 +251,7 @@ TEST(GpuHist, EvaluateSplits) { common::HistCutMatrix cmat = GetHostCutMatrix(); // Copy cut matrix to device. - DeviceShard::DeviceHistCutMatrix cut; + DeviceShard::DeviceHistCutMatrix cut; shard->ba.Allocate(0, true, &(shard->cut_.feature_segments), cmat.row_ptr.size(), &(shard->cut_.min_fvalue), cmat.min_val.size(), @@ -271,9 +278,9 @@ TEST(GpuHist, EvaluateSplits) { thrust::copy(hist.begin(), hist.end(), shard->hist.data.begin()); - // Initialize GPUHistMaker - GPUHistMaker hist_maker = GPUHistMaker(); + GPUHistMakerSpecialised hist_maker = + GPUHistMakerSpecialised(); hist_maker.param_ = param; hist_maker.shards_.push_back(std::move(shard)); hist_maker.column_sampler_.Init(n_cols, @@ -301,7 +308,8 @@ TEST(GpuHist, EvaluateSplits) { } TEST(GpuHist, ApplySplit) { - GPUHistMaker hist_maker = GPUHistMaker(); + GPUHistMakerSpecialised hist_maker = + GPUHistMakerSpecialised(); int constexpr nid = 0; int constexpr n_rows = 16; int constexpr n_cols = 8; @@ -315,7 +323,7 @@ TEST(GpuHist, ApplySplit) { } hist_maker.shards_.resize(1); - hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param)); + hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param)); auto& shard = hist_maker.shards_.at(0); shard->ridx_segments.resize(3); // 3 nodes. @@ -337,7 +345,7 @@ TEST(GpuHist, ApplySplit) { 0.59, 4, // fvalue has to be equal to one of the cut field GradientPair(8.2, 2.8), GradientPair(6.3, 3.6), GPUTrainingParam(param)); - GPUHistMaker::ExpandEntry candidate_entry {0, 0, candidate, 0}; + GPUHistMakerSpecialised::ExpandEntry candidate_entry {0, 0, candidate, 0}; candidate_entry.nid = nid; auto const& nodes = tree.GetNodes(); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 8c4cb1cda..4cb683a88 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -31,12 +31,14 @@ class TestGPU(unittest.TestCase): assert_gpu_results(cpu_results, gpu_results) def test_gpu_hist(self): - variable_param = {'n_gpus': [-1], 'max_depth': [2, 8], - 'max_leaves': [255, 4], - 'max_bin': [2, 256], 'min_child_weight': [0, 1], - 'lambda': [0.0, 1.0], - 'grow_policy': ['lossguide']} - for param in parameter_combinations(variable_param): + test_param = parameter_combinations({'n_gpus': [1], 'max_depth': [2, 8], + 'max_leaves': [255, 4], + 'max_bin': [2, 256], + 'grow_policy': ['lossguide']}) + test_param.append({'single_precision_histogram': True}) + test_param.append({'min_child_weight': 0, + 'lambda': 0}) + for param in test_param: param['tree_method'] = 'gpu_hist' gpu_results = run_suite(param, select_datasets=datasets) assert_results_non_increasing(gpu_results, 1e-2)