From eda9e180f0473db90f686170eea75ddefa8fecdb Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sat, 5 Aug 2017 22:16:23 +1200 Subject: [PATCH] [GPU-Plugin] Various fixes (#2579) * Fix test large * Add check for max_depth 0 * Update readme * Add LBS specialisation for dense data * Add bst_gpair_precise * Temporarily disable accuracy tests on test_large.py * Solve unused variable compiler warning * Fix max_bin > 1024 error --- include/xgboost/base.h | 46 ++++++++---- plugin/updater_gpu/README.md | 2 +- plugin/updater_gpu/src/common.cuh | 26 ++++--- plugin/updater_gpu/src/device_helpers.cuh | 69 +++++++++++------- plugin/updater_gpu/src/gpu_hist_builder.cu | 77 ++++++++++++-------- plugin/updater_gpu/src/gpu_hist_builder.cuh | 12 +-- plugin/updater_gpu/test/python/test_large.py | 12 +-- src/gbm/gbtree_model.h | 1 - 8 files changed, 147 insertions(+), 98 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index b0ce3bf76..9f54588c5 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -82,49 +82,67 @@ typedef uint64_t bst_ulong; // NOLINT(*) /*! \brief float type, used for storing statistics */ typedef float bst_float; -/*! \brief gradient statistics pair usually needed in gradient boosting */ -struct bst_gpair { + +/*! \brief Implementation of gradient statistics pair */ +template +struct bst_gpair_internal { /*! \brief gradient statistics */ - bst_float grad; + T grad; /*! \brief second order gradient statistics */ - bst_float hess; + T hess; - XGBOOST_DEVICE bst_gpair() : grad(0), hess(0) {} + XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {} - XGBOOST_DEVICE bst_gpair(bst_float grad, bst_float hess) + XGBOOST_DEVICE bst_gpair_internal(T grad, T hess) : grad(grad), hess(hess) {} - XGBOOST_DEVICE bst_gpair &operator+=(const bst_gpair &rhs) { + template + XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal&g) + : grad(g.grad), hess(g.hess) {} + + XGBOOST_DEVICE bst_gpair_internal &operator+=(const bst_gpair_internal &rhs) { grad += rhs.grad; hess += rhs.hess; return *this; } - XGBOOST_DEVICE bst_gpair operator+(const bst_gpair &rhs) const { - bst_gpair g; + XGBOOST_DEVICE bst_gpair_internal operator+(const bst_gpair_internal &rhs) const { + bst_gpair_internal g; g.grad = grad + rhs.grad; g.hess = hess + rhs.hess; return g; } - XGBOOST_DEVICE bst_gpair &operator-=(const bst_gpair &rhs) { + XGBOOST_DEVICE bst_gpair_internal &operator-=(const bst_gpair_internal &rhs) { grad -= rhs.grad; hess -= rhs.hess; return *this; } - XGBOOST_DEVICE bst_gpair operator-(const bst_gpair &rhs) const { - bst_gpair g; + XGBOOST_DEVICE bst_gpair_internal operator-(const bst_gpair_internal &rhs) const { + bst_gpair_internal g; g.grad = grad - rhs.grad; g.hess = hess - rhs.hess; return g; } - XGBOOST_DEVICE bst_gpair(int value) { - *this = bst_gpair(static_cast(value), static_cast(value)); + XGBOOST_DEVICE bst_gpair_internal(int value) { + *this = bst_gpair_internal(static_cast(value), static_cast(value)); + } + + friend std::ostream &operator<<(std::ostream &os, + const bst_gpair_internal &g) { + os << g.grad << "/" << g.hess; + return os; } }; +/*! \brief gradient statistics pair usually needed in gradient boosting */ +typedef bst_gpair_internal bst_gpair; + +/*! \brief High precision gradient statistics pair */ +typedef bst_gpair_internal bst_gpair_precise; + /*! \brief small eps gap for minimum split decision. */ const bst_float rt_eps = 1e-6f; diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index 5664a7a22..baeaf710f 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -27,7 +27,7 @@ This plugin currently works with the CLI version and python version. Python example: ```python -param['gpu_id'] = 1 +param['gpu_id'] = 0 param['max_bin'] = 16 param['tree_method'] = 'gpu_hist' ``` diff --git a/plugin/updater_gpu/src/common.cuh b/plugin/updater_gpu/src/common.cuh index e4c2d3972..c3938a579 100644 --- a/plugin/updater_gpu/src/common.cuh +++ b/plugin/updater_gpu/src/common.cuh @@ -15,28 +15,30 @@ namespace xgboost { namespace tree { +template __device__ inline float device_calc_loss_chg(const GPUTrainingParam& param, - const bst_gpair& scan, - const bst_gpair& missing, - const bst_gpair& parent_sum, + const gpair_t& scan, + const gpair_t& missing, + const gpair_t& parent_sum, const float& parent_gain, bool missing_left) { - bst_gpair left = scan; + gpair_t left = scan; if (missing_left) { left += missing; } - bst_gpair right = parent_sum - left; + gpair_t right = parent_sum - left; float left_gain = CalcGain(param, left.grad, left.hess); float right_gain = CalcGain(param, right.grad, right.hess); return left_gain + right_gain - parent_gain; } -__device__ float inline loss_chg_missing(const bst_gpair& scan, - const bst_gpair& missing, - const bst_gpair& parent_sum, +template +__device__ float inline loss_chg_missing(const gpair_t& scan, + const gpair_t& missing, + const gpair_t& parent_sum, const float& parent_gain, const GPUTrainingParam& param, bool& missing_left_out) { // NOLINT @@ -177,11 +179,11 @@ inline std::vector col_sample(std::vector features, float colsample) { } struct GpairCallbackOp { // Running prefix - bst_gpair running_total; + bst_gpair_precise running_total; // Constructor - __device__ GpairCallbackOp() : running_total(bst_gpair()) {} - __device__ bst_gpair operator()(bst_gpair block_aggregate) { - bst_gpair old_prefix = running_total; + __device__ GpairCallbackOp() : running_total(bst_gpair_precise()) {} + __device__ bst_gpair_precise operator()(bst_gpair_precise block_aggregate) { + bst_gpair_precise old_prefix = running_total; running_total += block_aggregate; return old_prefix; } diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index d4be2e346..f734e9999 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -729,32 +729,8 @@ __global__ void LbsKernel(coordinate_t *d_coordinates, } } -/** - * \fn template - * void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, - * segments_iter segments, offset_t num_segments, func_t f) - * - * \brief Load balancing search function. Reads a CSR type matrix description - * and allows a function to be executed on each element. Search 'modern GPU load - * balancing search' for more information. - * - * \author Rory - * \date 7/9/2017 - * - * \tparam func_t Type of the function t. - * \tparam segments_iter Type of the segments iterator. - * \tparam offset_t Type of the offset. - * \tparam segments_t Type of the segments t. - * \param device_idx Zero-based index of the device. - * \param [in,out] temp_memory Temporary memory allocator. - * \param count Number of elements. - * \param segments Device pointer to segments. - * \param num_segments Number of segments. - * \param f Lambda to be executed on matrix elements. - */ - template -void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, +void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, segments_iter segments, offset_t num_segments, func_t f) { typedef typename cub::CubVector::Type coordinate_t; dh::safe_cuda(cudaSetDevice(device_idx)); @@ -775,4 +751,47 @@ void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, num_segments); } +template +void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments, func_t f) { + CHECK(count % num_segments == 0) << "Data is not dense."; + + launch_n(device_idx, count, [=]__device__(offset_t idx) + { + offset_t segment = idx / (count / num_segments); + f(idx, segment); + }); +} + +/** + * \fn template void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, segments_iter segments, offset_t num_segments, bool is_dense, func_t f) + * + * \brief Load balancing search function. Reads a CSR type matrix description and allows a function + * to be executed on each element. Search 'modern GPU load balancing search' for more + * information. + * + * \author Rory + * \date 7/9/2017 + * + * \tparam func_t Type of the function t. + * \tparam segments_iter Type of the segments iterator. + * \tparam offset_t Type of the offset. + * \param device_idx Zero-based index of the device. + * \param [in,out] temp_memory Temporary memory allocator. + * \param count Number of elements. + * \param segments Device pointer to segments. + * \param num_segments Number of segments. + * \param is_dense True if this object is dense. + * \param f Lambda to be executed on matrix elements. + */ + +template +void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, + segments_iter segments, offset_t num_segments, bool is_dense, func_t f) { + if (is_dense) { + DenseTransformLbs(device_idx, count, num_segments, f); + } + else { + SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments, f); + } +} } // namespace dh diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index 760939f46..47e71c030 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -49,10 +49,10 @@ void DeviceHist::Init(int n_bins_in) { void DeviceHist::Reset(int device_idx) { cudaSetDevice(device_idx); - data.fill(bst_gpair()); + data.fill(bst_gpair_precise()); } -bst_gpair* DeviceHist::GetLevelPtr(int depth) { +bst_gpair_precise* DeviceHist::GetLevelPtr(int depth) { return data.data() + n_nodes(depth - 1) * n_bins; } @@ -62,10 +62,25 @@ HistBuilder DeviceHist::GetBuilder() { return HistBuilder(data.data(), n_bins); } -HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins) +HistBuilder::HistBuilder(bst_gpair_precise* ptr, int n_bins) : d_hist(ptr), n_bins(n_bins) {} -__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const { +// Define double precision atomic add for older architectures +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +#else +__device__ double atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT + unsigned long long int old = *address_as_ull, assumed; // NOLINT + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} +#endif + +__device__ void HistBuilder::Add(bst_gpair_precise gpair, int gidx, int nidx) const { int hist_idx = nidx * n_bins + gidx; atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below // line lead to about 3X @@ -75,7 +90,7 @@ __device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const { atomicAdd(&(d_hist[hist_idx].hess), gpair.hess); } -__device__ bst_gpair HistBuilder::Get(int gidx, int nidx) const { +__device__ bst_gpair_precise HistBuilder::Get(int gidx, int nidx) const { return d_hist[nidx * n_bins + gidx]; } @@ -104,6 +119,7 @@ GPUHistBuilder::~GPUHistBuilder() { void GPUHistBuilder::Init(const TrainParam& param) { CHECK(param.max_depth < 16) << "Tree depth too large."; + CHECK(param.max_depth != 0) << "Tree depth cannot be 0."; CHECK(param.grow_policy != TrainParam::kLossGuide) << "Loss guided growth policy not supported. Use CPU algorithm."; this->param = param; @@ -358,7 +374,7 @@ void GPUHistBuilder::BuildHist(int depth) { auto hist_builder = hist_vec[d_idx].GetBuilder(); dh::TransformLbs( device_idx, &temp_memory[d_idx], end - begin, d_row_ptr, - row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) { + row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) { int nidx = d_position[local_ridx]; // OPTMARK: latency if (!is_active(nidx, depth)) return; @@ -396,8 +412,8 @@ void GPUHistBuilder::BuildHist(int depth) { dh::safe_nccl(ncclAllReduce( reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), - hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / sizeof(float), - ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx]))); + hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair_precise) / sizeof(double), + ncclDouble, ncclSum, comms[d_idx], *(streams[d_idx]))); } for (int d_idx = 0; d_idx < n_devices; d_idx++) { @@ -428,9 +444,9 @@ void GPUHistBuilder::BuildHist(int depth) { } int gidx = idx % hist_builder.n_bins; - bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx)); + bst_gpair_precise parent = hist_builder.Get(gidx, parent_nidx(nidx)); int other_nidx = left_smallest ? nidx - 1 : nidx + 1; - bst_gpair other = hist_builder.Get(gidx, other_nidx); + bst_gpair_precise other = hist_builder.Get(gidx, other_nidx); hist_builder.Add(parent - other, gidx, nidx); // OPTMARK: This is slow, could use shared // memory or cache results intead of writing to @@ -443,16 +459,16 @@ void GPUHistBuilder::BuildHist(int depth) { template __global__ void find_split_kernel( - const bst_gpair* d_level_hist, int* d_feature_segments, int depth, + const bst_gpair_precise* d_level_hist, int* d_feature_segments, int depth, int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp, Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) { typedef cub::KeyValuePair ArgMaxT; - typedef cub::BlockScan + typedef cub::BlockScan BlockScanT; typedef cub::BlockReduce MaxReduceT; - typedef cub::BlockReduce SumReduceT; + typedef cub::BlockReduce SumReduceT; union TempStorage { typename BlockScanT::TempStorage scan; @@ -461,12 +477,12 @@ __global__ void find_split_kernel( }; struct UninitializedSplit : cub::Uninitialized {}; - struct UninitializedGpair : cub::Uninitialized {}; + struct UninitializedGpair : cub::Uninitialized {}; __shared__ UninitializedSplit uninitialized_split; Split& split = uninitialized_split.Alias(); __shared__ UninitializedGpair uninitialized_sum; - bst_gpair& shared_sum = uninitialized_sum.Alias(); + bst_gpair_precise& shared_sum = uninitialized_sum.Alias(); __shared__ ArgMaxT block_max; __shared__ TempStorage temp_storage; @@ -486,15 +502,14 @@ __global__ void find_split_kernel( int begin = d_feature_segments[level_node_idx * n_features + fidx]; int end = d_feature_segments[level_node_idx * n_features + fidx + 1]; - int gidx = (begin - (level_node_idx * n_bins)) + threadIdx.x; - bool thread_active = threadIdx.x < end - begin; - bst_gpair feature_sum = bst_gpair(); + bst_gpair_precise feature_sum = bst_gpair_precise(); for (int reduce_begin = begin; reduce_begin < end; reduce_begin += BLOCK_THREADS) { + bool thread_active = reduce_begin + threadIdx.x < end; // Scan histogram - bst_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x] - : bst_gpair(); + bst_gpair_precise bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x] + : bst_gpair_precise(); feature_sum += SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum()); @@ -508,17 +523,18 @@ __global__ void find_split_kernel( GpairCallbackOp prefix_op = GpairCallbackOp(); for (int scan_begin = begin; scan_begin < end; scan_begin += BLOCK_THREADS) { - bst_gpair bin = - thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair(); + bool thread_active = scan_begin + threadIdx.x < end; + bst_gpair_precise bin = + thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair_precise(); BlockScanT(temp_storage.scan) .ExclusiveScan(bin, bin, cub::Sum(), prefix_op); // Calculate gain - bst_gpair parent_sum = d_nodes[node_idx].sum_gradients; + bst_gpair_precise parent_sum = d_nodes[node_idx].sum_gradients; float parent_gain = d_nodes[node_idx].root_gain; - bst_gpair missing = parent_sum - shared_sum; + bst_gpair_precise missing = parent_sum - shared_sum; bool missing_left; float gain = thread_active @@ -541,6 +557,7 @@ __global__ void find_split_kernel( // Best thread updates split if (threadIdx.x == block_max.key) { float fvalue; + int gidx = (scan_begin - (level_node_idx * n_bins)) + threadIdx.x; if (threadIdx.x == 0 && begin == scan_begin) { // check at start of first tile fvalue = d_fidx_min_map[fidx]; @@ -548,8 +565,8 @@ __global__ void find_split_kernel( fvalue = d_gidx_fvalue_map[gidx - 1]; } - bst_gpair left = missing_left ? bin + missing : bin; - bst_gpair right = parent_sum - left; + bst_gpair_precise left = missing_left ? bin + missing : bin; + bst_gpair_precise right = parent_sum - left; split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param); } @@ -662,7 +679,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { int nodes_offset_device = d_idx * num_nodes_device; find_split_kernel<<>>( - (const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)), + (const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)), feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(), nodes_child_temp[d_idx].data(), nodes_offset_device, @@ -759,7 +776,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { int nodes_offset_device = d_idx * num_nodes_device; find_split_kernel<<>>( - (const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)), + (const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)), feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, nodes_offset_device, fidx_min_map[d_idx].data(), @@ -812,7 +829,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) { int nodes_offset_device = 0; find_split_kernel<<>>( - (const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)), + (const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)), feature_segments[d_idx].data(), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL, nodes_offset_device, fidx_min_map[d_idx].data(), @@ -958,7 +975,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) { dh::TransformLbs( device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr, - row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) { + row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) { int pos = d_position[local_ridx]; if (!is_active(pos, depth)) { return; diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh index bcaf49d39..f80ca2990 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cuh +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -25,16 +25,16 @@ struct DeviceGMat { }; struct HistBuilder { - bst_gpair *d_hist; + bst_gpair_precise *d_hist; int n_bins; - __host__ __device__ HistBuilder(bst_gpair *ptr, int n_bins); - __device__ void Add(bst_gpair gpair, int gidx, int nidx) const; - __device__ bst_gpair Get(int gidx, int nidx) const; + __host__ __device__ HistBuilder(bst_gpair_precise *ptr, int n_bins); + __device__ void Add(bst_gpair_precise gpair, int gidx, int nidx) const; + __device__ bst_gpair_precise Get(int gidx, int nidx) const; }; struct DeviceHist { int n_bins; - dh::dvec data; + dh::dvec data; void Init(int max_depth); @@ -42,7 +42,7 @@ struct DeviceHist { HistBuilder GetBuilder(); - bst_gpair *GetLevelPtr(int depth); + bst_gpair_precise *GetLevelPtr(int depth); int LevelSize(int depth); }; diff --git a/plugin/updater_gpu/test/python/test_large.py b/plugin/updater_gpu/test/python/test_large.py index 501ae6845..5878363ad 100644 --- a/plugin/updater_gpu/test/python/test_large.py +++ b/plugin/updater_gpu/test/python/test_large.py @@ -82,22 +82,16 @@ class TestGPU(unittest.TestCase): num_rounds = 1 - eprint("normal updater") - xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res) eprint("hist updater") - xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], evals_result=ag_resb) eprint("gpu_hist updater 1 gpu") - xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], evals_result=ag_res2) eprint("gpu_hist updater all gpus") - xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], evals_result=ag_res3) - assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001 - assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001 - assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001 diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index aa201cce9..304dde31b 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -117,7 +117,6 @@ struct GBTreeModel { } void CommitModel(std::vector >&& new_trees, int bst_group) { - size_t old_ntree = trees.size(); for (size_t i = 0; i < new_trees.size(); ++i) { trees.push_back(std::move(new_trees[i])); tree_info.push_back(bst_group);