Single precision histograms on GPU (#3965)

* Allow single precision histogram summation in gpu_hist

* Add python test, reduce run-time of gpu_hist tests

* Update documentation
This commit is contained in:
Rory Mitchell 2018-12-10 10:55:30 +13:00 committed by GitHub
parent 9af6b689d6
commit 93f9ce9ef9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 351 additions and 212 deletions

View File

@ -37,30 +37,34 @@ Supported parameters
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| parameter | ``gpu_exact`` | ``gpu_hist`` |
+==========================+===============+==============+
+================================+===============+==============+
| ``subsample`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``colsample_bytree`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``colsample_bylevel`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``max_bin`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``gpu_id`` | |tick| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``n_gpus`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``predictor`` | |tick| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``grow_policy`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``monotone_constraints`` | |cross| | |tick| |
+--------------------------+---------------+--------------+
+--------------------------------+---------------+--------------+
| ``single_precision_histogram`` | |cross| | |tick| |
+--------------------------------+---------------+--------------+
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures.
The device ordinal can be selected using the ``gpu_id`` parameter, which defaults to 0.
Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the selected gpu devices will be from ``gpu_id`` to ``gpu_id+n_gpus``, please note that ``gpu_id+n_gpus`` must be less than or equal to the number of available GPUs on your system. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance.
@ -121,6 +125,52 @@ For multi-gpu support, objective functions also honor the ``n_gpus`` parameter,
which, by default is set to 1. To disable running objectives on GPU, just set
``n_gpus`` to 0.
Metric functions
===================
Following table shows current support status for evaluation metrics on the GPU.
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
+-----------------+-------------+
| Metric | GPU Support |
+=================+=============+
| rmse | |tick| |
+-----------------+-------------+
| mae | |tick| |
+-----------------+-------------+
| logloss | |tick| |
+-----------------+-------------+
| error | |tick| |
+-----------------+-------------+
| merror | |cross| |
+-----------------+-------------+
| mlogloss | |cross| |
+-----------------+-------------+
| auc | |cross| |
+-----------------+-------------+
| aucpr | |cross| |
+-----------------+-------------+
| ndcg | |cross| |
+-----------------+-------------+
| map | |cross| |
+-----------------+-------------+
| poisson-nloglik | |tick| |
+-----------------+-------------+
| gamma-nloglik | |tick| |
+-----------------+-------------+
| cox-nloglik | |cross| |
+-----------------+-------------+
| gamma-deviance | |tick| |
+-----------------+-------------+
| tweedie-nloglik | |tick| |
+-----------------+-------------+
As for objective functions, metrics honor the ``n_gpus`` parameter,
which, by default is set to 1. To disable running metrics on GPU, just set
``n_gpus`` to 0.
Benchmarks
==========
You can run benchmarks on synthetic data for binary classification:
@ -152,12 +202,15 @@ References
`Nvidia Parallel Forall: Gradient Boosting, Decision Trees and XGBoost with CUDA <https://devblogs.nvidia.com/parallelforall/gradient-boosting-decision-trees-xgboost-cuda/>`_
Authors
Contributors
=======
* Rory Mitchell
Many thanks to the following contributors (alphabetical order):
* Andrey Adinets
* Jiaming Yuan
* Jonathan C. McKinney
* Philip Cho
* Rory Mitchell
* Shankara Rao Thejaswi Nanditale
* Vinay Deshpande
* ... and the rest of the H2O.ai and NVIDIA team.
Please report bugs to the user forum https://discuss.xgboost.ai/.

View File

@ -944,6 +944,32 @@ class AllReducer {
#endif
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(int communication_group_idx, const float *sendbuff,
float *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
}
#endif
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*

View File

@ -116,19 +116,19 @@ struct GPUSketcher {
n_rows_(row_end - row_begin), param_(std::move(param)) {
}
void Init(const SparsePage& row_batch, const MetaInfo& info) {
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
num_cols_ = info.num_col_;
has_weights_ = info.weights_.Size() > 0;
// find the batch size
if (param_.gpu_batch_nrows == 0) {
if (gpu_batch_nrows == 0) {
// By default, use no more than 1/16th of GPU memory
gpu_batch_nrows_ = dh::TotalMemory(device_) /
(16 * num_cols_ * sizeof(Entry));
} else if (param_.gpu_batch_nrows == -1) {
} else if (gpu_batch_nrows == -1) {
gpu_batch_nrows_ = n_rows_;
} else {
gpu_batch_nrows_ = param_.gpu_batch_nrows;
gpu_batch_nrows_ = gpu_batch_nrows;
}
if (gpu_batch_nrows_ > n_rows_) {
gpu_batch_nrows_ = n_rows_;
@ -346,7 +346,8 @@ struct GPUSketcher {
}
};
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
void Sketch(const SparsePage& batch, const MetaInfo& info,
HistCutMatrix* hmat, int gpu_batch_nrows) {
// create device shards
shards_.resize(dist_.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
@ -358,8 +359,9 @@ struct GPUSketcher {
});
// compute sketches for each shard
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info);
dh::ExecuteIndexShards(&shards_,
[&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info, gpu_batch_nrows);
shard->Sketch(batch, info);
});
@ -390,9 +392,9 @@ struct GPUSketcher {
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat) {
const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows) {
GPUSketcher sketcher(param, info.num_row_);
sketcher.Sketch(batch, info, hmat);
sketcher.Sketch(batch, info, hmat, gpu_batch_nrows);
}
} // namespace common

View File

@ -72,7 +72,7 @@ struct HistCutMatrix {
/*! \brief Builds the cut matrix on the GPU */
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat);
const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows);
/*!
* \brief A single row in global histogram index.

View File

@ -77,8 +77,6 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
int gpu_id;
// number of GPUs to use
int n_gpus;
// number of rows in a single GPU batch
int gpu_batch_nrows;
// the criteria to use for ranking splits
std::string split_evaluator;
@ -205,11 +203,6 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_lower_bound(-1)
.set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
DMLC_DECLARE_FIELD(gpu_batch_nrows)
.set_lower_bound(-1)
.set_default(0)
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic,interaction")
.describe("The criteria to use for ranking splits");

View File

@ -36,50 +36,16 @@ XGBOOST_DEVICE __forceinline__ double atomicAdd(double* address, double val) {
namespace xgboost {
namespace tree {
// Atomic add function for double precision gradients
__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
const GradientPair& gpair) {
auto dst_ptr = reinterpret_cast<double*>(dest);
atomicAdd(dst_ptr, static_cast<double>(gpair.GetGrad()));
atomicAdd(dst_ptr + 1, static_cast<double>(gpair.GetHess()));
}
// used by shared-memory atomics code
__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
const GradientPairPrecise& gpair) {
auto dst_ptr = reinterpret_cast<double*>(dest);
atomicAdd(dst_ptr, gpair.GetGrad());
atomicAdd(dst_ptr + 1, gpair.GetHess());
}
// For integer gradients
__device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest,
const GradientPair& gpair) {
auto dst_ptr = reinterpret_cast<unsigned long long int*>(dest); // NOLINT
GradientPairInteger tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<GradientPairInteger::ValueT*>(&tmp);
// Atomic add function for gradients
template <typename OutputGradientT, typename InputGradientT>
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
const InputGradientT& gpair) {
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
static_cast<typename OutputGradientT::ValueT>(gpair.GetGrad()));
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
}
/**
* \brief Check maximum gradient value is below 2^16. This is to prevent
* overflow when using integer gradient summation.
*/
inline void CheckGradientMax(const std::vector<GradientPair>& gpair) {
auto* ptr = reinterpret_cast<const float*>(gpair.data());
float abs_max =
std::accumulate(ptr, ptr + (gpair.size() * 2), 0.f,
[=](float a, float b) {
return std::max(abs(a), abs(b)); });
CHECK_LT(abs_max, std::pow(2.0f, 16.0f))
<< "Labels are too large for this algorithm. Rescale to less than 2^16.";
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}
struct GPUTrainingParam {

View File

@ -30,7 +30,25 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
using GradientPairSumT = GradientPairPrecise;
// training parameters specific to this algorithm
struct GPUHistMakerTrainParam
: public dmlc::Parameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
// number of rows in a single GPU batch
int gpu_batch_nrows;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
DMLC_DECLARE_FIELD(gpu_batch_nrows)
.set_lower_bound(-1)
.set_default(0)
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
}
};
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
/*!
* \brief
@ -42,20 +60,20 @@ using GradientPairSumT = GradientPairPrecise;
* \param end
* \param temp_storage Shared memory for intermediate result.
*/
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT>
__device__ GradientPairSumT ReduceFeature(common::Span<const GradientPairSumT> feature_histogram,
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT, typename GradientSumT>
__device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_histogram,
TempStorageT* temp_storage) {
__shared__ cub::Uninitialized<GradientPairSumT> uninitialized_sum;
GradientPairSumT& shared_sum = uninitialized_sum.Alias();
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
GradientSumT& shared_sum = uninitialized_sum.Alias();
GradientPairSumT local_sum = GradientPairSumT();
GradientSumT local_sum = GradientSumT();
// For loop sums features into one block size
auto begin = feature_histogram.data();
auto end = begin + feature_histogram.size();
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
bool thread_active = itr + threadIdx.x < end;
// Scan histogram
GradientPairSumT bin = thread_active ? *(itr + threadIdx.x) : GradientPairSumT();
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
local_sum += bin;
}
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
@ -69,10 +87,10 @@ __device__ GradientPairSumT ReduceFeature(common::Span<const GradientPairSumT> f
/*! \brief Find the thread with best gain. */
template <int BLOCK_THREADS, typename ReduceT, typename scan_t,
typename max_ReduceT, typename TempStorageT>
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
__device__ void EvaluateFeature(
int fidx,
common::Span<const GradientPairSumT> node_histogram,
common::Span<const GradientSumT> node_histogram,
common::Span<const uint32_t> feature_segments, // cut.row_ptr
float min_fvalue, // cut.min_value
common::Span<const float> gidx_fvalue_map, // cut.cut
@ -86,22 +104,22 @@ __device__ void EvaluateFeature(
uint32_t gidx_end = feature_segments[fidx + 1]; // end bin for i^th feature
// Sum histogram bins for current feature
GradientPairSumT const feature_sum = ReduceFeature<BLOCK_THREADS, ReduceT>(
GradientSumT const feature_sum = ReduceFeature<BLOCK_THREADS, ReduceT>(
node_histogram.subspan(gidx_begin, gidx_end - gidx_begin), temp_storage);
GradientPairSumT const parent_sum = GradientPairSumT(node.sum_gradients);
GradientPairSumT const missing = parent_sum - feature_sum;
GradientSumT const parent_sum = GradientSumT(node.sum_gradients);
GradientSumT const missing = parent_sum - feature_sum;
float const null_gain = -std::numeric_limits<bst_float>::infinity();
SumCallbackOp<GradientPairSumT> prefix_op =
SumCallbackOp<GradientPairSumT>();
SumCallbackOp<GradientSumT> prefix_op =
SumCallbackOp<GradientSumT>();
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
// Gradient value for current bin.
GradientPairSumT bin =
thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientPairSumT();
GradientSumT bin =
thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientSumT();
scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Whether the gradient of missing values is put to the left side.
@ -117,7 +135,7 @@ __device__ void EvaluateFeature(
// Find thread with best gain
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
cub::KeyValuePair<int, float> best =
max_ReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
__shared__ cub::KeyValuePair<int, float> block_max;
if (threadIdx.x == 0) {
@ -131,8 +149,8 @@ __device__ void EvaluateFeature(
int gidx = scan_begin + threadIdx.x;
float fvalue =
gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1];
GradientPairSumT left = missing_left ? bin + missing : bin;
GradientPairSumT right = parent_sum - left;
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir,
fvalue, fidx,
GradientPair(left),
@ -143,9 +161,9 @@ __device__ void EvaluateFeature(
}
}
template <int BLOCK_THREADS>
template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitKernel(
common::Span<const GradientPairSumT>
common::Span<const GradientSumT>
node_histogram, // histogram for gradients
common::Span<const int> feature_set, // Selected features
DeviceNodeStats node,
@ -160,10 +178,10 @@ __global__ void EvaluateSplitKernel(
// KeyValuePair here used as threadIdx.x -> gain_value
typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<
GradientPairSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT;
GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<GradientPairSumT, BLOCK_THREADS> SumReduceT;
typedef cub::BlockReduce<GradientSumT, BLOCK_THREADS> SumReduceT;
union TempStorage {
typename BlockScanT::TempStorage scan;
@ -233,10 +251,11 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data,
* \author Rory
* \date 28/07/2018
*/
template <typename GradientSumT>
struct DeviceHistogram {
/*! \brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map;
thrust::device_vector<GradientPairSumT::ValueT> data;
thrust::device_vector<typename GradientSumT::ValueT> data;
const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
int n_bins;
int device_id_;
@ -264,7 +283,7 @@ struct DeviceHistogram {
std::pair<int, size_t> old_entry = *nidx_map.begin();
nidx_map.erase(old_entry.first);
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
n_bins * sizeof(GradientPairSumT)));
n_bins * sizeof(GradientSumT)));
nidx_map[nidx] = old_entry.second;
} else {
// Append new node histogram
@ -280,11 +299,11 @@ struct DeviceHistogram {
* \param nidx Tree node index.
* \return hist pointer.
*/
common::Span<GradientPairSumT> GetNodeHistogram(int nidx) {
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
CHECK(this->HistogramExists(nidx));
auto ptr = data.data().get() + nidx_map[nidx];
return common::Span<GradientPairSumT>(
reinterpret_cast<GradientPairSumT*>(ptr), n_bins);
return common::Span<GradientSumT>(
reinterpret_cast<GradientSumT*>(ptr), n_bins);
}
};
@ -341,18 +360,17 @@ __global__ void compress_bin_ellpack_k(
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
__global__ void sharedMemHistKernel(size_t row_stride,
const bst_uint* d_ridx,
template <typename GradientSumT>
__global__ void SharedMemHistKernel(size_t row_stride, const bst_uint* d_ridx,
common::CompressedIterator<uint32_t> d_gidx,
int null_gidx_value,
GradientPairSumT* d_node_hist,
GradientSumT* d_node_hist,
const GradientPair* d_gpair,
size_t segment_begin,
size_t n_elements) {
size_t segment_begin, size_t n_elements) {
extern __shared__ char smem[];
GradientPairSumT* smem_arr = reinterpret_cast<GradientPairSumT*>(smem); // NOLINT
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
smem_arr[i] = GradientPairSumT();
smem_arr[i] = GradientSumT();
}
__syncthreads();
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
@ -427,15 +445,18 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
out_itr, position.size());
}
template <typename GradientSumT>
struct DeviceShard;
template <typename GradientSumT>
struct GPUHistBuilderBase {
public:
virtual void Build(DeviceShard* shard, int idx) = 0;
virtual void Build(DeviceShard<GradientSumT>* shard, int idx) = 0;
virtual ~GPUHistBuilderBase() = default;
};
// Manage memory for a single GPU
template <typename GradientSumT>
struct DeviceShard {
int device_id_;
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
@ -452,7 +473,7 @@ struct DeviceShard {
/*! \brief Range of rows for each node. */
std::vector<Segment> ridx_segments;
DeviceHistogram hist;
DeviceHistogram<GradientSumT> hist;
/*! \brief global index of histogram, which is stored in ELLPack format. */
dh::DVec<common::CompressedByteT> gidx_buffer;
@ -489,7 +510,7 @@ struct DeviceShard {
dh::CubMemory temp_memory;
std::unique_ptr<GPUHistBuilderBase> hist_builder;
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end,
@ -567,7 +588,7 @@ struct DeviceShard {
// One block for each feature
int constexpr BLOCK_THREADS = 256;
EvaluateSplitKernel<BLOCK_THREADS>
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
<<<uint32_t(feature_set.Size()), BLOCK_THREADS, 0>>>(
hist.GetNodeHistogram(nidx), feature_set.DeviceSpan(device_id_), node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
@ -719,8 +740,9 @@ struct DeviceShard {
}
};
struct SharedMemHistBuilder : public GPUHistBuilderBase {
void Build(DeviceShard* shard, int nidx) override {
template <typename GradientSumT>
struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
auto segment = shard->ridx_segments[nidx];
auto segment_begin = segment.begin;
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
@ -731,7 +753,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase {
int null_gidx_value = shard->null_gidx_value;
auto n_elements = segment.Size() * shard->row_stride;
const size_t smem_size = sizeof(GradientPairSumT) * shard->null_gidx_value;
const size_t smem_size = sizeof(GradientSumT) * shard->null_gidx_value;
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size =
@ -741,14 +763,15 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase {
return;
}
dh::safe_cuda(cudaSetDevice(shard->device_id_));
sharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
(shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair,
segment_begin, n_elements);
}
};
struct GlobalMemHistBuilder : public GPUHistBuilderBase {
void Build(DeviceShard* shard, int nidx) override {
template <typename GradientSumT>
struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
Segment segment = shard->ridx_segments[nidx];
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
common::CompressedIterator<uint32_t> d_gidx = shard->gidx;
@ -771,7 +794,8 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase {
}
};
inline void DeviceShard::InitCompressedData(
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
n_bins = hmat.row_ptr.back();
null_gidx_value = hmat.row_ptr.back();
@ -820,19 +844,21 @@ inline void DeviceShard::InitCompressedData(
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value;
auto histogram_size = sizeof(GradientSumT) * null_gidx_value;
auto max_smem = dh::MaxSharedMemory(device_id_);
if (histogram_size <= max_smem) {
hist_builder.reset(new SharedMemHistBuilder);
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
} else {
hist_builder.reset(new GlobalMemHistBuilder);
hist_builder.reset(new GlobalMemHistBuilder<GradientSumT>);
}
// Init histogram
hist.Init(device_id_, hmat.row_ptr.back());
}
inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_batch) {
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows =
@ -882,14 +908,17 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
entries_d.shrink_to_fit();
}
class GPUHistMaker : public TreeUpdater {
template <typename GradientSumT>
class GPUHistMakerSpecialised{
public:
struct ExpandEntry;
GPUHistMaker() : initialised_(false), p_last_fmat_(nullptr) {}
GPUHistMakerSpecialised() : initialised_(false), p_last_fmat_(nullptr) {}
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
const std::vector<std::pair<std::string, std::string>>& args) {
param_.InitAllowUnknown(args);
hist_maker_param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
n_devices_ = param_.n_gpus;
dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus));
@ -906,7 +935,7 @@ class GPUHistMaker : public TreeUpdater {
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
const std::vector<RegTree*>& trees) {
monitor_.Start("Update", dist_.Devices());
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees
@ -943,23 +972,24 @@ class GPUHistMaker : public TreeUpdater {
const SparsePage& batch = *batch_iter;
// Create device shards
shards_.resize(n_devices);
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
size_t start = dist_.ShardStart(info_->num_row_, i);
size_t size = dist_.ShardSize(info_->num_row_, i);
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(dist_.Devices().DeviceId(i),
shard = std::unique_ptr<DeviceShard<GradientSumT>>
(new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(i),
start, start + size, param_));
shard->InitRowPtrs(batch);
});
// Find the cuts.
monitor_.Start("Quantiles", dist_.Devices());
common::DeviceSketch(batch, *info_, param_, &hmat_);
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", dist_.Devices());
monitor_.Start("BinningCompression", dist_.Devices());
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx,
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->InitCompressedData(hmat_, batch);
});
monitor_.Stop("BinningCompression", dist_.Devices());
@ -983,7 +1013,9 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Start("InitDataReset", dist_.Devices());
gpair->Reshard(dist_);
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->Reset(gpair);
});
monitor_.Stop("InitDataReset", dist_.Devices());
@ -993,13 +1025,16 @@ class GPUHistMaker : public TreeUpdater {
if (shards_.size() == 1) return;
monitor_.Start("AllReduce");
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
reducer_.AllReduceSum(
dist_.Devices().Index(shard->device_id_),
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)));
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
n_bins_ * (sizeof(GradientSumT) /
sizeof(typename GradientSumT::ValueT)));
});
monitor_.Stop("AllReduce");
}
@ -1026,7 +1061,9 @@ class GPUHistMaker : public TreeUpdater {
}
// Build histogram for node with the smallest number of training examples
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->BuildHist(build_hist_nidx);
});
@ -1041,13 +1078,17 @@ class GPUHistMaker : public TreeUpdater {
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
});
} else {
// Calculate other histogram manually
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->BuildHist(subtraction_trick_nidx);
});
@ -1066,17 +1107,20 @@ class GPUHistMaker : public TreeUpdater {
// Sum gradients
std::vector<GradientPair> tmp_sums(shards_.size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id_));
tmp_sums[i] =
dh::SumReduction(shard->temp_memory, shard->gpair.Data(),
shard->gpair.Size());
tmp_sums[i] = dh::SumReduction(
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
});
GradientPair sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
// Generate root histogram
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->BuildHist(root_nidx);
});
@ -1122,10 +1166,12 @@ class GPUHistMaker : public TreeUpdater {
}
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx,
split_gidx, default_dir_left,
is_dense, fidx_begin, fidx_end);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx,
default_dir_left, is_dense, fidx_begin,
fidx_end);
});
}
@ -1223,12 +1269,14 @@ class GPUHistMaker : public TreeUpdater {
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
monitor_.Start("UpdatePredictionCache", dist_.Devices());
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false;
p_out_preds->Reshard(dist_.Devices());
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->UpdatePredictionCache(
p_out_preds->DevicePointer(shard->device_id_));
});
@ -1286,6 +1334,7 @@ class GPUHistMaker : public TreeUpdater {
}
}
TrainParam param_;
GPUHistMakerTrainParam hist_maker_param_;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo* info_;
@ -1293,7 +1342,7 @@ class GPUHistMaker : public TreeUpdater {
int n_devices_;
int n_bins_;
std::vector<std::unique_ptr<DeviceShard>> shards_;
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_;
common::ColumnSampler column_sampler_;
using ExpandQueue = std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
@ -1308,6 +1357,46 @@ class GPUHistMaker : public TreeUpdater {
GPUDistribution dist_;
};
class GPUHistMaker : public TreeUpdater {
public:
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
hist_maker_param_.InitAllowUnknown(args);
float_maker_.reset();
double_maker_.reset();
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
float_maker_->Init(args);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
double_maker_->Init(args);
}
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) {
float_maker_->Update(gpair, dmat, trees);
} else {
double_maker_->Update(gpair, dmat, trees);
}
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {
return double_maker_->UpdatePredictionCache(data, p_out_preds);
}
}
private:
GPUHistMakerTrainParam hist_maker_param_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker(); });

View File

@ -30,7 +30,7 @@ void TestDeviceSketch(const GPUSet& devices) {
p.gpu_id = 0;
p.n_gpus = devices.Size();
// ensure that the exact quantiles are found
p.gpu_batch_nrows = nrows * 10;
int gpu_batch_nrows = nrows * 10;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
@ -39,7 +39,7 @@ void TestDeviceSketch(const GPUSet& devices) {
// find the cuts on the GPU
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
HistCutMatrix hmat_gpu;
DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu);
DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows);
// compare the cuts
double eps = 1e-2;

View File

@ -17,7 +17,8 @@
namespace xgboost {
namespace tree {
void BuildGidx(DeviceShard* shard, int n_rows, int n_cols,
template <typename GradientSumT>
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
bst_float sparsity=0) {
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
@ -48,7 +49,7 @@ TEST(GpuHist, BuildGidxDense) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard shard(0, 0, n_rows, param);
DeviceShard<GradientPairPrecise> shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols);
std::vector<common::CompressedByteT> h_gidx_buffer;
@ -87,7 +88,7 @@ TEST(GpuHist, BuildGidxSparse) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard shard(0, 0, n_rows, param);
DeviceShard<GradientPairPrecise> shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols, 0.9f);
std::vector<common::CompressedByteT> h_gidx_buffer;
@ -122,7 +123,8 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
return hist_gpair;
}
void TestBuildHist(GPUHistBuilderBase& builder) {
template <typename GradientSumT>
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
int const n_rows = 16, n_cols = 8;
TrainParam param;
@ -130,7 +132,7 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
param.n_gpus = 1;
param.max_leaves = 0;
DeviceShard shard(0, 0, n_rows, param);
DeviceShard<GradientSumT> shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols);
@ -166,13 +168,14 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
shard.ridx.CurrentDVec().tend());
builder.Build(&shard, 0);
DeviceHistogram d_hist = shard.hist;
DeviceHistogram<GradientSumT> d_hist = shard.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair
thrust::host_vector<GradientPairSumT> h_result (d_hist.data.size()/2);
size_t data_size = sizeof(GradientPairSumT) / (
sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT));
thrust::host_vector<GradientSumT> h_result (d_hist.data.size()/2);
size_t data_size =
sizeof(GradientSumT) /
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT));
data_size *= d_hist.data.size();
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), data_size,
cudaMemcpyDeviceToHost));
@ -186,13 +189,17 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
}
TEST(GpuHist, BuildHistGlobalMem) {
GlobalMemHistBuilder builder;
TestBuildHist(builder);
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
GlobalMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
}
TEST(GpuHist, BuildHistSharedMem) {
SharedMemHistBuilder builder;
TestBuildHist(builder);
SharedMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
SharedMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
}
common::HistCutMatrix GetHostCutMatrix () {
@ -236,7 +243,7 @@ TEST(GpuHist, EvaluateSplits) {
int max_bins = 4;
// Initialize DeviceShard
std::unique_ptr<DeviceShard> shard {new DeviceShard(0, 0, n_rows, param)};
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard {new DeviceShard<GradientPairPrecise>(0, 0, n_rows, param)};
// Initialize DeviceShard::node_sum_gradients
shard->node_sum_gradients = {{6.4f, 12.8f}};
@ -244,7 +251,7 @@ TEST(GpuHist, EvaluateSplits) {
common::HistCutMatrix cmat = GetHostCutMatrix();
// Copy cut matrix to device.
DeviceShard::DeviceHistCutMatrix cut;
DeviceShard<GradientPairPrecise>::DeviceHistCutMatrix cut;
shard->ba.Allocate(0, true,
&(shard->cut_.feature_segments), cmat.row_ptr.size(),
&(shard->cut_.min_fvalue), cmat.min_val.size(),
@ -271,9 +278,9 @@ TEST(GpuHist, EvaluateSplits) {
thrust::copy(hist.begin(), hist.end(),
shard->hist.data.begin());
// Initialize GPUHistMaker
GPUHistMaker hist_maker = GPUHistMaker();
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
GPUHistMakerSpecialised<GradientPairPrecise>();
hist_maker.param_ = param;
hist_maker.shards_.push_back(std::move(shard));
hist_maker.column_sampler_.Init(n_cols,
@ -301,7 +308,8 @@ TEST(GpuHist, EvaluateSplits) {
}
TEST(GpuHist, ApplySplit) {
GPUHistMaker hist_maker = GPUHistMaker();
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
GPUHistMakerSpecialised<GradientPairPrecise>();
int constexpr nid = 0;
int constexpr n_rows = 16;
int constexpr n_cols = 8;
@ -315,7 +323,7 @@ TEST(GpuHist, ApplySplit) {
}
hist_maker.shards_.resize(1);
hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param));
hist_maker.shards_[0].reset(new DeviceShard<GradientPairPrecise>(0, 0, n_rows, param));
auto& shard = hist_maker.shards_.at(0);
shard->ridx_segments.resize(3); // 3 nodes.
@ -337,7 +345,7 @@ TEST(GpuHist, ApplySplit) {
0.59, 4, // fvalue has to be equal to one of the cut field
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
GPUTrainingParam(param));
GPUHistMaker::ExpandEntry candidate_entry {0, 0, candidate, 0};
GPUHistMakerSpecialised<GradientPairPrecise>::ExpandEntry candidate_entry {0, 0, candidate, 0};
candidate_entry.nid = nid;
auto const& nodes = tree.GetNodes();

View File

@ -31,12 +31,14 @@ class TestGPU(unittest.TestCase):
assert_gpu_results(cpu_results, gpu_results)
def test_gpu_hist(self):
variable_param = {'n_gpus': [-1], 'max_depth': [2, 8],
test_param = parameter_combinations({'n_gpus': [1], 'max_depth': [2, 8],
'max_leaves': [255, 4],
'max_bin': [2, 256], 'min_child_weight': [0, 1],
'lambda': [0.0, 1.0],
'grow_policy': ['lossguide']}
for param in parameter_combinations(variable_param):
'max_bin': [2, 256],
'grow_policy': ['lossguide']})
test_param.append({'single_precision_histogram': True})
test_param.append({'min_child_weight': 0,
'lambda': 0})
for param in test_param:
param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)