diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst
index e554386a4..94fadc469 100644
--- a/doc/gpu/index.rst
+++ b/doc/gpu/index.rst
@@ -37,30 +37,34 @@ Supported parameters
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
-+--------------------------+---------------+--------------+
-| parameter | ``gpu_exact`` | ``gpu_hist`` |
-+==========================+===============+==============+
-| ``subsample`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``colsample_bytree`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``colsample_bylevel`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``max_bin`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``gpu_id`` | |tick| | |tick| |
-+--------------------------+---------------+--------------+
-| ``n_gpus`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``predictor`` | |tick| | |tick| |
-+--------------------------+---------------+--------------+
-| ``grow_policy`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
-| ``monotone_constraints`` | |cross| | |tick| |
-+--------------------------+---------------+--------------+
++--------------------------------+---------------+--------------+
+| parameter | ``gpu_exact`` | ``gpu_hist`` |
++================================+===============+==============+
+| ``subsample`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``colsample_bytree`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``colsample_bylevel`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``max_bin`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``gpu_id`` | |tick| | |tick| |
++--------------------------------+---------------+--------------+
+| ``n_gpus`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``predictor`` | |tick| | |tick| |
++--------------------------------+---------------+--------------+
+| ``grow_policy`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``monotone_constraints`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
+| ``single_precision_histogram`` | |cross| | |tick| |
++--------------------------------+---------------+--------------+
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
+The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures.
+
The device ordinal can be selected using the ``gpu_id`` parameter, which defaults to 0.
Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the selected gpu devices will be from ``gpu_id`` to ``gpu_id+n_gpus``, please note that ``gpu_id+n_gpus`` must be less than or equal to the number of available GPUs on your system. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance.
@@ -121,6 +125,52 @@ For multi-gpu support, objective functions also honor the ``n_gpus`` parameter,
which, by default is set to 1. To disable running objectives on GPU, just set
``n_gpus`` to 0.
+Metric functions
+===================
+Following table shows current support status for evaluation metrics on the GPU.
+
+.. |tick| unicode:: U+2714
+.. |cross| unicode:: U+2718
+
++-----------------+-------------+
+| Metric | GPU Support |
++=================+=============+
+| rmse | |tick| |
++-----------------+-------------+
+| mae | |tick| |
++-----------------+-------------+
+| logloss | |tick| |
++-----------------+-------------+
+| error | |tick| |
++-----------------+-------------+
+| merror | |cross| |
++-----------------+-------------+
+| mlogloss | |cross| |
++-----------------+-------------+
+| auc | |cross| |
++-----------------+-------------+
+| aucpr | |cross| |
++-----------------+-------------+
+| ndcg | |cross| |
++-----------------+-------------+
+| map | |cross| |
++-----------------+-------------+
+| poisson-nloglik | |tick| |
++-----------------+-------------+
+| gamma-nloglik | |tick| |
++-----------------+-------------+
+| cox-nloglik | |cross| |
++-----------------+-------------+
+| gamma-deviance | |tick| |
++-----------------+-------------+
+| tweedie-nloglik | |tick| |
++-----------------+-------------+
+
+As for objective functions, metrics honor the ``n_gpus`` parameter,
+which, by default is set to 1. To disable running metrics on GPU, just set
+``n_gpus`` to 0.
+
+
Benchmarks
==========
You can run benchmarks on synthetic data for binary classification:
@@ -152,12 +202,15 @@ References
`Nvidia Parallel Forall: Gradient Boosting, Decision Trees and XGBoost with CUDA `_
-Authors
+Contributors
=======
-* Rory Mitchell
+Many thanks to the following contributors (alphabetical order):
+* Andrey Adinets
+* Jiaming Yuan
* Jonathan C. McKinney
+* Philip Cho
+* Rory Mitchell
* Shankara Rao Thejaswi Nanditale
* Vinay Deshpande
-* ... and the rest of the H2O.ai and NVIDIA team.
Please report bugs to the user forum https://discuss.xgboost.ai/.
diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh
index e0d3e41ed..2d5393f3e 100644
--- a/src/common/device_helpers.cuh
+++ b/src/common/device_helpers.cuh
@@ -944,6 +944,32 @@ class AllReducer {
#endif
}
+ /**
+ * \brief Allreduce. Use in exactly the same way as NCCL but without needing
+ * streams or comms.
+ *
+ * \param communication_group_idx Zero-based index of the communication group.
+ * \param sendbuff The sendbuff.
+ * \param recvbuff The recvbuff.
+ * \param count Number of elements.
+ */
+
+ void AllReduceSum(int communication_group_idx, const float *sendbuff,
+ float *recvbuff, int count) {
+#ifdef XGBOOST_USE_NCCL
+ CHECK(initialised_);
+ dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
+ dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum,
+ comms.at(communication_group_idx),
+ streams.at(communication_group_idx)));
+ if(communication_group_idx == 0)
+ {
+ allreduce_bytes_ += count * sizeof(float);
+ allreduce_calls_ += 1;
+ }
+#endif
+ }
+
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu
index 7daf7fe0d..0f2f89b95 100644
--- a/src/common/hist_util.cu
+++ b/src/common/hist_util.cu
@@ -116,19 +116,19 @@ struct GPUSketcher {
n_rows_(row_end - row_begin), param_(std::move(param)) {
}
- void Init(const SparsePage& row_batch, const MetaInfo& info) {
+ void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
num_cols_ = info.num_col_;
has_weights_ = info.weights_.Size() > 0;
// find the batch size
- if (param_.gpu_batch_nrows == 0) {
+ if (gpu_batch_nrows == 0) {
// By default, use no more than 1/16th of GPU memory
gpu_batch_nrows_ = dh::TotalMemory(device_) /
(16 * num_cols_ * sizeof(Entry));
- } else if (param_.gpu_batch_nrows == -1) {
+ } else if (gpu_batch_nrows == -1) {
gpu_batch_nrows_ = n_rows_;
} else {
- gpu_batch_nrows_ = param_.gpu_batch_nrows;
+ gpu_batch_nrows_ = gpu_batch_nrows;
}
if (gpu_batch_nrows_ > n_rows_) {
gpu_batch_nrows_ = n_rows_;
@@ -346,7 +346,8 @@ struct GPUSketcher {
}
};
- void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
+ void Sketch(const SparsePage& batch, const MetaInfo& info,
+ HistCutMatrix* hmat, int gpu_batch_nrows) {
// create device shards
shards_.resize(dist_.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) {
@@ -358,10 +359,11 @@ struct GPUSketcher {
});
// compute sketches for each shard
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->Init(batch, info);
- shard->Sketch(batch, info);
- });
+ dh::ExecuteIndexShards(&shards_,
+ [&](int idx, std::unique_ptr& shard) {
+ shard->Init(batch, info, gpu_batch_nrows);
+ shard->Sketch(batch, info);
+ });
// merge the sketches from all shards
// TODO(canonizer): do it in a tree-like reduction
@@ -390,9 +392,9 @@ struct GPUSketcher {
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
- const tree::TrainParam& param, HistCutMatrix* hmat) {
+ const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows) {
GPUSketcher sketcher(param, info.num_row_);
- sketcher.Sketch(batch, info, hmat);
+ sketcher.Sketch(batch, info, hmat, gpu_batch_nrows);
}
} // namespace common
diff --git a/src/common/hist_util.h b/src/common/hist_util.h
index ad83dd6c8..1c112641d 100644
--- a/src/common/hist_util.h
+++ b/src/common/hist_util.h
@@ -72,7 +72,7 @@ struct HistCutMatrix {
/*! \brief Builds the cut matrix on the GPU */
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
- const tree::TrainParam& param, HistCutMatrix* hmat);
+ const tree::TrainParam& param, HistCutMatrix* hmat, int gpu_batch_nrows);
/*!
* \brief A single row in global histogram index.
diff --git a/src/tree/param.h b/src/tree/param.h
index 3662193e5..8f607647e 100644
--- a/src/tree/param.h
+++ b/src/tree/param.h
@@ -77,8 +77,6 @@ struct TrainParam : public dmlc::Parameter {
int gpu_id;
// number of GPUs to use
int n_gpus;
- // number of rows in a single GPU batch
- int gpu_batch_nrows;
// the criteria to use for ranking splits
std::string split_evaluator;
@@ -205,11 +203,6 @@ struct TrainParam : public dmlc::Parameter {
.set_lower_bound(-1)
.set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
- DMLC_DECLARE_FIELD(gpu_batch_nrows)
- .set_lower_bound(-1)
- .set_default(0)
- .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
- "-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic,interaction")
.describe("The criteria to use for ranking splits");
diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh
index 297b40e39..94b52e971 100644
--- a/src/tree/updater_gpu_common.cuh
+++ b/src/tree/updater_gpu_common.cuh
@@ -36,50 +36,16 @@ XGBOOST_DEVICE __forceinline__ double atomicAdd(double* address, double val) {
namespace xgboost {
namespace tree {
-// Atomic add function for double precision gradients
-__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
- const GradientPair& gpair) {
- auto dst_ptr = reinterpret_cast(dest);
-
- atomicAdd(dst_ptr, static_cast(gpair.GetGrad()));
- atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess()));
-}
-// used by shared-memory atomics code
-__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
- const GradientPairPrecise& gpair) {
- auto dst_ptr = reinterpret_cast(dest);
-
- atomicAdd(dst_ptr, gpair.GetGrad());
- atomicAdd(dst_ptr + 1, gpair.GetHess());
-}
-
-// For integer gradients
-__device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest,
- const GradientPair& gpair) {
- auto dst_ptr = reinterpret_cast(dest); // NOLINT
- GradientPairInteger tmp(gpair.GetGrad(), gpair.GetHess());
- auto src_ptr = reinterpret_cast(&tmp);
+// Atomic add function for gradients
+template
+DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
+ const InputGradientT& gpair) {
+ auto dst_ptr = reinterpret_cast(dest);
atomicAdd(dst_ptr,
- static_cast(*src_ptr)); // NOLINT
+ static_cast(gpair.GetGrad()));
atomicAdd(dst_ptr + 1,
- static_cast(*(src_ptr + 1))); // NOLINT
-}
-
-/**
- * \brief Check maximum gradient value is below 2^16. This is to prevent
- * overflow when using integer gradient summation.
- */
-
-inline void CheckGradientMax(const std::vector& gpair) {
- auto* ptr = reinterpret_cast(gpair.data());
- float abs_max =
- std::accumulate(ptr, ptr + (gpair.size() * 2), 0.f,
- [=](float a, float b) {
- return std::max(abs(a), abs(b)); });
-
- CHECK_LT(abs_max, std::pow(2.0f, 16.0f))
- << "Labels are too large for this algorithm. Rescale to less than 2^16.";
+ static_cast(gpair.GetHess()));
}
struct GPUTrainingParam {
diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu
index 2c11269f4..e09e3e921 100644
--- a/src/tree/updater_gpu_hist.cu
+++ b/src/tree/updater_gpu_hist.cu
@@ -30,7 +30,25 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
-using GradientPairSumT = GradientPairPrecise;
+// training parameters specific to this algorithm
+struct GPUHistMakerTrainParam
+ : public dmlc::Parameter {
+ bool single_precision_histogram;
+ // number of rows in a single GPU batch
+ int gpu_batch_nrows;
+ // declare parameters
+ DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
+ DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
+ "Use single precision to build histograms.");
+ DMLC_DECLARE_FIELD(gpu_batch_nrows)
+ .set_lower_bound(-1)
+ .set_default(0)
+ .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
+ "-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
+ }
+};
+
+DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
/*!
* \brief
@@ -42,20 +60,20 @@ using GradientPairSumT = GradientPairPrecise;
* \param end
* \param temp_storage Shared memory for intermediate result.
*/
-template
-__device__ GradientPairSumT ReduceFeature(common::Span feature_histogram,
+template
+__device__ GradientSumT ReduceFeature(common::Span feature_histogram,
TempStorageT* temp_storage) {
- __shared__ cub::Uninitialized uninitialized_sum;
- GradientPairSumT& shared_sum = uninitialized_sum.Alias();
+ __shared__ cub::Uninitialized uninitialized_sum;
+ GradientSumT& shared_sum = uninitialized_sum.Alias();
- GradientPairSumT local_sum = GradientPairSumT();
+ GradientSumT local_sum = GradientSumT();
// For loop sums features into one block size
auto begin = feature_histogram.data();
auto end = begin + feature_histogram.size();
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
bool thread_active = itr + threadIdx.x < end;
// Scan histogram
- GradientPairSumT bin = thread_active ? *(itr + threadIdx.x) : GradientPairSumT();
+ GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
local_sum += bin;
}
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
@@ -69,10 +87,10 @@ __device__ GradientPairSumT ReduceFeature(common::Span f
/*! \brief Find the thread with best gain. */
template
+ typename MaxReduceT, typename TempStorageT, typename GradientSumT>
__device__ void EvaluateFeature(
int fidx,
- common::Span node_histogram,
+ common::Span node_histogram,
common::Span feature_segments, // cut.row_ptr
float min_fvalue, // cut.min_value
common::Span gidx_fvalue_map, // cut.cut
@@ -86,22 +104,22 @@ __device__ void EvaluateFeature(
uint32_t gidx_end = feature_segments[fidx + 1]; // end bin for i^th feature
// Sum histogram bins for current feature
- GradientPairSumT const feature_sum = ReduceFeature(
+ GradientSumT const feature_sum = ReduceFeature(
node_histogram.subspan(gidx_begin, gidx_end - gidx_begin), temp_storage);
- GradientPairSumT const parent_sum = GradientPairSumT(node.sum_gradients);
- GradientPairSumT const missing = parent_sum - feature_sum;
+ GradientSumT const parent_sum = GradientSumT(node.sum_gradients);
+ GradientSumT const missing = parent_sum - feature_sum;
float const null_gain = -std::numeric_limits::infinity();
- SumCallbackOp prefix_op =
- SumCallbackOp();
+ SumCallbackOp prefix_op =
+ SumCallbackOp();
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
// Gradient value for current bin.
- GradientPairSumT bin =
- thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientPairSumT();
+ GradientSumT bin =
+ thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientSumT();
scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Whether the gradient of missing values is put to the left side.
@@ -117,7 +135,7 @@ __device__ void EvaluateFeature(
// Find thread with best gain
cub::KeyValuePair tuple(threadIdx.x, gain);
cub::KeyValuePair best =
- max_ReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
+ MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
__shared__ cub::KeyValuePair block_max;
if (threadIdx.x == 0) {
@@ -131,8 +149,8 @@ __device__ void EvaluateFeature(
int gidx = scan_begin + threadIdx.x;
float fvalue =
gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1];
- GradientPairSumT left = missing_left ? bin + missing : bin;
- GradientPairSumT right = parent_sum - left;
+ GradientSumT left = missing_left ? bin + missing : bin;
+ GradientSumT right = parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir,
fvalue, fidx,
GradientPair(left),
@@ -143,9 +161,9 @@ __device__ void EvaluateFeature(
}
}
-template
+template
__global__ void EvaluateSplitKernel(
- common::Span
+ common::Span
node_histogram, // histogram for gradients
common::Span feature_set, // Selected features
DeviceNodeStats node,
@@ -160,10 +178,10 @@ __global__ void EvaluateSplitKernel(
// KeyValuePair here used as threadIdx.x -> gain_value
typedef cub::KeyValuePair ArgMaxT;
typedef cub::BlockScan<
- GradientPairSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT;
+ GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT;
typedef cub::BlockReduce MaxReduceT;
- typedef cub::BlockReduce SumReduceT;
+ typedef cub::BlockReduce SumReduceT;
union TempStorage {
typename BlockScanT::TempStorage scan;
@@ -233,10 +251,11 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, GidxIterT data,
* \author Rory
* \date 28/07/2018
*/
+template
struct DeviceHistogram {
/*! \brief Map nidx to starting index of its histogram. */
std::map nidx_map;
- thrust::device_vector data;
+ thrust::device_vector data;
const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size
int n_bins;
int device_id_;
@@ -264,7 +283,7 @@ struct DeviceHistogram {
std::pair old_entry = *nidx_map.begin();
nidx_map.erase(old_entry.first);
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
- n_bins * sizeof(GradientPairSumT)));
+ n_bins * sizeof(GradientSumT)));
nidx_map[nidx] = old_entry.second;
} else {
// Append new node histogram
@@ -280,11 +299,11 @@ struct DeviceHistogram {
* \param nidx Tree node index.
* \return hist pointer.
*/
- common::Span GetNodeHistogram(int nidx) {
+ common::Span GetNodeHistogram(int nidx) {
CHECK(this->HistogramExists(nidx));
auto ptr = data.data().get() + nidx_map[nidx];
- return common::Span(
- reinterpret_cast(ptr), n_bins);
+ return common::Span(
+ reinterpret_cast(ptr), n_bins);
}
};
@@ -341,18 +360,17 @@ __global__ void compress_bin_ellpack_k(
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}
-__global__ void sharedMemHistKernel(size_t row_stride,
- const bst_uint* d_ridx,
+template
+__global__ void SharedMemHistKernel(size_t row_stride, const bst_uint* d_ridx,
common::CompressedIterator d_gidx,
int null_gidx_value,
- GradientPairSumT* d_node_hist,
+ GradientSumT* d_node_hist,
const GradientPair* d_gpair,
- size_t segment_begin,
- size_t n_elements) {
+ size_t segment_begin, size_t n_elements) {
extern __shared__ char smem[];
- GradientPairSumT* smem_arr = reinterpret_cast(smem); // NOLINT
+ GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
- smem_arr[i] = GradientPairSumT();
+ smem_arr[i] = GradientSumT();
}
__syncthreads();
for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) {
@@ -427,15 +445,18 @@ void SortPosition(dh::CubMemory* temp_memory, common::Span position,
out_itr, position.size());
}
+template
struct DeviceShard;
+template
struct GPUHistBuilderBase {
public:
- virtual void Build(DeviceShard* shard, int idx) = 0;
+ virtual void Build(DeviceShard* shard, int idx) = 0;
virtual ~GPUHistBuilderBase() = default;
};
// Manage memory for a single GPU
+template
struct DeviceShard {
int device_id_;
dh::BulkAllocator ba;
@@ -452,7 +473,7 @@ struct DeviceShard {
/*! \brief Range of rows for each node. */
std::vector ridx_segments;
- DeviceHistogram hist;
+ DeviceHistogram hist;
/*! \brief global index of histogram, which is stored in ELLPack format. */
dh::DVec gidx_buffer;
@@ -489,7 +510,7 @@ struct DeviceShard {
dh::CubMemory temp_memory;
- std::unique_ptr hist_builder;
+ std::unique_ptr> hist_builder;
// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end,
@@ -567,7 +588,7 @@ struct DeviceShard {
// One block for each feature
int constexpr BLOCK_THREADS = 256;
- EvaluateSplitKernel
+ EvaluateSplitKernel
<<>>(
hist.GetNodeHistogram(nidx), feature_set.DeviceSpan(device_id_), node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
@@ -719,8 +740,9 @@ struct DeviceShard {
}
};
-struct SharedMemHistBuilder : public GPUHistBuilderBase {
- void Build(DeviceShard* shard, int nidx) override {
+template
+struct SharedMemHistBuilder : public GPUHistBuilderBase {
+ void Build(DeviceShard* shard, int nidx) override {
auto segment = shard->ridx_segments[nidx];
auto segment_begin = segment.begin;
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
@@ -731,7 +753,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase {
int null_gidx_value = shard->null_gidx_value;
auto n_elements = segment.Size() * shard->row_stride;
- const size_t smem_size = sizeof(GradientPairSumT) * shard->null_gidx_value;
+ const size_t smem_size = sizeof(GradientSumT) * shard->null_gidx_value;
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size =
@@ -741,14 +763,15 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase {
return;
}
dh::safe_cuda(cudaSetDevice(shard->device_id_));
- sharedMemHistKernel<<>>
+ SharedMemHistKernel<<>>
(shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist.data(), d_gpair,
segment_begin, n_elements);
}
};
-struct GlobalMemHistBuilder : public GPUHistBuilderBase {
- void Build(DeviceShard* shard, int nidx) override {
+template
+struct GlobalMemHistBuilder : public GPUHistBuilderBase {
+ void Build(DeviceShard* shard, int nidx) override {
Segment segment = shard->ridx_segments[nidx];
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
common::CompressedIterator d_gidx = shard->gidx;
@@ -771,7 +794,8 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase {
}
};
-inline void DeviceShard::InitCompressedData(
+template
+inline void DeviceShard::InitCompressedData(
const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
n_bins = hmat.row_ptr.back();
null_gidx_value = hmat.row_ptr.back();
@@ -820,19 +844,21 @@ inline void DeviceShard::InitCompressedData(
// check if we can use shared memory for building histograms
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
- auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value;
+ auto histogram_size = sizeof(GradientSumT) * null_gidx_value;
auto max_smem = dh::MaxSharedMemory(device_id_);
if (histogram_size <= max_smem) {
- hist_builder.reset(new SharedMemHistBuilder);
+ hist_builder.reset(new SharedMemHistBuilder);
} else {
- hist_builder.reset(new GlobalMemHistBuilder);
+ hist_builder.reset(new GlobalMemHistBuilder);
}
// Init histogram
hist.Init(device_id_, hmat.row_ptr.back());
}
-inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
+
+template
+inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
size_t gpu_batch_nrows =
@@ -882,14 +908,17 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
entries_d.shrink_to_fit();
}
-class GPUHistMaker : public TreeUpdater {
+
+template
+class GPUHistMakerSpecialised{
public:
struct ExpandEntry;
- GPUHistMaker() : initialised_(false), p_last_fmat_(nullptr) {}
+ GPUHistMakerSpecialised() : initialised_(false), p_last_fmat_(nullptr) {}
void Init(
- const std::vector>& args) override {
+ const std::vector>& args) {
param_.InitAllowUnknown(args);
+ hist_maker_param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
n_devices_ = param_.n_gpus;
dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus));
@@ -906,7 +935,7 @@ class GPUHistMaker : public TreeUpdater {
}
void Update(HostDeviceVector* gpair, DMatrix* dmat,
- const std::vector& trees) override {
+ const std::vector& trees) {
monitor_.Start("Update", dist_.Devices());
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees
@@ -943,23 +972,24 @@ class GPUHistMaker : public TreeUpdater {
const SparsePage& batch = *batch_iter;
// Create device shards
shards_.resize(n_devices);
- dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) {
+ dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr>& shard) {
size_t start = dist_.ShardStart(info_->num_row_, i);
size_t size = dist_.ShardSize(info_->num_row_, i);
- shard = std::unique_ptr
- (new DeviceShard(dist_.Devices().DeviceId(i),
+ shard = std::unique_ptr>
+ (new DeviceShard(dist_.Devices().DeviceId(i),
start, start + size, param_));
shard->InitRowPtrs(batch);
});
// Find the cuts.
monitor_.Start("Quantiles", dist_.Devices());
- common::DeviceSketch(batch, *info_, param_, &hmat_);
+ common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", dist_.Devices());
monitor_.Start("BinningCompression", dist_.Devices());
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
+ dh::ExecuteIndexShards(&shards_, [&](int idx,
+ std::unique_ptr>& shard) {
shard->InitCompressedData(hmat_, batch);
});
monitor_.Stop("BinningCompression", dist_.Devices());
@@ -983,9 +1013,11 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Start("InitDataReset", dist_.Devices());
gpair->Reshard(dist_);
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->Reset(gpair);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->Reset(gpair);
+ });
monitor_.Stop("InitDataReset", dist_.Devices());
}
@@ -993,14 +1025,17 @@ class GPUHistMaker : public TreeUpdater {
if (shards_.size() == 1) return;
monitor_.Start("AllReduce");
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
- reducer_.AllReduceSum(
- dist_.Devices().Index(shard->device_id_),
- reinterpret_cast(d_node_hist),
- reinterpret_cast(d_node_hist),
- n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)));
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
+ reducer_.AllReduceSum(
+ dist_.Devices().Index(shard->device_id_),
+ reinterpret_cast(d_node_hist),
+ reinterpret_cast(d_node_hist),
+ n_bins_ * (sizeof(GradientSumT) /
+ sizeof(typename GradientSumT::ValueT)));
+ });
monitor_.Stop("AllReduce");
}
@@ -1026,9 +1061,11 @@ class GPUHistMaker : public TreeUpdater {
}
// Build histogram for node with the smallest number of training examples
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->BuildHist(build_hist_nidx);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->BuildHist(build_hist_nidx);
+ });
this->AllReduceHist(build_hist_nidx);
@@ -1041,15 +1078,19 @@ class GPUHistMaker : public TreeUpdater {
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->SubtractionTrick(nidx_parent, build_hist_nidx,
- subtraction_trick_nidx);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->SubtractionTrick(nidx_parent, build_hist_nidx,
+ subtraction_trick_nidx);
+ });
} else {
// Calculate other histogram manually
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->BuildHist(subtraction_trick_nidx);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->BuildHist(subtraction_trick_nidx);
+ });
this->AllReduceHist(subtraction_trick_nidx);
}
@@ -1066,19 +1107,22 @@ class GPUHistMaker : public TreeUpdater {
// Sum gradients
std::vector tmp_sums(shards_.size());
- dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) {
- dh::safe_cuda(cudaSetDevice(shard->device_id_));
- tmp_sums[i] =
- dh::SumReduction(shard->temp_memory, shard->gpair.Data(),
- shard->gpair.Size());
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int i, std::unique_ptr>& shard) {
+ dh::safe_cuda(cudaSetDevice(shard->device_id_));
+ tmp_sums[i] = dh::SumReduction(
+ shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
+ });
GradientPair sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
// Generate root histogram
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->BuildHist(root_nidx);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->BuildHist(root_nidx);
+ });
this->AllReduceHist(root_nidx);
@@ -1122,11 +1166,13 @@ class GPUHistMaker : public TreeUpdater {
}
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx,
- split_gidx, default_dir_left,
- is_dense, fidx_begin, fidx_end);
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx,
+ default_dir_left, is_dense, fidx_begin,
+ fidx_end);
+ });
}
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
@@ -1223,15 +1269,17 @@ class GPUHistMaker : public TreeUpdater {
}
bool UpdatePredictionCache(
- const DMatrix* data, HostDeviceVector* p_out_preds) override {
+ const DMatrix* data, HostDeviceVector* p_out_preds) {
monitor_.Start("UpdatePredictionCache", dist_.Devices());
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false;
p_out_preds->Reshard(dist_.Devices());
- dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) {
- shard->UpdatePredictionCache(
- p_out_preds->DevicePointer(shard->device_id_));
- });
+ dh::ExecuteIndexShards(
+ &shards_,
+ [&](int idx, std::unique_ptr>& shard) {
+ shard->UpdatePredictionCache(
+ p_out_preds->DevicePointer(shard->device_id_));
+ });
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
return true;
}
@@ -1286,6 +1334,7 @@ class GPUHistMaker : public TreeUpdater {
}
}
TrainParam param_;
+ GPUHistMakerTrainParam hist_maker_param_;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo* info_;
@@ -1293,7 +1342,7 @@ class GPUHistMaker : public TreeUpdater {
int n_devices_;
int n_bins_;
- std::vector> shards_;
+ std::vector>> shards_;
common::ColumnSampler column_sampler_;
using ExpandQueue = std::priority_queue,
std::function>;
@@ -1308,6 +1357,46 @@ class GPUHistMaker : public TreeUpdater {
GPUDistribution dist_;
};
+class GPUHistMaker : public TreeUpdater {
+ public:
+ void Init(
+ const std::vector>& args) override {
+ hist_maker_param_.InitAllowUnknown(args);
+ float_maker_.reset();
+ double_maker_.reset();
+ if (hist_maker_param_.single_precision_histogram) {
+ float_maker_.reset(new GPUHistMakerSpecialised());
+ float_maker_->Init(args);
+ } else {
+ double_maker_.reset(new GPUHistMakerSpecialised());
+ double_maker_->Init(args);
+ }
+ }
+
+ void Update(HostDeviceVector* gpair, DMatrix* dmat,
+ const std::vector& trees) override {
+ if (hist_maker_param_.single_precision_histogram) {
+ float_maker_->Update(gpair, dmat, trees);
+ } else {
+ double_maker_->Update(gpair, dmat, trees);
+ }
+ }
+
+ bool UpdatePredictionCache(
+ const DMatrix* data, HostDeviceVector* p_out_preds) override {
+ if (hist_maker_param_.single_precision_histogram) {
+ return float_maker_->UpdatePredictionCache(data, p_out_preds);
+ } else {
+ return double_maker_->UpdatePredictionCache(data, p_out_preds);
+ }
+ }
+
+ private:
+ GPUHistMakerTrainParam hist_maker_param_;
+ std::unique_ptr> float_maker_;
+ std::unique_ptr> double_maker_;
+};
+
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker(); });
diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu
index f7bce5580..7c4fbd745 100644
--- a/tests/cpp/common/test_gpu_hist_util.cu
+++ b/tests/cpp/common/test_gpu_hist_util.cu
@@ -30,7 +30,7 @@ void TestDeviceSketch(const GPUSet& devices) {
p.gpu_id = 0;
p.n_gpus = devices.Size();
// ensure that the exact quantiles are found
- p.gpu_batch_nrows = nrows * 10;
+ int gpu_batch_nrows = nrows * 10;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
@@ -39,7 +39,7 @@ void TestDeviceSketch(const GPUSet& devices) {
// find the cuts on the GPU
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
HistCutMatrix hmat_gpu;
- DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu);
+ DeviceSketch(batch, (*dmat)->Info(), p, &hmat_gpu, gpu_batch_nrows);
// compare the cuts
double eps = 1e-2;
diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu
index cd4096fca..4f8e7d6ff 100644
--- a/tests/cpp/tree/test_gpu_hist.cu
+++ b/tests/cpp/tree/test_gpu_hist.cu
@@ -17,7 +17,8 @@
namespace xgboost {
namespace tree {
-void BuildGidx(DeviceShard* shard, int n_rows, int n_cols,
+template
+void BuildGidx(DeviceShard* shard, int n_rows, int n_cols,
bst_float sparsity=0) {
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
@@ -48,7 +49,7 @@ TEST(GpuHist, BuildGidxDense) {
param.n_gpus = 1;
param.max_leaves = 0;
- DeviceShard shard(0, 0, n_rows, param);
+ DeviceShard shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols);
std::vector h_gidx_buffer;
@@ -87,7 +88,7 @@ TEST(GpuHist, BuildGidxSparse) {
param.n_gpus = 1;
param.max_leaves = 0;
- DeviceShard shard(0, 0, n_rows, param);
+ DeviceShard shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols, 0.9f);
std::vector h_gidx_buffer;
@@ -122,7 +123,8 @@ std::vector GetHostHistGpair() {
return hist_gpair;
}
-void TestBuildHist(GPUHistBuilderBase& builder) {
+template
+void TestBuildHist(GPUHistBuilderBase& builder) {
int const n_rows = 16, n_cols = 8;
TrainParam param;
@@ -130,7 +132,7 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
param.n_gpus = 1;
param.max_leaves = 0;
- DeviceShard shard(0, 0, n_rows, param);
+ DeviceShard shard(0, 0, n_rows, param);
BuildGidx(&shard, n_rows, n_cols);
@@ -166,13 +168,14 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
shard.ridx.CurrentDVec().tend());
builder.Build(&shard, 0);
- DeviceHistogram d_hist = shard.hist;
+ DeviceHistogram d_hist = shard.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair
- thrust::host_vector h_result (d_hist.data.size()/2);
- size_t data_size = sizeof(GradientPairSumT) / (
- sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT));
+ thrust::host_vector h_result (d_hist.data.size()/2);
+ size_t data_size =
+ sizeof(GradientSumT) /
+ (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT));
data_size *= d_hist.data.size();
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), data_size,
cudaMemcpyDeviceToHost));
@@ -186,13 +189,17 @@ void TestBuildHist(GPUHistBuilderBase& builder) {
}
TEST(GpuHist, BuildHistGlobalMem) {
- GlobalMemHistBuilder builder;
- TestBuildHist(builder);
+ GlobalMemHistBuilder double_builder;
+ TestBuildHist(double_builder);
+ GlobalMemHistBuilder float_builder;
+ TestBuildHist(float_builder);
}
TEST(GpuHist, BuildHistSharedMem) {
- SharedMemHistBuilder builder;
- TestBuildHist(builder);
+ SharedMemHistBuilder double_builder;
+ TestBuildHist(double_builder);
+ SharedMemHistBuilder float_builder;
+ TestBuildHist(float_builder);
}
common::HistCutMatrix GetHostCutMatrix () {
@@ -236,7 +243,7 @@ TEST(GpuHist, EvaluateSplits) {
int max_bins = 4;
// Initialize DeviceShard
- std::unique_ptr shard {new DeviceShard(0, 0, n_rows, param)};
+ std::unique_ptr> shard {new DeviceShard(0, 0, n_rows, param)};
// Initialize DeviceShard::node_sum_gradients
shard->node_sum_gradients = {{6.4f, 12.8f}};
@@ -244,7 +251,7 @@ TEST(GpuHist, EvaluateSplits) {
common::HistCutMatrix cmat = GetHostCutMatrix();
// Copy cut matrix to device.
- DeviceShard::DeviceHistCutMatrix cut;
+ DeviceShard::DeviceHistCutMatrix cut;
shard->ba.Allocate(0, true,
&(shard->cut_.feature_segments), cmat.row_ptr.size(),
&(shard->cut_.min_fvalue), cmat.min_val.size(),
@@ -271,9 +278,9 @@ TEST(GpuHist, EvaluateSplits) {
thrust::copy(hist.begin(), hist.end(),
shard->hist.data.begin());
-
// Initialize GPUHistMaker
- GPUHistMaker hist_maker = GPUHistMaker();
+ GPUHistMakerSpecialised hist_maker =
+ GPUHistMakerSpecialised();
hist_maker.param_ = param;
hist_maker.shards_.push_back(std::move(shard));
hist_maker.column_sampler_.Init(n_cols,
@@ -301,7 +308,8 @@ TEST(GpuHist, EvaluateSplits) {
}
TEST(GpuHist, ApplySplit) {
- GPUHistMaker hist_maker = GPUHistMaker();
+ GPUHistMakerSpecialised hist_maker =
+ GPUHistMakerSpecialised();
int constexpr nid = 0;
int constexpr n_rows = 16;
int constexpr n_cols = 8;
@@ -315,7 +323,7 @@ TEST(GpuHist, ApplySplit) {
}
hist_maker.shards_.resize(1);
- hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param));
+ hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param));
auto& shard = hist_maker.shards_.at(0);
shard->ridx_segments.resize(3); // 3 nodes.
@@ -337,7 +345,7 @@ TEST(GpuHist, ApplySplit) {
0.59, 4, // fvalue has to be equal to one of the cut field
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
GPUTrainingParam(param));
- GPUHistMaker::ExpandEntry candidate_entry {0, 0, candidate, 0};
+ GPUHistMakerSpecialised::ExpandEntry candidate_entry {0, 0, candidate, 0};
candidate_entry.nid = nid;
auto const& nodes = tree.GetNodes();
diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py
index 8c4cb1cda..4cb683a88 100644
--- a/tests/python-gpu/test_gpu_updaters.py
+++ b/tests/python-gpu/test_gpu_updaters.py
@@ -31,12 +31,14 @@ class TestGPU(unittest.TestCase):
assert_gpu_results(cpu_results, gpu_results)
def test_gpu_hist(self):
- variable_param = {'n_gpus': [-1], 'max_depth': [2, 8],
- 'max_leaves': [255, 4],
- 'max_bin': [2, 256], 'min_child_weight': [0, 1],
- 'lambda': [0.0, 1.0],
- 'grow_policy': ['lossguide']}
- for param in parameter_combinations(variable_param):
+ test_param = parameter_combinations({'n_gpus': [1], 'max_depth': [2, 8],
+ 'max_leaves': [255, 4],
+ 'max_bin': [2, 256],
+ 'grow_policy': ['lossguide']})
+ test_param.append({'single_precision_histogram': True})
+ test_param.append({'min_child_weight': 0,
+ 'lambda': 0})
+ for param in test_param:
param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)