diff --git a/.gitignore b/.gitignore index 803e0e6d8..aa8032113 100644 --- a/.gitignore +++ b/.gitignore @@ -90,7 +90,6 @@ lib/ # spark metastore_db -plugin/updater_gpu/test/cpp/data /include/xgboost/build_config.h # files from R-package source install diff --git a/R-package/demo/gpu_accelerated.R b/R-package/demo/gpu_accelerated.R index 770148a5d..321255c72 100644 --- a/R-package/demo/gpu_accelerated.R +++ b/R-package/demo/gpu_accelerated.R @@ -30,7 +30,7 @@ wl <- list(train = dtrain, test = dtest) # - similar to the 'hist' # - the fastest option for moderately large datasets # - current limitations: max_depth < 16, does not implement guided loss -# You can use tree_method = 'gpu_exact' for another GPU accelerated algorithm, +# You can use tree_method = 'gpu_hist' for another GPU accelerated algorithm, # which is slower, more memory-hungry, but does not use binning. param <- list(objective = 'reg:logistic', eval_metric = 'auc', subsample = 0.5, nthread = 4, max_bin = 64, tree_method = 'gpu_hist') diff --git a/doc/build.rst b/doc/build.rst index e9fca36b1..2dadadb53 100644 --- a/doc/build.rst +++ b/doc/build.rst @@ -13,7 +13,7 @@ Installation Guide # * xgboost-{version}-py2.py3-none-win_amd64.whl pip3 install xgboost - * The binary wheel will support GPU algorithms (`gpu_exact`, `gpu_hist`) on machines with NVIDIA GPUs. Please note that **training with multiple GPUs is only supported for Linux platform**. See :doc:`gpu/index`. + * The binary wheel will support GPU algorithms (`gpu_hist`) on machines with NVIDIA GPUs. Please note that **training with multiple GPUs is only supported for Linux platform**. See :doc:`gpu/index`. * Currently, we provide binary wheels for 64-bit Linux and Windows. **************************** diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 4aefc1d6f..5c335e66a 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -26,8 +26,6 @@ Algorithms +-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | tree_method | Description | +=======================+=======================================================================================================================================================================+ -| gpu_exact (deprecated)| The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than ``gpu_hist``. | -+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Will run very slowly on GPUs older than Pascal architecture. | +-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -37,31 +35,31 @@ Supported parameters .. |tick| unicode:: U+2714 .. |cross| unicode:: U+2718 -+--------------------------------+----------------------------+--------------+ -| parameter | ``gpu_exact`` (deprecated) | ``gpu_hist`` | -+================================+============================+==============+ -| ``subsample`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``colsample_bytree`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``colsample_bylevel`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``max_bin`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``gpu_id`` | |tick| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``n_gpus`` (deprecated) | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``predictor`` | |tick| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``grow_policy`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``monotone_constraints`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``interaction_constraints`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ -| ``single_precision_histogram`` | |cross| | |tick| | -+--------------------------------+----------------------------+--------------+ ++--------------------------------+--------------+ +| parameter | ``gpu_hist`` | ++================================+==============+ +| ``subsample`` | |tick| | ++--------------------------------+--------------+ +| ``colsample_bytree`` | |tick| | ++--------------------------------+--------------+ +| ``colsample_bylevel`` | |tick| | ++--------------------------------+--------------+ +| ``max_bin`` | |tick| | ++--------------------------------+--------------+ +| ``gpu_id`` | |tick| | ++--------------------------------+--------------+ +| ``n_gpus`` (deprecated) | |tick| | ++--------------------------------+--------------+ +| ``predictor`` | |tick| | ++--------------------------------+--------------+ +| ``grow_policy`` | |tick| | ++--------------------------------+--------------+ +| ``monotone_constraints`` | |tick| | ++--------------------------------+--------------+ +| ``interaction_constraints`` | |tick| | ++--------------------------------+--------------+ +| ``single_precision_histogram`` | |tick| | ++--------------------------------+--------------+ GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. @@ -194,12 +192,10 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a +--------------+----------+ | hist | 63.55 | +--------------+----------+ -| gpu_exact | 161.08 | -+--------------+----------+ | exact | 1082.20 | +--------------+----------+ -See `GPU Accelerated XGBoost `_ and `Updates to the XGBoost GPU algorithms `_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods. +See `GPU Accelerated XGBoost `_ and `Updates to the XGBoost GPU algorithms `_ for additional performance benchmarks of the ``gpu_hist`` tree method. Developer notes =============== diff --git a/doc/parameter.rst b/doc/parameter.rst index e7a2fd630..f19ddb462 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -184,7 +184,7 @@ Parameters for Tree Booster - The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU. - ``cpu_predictor``: Multicore CPU prediction algorithm. - - ``gpu_predictor``: Prediction using GPU. Default when ``tree_method`` is ``gpu_exact`` or ``gpu_hist``. + - ``gpu_predictor``: Prediction using GPU. Default when ``tree_method`` is ``gpu_hist``. * ``num_parallel_tree``, [default=1] - Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index c5f064fdc..3541b0d21 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -15,7 +15,7 @@ path to a cache file that XGBoost will use for external memory cache. .. note:: External memory is not available with GPU algorithms - External memory is not available when ``tree_method`` is set to ``gpu_exact`` or ``gpu_hist``. + External memory is not available when ``tree_method`` is set to ``gpu_hist``. The following code was extracted from `demo/guide-python/external_memory.py `_: diff --git a/make/config.mk b/make/config.mk index a2a48f1e9..fc4bb7ae5 100644 --- a/make/config.mk +++ b/make/config.mk @@ -75,9 +75,3 @@ CUB_PATH ?= cub # you can also add your own plugin like this # # XGB_PLUGINS += plugin/example/plugin.mk - -# plugin to build tree on GPUs using CUDA -PLUGIN_UPDATER_GPU ?= OFF -ifeq ($(PLUGIN_UPDATER_GPU),ON) - XGB_PLUGINS += plugin/updater_gpu/plugin.mk -endif diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 6a5c0714c..7a911a4b8 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -146,13 +146,6 @@ void GBTree::ConfigureUpdaters(const std::map& cfg) { "single updater grow_quantile_histmaker."; tparam_.updater_seq = "grow_quantile_histmaker"; break; - case TreeMethod::kGPUExact: - this->AssertGPUSupport(); - tparam_.updater_seq = "grow_gpu,prune"; - if (cfg.find("predictor") == cfg.cend()) { - tparam_.predictor = "gpu_predictor"; - } - break; case TreeMethod::kGPUHist: this->AssertGPUSupport(); tparam_.updater_seq = "grow_gpu_hist"; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 91fef174b..fa6ede83f 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -30,7 +30,7 @@ namespace xgboost { enum class TreeMethod : int { kAuto = 0, kApprox = 1, kExact = 2, kHist = 3, - kGPUExact = 4, kGPUHist = 5 + kGPUHist = 5 }; // boosting process types @@ -88,7 +88,6 @@ struct GBTreeTrainParam : public dmlc::Parameter { .add_enum("approx", TreeMethod::kApprox) .add_enum("exact", TreeMethod::kExact) .add_enum("hist", TreeMethod::kHist) - .add_enum("gpu_exact", TreeMethod::kGPUExact) .add_enum("gpu_hist", TreeMethod::kGPUHist) .describe("Choice of tree construction method."); } @@ -171,8 +170,7 @@ class GBTree : public GradientBooster { bool UseGPU() const override { return tparam_.predictor == "gpu_predictor" || - tparam_.tree_method == TreeMethod::kGPUHist || - tparam_.tree_method == TreeMethod::kGPUExact; + tparam_.tree_method == TreeMethod::kGPUHist; } void Load(dmlc::Stream* fi) override { diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 179c013fc..629306ffb 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -37,7 +37,6 @@ DMLC_REGISTRY_LINK_TAG(updater_quantile_hist); DMLC_REGISTRY_LINK_TAG(updater_histmaker); DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA -DMLC_REGISTRY_LINK_TAG(updater_gpu); DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); #endif // XGBOOST_USE_CUDA } // namespace tree diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu deleted file mode 100644 index dd6ab3eec..000000000 --- a/src/tree/updater_gpu.cu +++ /dev/null @@ -1,844 +0,0 @@ -/*! - * Copyright 2017-2018 XGBoost contributors - */ -#include -#include -#include -#include -#include - -#include "../common/common.h" -#include "param.h" -#include "updater_gpu_common.cuh" - -namespace xgboost { -namespace tree { - -DMLC_REGISTRY_FILE_TAG(updater_gpu); - -template -XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan, - const GradientPairT& missing, - const GradientPairT& parent_sum, - const float& parent_gain, - const GPUTrainingParam& param, - bool& missing_left_out) { // NOLINT - // Put gradients of missing values to left - float missing_left_loss = - DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain); - float missing_right_loss = - DeviceCalcLossChange(param, scan, parent_sum, parent_gain); - - if (missing_left_loss >= missing_right_loss) { - missing_left_out = true; - return missing_left_loss; - } else { - missing_left_out = false; - return missing_right_loss; - } -} - -/** - * @brief Absolute BFS order IDs to col-wise unique IDs based on user input - * @param tid the index of the element that this thread should access - * @param abs the array of absolute IDs - * @param colIds the array of column IDs for each element - * @param nodeStart the start of the node ID at this level - * @param nKeys number of nodes at this level. - * @return the uniq key - */ -static HOST_DEV_INLINE NodeIdT Abs2UniqueKey(int tid, - common::Span abs, - common::Span colIds, - NodeIdT nodeStart, int nKeys) { - int a = abs[tid]; - if (a == kUnusedNode) return a; - return ((a - nodeStart) + (colIds[tid] * nKeys)); -} - -/** - * @struct Pair - * @brief Pair used for key basd scan operations on GradientPair - */ -struct Pair { - int key; - GradientPair value; -}; - -/** define a key that's not used at all in the entire boosting process */ -static const int kNoneKey = -100; - -/** - * @brief Allocate temporary buffers needed for scan operations - * @param tmpScans gradient buffer - * @param tmpKeys keys buffer - * @param size number of elements that will be scanned - */ -template -int ScanTempBufferSize(int size) { - int num_blocks = common::DivRoundUp(size, BLKDIM_L1L3); - return num_blocks; -} - -struct AddByKey { - template - HOST_DEV_INLINE T operator()(const T& first, const T& second) const { - T result; - if (first.key == second.key) { - result.key = first.key; - result.value = first.value + second.value; - } else { - result.key = second.key; - result.value = second.value; - } - return result; - } -}; - -/** - * @brief Gradient value getter function - * @param id the index into the vals or instIds array to which to fetch - * @param vals the gradient value buffer - * @param instIds instance index buffer - * @return the expected gradient value - */ -HOST_DEV_INLINE GradientPair Get(int id, - common::Span vals, - common::Span instIds) { - id = instIds[id]; - return vals[id]; -} - -template -__global__ void CubScanByKeyL1( - common::Span scans, - common::Span vals, - common::Span instIds, - common::Span mScans, - common::Span mKeys, - common::Span keys, - int nUniqKeys, - common::Span colIds, NodeIdT nodeStart, - const int size) { - Pair rootPair = {kNoneKey, GradientPair(0.f, 0.f)}; - int myKey; - GradientPair myValue; - using BlockScan = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - Pair threadData; - int tid = blockIdx.x * BLKDIM_L1L3 + threadIdx.x; - if (tid < size) { - myKey = Abs2UniqueKey(tid, keys, colIds, nodeStart, nUniqKeys); - myValue = Get(tid, vals, instIds); - } else { - myKey = kNoneKey; - myValue = {}; - } - threadData.key = myKey; - threadData.value = myValue; - // get previous key, especially needed for the last thread in this block - // in order to pass on the partial scan values. - // this statement MUST appear before the checks below! - // else, the result of this shuffle operation will be undefined -#if (__CUDACC_VER_MAJOR__ >= 9) - int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1); -#else - int previousKey = __shfl_up(myKey, 1); -#endif // (__CUDACC_VER_MAJOR__ >= 9) - // Collectively compute the block-wide exclusive prefix sum - BlockScan(temp_storage) - .ExclusiveScan(threadData, threadData, rootPair, AddByKey()); - if (tid < size) { - scans[tid] = threadData.value; - } else { - return; - } - if (threadIdx.x == BLKDIM_L1L3 - 1) { - threadData.value = - (myKey == previousKey) ? threadData.value : GradientPair(0.0f, 0.0f); - mKeys[blockIdx.x] = myKey; - mScans[blockIdx.x] = threadData.value + myValue; - } -} - -template -__global__ void CubScanByKeyL2(common::Span mScans, - common::Span mKeys, int mLength) { - using BlockScan = cub::BlockScan; - Pair threadData; - __shared__ typename BlockScan::TempStorage temp_storage; - for (int i = threadIdx.x; i < mLength; i += BLKSIZE - 1) { - threadData.key = mKeys[i]; - threadData.value = mScans[i]; - BlockScan(temp_storage).InclusiveScan(threadData, threadData, AddByKey()); - mScans[i] = threadData.value; - __syncthreads(); - } -} - -template -__global__ void CubScanByKeyL3(common::Span sums, - common::Span scans, - common::Span vals, - common::Span instIds, - common::Span mScans, - common::Span mKeys, - common::Span keys, - int nUniqKeys, - common::Span colIds, NodeIdT nodeStart, - const int size) { - int relId = threadIdx.x; - int tid = (blockIdx.x * BLKDIM_L1L3) + relId; - // to avoid the following warning from nvcc: - // __shared__ memory variable with non-empty constructor or destructor - // (potential race between threads) - __shared__ char gradBuff[sizeof(GradientPair)]; - __shared__ int s_mKeys; - GradientPair* s_mScans = reinterpret_cast(gradBuff); - if (tid >= size) return; - // cache block-wide partial scan info - if (relId == 0) { - s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : kNoneKey; - s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : GradientPair(); - } - int myKey = Abs2UniqueKey(tid, keys, colIds, nodeStart, nUniqKeys); - int previousKey = - tid == 0 ? kNoneKey - : Abs2UniqueKey(tid - 1, keys, colIds, nodeStart, nUniqKeys); - GradientPair my_value = scans[tid]; - __syncthreads(); - if (blockIdx.x > 0 && s_mKeys == previousKey) { - my_value += s_mScans[0]; - } - if (tid == size - 1) { - sums[previousKey] = my_value + Get(tid, vals, instIds); - } - if ((previousKey != myKey) && (previousKey >= 0)) { - sums[previousKey] = my_value; - my_value = GradientPair(0.0f, 0.0f); - } - scans[tid] = my_value; -} - -/** - * @brief Performs fused reduce and scan by key functionality. It is assumed - * that - * the keys occur contiguously! - * @param sums the output gradient reductions for each element performed - * key-wise - * @param scans the output gradient scans for each element performed key-wise - * @param vals the gradients evaluated for each observation. - * @param instIds instance ids for each element - * @param keys keys to be used to segment the reductions. They need not occur - * contiguously in contrast to scan_by_key. Currently, we need one key per - * value in the 'vals' array. - * @param size number of elements in the 'vals' array - * @param nUniqKeys max number of uniq keys found per column - * @param nCols number of columns - * @param tmpScans temporary scan buffer needed for cub-pyramid algo - * @param tmpKeys temporary key buffer needed for cub-pyramid algo - * @param colIds column indices for each element in the array - * @param nodeStart index of the leftmost node in the current level - */ -template -void ReduceScanByKey(common::Span sums, - common::Span scans, - common::Span vals, - common::Span instIds, - common::Span keys, - int size, int nUniqKeys, int nCols, - common::Span tmpScans, - common::Span tmpKeys, - common::Span colIds, NodeIdT nodeStart) { - int nBlks = common::DivRoundUp(size, BLKDIM_L1L3); - cudaMemset(sums.data(), 0, nUniqKeys * nCols * sizeof(GradientPair)); - CubScanByKeyL1 - <<>>(scans, vals, instIds, tmpScans, tmpKeys, keys, - nUniqKeys, colIds, nodeStart, size); - CubScanByKeyL2<<<1, BLKDIM_L2>>>(tmpScans, tmpKeys, nBlks); - CubScanByKeyL3 - <<>>(sums, scans, vals, instIds, tmpScans, tmpKeys, - keys, nUniqKeys, colIds, nodeStart, size); -} - -/** - * @struct ExactSplitCandidate - * @brief Abstraction of a possible split in the decision tree - */ -struct ExactSplitCandidate { - /** the optimal gain score for this node */ - float score; - /** index where to split in the DMatrix */ - int index; - - HOST_DEV_INLINE ExactSplitCandidate() : score{-FLT_MAX}, index{INT_MAX} {} - - /** - * @brief Whether the split info is valid to be used to create a new child - * @param minSplitLoss minimum score above which decision to split is made - * @return true if splittable, else false - */ - HOST_DEV_INLINE bool IsSplittable(float minSplitLoss) const { - return ((score >= minSplitLoss) && (index != INT_MAX)); - } -}; - -/** - * @enum ArgMaxByKeyAlgo best_split_evaluation.cuh - * @brief Help decide which algorithm to use for multi-argmax operation - */ -enum ArgMaxByKeyAlgo { - /** simplest, use gmem-atomics for all updates */ - kAbkGmem = 0, - /** use smem-atomics for updates (when number of keys are less) */ - kAbkSmem -}; - -/** max depth until which to use shared mem based atomics for argmax */ -static const int kMaxAbkLevels = 3; - -HOST_DEV_INLINE ExactSplitCandidate MaxSplit(ExactSplitCandidate a, - ExactSplitCandidate b) { - ExactSplitCandidate out; - if (a.score < b.score) { - out.score = b.score; - out.index = b.index; - } else if (a.score == b.score) { - out.score = a.score; - out.index = (a.index < b.index) ? a.index : b.index; - } else { - out.score = a.score; - out.index = a.index; - } - return out; -} - -DEV_INLINE void AtomicArgMax(ExactSplitCandidate* address, - ExactSplitCandidate val) { - unsigned long long* intAddress = reinterpret_cast(address); // NOLINT - unsigned long long old = *intAddress; // NOLINT - unsigned long long assumed = old; // NOLINT - do { - assumed = old; - ExactSplitCandidate res = - MaxSplit(val, *reinterpret_cast(&assumed)); - old = atomicCAS(intAddress, assumed, *reinterpret_cast(&res)); - } while (assumed != old); -} - -DEV_INLINE void ArgMaxWithAtomics( - int id, - common::Span nodeSplits, - common::Span gradScans, - common::Span gradSums, - common::Span vals, - common::Span colIds, - common::Span nodeAssigns, - common::Span nodes, int nUniqKeys, - NodeIdT nodeStart, int len, - const GPUTrainingParam& param) { - int nodeId = nodeAssigns[id]; - // @todo: this is really a bad check! but will be fixed when we move - // to key-based reduction - if ((id == 0) || - !((nodeId == nodeAssigns[id - 1]) && (colIds[id] == colIds[id - 1]) && - (vals[id] == vals[id - 1]))) { - if (nodeId != kUnusedNode) { - int sumId = Abs2UniqueKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys); - GradientPair colSum = gradSums[sumId]; - int uid = nodeId - nodeStart; - DeviceNodeStats node_stat = nodes[nodeId]; - GradientPair parentSum = node_stat.sum_gradients; - float parentGain = node_stat.root_gain; - bool tmp; - ExactSplitCandidate s; - GradientPair missing = parentSum - colSum; - s.score = LossChangeMissing(gradScans[id], missing, parentSum, parentGain, - param, tmp); - s.index = id; - AtomicArgMax(&nodeSplits[uid], s); - } // end if nodeId != UNUSED_NODE - } // end if id == 0 ... -} - -__global__ void AtomicArgMaxByKeyGmem( - common::Span nodeSplits, - common::Span gradScans, - common::Span gradSums, - common::Span vals, - common::Span colIds, - common::Span nodeAssigns, - common::Span nodes, - int nUniqKeys, - NodeIdT nodeStart, - int len, - const TrainParam param) { - int id = threadIdx.x + (blockIdx.x * blockDim.x); - const int stride = blockDim.x * gridDim.x; - for (; id < len; id += stride) { - ArgMaxWithAtomics(id, nodeSplits, gradScans, gradSums, vals, colIds, - nodeAssigns, nodes, nUniqKeys, nodeStart, len, - GPUTrainingParam(param)); - } -} - -__global__ void AtomicArgMaxByKeySmem( - common::Span nodeSplits, - common::Span gradScans, - common::Span gradSums, - common::Span vals, - common::Span colIds, - common::Span nodeAssigns, - common::Span nodes, - int nUniqKeys, NodeIdT nodeStart, int len, const GPUTrainingParam param) { - extern __shared__ char sArr[]; - common::Span sNodeSplits = - common::Span( - reinterpret_cast(sArr), - static_cast::index_type>( - nUniqKeys * sizeof(ExactSplitCandidate))); - int tid = threadIdx.x; - ExactSplitCandidate defVal; - - for (int i = tid; i < nUniqKeys; i += blockDim.x) { - sNodeSplits[i] = defVal; - } - __syncthreads(); - int id = tid + (blockIdx.x * blockDim.x); - const int stride = blockDim.x * gridDim.x; - for (; id < len; id += stride) { - ArgMaxWithAtomics(id, sNodeSplits, gradScans, gradSums, vals, colIds, - nodeAssigns, nodes, nUniqKeys, nodeStart, len, param); - } - __syncthreads(); - for (int i = tid; i < nUniqKeys; i += blockDim.x) { - ExactSplitCandidate s = sNodeSplits[i]; - AtomicArgMax(&nodeSplits[i], s); - } -} - -/** - * @brief Performs argmax_by_key functionality but for cases when keys need not - * occur contiguously - * @param nodeSplits will contain information on best split for each node - * @param gradScans exclusive sum on sorted segments for each col - * @param gradSums gradient sum for each column in DMatrix based on to node-ids - * @param vals feature values - * @param colIds column index for each element in the feature values array - * @param nodeAssigns node-id assignments to each element in DMatrix - * @param nodes pointer to all nodes for this tree in BFS order - * @param nUniqKeys number of unique node-ids in this level - * @param nodeStart start index of the node-ids in this level - * @param len number of elements - * @param param training parameters - * @param algo which algorithm to use for argmax_by_key - */ -template -void ArgMaxByKey(common::Span nodeSplits, - common::Span gradScans, - common::Span gradSums, - common::Span vals, - common::Span colIds, - common::Span nodeAssigns, - common::Span nodes, - int nUniqKeys, - NodeIdT nodeStart, int len, const TrainParam param, - ArgMaxByKeyAlgo algo, - GPUSet const& devices) { - dh::FillConst( - *(devices.begin()), nodeSplits.data(), nUniqKeys, - ExactSplitCandidate()); - int nBlks = common::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM); - switch (algo) { - case kAbkGmem: - AtomicArgMaxByKeyGmem<<>>( - nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes, - nUniqKeys, nodeStart, len, param); - break; - case kAbkSmem: - AtomicArgMaxByKeySmem<<>>( - nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes, - nUniqKeys, nodeStart, len, GPUTrainingParam(param)); - break; - default: - throw std::runtime_error("argMaxByKey: Bad algo passed!"); - } -} - -__global__ void AssignColIds(int* colIds, const int* colOffsets) { - int myId = blockIdx.x; - int start = colOffsets[myId]; - int end = colOffsets[myId + 1]; - for (int id = start + threadIdx.x; id < end; id += blockDim.x) { - colIds[id] = myId; - } -} - -__global__ void FillDefaultNodeIds(NodeIdT* nodeIdsPerInst, - const DeviceNodeStats* nodes, int n_rows) { - int id = threadIdx.x + (blockIdx.x * blockDim.x); - if (id >= n_rows) { - return; - } - // if this element belongs to none of the currently active node-id's - NodeIdT nId = nodeIdsPerInst[id]; - if (nId == kUnusedNode) { - return; - } - const DeviceNodeStats n = nodes[nId]; - NodeIdT result; - if (n.IsLeaf() || n.IsUnused()) { - result = kUnusedNode; - } else if (n.dir == kLeftDir) { - result = (2 * n.idx) + 1; - } else { - result = (2 * n.idx) + 2; - } - nodeIdsPerInst[id] = result; -} - -__global__ void AssignNodeIds(NodeIdT* nodeIdsPerInst, int* nodeLocations, - const NodeIdT* nodeIds, const int* instId, - const DeviceNodeStats* nodes, - const int* colOffsets, const float* vals, - int nVals, int nCols) { - int id = threadIdx.x + (blockIdx.x * blockDim.x); - const int stride = blockDim.x * gridDim.x; - for (; id < nVals; id += stride) { - // fusing generation of indices for node locations - nodeLocations[id] = id; - // using nodeIds here since the previous kernel would have updated - // the nodeIdsPerInst with all default assignments - int nId = nodeIds[id]; - // if this element belongs to none of the currently active node-id's - if (nId != kUnusedNode) { - const DeviceNodeStats n = nodes[nId]; - int colId = n.fidx; - // printf("nid=%d colId=%d id=%d\n", nId, colId, id); - int start = colOffsets[colId]; - int end = colOffsets[colId + 1]; - // @todo: too much wasteful threads!! - if ((id >= start) && (id < end) && !(n.IsLeaf() || n.IsUnused())) { - NodeIdT result = (2 * n.idx) + 1 + (vals[id] >= n.fvalue); - nodeIdsPerInst[instId[id]] = result; - } - } - } -} - -__global__ void MarkLeavesKernel(DeviceNodeStats* nodes, int len) { - int id = (blockIdx.x * blockDim.x) + threadIdx.x; - if ((id < len) && !nodes[id].IsUnused()) { - int lid = (id << 1) + 1; - int rid = (id << 1) + 2; - if ((lid >= len) || (rid >= len)) { - nodes[id].root_gain = -FLT_MAX; // bottom-most nodes - } else if (nodes[lid].IsUnused() && nodes[rid].IsUnused()) { - nodes[id].root_gain = -FLT_MAX; // unused child nodes - } - } -} - -class GPUMaker : public TreeUpdater { - protected: - TrainParam param_; - /** whether we have initialized memory already (so as not to repeat!) */ - bool allocated_; - /** feature values stored in column-major compressed format */ - dh::DoubleBuffer vals_; - common::Span vals_cached_; - /** corresponding instance id's of these featutre values */ - dh::DoubleBuffer instIds_; - common::Span inst_ids_cached_; - /** column offsets for these feature values */ - common::Span colOffsets_; - common::Span gradsInst_; - dh::DoubleBuffer nodeAssigns_; - dh::DoubleBuffer nodeLocations_; - common::Span nodes_; - common::Span node_assigns_per_inst_; - common::Span gradsums_; - common::Span gradscans_; - common::Span nodeSplits_; - int n_vals_; - int n_rows_; - int n_cols_; - int maxNodes_; - int maxLeaves_; - - // devices are only used for sharding the HostDeviceVector passed as a parameter; - // the algorithm works with a single GPU only - GPUSet devices_; - - dh::CubMemory tmp_mem_; - common::Span tmpScanGradBuff_; - common::Span tmp_scan_key_buff_; - common::Span colIds_; - dh::BulkAllocator ba_; - - public: - GPUMaker() : allocated_{false} {} - ~GPUMaker() override = default; - - char const* Name() const override { - return "gpu_exact"; - } - - void Configure(const Args &args) override { - param_.InitAllowUnknown(args); - maxNodes_ = (1 << (param_.max_depth + 1)) - 1; - maxLeaves_ = 1 << param_.max_depth; - - devices_ = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus); - } - - void Update(HostDeviceVector* gpair, DMatrix* dmat, - const std::vector& trees) override { - // rescale learning rate according to size of trees - float lr = param_.learning_rate; - param_.learning_rate = lr / trees.size(); - - gpair->Shard(devices_); - - try { - // build tree - for (auto tree : trees) { - UpdateTree(gpair, dmat, tree); - } - } catch (const std::exception& e) { - LOG(FATAL) << "grow_gpu exception: " << e.what() << std::endl; - } - param_.learning_rate = lr; - } - /// @note: Update should be only after Init!! - void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, - RegTree* hTree) { - if (!allocated_) { - SetupOneTimeData(dmat); - } - for (int i = 0; i < param_.max_depth; ++i) { - if (i == 0) { - // make sure to start on a fresh tree with sorted values! - dh::CopyDeviceSpan(vals_.CurrentSpan(), vals_cached_); - dh::CopyDeviceSpan(instIds_.CurrentSpan(), inst_ids_cached_); - TransferGrads(gpair); - } - int nNodes = 1 << i; - NodeIdT nodeStart = nNodes - 1; - InitNodeData(i, nodeStart, nNodes); - FindSplit(i, nodeStart, nNodes); - } - // mark all the used nodes with unused children as leaf nodes - MarkLeaves(); - Dense2SparseTree(hTree, nodes_, param_); - } - - void Split2Node(int nNodes, NodeIdT nodeStart) { - auto d_nodes = nodes_; - auto d_gradScans = gradscans_; - auto d_gradsums = gradsums_; - auto d_nodeAssigns = nodeAssigns_.CurrentSpan(); - auto d_colIds = colIds_; - auto d_vals = vals_.Current(); - auto d_nodeSplits = nodeSplits_.data(); - int nUniqKeys = nNodes; - float min_split_loss = param_.min_split_loss; - auto gpu_param = GPUTrainingParam(param_); - - dh::LaunchN(*(devices_.begin()), nNodes, [=] __device__(int uid) { - int absNodeId = uid + nodeStart; - ExactSplitCandidate s = d_nodeSplits[uid]; - if (s.IsSplittable(min_split_loss)) { - int idx = s.index; - int nodeInstId = - Abs2UniqueKey(idx, d_nodeAssigns, d_colIds, nodeStart, nUniqKeys); - bool missingLeft = true; - const DeviceNodeStats& n = d_nodes[absNodeId]; - GradientPair gradScan = d_gradScans[idx]; - GradientPair gradSum = d_gradsums[nodeInstId]; - float thresh = d_vals[idx]; - int colId = d_colIds[idx]; - // get the default direction for the current node - GradientPair missing = n.sum_gradients - gradSum; - LossChangeMissing(gradScan, missing, n.sum_gradients, n.root_gain, - gpu_param, missingLeft); - // get the score/weight/id/gradSum for left and right child nodes - GradientPair lGradSum = missingLeft ? gradScan + missing : gradScan; - GradientPair rGradSum = n.sum_gradients - lGradSum; - - // Create children - d_nodes[LeftChildNodeIdx(absNodeId)] = - DeviceNodeStats(lGradSum, LeftChildNodeIdx(absNodeId), gpu_param); - d_nodes[RightChildNodeIdx(absNodeId)] = - DeviceNodeStats(rGradSum, RightChildNodeIdx(absNodeId), gpu_param); - // Set split for parent - d_nodes[absNodeId].SetSplit(thresh, colId, - missingLeft ? kLeftDir : kRightDir, lGradSum, - rGradSum); - } else { - // cannot be split further, so this node is a leaf! - d_nodes[absNodeId].root_gain = -FLT_MAX; - } - }); - } - - void FindSplit(int level, NodeIdT nodeStart, int nNodes) { - ReduceScanByKey(gradsums_, gradscans_, gradsInst_, - instIds_.CurrentSpan(), nodeAssigns_.CurrentSpan(), n_vals_, nNodes, - n_cols_, tmpScanGradBuff_, tmp_scan_key_buff_, - colIds_, nodeStart); - auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus); - ArgMaxByKey(nodeSplits_, gradscans_, gradsums_, - vals_.CurrentSpan(), colIds_, nodeAssigns_.CurrentSpan(), - nodes_, nNodes, nodeStart, n_vals_, param_, - level <= kMaxAbkLevels ? kAbkSmem : kAbkGmem, - devices); - Split2Node(nNodes, nodeStart); - } - - void AllocateAllData(int offsetSize) { - int tmpBuffSize = ScanTempBufferSize(n_vals_); - ba_.Allocate(*(devices_.begin()), &vals_, n_vals_, - &vals_cached_, n_vals_, &instIds_, n_vals_, &inst_ids_cached_, n_vals_, - &colOffsets_, offsetSize, &gradsInst_, n_rows_, &nodeAssigns_, n_vals_, - &nodeLocations_, n_vals_, &nodes_, maxNodes_, &node_assigns_per_inst_, - n_rows_, &gradsums_, maxLeaves_ * n_cols_, &gradscans_, n_vals_, - &nodeSplits_, maxLeaves_, &tmpScanGradBuff_, tmpBuffSize, - &tmp_scan_key_buff_, tmpBuffSize, &colIds_, n_vals_); - } - - void SetupOneTimeData(DMatrix* dmat) { - if (!dmat->SingleColBlock()) { - LOG(FATAL) << "exact::GPUBuilder - must have 1 column block"; - } - std::vector fval; - std::vector fId; - std::vector offset; - ConvertToCsc(dmat, &fval, &fId, &offset); - AllocateAllData(static_cast(offset.size())); - TransferAndSortData(fval, fId, offset); - allocated_ = true; - } - - void ConvertToCsc(DMatrix* dmat, std::vector* fval, - std::vector* fId, std::vector* offset) { - const MetaInfo& info = dmat->Info(); - CHECK(info.num_col_ < std::numeric_limits::max()); - CHECK(info.num_row_ < std::numeric_limits::max()); - n_rows_ = static_cast(info.num_row_); - n_cols_ = static_cast(info.num_col_); - offset->reserve(n_cols_ + 1); - offset->push_back(0); - fval->reserve(n_cols_ * n_rows_); - fId->reserve(n_cols_ * n_rows_); - // in case you end up with a DMatrix having no column access - // then make sure to enable that before copying the data! - for (const auto& batch : dmat->GetBatches()) { - for (int i = 0; i < batch.Size(); i++) { - auto col = batch[i]; - for (const Entry& e : col) { - int inst_id = static_cast(e.index); - fval->push_back(e.fvalue); - fId->push_back(inst_id); - } - offset->push_back(static_cast(fval->size())); - } - } - CHECK(fval->size() < std::numeric_limits::max()); - n_vals_ = static_cast(fval->size()); - } - - void TransferAndSortData(const std::vector& fval, - const std::vector& fId, - const std::vector& offset) { - dh::CopyVectorToDeviceSpan(vals_.CurrentSpan(), fval); - dh::CopyVectorToDeviceSpan(instIds_.CurrentSpan(), fId); - dh::CopyVectorToDeviceSpan(colOffsets_, offset); - dh::SegmentedSort(&tmp_mem_, &vals_, &instIds_, n_vals_, n_cols_, - colOffsets_); - dh::CopyDeviceSpan(vals_cached_, vals_.CurrentSpan()); - dh::CopyDeviceSpan(inst_ids_cached_, instIds_.CurrentSpan()); - AssignColIds<<>>(colIds_.data(), colOffsets_.data()); - } - - void TransferGrads(HostDeviceVector* gpair) { - gpair->GatherTo( - thrust::device_pointer_cast(gradsInst_.data()), - thrust::device_pointer_cast(gradsInst_.data() + gradsInst_.size())); - // evaluate the full-grad reduction for the root node - dh::SumReduction(tmp_mem_, gradsInst_, gradsums_, n_rows_); - } - - void InitNodeData(int level, NodeIdT nodeStart, int nNodes) { - // all instances belong to root node at the beginning! - if (level == 0) { - thrust::fill(thrust::device_pointer_cast(nodes_.data()), - thrust::device_pointer_cast(nodes_.data() + nodes_.size()), - DeviceNodeStats()); - thrust::fill(thrust::device_pointer_cast(nodeAssigns_.Current()), - thrust::device_pointer_cast(nodeAssigns_.Current() + - nodeAssigns_.Size()), - 0); - thrust::fill(thrust::device_pointer_cast(node_assigns_per_inst_.data()), - thrust::device_pointer_cast(node_assigns_per_inst_.data() + - node_assigns_per_inst_.size()), - 0); - // for root node, just update the gradient/score/weight/id info - // before splitting it! Currently all data is on GPU, hence this - // stupid little kernel - auto d_nodes = nodes_; - auto d_sums = gradsums_; - auto gpu_params = GPUTrainingParam(param_); - dh::LaunchN(*(devices_.begin()), 1, [=] __device__(int idx) { - d_nodes[0] = DeviceNodeStats(d_sums[0], 0, gpu_params); - }); - } else { - const int BlkDim = 256; - const int ItemsPerThread = 4; - // assign default node ids first - int nBlks = common::DivRoundUp(n_rows_, BlkDim); - FillDefaultNodeIds<<>>(node_assigns_per_inst_.data(), - nodes_.data(), n_rows_); - // evaluate the correct child indices of non-missing values next - nBlks = common::DivRoundUp(n_vals_, BlkDim * ItemsPerThread); - AssignNodeIds<<>>( - node_assigns_per_inst_.data(), nodeLocations_.Current(), - nodeAssigns_.Current(), instIds_.Current(), nodes_.data(), - colOffsets_.data(), vals_.Current(), n_vals_, n_cols_); - // gather the node assignments across all other columns too - dh::Gather(*(devices_.begin()), nodeAssigns_.Current(), - node_assigns_per_inst_.data(), instIds_.Current(), n_vals_); - SortKeys(level); - } - } - - void SortKeys(int level) { - // segmented-sort the arrays based on node-id's - // but we don't need more than level+1 bits for sorting! - SegmentedSort(&tmp_mem_, &nodeAssigns_, &nodeLocations_, n_vals_, n_cols_, - colOffsets_, 0, level + 1); - dh::Gather(*(devices_.begin()), vals_.other(), - vals_.Current(), instIds_.other(), instIds_.Current(), - nodeLocations_.Current(), n_vals_); - vals_.buff.selector ^= 1; - instIds_.buff.selector ^= 1; - } - - void MarkLeaves() { - const int BlkDim = 128; - int nBlks = common::DivRoundUp(maxNodes_, BlkDim); - MarkLeavesKernel<<>>(nodes_.data(), maxNodes_); - } -}; - -XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu") - .describe("Grow tree with GPU.") - .set_body([]() { - LOG(WARNING) << "The gpu_exact tree method is deprecated and may be " - "removed in a future version."; - return new GPUMaker(); - }); - -} // namespace tree -} // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index df9fb44c1..10ba161dd 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -36,10 +36,6 @@ TEST(GBTree, SelectTreeMethod) { ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker"); #ifdef XGBOOST_USE_CUDA generic_param.InitAllowUnknown(std::vector{Arg{"n_gpus", "1"}}); - gbtree.ConfigureWithKnownData({Arg("tree_method", "gpu_exact"), - Arg("num_feature", n_feat)}, p_dmat); - ASSERT_EQ(tparam.updater_seq, "grow_gpu,prune"); - ASSERT_EQ(tparam.predictor, "gpu_predictor"); gbtree.ConfigureWithKnownData({Arg("tree_method", "gpu_hist"), Arg("num_feature", n_feat)}, p_dmat); ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist"); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 46d7bc738..6cb910903 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -171,13 +171,6 @@ TEST(Learner, GPUConfiguration) { ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1); } - { - std::unique_ptr learner {Learner::Create(mat)}; - learner->SetParams({Arg{"tree_method", "gpu_exact"}}); - learner->UpdateOneIter(0, p_dmat.get()); - ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); - ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1); - } { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"tree_method", "gpu_hist"}}); diff --git a/tests/cpp/tree/test_gpu_exact.cu b/tests/cpp/tree/test_gpu_exact.cu deleted file mode 100644 index cacc20311..000000000 --- a/tests/cpp/tree/test_gpu_exact.cu +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include - -#include -#include -#include - -#include "../helpers.h" - -namespace xgboost { -namespace tree { - -TEST(GPUExact, Update) { - using Arg = std::pair; - auto lparam = CreateEmptyGenericParam(0, 1); - std::vector args{{"max_depth", "1"}}; - - auto* p_gpuexact_maker = TreeUpdater::Create("grow_gpu", &lparam); - p_gpuexact_maker->Configure(args); - - size_t constexpr kNRows = 4; - size_t constexpr kNCols = 8; - bst_float constexpr kSparsity = 0.0f; - - auto dmat = CreateDMatrix(kNRows, kNCols, kSparsity, 3); - std::vector h_gpair(kNRows); - for (size_t i = 0; i < kNRows; ++i) { - h_gpair[i] = GradientPair(i % 2, 1); - } - HostDeviceVector gpair (h_gpair); - RegTree tree; - - p_gpuexact_maker->Update(&gpair, (*dmat).get(), {&tree}); - auto const& nodes = tree.GetNodes(); - ASSERT_EQ(nodes.size(), 3); - - float constexpr kRtEps = 1e-6; - ASSERT_NEAR(tree.Stat(0).sum_hess, 4, kRtEps); - ASSERT_NEAR(tree.Stat(1).sum_hess, 2, kRtEps); - ASSERT_NEAR(tree.Stat(2).sum_hess, 2, kRtEps); - - ASSERT_NEAR(tree.Stat(0).loss_chg, 0.8f, kRtEps); - - delete dmat; - delete p_gpuexact_maker; -} - -} // namespace tree -} // namespace xgboost \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 5c26ae11c..5039093b4 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -20,16 +20,6 @@ datasets = ["Boston", "Cancer", "Digits", "Sparse regression", class TestGPU(unittest.TestCase): - def test_gpu_exact(self): - variable_param = {'max_depth': [2, 6, 15], } - for param in parameter_combinations(variable_param): - param['tree_method'] = 'gpu_exact' - gpu_results = run_suite(param, select_datasets=datasets) - assert_results_non_increasing(gpu_results, 1e-2) - param['tree_method'] = 'exact' - cpu_results = run_suite(param, select_datasets=datasets) - assert_gpu_results(cpu_results, gpu_results) - def test_gpu_hist(self): test_param = parameter_combinations({'n_gpus': [1], 'max_depth': [2, 8], 'max_leaves': [255, 4], @@ -65,7 +55,7 @@ class TestGPU(unittest.TestCase): 'max_leaves': [255, 4], 'max_bin': [2, 64], 'grow_policy': ['lossguide'], - 'tree_method': ['gpu_hist', 'gpu_exact']} + 'tree_method': ['gpu_hist']} for param in parameter_combinations(variable_param): gpu_results = run_suite(param, select_datasets=datasets) assert_results_non_increasing(gpu_results, 1e-2) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 01ebe8531..2fff86bf4 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -444,9 +444,9 @@ def test_sklearn_n_jobs(): def test_kwargs(): - params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1} + params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1} clf = xgb.XGBClassifier(n_estimators=1000, **params) - assert clf.get_params()['updater'] == 'grow_gpu' + assert clf.get_params()['updater'] == 'grow_gpu_hist' assert clf.get_params()['subsample'] == .5 assert clf.get_params()['n_estimators'] == 1000 @@ -472,7 +472,7 @@ def test_kwargs_grid_search(): def test_kwargs_error(): - params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1} + params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1} with pytest.raises(TypeError): clf = xgb.XGBClassifier(n_jobs=1000, **params) assert isinstance(clf, xgb.XGBClassifier)