diff --git a/include/xgboost/base.h b/include/xgboost/base.h index cc08bfebf..259c7a891 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -8,6 +8,7 @@ #include #include +#include /*! * \brief string flag for R library, to leave hooks when needed. @@ -163,7 +164,7 @@ class bst_gpair_internal { friend std::ostream &operator<<(std::ostream &os, const bst_gpair_internal &g) { - os << g.grad_ << "/" << g.hess_; + os << g.GetGrad() << "/" << g.GetHess(); return os; } }; @@ -178,11 +179,11 @@ inline XGBOOST_DEVICE float bst_gpair_internal::GetHess() const { } template<> inline XGBOOST_DEVICE void bst_gpair_internal::SetGrad(float g) { - grad_ = g * 1e5; + grad_ = std::round(g * 1e5); } template<> inline XGBOOST_DEVICE void bst_gpair_internal::SetHess(float h) { - hess_ = h * 1e5; + hess_ = std::round(h * 1e5); } } // namespace detail diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index e72098f4d..40cc62200 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -2,9 +2,9 @@ * Copyright 2017 XGBoost contributors */ #pragma once +#include #include #include -#include #include #include #include @@ -58,10 +58,20 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file, return code; } +template +T *raw(thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} + +template +const T *raw(const thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} + inline int n_visible_devices() { int n_visgpus = 0; - cudaGetDeviceCount(&n_visgpus); + dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); return n_visgpus; } @@ -127,29 +137,6 @@ inline int get_device_idx(int gpu_id) { return (std::abs(gpu_id) + 0) % dh::n_visible_devices(); } -/* - * Timers - */ - -struct Timer { - typedef std::chrono::high_resolution_clock ClockT; - - typedef std::chrono::high_resolution_clock::time_point TimePointT; - TimePointT start; - Timer() { reset(); } - - void reset() { start = ClockT::now(); } - int64_t elapsed() const { return (ClockT::now() - start).count(); } - double elapsedSeconds() const { - return elapsed() * ((double)ClockT::period::num / ClockT::period::den); - } - void printElapsed(std::string label) { - // synchronize_n_devices(n_devices, dList); - printf("%s:\t %fs\n", label.c_str(), elapsedSeconds()); - reset(); - } -}; - /* * Range iterator */ @@ -224,6 +211,68 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) { } } +/* + * Kernel launcher + */ + +template +T1 div_round_up(const T1 a, const T2 b) { + return static_cast(ceil(static_cast(a) / b)); +} + +template +__global__ void launch_n_kernel(size_t begin, size_t end, L lambda) { + for (auto i : grid_stride_range(begin, end)) { + lambda(i); + } +} +template +__global__ void launch_n_kernel(int device_idx, size_t begin, size_t end, + L lambda) { + for (auto i : grid_stride_range(begin, end)) { + lambda(i, device_idx); + } +} + +template +inline void launch_n(int device_idx, size_t n, L lambda) { + if (n == 0) { + return; + } + + safe_cuda(cudaSetDevice(device_idx)); + // TODO: Template on n so GRID_SIZE always fits into int. + const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); + launch_n_kernel<<>>(static_cast(0), n, + lambda); +} + +/* + * Timers + */ + +struct Timer { + typedef std::chrono::high_resolution_clock ClockT; + typedef std::chrono::high_resolution_clock::time_point TimePointT; + typedef std::chrono::high_resolution_clock::duration DurationT; + typedef std::chrono::duration SecondsT; + + TimePointT start; + DurationT elapsed; + Timer() { Reset(); } + void Reset() { + elapsed = DurationT::zero(); + Start(); + } + void Start() { start = ClockT::now(); } + void Stop() { elapsed += ClockT::now() - start; } + double ElapsedSeconds() const { return SecondsT(elapsed).count(); } + void PrintElapsed(std::string label) { + printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count()); + Reset(); + } +}; + /* * Memory */ @@ -273,8 +322,9 @@ class dvec { } void fill(T value) { - safe_cuda(cudaSetDevice(_device_idx)); - thrust::fill_n(thrust::device_pointer_cast(_ptr), size(), value); + auto d_ptr = _ptr; + launch_n(_device_idx, size(), + [=] __device__(size_t idx) { d_ptr[idx] = value; }); } void print() { @@ -304,7 +354,9 @@ class dvec { } safe_cuda(cudaSetDevice(this->device_idx())); if (other.device_idx() == this->device_idx()) { - thrust::copy(other.tbegin(), other.tend(), this->tbegin()); + dh::safe_cuda(cudaMemcpy(this->data(), other.data(), + other.size() * sizeof(T), + cudaMemcpyDeviceToDevice)); } else { std::cout << "deviceother: " << other.device_idx() << " devicethis: " << this->device_idx() << std::endl; @@ -496,6 +548,12 @@ struct CubMemory { ~CubMemory() { Free(); } + template + T* Pointer() + { + return static_cast(d_temp_storage); + } + void Free() { if (this->IsAllocated()) { safe_cuda(cudaFree(d_temp_storage)); @@ -527,15 +585,6 @@ struct CubMemory { * Utility functions */ -template -void print(const thrust::device_vector &v, size_t max_items = 10) { - thrust::host_vector h = v; - for (size_t i = 0; i < std::min(max_items, h.size()); i++) { - std::cout << " " << h[i]; - } - std::cout << "\n"; -} - template void print(const dvec &v, size_t max_items = 10) { std::vector h = v.as_vector(); @@ -545,91 +594,6 @@ void print(const dvec &v, size_t max_items = 10) { std::cout << "\n"; } -template -void print(char *label, const thrust::device_vector &v, - const char *format = "%d ", size_t max = 10) { - thrust::host_vector h_v = v; - std::cout << label << ":\n"; - for (size_t i = 0; i < std::min(static_cast(h_v.size()), max); i++) { - printf(format, h_v[i]); - } - std::cout << "\n"; -} - -template -T1 div_round_up(const T1 a, const T2 b) { - return static_cast(ceil(static_cast(a) / b)); -} - -template -thrust::device_ptr dptr(T *d_ptr) { - return thrust::device_pointer_cast(d_ptr); -} - -template -T *raw(thrust::device_vector &v) { // NOLINT - return raw_pointer_cast(v.data()); -} - -template -const T *raw(const thrust::device_vector &v) { // NOLINT - return raw_pointer_cast(v.data()); -} - -template -size_t size_bytes(const thrust::device_vector &v) { - return sizeof(T) * v.size(); -} -/* - * Kernel launcher - */ - -template -__global__ void launch_n_kernel(size_t begin, size_t end, L lambda) { - for (auto i : grid_stride_range(begin, end)) { - lambda(i); - } -} -template -__global__ void launch_n_kernel(int device_idx, size_t begin, size_t end, - L lambda) { - for (auto i : grid_stride_range(begin, end)) { - lambda(i, device_idx); - } -} - -template -inline void launch_n(int device_idx, size_t n, L lambda) { - safe_cuda(cudaSetDevice(device_idx)); - // TODO: Template on n so GRID_SIZE always fits into int. - const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); -#if defined(__CUDACC__) - launch_n_kernel<<>>(static_cast(0), n, - lambda); -#endif -} - -// if n_devices=-1, then use all visible devices -template -inline void multi_launch_n(size_t n, int n_devices, L lambda) { - n_devices = n_devices < 0 ? n_visible_devices() : n_devices; - CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested " - "needs to be less than equal to " - "number of visible devices."; - // TODO: Template on n so GRID_SIZE always fits into int. - const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); -#if defined(__CUDACC__) - n_devices = n_devices > n ? n : n_devices; - for (int device_idx = 0; device_idx < n_devices; device_idx++) { - safe_cuda(cudaSetDevice(device_idx)); - size_t begin = (n / n_devices) * device_idx; - size_t end = std::min((n / n_devices) * (device_idx + 1), n); - launch_n_kernel<<>>(device_idx, begin, end, - lambda); - } -#endif -} - /** * @brief Helper macro to measure timing on GPU * @param call the GPU call diff --git a/src/learner.cc b/src/learner.cc index 689e1f977..7555fadc6 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -110,6 +110,7 @@ struct LearnerTrainParam : public dmlc::Parameter { .add_enum("hist", 3) .add_enum("gpu_exact", 4) .add_enum("gpu_hist", 5) + .add_enum("gpu_hist_experimental", 6) .describe("Choice of tree construction method."); DMLC_DECLARE_FIELD(test_flag).set_default("").describe( "Internal test flag"); @@ -178,6 +179,13 @@ class LearnerImpl : public Learner { if (cfg_.count("predictor") == 0) { cfg_["predictor"] = "gpu_predictor"; } + } else if (tparam.tree_method == 6) { + if (cfg_.count("updater") == 0) { + cfg_["updater"] = "grow_gpu_hist_experimental,prune"; + } + if (cfg_.count("predictor") == 0) { + cfg_["predictor"] = "gpu_predictor"; + } } } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 702a01a8f..8bf921d5c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -3,6 +3,7 @@ */ #include #include +#include #include #include #include diff --git a/src/tree/param.h b/src/tree/param.h index 646955397..20f0feee2 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -271,6 +271,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess return -2.0 * (ret + p.reg_alpha * std::abs(w)); } } + // calculate weight given the statistics template XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, @@ -292,6 +293,11 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, return dw; } +template +XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, gpair_t sum_grad) { + return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess()); +} + /*! \brief core statistics used for tree construction */ struct XGBOOST_ALIGNAS(16) GradStats { /*! \brief sum gradient statistics */ diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 32630903e..2452dba55 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu); DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); +DMLC_REGISTRY_LINK_TAG(updater_gpu_hist_experimental); #endif } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index fc79e55e6..4ef24450a 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -23,14 +23,13 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu); */ static HOST_DEV_INLINE node_id_t abs2uniqKey(int tid, const node_id_t* abs, - const int* colIds, node_id_t nodeStart, - int nKeys) { + const int* colIds, + node_id_t nodeStart, int nKeys) { int a = abs[tid]; if (a == UNUSED_NODE) return a; return ((a - nodeStart) + (colIds[tid] * nKeys)); } - /** * @struct Pair * @brief Pair used for key basd scan operations on bst_gpair @@ -284,7 +283,7 @@ DEV_INLINE void atomicArgMax(ExactSplitCandidate* address, DEV_INLINE void argMaxWithAtomics( int id, ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums, const float* vals, const int* colIds, - const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys, + const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys, node_id_t nodeStart, int len, const GPUTrainingParam& param) { int nodeId = nodeAssigns[id]; // @todo: this is really a bad check! but will be fixed when we move @@ -296,7 +295,7 @@ DEV_INLINE void argMaxWithAtomics( int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys); bst_gpair colSum = gradSums[sumId]; int uid = nodeId - nodeStart; - DeviceDenseNode n = nodes[nodeId]; + DeviceNodeStats n = nodes[nodeId]; bst_gpair parentSum = n.sum_gradients; float parentGain = n.root_gain; bool tmp; @@ -313,7 +312,7 @@ DEV_INLINE void argMaxWithAtomics( __global__ void atomicArgMaxByKeyGmem( ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums, const float* vals, const int* colIds, - const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys, + const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys, node_id_t nodeStart, int len, const TrainParam param) { int id = threadIdx.x + (blockIdx.x * blockDim.x); const int stride = blockDim.x * gridDim.x; @@ -327,7 +326,7 @@ __global__ void atomicArgMaxByKeyGmem( __global__ void atomicArgMaxByKeySmem( ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums, const float* vals, const int* colIds, - const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys, + const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys, node_id_t nodeStart, int len, const TrainParam param) { extern __shared__ char sArr[]; ExactSplitCandidate* sNodeSplits = @@ -372,7 +371,7 @@ template void argMaxByKey(ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums, const float* vals, const int* colIds, const node_id_t* nodeAssigns, - const DeviceDenseNode* nodes, int nUniqKeys, + const DeviceNodeStats* nodes, int nUniqKeys, node_id_t nodeStart, int len, const TrainParam param, ArgMaxByKeyAlgo algo) { dh::fillConst( @@ -406,7 +405,7 @@ __global__ void assignColIds(int* colIds, const int* colOffsets) { } __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst, - const DeviceDenseNode* nodes, int nRows) { + const DeviceNodeStats* nodes, int nRows) { int id = threadIdx.x + (blockIdx.x * blockDim.x); if (id >= nRows) { return; @@ -416,7 +415,7 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst, if (nId == UNUSED_NODE) { return; } - const DeviceDenseNode n = nodes[nId]; + const DeviceNodeStats n = nodes[nId]; node_id_t result; if (n.IsLeaf() || n.IsUnused()) { result = UNUSED_NODE; @@ -430,7 +429,7 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst, __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations, const node_id_t* nodeIds, const int* instId, - const DeviceDenseNode* nodes, + const DeviceNodeStats* nodes, const int* colOffsets, const float* vals, int nVals, int nCols) { int id = threadIdx.x + (blockIdx.x * blockDim.x); @@ -443,7 +442,7 @@ __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations, int nId = nodeIds[id]; // if this element belongs to none of the currently active node-id's if (nId != UNUSED_NODE) { - const DeviceDenseNode n = nodes[nId]; + const DeviceNodeStats n = nodes[nId]; int colId = n.fidx; // printf("nid=%d colId=%d id=%d\n", nId, colId, id); int start = colOffsets[colId]; @@ -457,7 +456,7 @@ __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations, } } -__global__ void markLeavesKernel(DeviceDenseNode* nodes, int len) { +__global__ void markLeavesKernel(DeviceNodeStats* nodes, int len) { int id = (blockIdx.x * blockDim.x) + threadIdx.x; if ((id < len) && !nodes[id].IsUnused()) { int lid = (id << 1) + 1; @@ -486,7 +485,7 @@ class GPUMaker : public TreeUpdater { dh::dvec gradsInst; dh::dvec2 nodeAssigns; dh::dvec2 nodeLocations; - dh::dvec nodes; + dh::dvec nodes; dh::dvec nodeAssignsPerInst; dh::dvec gradSums; dh::dvec gradScans; @@ -573,7 +572,7 @@ class GPUMaker : public TreeUpdater { int nodeInstId = abs2uniqKey(idx, d_nodeAssigns, d_colIds, nodeStart, nUniqKeys); bool missingLeft = true; - const DeviceDenseNode& n = d_nodes[absNodeId]; + const DeviceNodeStats& n = d_nodes[absNodeId]; bst_gpair gradScan = d_gradScans[idx]; bst_gpair gradSum = d_gradSums[nodeInstId]; float thresh = d_vals[idx]; @@ -588,12 +587,13 @@ class GPUMaker : public TreeUpdater { // Create children d_nodes[left_child_nidx(absNodeId)] = - DeviceDenseNode(lGradSum, left_child_nidx(absNodeId), gpu_param); + DeviceNodeStats(lGradSum, left_child_nidx(absNodeId), gpu_param); d_nodes[right_child_nidx(absNodeId)] = - DeviceDenseNode(rGradSum, right_child_nidx(absNodeId), gpu_param); + DeviceNodeStats(rGradSum, right_child_nidx(absNodeId), gpu_param); // Set split for parent d_nodes[absNodeId].SetSplit(thresh, colId, - missingLeft ? LeftDir : RightDir); + missingLeft ? LeftDir : RightDir, lGradSum, + rGradSum); } else { // cannot be split further, so this node is a leaf! d_nodes[absNodeId].root_gain = -FLT_MAX; @@ -677,7 +677,7 @@ class GPUMaker : public TreeUpdater { instIds.current_dvec() = fId; colOffsets = offset; dh::segmentedSort(&tmp_mem, &vals, &instIds, nVals, nCols, - colOffsets); + colOffsets); vals_cached = vals.current_dvec(); instIds_cached = instIds.current_dvec(); assignColIds<<>>(colIds.data(), colOffsets.data()); @@ -695,7 +695,7 @@ class GPUMaker : public TreeUpdater { void initNodeData(int level, node_id_t nodeStart, int nNodes) { // all instances belong to root node at the beginning! if (level == 0) { - nodes.fill(DeviceDenseNode()); + nodes.fill(DeviceNodeStats()); nodeAssigns.current_dvec().fill(0); nodeAssignsPerInst.fill(0); // for root node, just update the gradient/score/weight/id info @@ -705,7 +705,7 @@ class GPUMaker : public TreeUpdater { auto d_sums = gradSums.data(); auto gpu_params = GPUTrainingParam(param); dh::launch_n(param.gpu_id, 1, [=] __device__(int idx) { - d_nodes[0] = DeviceDenseNode(d_sums[0], 0, gpu_params); + d_nodes[0] = DeviceNodeStats(d_sums[0], 0, gpu_params); }); } else { const int BlkDim = 256; @@ -722,7 +722,7 @@ class GPUMaker : public TreeUpdater { colOffsets.data(), vals.current(), nVals, nCols); // gather the node assignments across all other columns too dh::gather(dh::get_device_idx(param.gpu_id), nodeAssigns.current(), - nodeAssignsPerInst.data(), instIds.current(), nVals); + nodeAssignsPerInst.data(), instIds.current(), nVals); sortKeys(level); } } @@ -733,8 +733,8 @@ class GPUMaker : public TreeUpdater { segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols, colOffsets, 0, level + 1); dh::gather(dh::get_device_idx(param.gpu_id), vals.other(), - vals.current(), instIds.other(), instIds.current(), - nodeLocations.current(), nVals); + vals.current(), instIds.other(), instIds.current(), + nodeLocations.current(), nVals); vals.buff().selector ^= 1; instIds.buff().selector ^= 1; } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index e31c692ae..32513d7a4 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -4,13 +4,13 @@ #pragma once #include #include +#include #include #include #include +#include "../common/device_helpers.cuh" #include "../common/random.h" #include "param.h" -#include -#include "../common/device_helpers.cuh" namespace xgboost { namespace tree { @@ -52,7 +52,47 @@ enum DefaultDirection { RightDir }; -struct DeviceDenseNode { +struct DeviceSplitCandidate { + float loss_chg; + DefaultDirection dir; + float fvalue; + int findex; + bst_gpair_integer left_sum; + bst_gpair_integer right_sum; + + __host__ __device__ DeviceSplitCandidate() + : loss_chg(-FLT_MAX), dir(LeftDir), fvalue(0), findex(-1) {} + + template + __host__ __device__ void Update(const DeviceSplitCandidate &other, + const param_t& param) { + if (other.loss_chg > loss_chg && + other.left_sum.GetHess() >= param.min_child_weight && + other.right_sum.GetHess() >= param.min_child_weight) { + *this = other; + } + } + + __device__ void Update(float loss_chg_in, DefaultDirection dir_in, + float fvalue_in, int findex_in, + bst_gpair_integer left_sum_in, + bst_gpair_integer right_sum_in, + const GPUTrainingParam& param) { + if (loss_chg_in > loss_chg && + left_sum_in.GetHess() >= param.min_child_weight && + right_sum_in.GetHess() >= param.min_child_weight) { + loss_chg = loss_chg_in; + dir = dir_in; + fvalue = fvalue_in; + left_sum = left_sum_in; + right_sum = right_sum_in; + findex = findex_in; + } + } + __device__ bool IsValid() const { return loss_chg > 0.0f; } +}; + +struct DeviceNodeStats { bst_gpair sum_gradients; float root_gain; float weight; @@ -61,35 +101,50 @@ struct DeviceDenseNode { DefaultDirection dir; /** threshold value for comparison */ float fvalue; + bst_gpair left_sum; + bst_gpair right_sum; /** \brief The feature index. */ int fidx; /** node id (used as key for reduce/scan) */ node_id_t idx; - HOST_DEV_INLINE DeviceDenseNode() + HOST_DEV_INLINE DeviceNodeStats() : sum_gradients(), root_gain(-FLT_MAX), weight(-FLT_MAX), dir(LeftDir), fvalue(0.f), + left_sum(), + right_sum(), fidx(UNUSED_NODE), idx(UNUSED_NODE) {} - HOST_DEV_INLINE DeviceDenseNode(bst_gpair sum_gradients, node_id_t nidx, - const GPUTrainingParam& param) + template + HOST_DEV_INLINE DeviceNodeStats(bst_gpair sum_gradients, node_id_t nidx, + const param_t& param) : sum_gradients(sum_gradients), dir(LeftDir), fvalue(0.f), fidx(UNUSED_NODE), idx(nidx) { - this->root_gain = CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); - this->weight = CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); + this->root_gain = + CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); + this->weight = + CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); } - HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) { + HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir, + bst_gpair left_sum, bst_gpair right_sum) { this->fvalue = fvalue; this->fidx = fidx; this->dir = dir; + this->left_sum = left_sum; + this->right_sum = right_sum; + } + + HOST_DEV_INLINE void SetSplit(const DeviceSplitCandidate& split) { + this->SetSplit(split.fvalue, split.findex, split.dir, split.left_sum, + split.right_sum); } /** Tells whether this node is part of the decision tree */ @@ -101,18 +156,23 @@ struct DeviceDenseNode { } }; +template +struct SumCallbackOp { + // Running prefix + T running_total; + // Constructor + __device__ SumCallbackOp() : running_total(T()) {} + __device__ T operator()(T block_aggregate) { + T old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + template __device__ inline float device_calc_loss_chg( - const GPUTrainingParam& param, const gpair_t& scan, const gpair_t& missing, - const gpair_t& parent_sum, const float& parent_gain, bool missing_left) { - gpair_t left = scan; - - if (missing_left) { - left += missing; - } - + const GPUTrainingParam& param, const gpair_t& left, const gpair_t& parent_sum, const float& parent_gain) { gpair_t right = parent_sum - left; - float left_gain = CalcGain(param, left.GetGrad(), left.GetHess()); float right_gain = CalcGain(param, right.GetGrad(), right.GetHess()); return left_gain + right_gain - parent_gain; @@ -126,9 +186,9 @@ __device__ float inline loss_chg_missing(const gpair_t& scan, const GPUTrainingParam& param, bool& missing_left_out) { // NOLINT float missing_left_loss = - device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true); + device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain); float missing_right_loss = device_calc_loss_chg( - param, scan, missing, parent_sum, parent_gain, false); + param, scan, parent_sum, parent_gain); if (missing_left_loss >= missing_right_loss) { missing_left_out = true; @@ -168,14 +228,14 @@ __host__ __device__ inline bool is_left_child(int nidx) { // Copy gpu dense representation of tree to xgboost sparse representation inline void dense2sparse_tree(RegTree* p_tree, - const dh::dvec& nodes, + const dh::dvec& nodes, const TrainParam& param) { RegTree& tree = *p_tree; - std::vector h_nodes = nodes.as_vector(); + std::vector h_nodes = nodes.as_vector(); int nid = 0; for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { - const DeviceDenseNode& n = h_nodes[gpu_nid]; + const DeviceNodeStats& n = h_nodes[gpu_nid]; if (!n.IsUnused() && !n.IsLeaf()) { tree.AddChilds(nid); tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index a18c57086..22857c9f3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -43,12 +43,14 @@ struct DeviceGMat { gidx = common::CompressedIterator(gidx_buffer.data(), n_bins); // row_ptr - thrust::copy(gmat.row_ptr.data() + row_begin, - gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin()); + dh::safe_cuda(cudaMemcpy(row_ptr.data(), gmat.row_ptr.data() + row_begin, + row_ptr.size() * sizeof(size_t), + cudaMemcpyHostToDevice)); // normalise row_ptr size_t start = gmat.row_ptr[row_begin]; - thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(), - [=] __device__(size_t val) { return val - start; }); + auto d_row_ptr = row_ptr.data(); + dh::launch_n(row_ptr.device_idx(), row_ptr.size(), + [=] __device__(size_t idx) { d_row_ptr[idx] -= start; }); } }; @@ -61,12 +63,15 @@ struct HistHelper { __device__ void Add(bst_gpair gpair, int gidx, int nidx) const { int hist_idx = nidx * n_bins + gidx; - auto dst_ptr = reinterpret_cast(&d_hist[hist_idx]); // NOLINT + auto dst_ptr = + reinterpret_cast(&d_hist[hist_idx]); // NOLINT gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess()); auto src_ptr = reinterpret_cast(&tmp); - atomicAdd(dst_ptr, static_cast(*src_ptr)); // NOLINT - atomicAdd(dst_ptr + 1, static_cast(*(src_ptr + 1))); // NOLINT + atomicAdd(dst_ptr, + static_cast(*src_ptr)); // NOLINT + atomicAdd(dst_ptr + 1, + static_cast(*(src_ptr + 1))); // NOLINT } __device__ gpair_sum_t Get(int gidx, int nidx) const { return d_hist[nidx * n_bins + gidx]; @@ -96,51 +101,10 @@ struct DeviceHist { int LevelSize(int depth) { return n_bins * n_nodes_level(depth); } }; -struct SplitCandidate { - float loss_chg; - bool missing_left; - float fvalue; - int findex; - gpair_sum_t left_sum; - gpair_sum_t right_sum; - - __host__ __device__ SplitCandidate() - : loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {} - - __device__ void Update(float loss_chg_in, bool missing_left_in, - float fvalue_in, int findex_in, - gpair_sum_t left_sum_in, gpair_sum_t right_sum_in, - const GPUTrainingParam& param) { - if (loss_chg_in > loss_chg && - left_sum_in.GetHess() >= param.min_child_weight && - right_sum_in.GetHess() >= param.min_child_weight) { - loss_chg = loss_chg_in; - missing_left = missing_left_in; - fvalue = fvalue_in; - left_sum = left_sum_in; - right_sum = right_sum_in; - findex = findex_in; - } - } - __device__ bool IsValid() const { return loss_chg > 0.0f; } -}; - -struct GpairCallbackOp { - // Running prefix - gpair_sum_t running_total; - // Constructor - __device__ GpairCallbackOp() : running_total(gpair_sum_t()) {} - __device__ bst_gpair operator()(bst_gpair block_aggregate) { - gpair_sum_t old_prefix = running_total; - running_total += block_aggregate; - return old_prefix; - } -}; - template __global__ void find_split_kernel( const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth, - int n_features, int n_bins, DeviceDenseNode* d_nodes, + int n_features, int n_bins, DeviceNodeStats* d_nodes, int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) { @@ -156,15 +120,15 @@ __global__ void find_split_kernel( typename SumReduceT::TempStorage sum_reduce; }; - __shared__ cub::Uninitialized uninitialized_split; - SplitCandidate& split = uninitialized_split.Alias(); + __shared__ cub::Uninitialized uninitialized_split; + DeviceSplitCandidate& split = uninitialized_split.Alias(); __shared__ cub::Uninitialized uninitialized_sum; gpair_sum_t& shared_sum = uninitialized_sum.Alias(); __shared__ ArgMaxT block_max; __shared__ TempStorage temp_storage; if (threadIdx.x == 0) { - split = SplitCandidate(); + split = DeviceSplitCandidate(); } __syncthreads(); @@ -197,7 +161,7 @@ __global__ void find_split_kernel( } // __syncthreads(); // no need to synch because below there is a Scan - GpairCallbackOp prefix_op = GpairCallbackOp(); + auto prefix_op = SumCallbackOp(); for (int scan_begin = begin; scan_begin < end; scan_begin += BLOCK_THREADS) { bool thread_active = scan_begin + threadIdx.x < end; @@ -245,7 +209,8 @@ __global__ void find_split_kernel( gpair_sum_t left = missing_left ? bin + missing : bin; gpair_sum_t right = parent_sum - left; - split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param); + split.Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, + left, right, gpu_param); } __syncthreads(); } // end scan @@ -253,17 +218,16 @@ __global__ void find_split_kernel( // Create node if (threadIdx.x == 0 && split.IsValid()) { - d_nodes[node_idx].SetSplit(split.fvalue, split.findex, - split.missing_left ? LeftDir : RightDir); + d_nodes[node_idx].SetSplit(split); - DeviceDenseNode& left_child = d_nodes[left_child_nidx(node_idx)]; - DeviceDenseNode& right_child = d_nodes[right_child_nidx(node_idx)]; + DeviceNodeStats& left_child = d_nodes[left_child_nidx(node_idx)]; + DeviceNodeStats& right_child = d_nodes[right_child_nidx(node_idx)]; bool& left_child_smallest = d_left_child_smallest_temp[node_idx]; left_child = - DeviceDenseNode(split.left_sum, left_child_nidx(node_idx), gpu_param); + DeviceNodeStats(split.left_sum, left_child_nidx(node_idx), gpu_param); right_child = - DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param); + DeviceNodeStats(split.right_sum, right_child_nidx(node_idx), gpu_param); // Record smallest node if (split.left_sum.GetHess() <= split.right_sum.GetHess()) { @@ -336,7 +300,7 @@ class GPUHistMaker : public TreeUpdater { // reset static timers used across iterations cpu_init_time = 0; gpu_init_time = 0; - cpu_time.reset(); + cpu_time.Reset(); gpu_time = 0; // set dList member @@ -399,31 +363,31 @@ class GPUHistMaker : public TreeUpdater { is_dense = info->num_nonzero == info->num_col * info->num_row; dh::Timer time0; hmat_.Init(&fmat, param.max_bin); - cpu_init_time += time0.elapsedSeconds(); + cpu_init_time += time0.ElapsedSeconds(); if (param.debug_verbose) { // Only done once for each training session LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init " - << time0.elapsedSeconds() << " sec"; + << time0.ElapsedSeconds() << " sec"; fflush(stdout); } - time0.reset(); + time0.Reset(); gmat_.cut = &hmat_; - cpu_init_time += time0.elapsedSeconds(); + cpu_init_time += time0.ElapsedSeconds(); if (param.debug_verbose) { // Only done once for each training session LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut " - << time0.elapsedSeconds() << " sec"; + << time0.ElapsedSeconds() << " sec"; fflush(stdout); } - time0.reset(); + time0.Reset(); gmat_.Init(&fmat); - cpu_init_time += time0.elapsedSeconds(); + cpu_init_time += time0.ElapsedSeconds(); if (param.debug_verbose) { // Only done once for each training session LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() " - << time0.elapsedSeconds() << " sec"; + << time0.ElapsedSeconds() << " sec"; fflush(stdout); } - time0.reset(); + time0.Reset(); if (param.debug_verbose) { // Only done once for each training session LOG(CONSOLE) @@ -563,9 +527,9 @@ class GPUHistMaker : public TreeUpdater { int device_idx = dList[d_idx]; dh::safe_cuda(cudaSetDevice(device_idx)); - nodes[d_idx].fill(DeviceDenseNode()); - nodes_temp[d_idx].fill(DeviceDenseNode()); - nodes_child_temp[d_idx].fill(DeviceDenseNode()); + nodes[d_idx].fill(DeviceNodeStats()); + nodes_temp[d_idx].fill(DeviceNodeStats()); + nodes_child_temp[d_idx].fill(DeviceNodeStats()); position[d_idx].fill(0); @@ -584,7 +548,7 @@ class GPUHistMaker : public TreeUpdater { dh::synchronize_n_devices(n_devices, dList); if (!initialised) { - gpu_init_time = time1.elapsedSeconds() - cpu_init_time; + gpu_init_time = time1.ElapsedSeconds() - cpu_init_time; gpu_time = -cpu_init_time; if (param.debug_verbose) { // Only done once for each training session LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First " @@ -701,12 +665,12 @@ class GPUHistMaker : public TreeUpdater { dh::synchronize_n_devices(n_devices, dList); } } -#define MIN_BLOCK_THREADS 32 -#define CHUNK_BLOCK_THREADS 32 +#define MIN_BLOCK_THREADS 128 +#define CHUNK_BLOCK_THREADS 128 // MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due // to CUDA capability 35 and above requirement // for Maximum number of threads per block -#define MAX_BLOCK_THREADS 1024 +#define MAX_BLOCK_THREADS 512 void FindSplit(int depth) { // Specialised based on max_bins @@ -783,7 +747,7 @@ class GPUHistMaker : public TreeUpdater { dh::launch_n(device_idx, 1, [=] __device__(int idx) { bst_gpair sum_gradients = sum; - d_nodes[idx] = DeviceDenseNode(sum_gradients, 0, gpu_param); + d_nodes[idx] = DeviceNodeStats(sum_gradients, 0, gpu_param); }); } // synch all devices to host before moving on (No, can avoid because @@ -802,7 +766,7 @@ class GPUHistMaker : public TreeUpdater { int device_idx = dList[d_idx]; auto d_position = position[d_idx].data(); - DeviceDenseNode* d_nodes = nodes[d_idx].data(); + DeviceNodeStats* d_nodes = nodes[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); auto d_gidx = device_matrix[d_idx].gidx; int n_columns = info->num_col; @@ -814,7 +778,7 @@ class GPUHistMaker : public TreeUpdater { if (!is_active(pos, depth)) { return; } - DeviceDenseNode node = d_nodes[pos]; + DeviceNodeStats node = d_nodes[pos]; if (node.IsLeaf()) { return; @@ -842,7 +806,7 @@ class GPUHistMaker : public TreeUpdater { auto d_position = position[d_idx].data(); auto d_position_tmp = position_tmp[d_idx].data(); - DeviceDenseNode* d_nodes = nodes[d_idx].data(); + DeviceNodeStats* d_nodes = nodes[d_idx].data(); auto d_gidx_feature_map = gidx_feature_map[d_idx].data(); auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data(); auto d_gidx = device_matrix[d_idx].gidx; @@ -862,7 +826,7 @@ class GPUHistMaker : public TreeUpdater { return; } - DeviceDenseNode node = d_nodes[pos]; + DeviceNodeStats node = d_nodes[pos]; if (node.IsLeaf()) { d_position_tmp[local_idx] = pos; @@ -887,7 +851,7 @@ class GPUHistMaker : public TreeUpdater { return; } - DeviceDenseNode node = d_nodes[pos]; + DeviceNodeStats node = d_nodes[pos]; if (node.IsLeaf()) { return; @@ -976,8 +940,10 @@ class GPUHistMaker : public TreeUpdater { d_prediction_cache[local_idx] += d_nodes[pos].weight * eps; }); - thrust::copy(prediction_cache[d_idx].tbegin(), - prediction_cache[d_idx].tend(), &out_preds[row_begin]); + dh::safe_cuda( + cudaMemcpy(&out_preds[row_begin], prediction_cache[d_idx].data(), + prediction_cache[d_idx].size() * sizeof(bst_float), + cudaMemcpyDeviceToHost)); } dh::synchronize_n_devices(n_devices, dList); @@ -1003,7 +969,7 @@ class GPUHistMaker : public TreeUpdater { dh::safe_cuda(cudaSetDevice(master_device)); dense2sparse_tree(p_tree, nodes[0], param); - gpu_time += time0.elapsedSeconds(); + gpu_time += time0.ElapsedSeconds(); if (param.debug_verbose) { LOG(CONSOLE) @@ -1014,10 +980,10 @@ class GPUHistMaker : public TreeUpdater { if (param.debug_verbose) { LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time " - << cpu_time.elapsedSeconds() << " sec"; + << cpu_time.ElapsedSeconds() << " sec"; LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time excluding initial time " - << (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time) << " sec"; + << (cpu_time.ElapsedSeconds() - cpu_init_time - gpu_time) << " sec"; fflush(stdout); } } @@ -1048,9 +1014,9 @@ class GPUHistMaker : public TreeUpdater { std::vector temp_memory; std::vector hist_vec; - std::vector> nodes; - std::vector> nodes_temp; - std::vector> nodes_child_temp; + std::vector> nodes; + std::vector> nodes_temp; + std::vector> nodes_child_temp; std::vector> left_child_smallest; std::vector> left_child_smallest_temp; std::vector> feature_flags; diff --git a/src/tree/updater_gpu_hist_experimental.cu b/src/tree/updater_gpu_hist_experimental.cu new file mode 100644 index 000000000..6193b16f4 --- /dev/null +++ b/src/tree/updater_gpu_hist_experimental.cu @@ -0,0 +1,833 @@ +/*! + * Copyright 2017 XGBoost contributors + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../common/compressed_iterator.h" +#include "../common/device_helpers.cuh" +#include "../common/hist_util.h" +#include "param.h" +#include "updater_gpu_common.cuh" + +namespace xgboost { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_gpu_hist_experimental); + +template +__device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin, + const bst_gpair_integer* end, + temp_storage_t* temp_storage) { + __shared__ cub::Uninitialized uninitialized_sum; + bst_gpair_integer& shared_sum = uninitialized_sum.Alias(); + + bst_gpair_integer local_sum = bst_gpair_integer(); + for (auto itr = begin; itr < end; itr += BLOCK_THREADS) { + bool thread_active = itr + threadIdx.x < end; + // Scan histogram + bst_gpair_integer bin = + thread_active ? *(itr + threadIdx.x) : bst_gpair_integer(); + + local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum()); + } + + if (threadIdx.x == 0) { + shared_sum = local_sum; + } + __syncthreads(); + + return shared_sum; +} + +template +__device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist, + const int* feature_segments, float min_fvalue, + const float* gidx_fvalue_map, + DeviceSplitCandidate* best_split, + const DeviceNodeStats& node, + const GPUTrainingParam& param, + temp_storage_t* temp_storage) { + int gidx_begin = feature_segments[fidx]; + int gidx_end = feature_segments[fidx + 1]; + + bst_gpair_integer feature_sum = ReduceFeature( + hist + gidx_begin, hist + gidx_end, temp_storage); + + auto prefix_op = SumCallbackOp(); + for (int scan_begin = gidx_begin; scan_begin < gidx_end; + scan_begin += BLOCK_THREADS) { + bool thread_active = scan_begin + threadIdx.x < gidx_end; + + bst_gpair_integer bin = + thread_active ? hist[scan_begin + threadIdx.x] : bst_gpair_integer(); + scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + + // Calculate gain + bst_gpair_integer parent_sum = bst_gpair_integer(node.sum_gradients); + + bst_gpair_integer missing = parent_sum - feature_sum; + + bool missing_left = true; + const float null_gain = -FLT_MAX; + float gain = null_gain; + if (thread_active) { + gain = loss_chg_missing(bin, missing, parent_sum, node.root_gain, param, + missing_left); + } + + __syncthreads(); + + // Find thread with best gain + cub::KeyValuePair tuple(threadIdx.x, gain); + cub::KeyValuePair best = + max_reduce_t(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax()); + + __shared__ cub::KeyValuePair block_max; + if (threadIdx.x == 0) { + block_max = best; + } + + __syncthreads(); + + // Best thread updates split + if (threadIdx.x == block_max.key) { + int gidx = scan_begin + threadIdx.x; + float fvalue = + gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; + + bst_gpair_integer left = missing_left ? bin + missing : bin; + bst_gpair_integer right = parent_sum - left; + + best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, + left, right, param); + } + __syncthreads(); + } +} + +template +__global__ void evaluate_split_kernel(const bst_gpair_integer* d_hist, int nidx, + int n_features, DeviceNodeStats nodes, + const int* d_feature_segments, + const float* d_fidx_min_map, + const float* d_gidx_fvalue_map, + GPUTrainingParam gpu_param, + DeviceSplitCandidate* d_split) { + typedef cub::KeyValuePair ArgMaxT; + typedef cub::BlockScan + BlockScanT; + typedef cub::BlockReduce MaxReduceT; + + typedef cub::BlockReduce SumReduceT; + + union TempStorage { + typename BlockScanT::TempStorage scan; + typename MaxReduceT::TempStorage max_reduce; + typename SumReduceT::TempStorage sum_reduce; + }; + + __shared__ cub::Uninitialized uninitialized_split; + DeviceSplitCandidate& best_split = uninitialized_split.Alias(); + __shared__ TempStorage temp_storage; + + if (threadIdx.x == 0) { + best_split = DeviceSplitCandidate(); + } + + __syncthreads(); + + auto fidx = blockIdx.x; + EvaluateFeature( + fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map, + &best_split, nodes, gpu_param, &temp_storage); + + __syncthreads(); + + if (threadIdx.x == 0) { + // Record best loss + d_split[fidx] = best_split; + } +} + +// Find a gidx value for a given feature otherwise return -1 if not found +template +__device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data, + int fidx_begin, int fidx_end) { + // for(auto i = begin; i < end; i++) + //{ + // auto gidx = data[i]; + // if (gidx >= fidx_begin&&gidx < fidx_end) return gidx; + //} + // return -1; + + bst_uint previous_middle = UINT32_MAX; + while (end != begin) { + auto middle = begin + (end - begin) / 2; + if (middle == previous_middle) { + break; + } + previous_middle = middle; + + auto gidx = data[middle]; + + if (gidx >= fidx_begin && gidx < fidx_end) { + return gidx; + } else if (gidx < fidx_begin) { + begin = middle; + } else { + end = middle; + } + } + // Value is missing + return -1; +} + +template +__global__ void RadixSortSmall(bst_uint* d_ridx, int* d_position, bst_uint n) { + typedef cub::BlockRadixSort BlockRadixSort; + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + bool thread_active = threadIdx.x < n; + int thread_key[1]; + bst_uint thread_value[1]; + thread_key[0] = thread_active ? d_position[threadIdx.x] : INT_MAX; + thread_value[0] = thread_active ? d_ridx[threadIdx.x] : UINT_MAX; + BlockRadixSort(temp_storage).Sort(thread_key, thread_value); + + if (thread_active) { + d_position[threadIdx.x] = thread_key[0]; + d_ridx[threadIdx.x] = thread_value[0]; + } +} + +struct DeviceHistogram { + dh::bulk_allocator ba; + dh::dvec data; + std::map node_map; + int n_bins; + void Init(int device_idx, int max_nodes, int n_bins, bool silent) { + this->n_bins = n_bins; + ba.allocate(device_idx, silent, &data, max_nodes * n_bins); + } + + void Reset() { + data.fill(bst_gpair_integer()); + node_map.clear(); + } + + void AddNode(int nidx) { + CHECK_EQ(node_map.count(nidx), 0) + << nidx << " already exists in the histogram."; + node_map[nidx] = data.data() + n_bins * node_map.size(); + } +}; + +// Manage memory for a single GPU +struct DeviceShard { + int device_idx; + int normalised_device_idx; // Device index counting from param.gpu_id + dh::bulk_allocator ba; + dh::dvec gidx_buffer; + dh::dvec gpair; + dh::dvec2 ridx; + dh::dvec2 position; + std::vector> ridx_segments; + dh::dvec feature_segments; + dh::dvec gidx_fvalue_map; + dh::dvec min_fvalue; + std::vector node_sum_gradients; + common::CompressedIterator gidx; + int row_stride; + bst_uint row_start_idx; + bst_uint row_end_idx; + bst_uint n_rows; + int n_bins; + int null_gidx_value; + DeviceHistogram hist; + + std::vector streams; + + dh::CubMemory temp_memory; + + DeviceShard(int device_idx, int normalised_device_idx, + const common::GHistIndexMatrix& gmat, bst_uint row_begin, + bst_uint row_end, int n_bins, TrainParam param) + : device_idx(device_idx), + normalised_device_idx(normalised_device_idx), + row_start_idx(row_begin), + row_end_idx(row_end), + n_rows(row_end - row_begin), + n_bins(n_bins), + null_gidx_value(n_bins) { + // Convert to ELLPACK matrix representation + int max_elements_row = 0; + for (int i = row_begin; i < row_end; i++) { + max_elements_row = + (std::max)(max_elements_row, + static_cast(gmat.row_ptr[i + 1] - gmat.row_ptr[i])); + } + row_stride = max_elements_row; + std::vector ellpack_matrix(row_stride * n_rows, null_gidx_value); + + for (int i = row_begin; i < row_end; i++) { + int row_count = 0; + for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { + ellpack_matrix[i * row_stride + row_count] = gmat.index[j]; + row_count++; + } + } + + // Allocate + int num_symbols = n_bins + 1; + size_t compressed_size_bytes = + common::CompressedBufferWriter::CalculateBufferSize( + ellpack_matrix.size(), num_symbols); + int max_nodes = + param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth); + ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes, + &gpair, n_rows, &ridx, n_rows, &position, n_rows, + &feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map, + gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size()); + gidx_fvalue_map = gmat.cut->cut; + min_fvalue = gmat.cut->min_val; + feature_segments = gmat.cut->row_ptr; + + node_sum_gradients.resize(max_nodes); + ridx_segments.resize(max_nodes); + + // Compress gidx + common::CompressedBufferWriter cbw(num_symbols); + std::vector host_buffer(gidx_buffer.size()); + cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end()); + gidx_buffer = host_buffer; + gidx = + common::CompressedIterator(gidx_buffer.data(), num_symbols); + + common::CompressedIterator ci_host(host_buffer.data(), + num_symbols); + + // Init histogram + hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent); + } + + ~DeviceShard() { + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + } + + // Get vector of at least n initialised streams + std::vector& GetStreams(int n) { + if (n > streams.size()) { + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + + streams.clear(); + streams.resize(n); + + for (auto& stream : streams) { + dh::safe_cuda(cudaStreamCreate(&stream)); + } + } + + return streams; + } + + // Reset values for each update iteration + void Reset(const std::vector& host_gpair) { + position.current_dvec().fill(0); + std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), + bst_gpair()); + // TODO(rory): support subsampling + thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend(), + row_start_idx); + std::fill(ridx_segments.begin(), ridx_segments.end(), std::make_pair(0, 0)); + ridx_segments.front() = std::make_pair(0, ridx.size()); + this->gpair.copy(host_gpair.begin() + row_start_idx, + host_gpair.begin() + row_end_idx); + hist.Reset(); + } + + __device__ void IncrementHist(bst_gpair gpair, int gidx, + bst_gpair_integer* node_hist) const { + auto dst_ptr = + reinterpret_cast(&node_hist[gidx]); // NOLINT + bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess()); + auto src_ptr = reinterpret_cast(&tmp); + + atomicAdd(dst_ptr, + static_cast(*src_ptr)); // NOLINT + atomicAdd(dst_ptr + 1, + static_cast(*(src_ptr + 1))); // NOLINT + } + + void BuildHist(int nidx) { + hist.AddNode(nidx); + auto d_node_hist = hist.node_map[nidx]; + auto d_gidx = gidx; + auto d_ridx = ridx.current(); + auto d_gpair = gpair.data(); + auto row_stride = this->row_stride; + auto null_gidx_value = this->null_gidx_value; + auto segment = ridx_segments[nidx]; + auto n_elements = (segment.second - segment.first) * row_stride; + + dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) { + int relative_ridx = d_ridx[(idx / row_stride) + segment.first]; + int gidx = d_gidx[relative_ridx * row_stride + idx % row_stride]; + if (gidx != null_gidx_value) { + bst_gpair gpair = d_gpair[relative_ridx]; + IncrementHist(gpair, gidx, d_node_hist); + } + }); + } + void SortPosition(const std::pair& segment, int left_nidx, + int right_nidx) { + auto n = segment.second - segment.first; + int min_bits = 0; + int max_bits = std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1)); + // const int SINGLE_TILE_SIZE = 1024; + // if (n < SINGLE_TILE_SIZE) { + // RadixSortSmall + // <<<1, SINGLE_TILE_SIZE>>>(ridx.current() + segment.first, + // position.current() + segment.first, n); + //} else { + + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, position.current() + segment.first, + position.other() + segment.first, ridx.current() + segment.first, + ridx.other() + segment.first, n, min_bits, max_bits); + + temp_memory.LazyAllocate(temp_storage_bytes); + + cub::DeviceRadixSort::SortPairs( + temp_memory.d_temp_storage, temp_memory.temp_storage_bytes, + position.current() + segment.first, position.other() + segment.first, + ridx.current() + segment.first, ridx.other() + segment.first, n, + min_bits, max_bits); + dh::safe_cuda(cudaMemcpy(position.current() + segment.first, + position.other() + segment.first, n * sizeof(int), + cudaMemcpyDeviceToDevice)); + dh::safe_cuda(cudaMemcpy(ridx.current() + segment.first, + ridx.other() + segment.first, n * sizeof(bst_uint), + cudaMemcpyDeviceToDevice)); + //} + } +}; + +class GPUHistMakerExperimental : public TreeUpdater { + public: + struct ExpandEntry; + + GPUHistMakerExperimental() : initialised(false) {} + ~GPUHistMakerExperimental() {} + void Init( + const std::vector>& args) override { + param.InitAllowUnknown(args); + CHECK(param.n_gpus != 0) << "Must have at least one device"; + CHECK(param.n_gpus <= 1 && param.n_gpus != -1) + << "Only one GPU currently supported"; + n_devices = param.n_gpus; + + if (param.grow_policy == TrainParam::kLossGuide) { + qexpand_.reset(new ExpandQueue(loss_guide)); + } else { + qexpand_.reset(new ExpandQueue(depth_wise)); + } + + monitor.Init("updater_gpu_hist_experimental", param.debug_verbose); + } + void Update(const std::vector& gpair, DMatrix* dmat, + const std::vector& trees) override { + GradStats::CheckInfo(dmat->info()); + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + // build tree + try { + for (size_t i = 0; i < trees.size(); ++i) { + this->UpdateTree(gpair, dmat, trees[i]); + } + } catch (const std::exception& e) { + LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; + } + param.learning_rate = lr; + } + + void InitDataOnce(DMatrix* dmat) { + info = &dmat->info(); + hmat_.Init(dmat, param.max_bin); + gmat_.cut = &hmat_; + gmat_.Init(dmat); + n_bins = hmat_.row_ptr.back(); + shards.emplace_back(param.gpu_id, 0, gmat_, 0, info->num_row, n_bins, + param); + initialised = true; + } + + void InitData(const std::vector& gpair, DMatrix* dmat, + const RegTree& tree) { + if (!initialised) { + this->InitDataOnce(dmat); + } + + this->ColSampleTree(); + + // Copy gpair & reset memory + for (auto& shard : shards) { + shard.Reset(gpair); + } + } + + void BuildHist(int nidx) { + for (auto& shard : shards) { + shard.BuildHist(nidx); + } + } + + // Returns best loss + std::vector EvaluateSplits( + const std::vector& nidx_set, RegTree* p_tree) { + auto columns = info->num_col; + std::vector best_splits(nidx_set.size()); + std::vector candidate_splits(nidx_set.size() * + columns); + // Use first device + auto& shard = shards.front(); + dh::safe_cuda(cudaSetDevice(shard.device_idx)); + shard.temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns * + nidx_set.size()); + auto d_split = shard.temp_memory.Pointer(); + + auto& streams = shard.GetStreams(nidx_set.size()); + + // Use streams to process nodes concurrently + for (auto i = 0; i < nidx_set.size(); i++) { + auto nidx = nidx_set[i]; + DeviceNodeStats node(shard.node_sum_gradients[nidx], nidx, param); + + const int BLOCK_THREADS = 256; + evaluate_split_kernel + <<>>( + shard.hist.node_map[nidx], nidx, info->num_col, node, + shard.feature_segments.data(), shard.min_fvalue.data(), + shard.gidx_fvalue_map.data(), GPUTrainingParam(param), + d_split + i * columns); + } + + dh::safe_cuda( + cudaMemcpy(candidate_splits.data(), shard.temp_memory.d_temp_storage, + sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), + cudaMemcpyDeviceToHost)); + + for (auto i = 0; i < nidx_set.size(); i++) { + DeviceSplitCandidate nidx_best; + for (auto fidx = 0; fidx < columns; fidx++) { + nidx_best.Update(candidate_splits[i * columns + fidx], param); + } + best_splits[i] = nidx_best; + } + return std::move(best_splits); + } + + void InitRoot(const std::vector& gpair, RegTree* p_tree) { + int root_nidx = 0; + BuildHist(root_nidx); + + // TODO(rory): support sub sampling + // TODO(rory): not asynchronous + bst_gpair sum_gradient; + for (auto& shard : shards) { + sum_gradient += thrust::reduce(shard.gpair.tbegin(), shard.gpair.tend()); + } + + // Remember root stats + p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess(); + p_tree->stat(root_nidx).base_weight = CalcWeight(param, sum_gradient); + + // Store sum gradients + for (auto& shard : shards) { + shard.node_sum_gradients[root_nidx] = sum_gradient; + } + + auto splits = this->EvaluateSplits({root_nidx}, p_tree); + + // Generate candidate + qexpand_->push( + ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0)); + } + + struct MatchingFunctor : public thrust::unary_function { + int val; + __host__ __device__ MatchingFunctor(int val) : val(val) {} + __host__ __device__ int operator()(int x) const { return x == val; } + }; + + __device__ void CountLeft(bst_uint* d_count, int val, int left_nidx) { + unsigned ballot = __ballot(val == left_nidx); + if (threadIdx.x % 32 == 0) { + atomicAdd(d_count, __popc(ballot)); + } + } + + void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { + auto nidx = candidate.nid; + auto is_dense = info->num_nonzero == info->num_row * info->num_col; + auto left_nidx = (*p_tree)[nidx].cleft(); + auto right_nidx = (*p_tree)[nidx].cright(); + + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + auto split_gidx = -1; + auto fidx = candidate.split.findex; + auto default_dir_left = candidate.split.dir == LeftDir; + auto fidx_begin = hmat_.row_ptr[fidx]; + auto fidx_end = hmat_.row_ptr[fidx + 1]; + for (auto i = fidx_begin; i < fidx_end; ++i) { + if (candidate.split.fvalue == hmat_.cut[i]) { + split_gidx = static_cast(i); + } + } + + for (auto& shard : shards) { + monitor.Start("update position kernel"); + shard.temp_memory.LazyAllocate(sizeof(bst_uint)); + auto d_left_count = shard.temp_memory.Pointer(); + dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(bst_uint))); + dh::safe_cuda(cudaSetDevice(shard.device_idx)); + auto segment = shard.ridx_segments[nidx]; + CHECK_GT(segment.second - segment.first, 0); + auto d_ridx = shard.ridx.current(); + auto d_position = shard.position.current(); + auto d_gidx = shard.gidx; + auto row_stride = shard.row_stride; + dh::launch_n<1, 512>( + shard.device_idx, segment.second - segment.first, + [=] __device__(bst_uint idx) { + idx += segment.first; + auto ridx = d_ridx[idx]; + auto row_begin = row_stride * ridx; + auto row_end = row_begin + row_stride; + auto gidx = -1; + if (is_dense) { + gidx = d_gidx[row_begin + fidx]; + } else { + gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, + fidx_end); + } + + int position; + if (gidx >= 0) { + // Feature is found + position = gidx <= split_gidx ? left_nidx : right_nidx; + } else { + // Feature is missing + position = default_dir_left ? left_nidx : right_nidx; + } + + CountLeft(d_left_count, position, left_nidx); + d_position[idx] = position; + }); + + bst_uint left_count; + dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(bst_uint), + cudaMemcpyDeviceToHost)); + monitor.Stop("update position kernel"); + + monitor.Start("sort"); + shard.SortPosition(segment, left_nidx, right_nidx); + monitor.Stop("sort"); + shard.ridx_segments[left_nidx] = + std::make_pair(segment.first, segment.first + left_count); + shard.ridx_segments[right_nidx] = + std::make_pair(segment.first + left_count, segment.second); + } + } + + void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { + // Add new leaves + RegTree& tree = *p_tree; + tree.AddChilds(candidate.nid); + auto& parent = tree[candidate.nid]; + parent.set_split(candidate.split.findex, candidate.split.fvalue, + candidate.split.dir == LeftDir); + tree.stat(candidate.nid).loss_chg = candidate.split.loss_chg; + + // Configure left child + auto left_weight = CalcWeight(param, candidate.split.left_sum); + tree[parent.cleft()].set_leaf(left_weight * param.learning_rate, 0); + tree.stat(parent.cleft()).base_weight = left_weight; + tree.stat(parent.cleft()).sum_hess = candidate.split.left_sum.GetHess(); + + // Configure right child + auto right_weight = CalcWeight(param, candidate.split.right_sum); + tree[parent.cright()].set_leaf(right_weight * param.learning_rate, 0); + tree.stat(parent.cright()).base_weight = right_weight; + tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess(); + // Store sum gradients + for (auto& shard : shards) { + shard.node_sum_gradients[parent.cleft()] = candidate.split.left_sum; + shard.node_sum_gradients[parent.cright()] = candidate.split.right_sum; + } + this->UpdatePosition(candidate, p_tree); + } + + void ColSampleTree() { + if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return; + + feature_set_tree.resize(info->num_col); + std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); + feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree); + } + + struct Monitor { + bool debug_verbose = false; + std::string label = ""; + std::map timer_map; + + ~Monitor() { + if (!debug_verbose) return; + + std::cout << "Monitor: " << label << "\n"; + for (auto& kv : timer_map) { + kv.second.PrintElapsed(kv.first); + } + } + void Init(std::string label, bool debug_verbose) { + this->debug_verbose = debug_verbose; + this->label = label; + } + void Start(const std::string& name) { timer_map[name].Start(); } + void Stop(const std::string& name) { timer_map[name].Stop(); } + }; + + void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, + RegTree* p_tree) { + auto& tree = *p_tree; + + monitor.Start("InitData"); + this->InitData(gpair, p_fmat, *p_tree); + monitor.Stop("InitData"); + monitor.Start("InitRoot"); + this->InitRoot(gpair, p_tree); + monitor.Stop("InitRoot"); + + unsigned timestamp = qexpand_->size(); + auto num_leaves = 1; + + while (!qexpand_->empty()) { + auto candidate = qexpand_->top(); + qexpand_->pop(); + if (!candidate.IsValid(param, num_leaves)) continue; + // std::cout << candidate; + monitor.Start("ApplySplit"); + this->ApplySplit(candidate, p_tree); + monitor.Stop("ApplySplit"); + num_leaves++; + + auto left_child_nidx = tree[candidate.nid].cleft(); + auto right_child_nidx = tree[candidate.nid].cright(); + + // Only create child entries if needed + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), + num_leaves)) { + monitor.Start("BuildHist"); + this->BuildHist(left_child_nidx); + this->BuildHist(right_child_nidx); + monitor.Stop("BuildHist"); + + monitor.Start("EvaluateSplits"); + auto splits = + this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree); + qexpand_->push(ExpandEntry(left_child_nidx, + tree.GetDepth(left_child_nidx), splits[0], + timestamp++)); + qexpand_->push(ExpandEntry(right_child_nidx, + tree.GetDepth(right_child_nidx), splits[1], + timestamp++)); + monitor.Stop("EvaluateSplits"); + } + } + } + + struct ExpandEntry { + int nid; + int depth; + DeviceSplitCandidate split; + unsigned timestamp; + ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split, + unsigned timestamp) + : nid(nid), depth(depth), split(split), timestamp(timestamp) {} + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= rt_eps) return false; + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + return true; + } + + static bool ChildIsValid(const TrainParam& param, int depth, + int num_leaves) { + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } + }; + + inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.depth == rhs.depth) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.depth > rhs.depth; // favor small depth + } + } + inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.split.loss_chg == rhs.split.loss_chg) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg + } + } + TrainParam param; + common::HistCutMatrix hmat_; + common::GHistIndexMatrix gmat_; + MetaInfo* info; + bool initialised; + int n_devices; + int n_bins; + + std::vector shards; + std::vector feature_set_tree; + std::vector feature_set_level; + typedef std::priority_queue, + std::function> + ExpandQueue; + std::unique_ptr qexpand_; + Monitor monitor; +}; + +XGBOOST_REGISTER_TREE_UPDATER(GPUHistMakerExperimental, + "grow_gpu_hist_experimental") + .describe("Grow tree with GPU.") + .set_body([]() { return new GPUHistMakerExperimental(); }); +} // namespace tree +} // namespace xgboost diff --git a/tests/benchmark/benchmark.py b/tests/benchmark/benchmark.py index e0b51a934..2ee17ffb8 100644 --- a/tests/benchmark/benchmark.py +++ b/tests/benchmark/benchmark.py @@ -5,57 +5,45 @@ import numpy as np from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split import time +import ast + +rng = np.random.RandomState(1994) -def run_benchmark(args, gpu_algorithm, cpu_algorithm): +def run_benchmark(args): print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) tmp = time.time() X, y = make_classification(args.rows, n_features=args.columns, random_state=7) + if args.sparsity < 1.0: + X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X]) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7) print ("Generate Time: %s seconds" % (str(time.time() - tmp))) tmp = time.time() print ("DMatrix Start") - # omp way dtrain = xgb.DMatrix(X_train, y_train, nthread=-1) dtest = xgb.DMatrix(X_test, y_test, nthread=-1) print ("DMatrix Time: %s seconds" % (str(time.time() - tmp))) - param = {'objective': 'binary:logistic', - 'max_depth': 6, - 'silent': 0, - 'n_gpus': 1, - 'gpu_id': 0, - 'eval_metric': 'error', - 'debug_verbose': 0, - } + param = {'objective': 'binary:logistic'} + if args.params is not '': + param.update(ast.literal_eval(args.params)) - param['tree_method'] = gpu_algorithm + param['tree_method'] = args.tree_method print("Training with '%s'" % param['tree_method']) tmp = time.time() xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")]) print ("Train Time: %s seconds" % (str(time.time() - tmp))) - param['silent'] = 1 - param['tree_method'] = cpu_algorithm - print("Training with '%s'" % param['tree_method']) - tmp = time.time() - xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")]) - print ("Time: %s seconds" % (str(time.time() - tmp))) - - parser = argparse.ArgumentParser() -parser.add_argument('--algorithm', choices=['all', 'gpu_exact', 'gpu_hist'], default='all') +parser.add_argument('--tree_method', default='gpu_hist') +parser.add_argument('--sparsity', type=float, default=0.0) parser.add_argument('--rows', type=int, default=1000000) parser.add_argument('--columns', type=int, default=50) parser.add_argument('--iterations', type=int, default=500) parser.add_argument('--test_size', type=float, default=0.25) +parser.add_argument('--params', default='', help='Provide additional parameters as a Python dict string, e.g. --params \"{\'max_depth\':2}\"') args = parser.parse_args() -if 'gpu_hist' in args.algorithm: - run_benchmark(args, args.algorithm, 'hist') -elif 'gpu_exact' in args.algorithm: - run_benchmark(args, args.algorithm, 'exact') -elif 'all' in args.algorithm: - run_benchmark(args, 'gpu_exact', 'exact') - run_benchmark(args, 'gpu_hist', 'hist') +run_benchmark(args) diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 63eec0fcd..3c49a9993 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -41,7 +41,7 @@ void SpeedTest() { [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); dh::safe_cuda(cudaDeviceSynchronize()); - double time = t.elapsedSeconds(); + double time = t.ElapsedSeconds(); const int mb_size = 1048576; size_t size = (sizeof(int) * h_rows.size()) / mb_size; printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time, diff --git a/tests/cpp/tree/test_gpu_hist_experimental.cu b/tests/cpp/tree/test_gpu_hist_experimental.cu new file mode 100644 index 000000000..fd12aabb7 --- /dev/null +++ b/tests/cpp/tree/test_gpu_hist_experimental.cu @@ -0,0 +1,72 @@ + +/*! + * Copyright 2017 XGBoost contributors + */ +#include +#include +#include "../helpers.h" +#include "gtest/gtest.h" + +#include "../../../src/tree/updater_gpu_hist_experimental.cu" +#include "../../../src/gbm/gbtree_model.h" + +namespace xgboost { +namespace tree { +TEST(gpu_hist_experimental, TestSparseShard) { + int rows = 100; + int columns = 80; + int max_bins = 4; + auto dmat = CreateDMatrix(rows, columns, 0.9); + common::HistCutMatrix hmat; + common::GHistIndexMatrix gmat; + hmat.Init(dmat.get(), max_bins); + gmat.cut = &hmat; + gmat.Init(dmat.get()); + DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam()); + + ASSERT_LT(shard.row_stride, columns); + + auto host_gidx_buffer = shard.gidx_buffer.as_vector(); + + common::CompressedIterator gidx(host_gidx_buffer.data(), + hmat.row_ptr.back() + 1); + + for (int i = 0; i < rows; i++) { + int row_offset = 0; + for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { + ASSERT_EQ(gidx[i * shard.row_stride + row_offset], gmat.index[j]); + row_offset++; + } + + for (; row_offset < shard.row_stride; row_offset++) { + ASSERT_EQ(gidx[i * shard.row_stride + row_offset], shard.null_gidx_value); + } + } +} + +TEST(gpu_hist_experimental, TestDenseShard) { + int rows = 100; + int columns = 80; + int max_bins = 4; + auto dmat = CreateDMatrix(rows, columns, 0); + common::HistCutMatrix hmat; + common::GHistIndexMatrix gmat; + hmat.Init(dmat.get(), max_bins); + gmat.cut = &hmat; + gmat.Init(dmat.get()); + DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam()); + + ASSERT_EQ(shard.row_stride, columns); + + auto host_gidx_buffer = shard.gidx_buffer.as_vector(); + + common::CompressedIterator gidx(host_gidx_buffer.data(), + hmat.row_ptr.back() + 1); + + for (int i = 0; i < gmat.index.size(); i++) { + ASSERT_EQ(gidx[i], gmat.index[i]); + } +} + +} // namespace tree +} // namespace xgboost \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a9a68a14f..cb430ed55 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -7,316 +7,114 @@ import xgboost as xgb import numpy as np import unittest from nose.plugins.attrib import attr +from sklearn.datasets import load_digits, load_boston, load_breast_cancer, make_regression rng = np.random.RandomState(1994) -dpath = 'demo/data/' + +def non_increasing(L, tolerance): + return all((y - x) < tolerance for x, y in zip(L, L[1:])) + +#Check result is always decreasing and final accuracy is within tolerance +def assert_accuracy(res, tree_method, comparison_tree_method, tolerance): + assert non_increasing(res[tree_method], tolerance) + assert np.allclose(res[tree_method][-1], res[comparison_tree_method][-1], 1e-3, 1e-2) -def eprint(*args, **kwargs): - print(*args, file=sys.stderr, **kwargs) - print(*args, file=sys.stdout, **kwargs) +def train_boston(param_in, comparison_tree_method): + data = load_boston() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {} + param.update(param_in) + res_tmp = {} + res = {} + num_rounds = 10 + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[param['tree_method']] = res_tmp['train']['rmse'] + param["tree_method"] = comparison_tree_method + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[comparison_tree_method] = res_tmp['train']['rmse'] + + return res + + +def train_digits(param_in, comparison_tree_method): + data = load_digits() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {} + param['objective'] = 'multi:softmax' + param['num_class'] = 10 + param.update(param_in) + res_tmp = {} + res = {} + num_rounds = 10 + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[param['tree_method']] = res_tmp['train']['merror'] + param["tree_method"] = comparison_tree_method + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[comparison_tree_method] = res_tmp['train']['merror'] + return res + + +def train_cancer(param_in, comparison_tree_method): + data = load_breast_cancer() + dtrain = xgb.DMatrix(data.data, label=data.target) + param = {} + param['objective'] = 'binary:logistic' + param.update(param_in) + res_tmp = {} + res = {} + num_rounds = 10 + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[param['tree_method']] = res_tmp['train']['error'] + param["tree_method"] = comparison_tree_method + xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[comparison_tree_method] = res_tmp['train']['error'] + return res + + +def train_sparse(param_in, comparison_tree_method): + n = 5000 + sparsity = 0.75 + X, y = make_regression(n, random_state=rng) + X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X]) + dtrain = xgb.DMatrix(X, label=y) + param = {} + param.update(param_in) + res_tmp = {} + res = {} + num_rounds = 10 + bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[param['tree_method']] = res_tmp['train']['rmse'] + param["tree_method"] = comparison_tree_method + bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + res[comparison_tree_method] = res_tmp['train']['rmse'] + return res + + +def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, tolerance): + param = {'tree_method': tree_method} + for k, set in variable_param.items(): + for val in set: + param_tmp = param.copy() + param_tmp[k] = val + print(param_tmp, file=sys.stderr) + assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) + assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) + assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) + assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) @attr('gpu') class TestGPU(unittest.TestCase): - def test_grow_gpu(self): - from sklearn.datasets import load_digits - try: - from sklearn.model_selection import train_test_split - except: - from sklearn.cross_validation import train_test_split + def test_gpu_hist(self): + variable_param = {'max_depth': [2, 6, 11], 'max_bin': [2, 16, 1024], 'n_gpus': [1, -1]} + assert_updater_accuracy('gpu_hist', 'hist', variable_param, 0.02) - ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') - ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + def test_gpu_exact(self): + variable_param = {'max_depth': [2, 6, 15]} + assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02) - ag_param = {'max_depth': 2, - 'tree_method': 'exact', - 'nthread': 0, - 'eta': 1, - 'silent': 1, - 'debug_verbose': 0, - 'objective': 'binary:logistic', - 'eval_metric': 'auc'} - ag_param2 = {'max_depth': 2, - 'tree_method': 'gpu_exact', - 'nthread': 0, - 'eta': 1, - 'silent': 1, - 'debug_verbose': 0, - 'objective': 'binary:logistic', - 'eval_metric': 'auc'} - ag_res = {} - ag_res2 = {} - - num_rounds = 10 - xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res) - xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res2) - assert ag_res['train']['auc'] == ag_res2['train']['auc'] - assert ag_res['test']['auc'] == ag_res2['test']['auc'] - - digits = load_digits(2) - X = digits['data'] - y = digits['target'] - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - dtrain = xgb.DMatrix(X_train, y_train) - dtest = xgb.DMatrix(X_test, y_test) - - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_exact', - 'max_depth': 3, - 'debug_verbose': 0, - 'eval_metric': 'auc'} - res = {} - xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], - evals_result=res) - assert self.non_decreasing(res['train']['auc']) - assert self.non_decreasing(res['test']['auc']) - - # fail-safe test for dense data - from sklearn.datasets import load_svmlight_file - X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train') - X2 = X2.toarray() - dtrain2 = xgb.DMatrix(X2, label=y2) - - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_exact', - 'max_depth': 2, - 'debug_verbose': 0, - 'eval_metric': 'auc'} - res = {} - xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) - - assert self.non_decreasing(res['train']['auc']) - assert res['train']['auc'][0] >= 0.85 - - for j in range(X2.shape[1]): - for i in rng.choice(X2.shape[0], size=num_rounds, replace=False): - X2[i, j] = 2 - - dtrain3 = xgb.DMatrix(X2, label=y2) - res = {} - - xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res) - - assert self.non_decreasing(res['train']['auc']) - assert res['train']['auc'][0] >= 0.85 - - for j in range(X2.shape[1]): - for i in np.random.choice(X2.shape[0], size=num_rounds, replace=False): - X2[i, j] = 3 - - dtrain4 = xgb.DMatrix(X2, label=y2) - res = {} - xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res) - assert self.non_decreasing(res['train']['auc']) - assert res['train']['auc'][0] >= 0.85 - - def test_grow_gpu_hist(self): - n_gpus = -1 - from sklearn.datasets import load_digits - try: - from sklearn.model_selection import train_test_split - except: - from sklearn.cross_validation import train_test_split - - ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') - ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') - - for max_depth in range(3, 10): # TODO: Doesn't work with 2 for some tests - # eprint("max_depth=%d" % (max_depth)) - - for max_bin_i in range(3, 11): - max_bin = np.power(2, max_bin_i) - # eprint("max_bin=%d" % (max_bin)) - - - - # regression test --- hist must be same as exact on all-categorial data - ag_param = {'max_depth': max_depth, - 'tree_method': 'exact', - 'nthread': 0, - 'eta': 1, - 'silent': 1, - 'debug_verbose': 0, - 'objective': 'binary:logistic', - 'eval_metric': 'auc'} - ag_param2 = {'max_depth': max_depth, - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'eta': 1, - 'silent': 1, - 'debug_verbose': 0, - 'n_gpus': 1, - 'objective': 'binary:logistic', - 'max_bin': max_bin, - 'eval_metric': 'auc'} - ag_param3 = {'max_depth': max_depth, - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'eta': 1, - 'silent': 1, - 'debug_verbose': 0, - 'n_gpus': n_gpus, - 'objective': 'binary:logistic', - 'max_bin': max_bin, - 'eval_metric': 'auc'} - ag_res = {} - ag_res2 = {} - ag_res3 = {} - - num_rounds = 10 - # eprint("normal updater"); - xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res) - # eprint("grow_gpu_hist updater 1 gpu"); - xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res2) - # eprint("grow_gpu_hist updater %d gpus" % (n_gpus)); - xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], - evals_result=ag_res3) - # assert 1==0 - assert ag_res['train']['auc'] == ag_res2['train']['auc'] - assert ag_res['test']['auc'] == ag_res2['test']['auc'] - assert ag_res['test']['auc'] == ag_res3['test']['auc'] - - ###################################################################### - digits = load_digits(2) - X = digits['data'] - y = digits['target'] - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - dtrain = xgb.DMatrix(X_train, y_train) - dtest = xgb.DMatrix(X_test, y_test) - - param = {'objective': 'binary:logistic', - 'tree_method': 'gpu_hist', - 'nthread': 0, - 'max_depth': max_depth, - 'n_gpus': 1, - 'max_bin': max_bin, - 'debug_verbose': 0, - 'eval_metric': 'auc'} - res = {} - # eprint("digits: grow_gpu_hist updater 1 gpu"); - xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], - evals_result=res) - assert self.non_decreasing(res['train']['auc']) - # assert self.non_decreasing(res['test']['auc']) - param2 = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'max_depth': max_depth, - 'n_gpus': n_gpus, - 'max_bin': max_bin, - 'debug_verbose': 0, - 'eval_metric': 'auc'} - res2 = {} - # eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus)); - xgb.train(param2, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], - evals_result=res2) - assert self.non_decreasing(res2['train']['auc']) - # assert self.non_decreasing(res2['test']['auc']) - assert res['train']['auc'] == res2['train']['auc'] - # assert res['test']['auc'] == res2['test']['auc'] - - ###################################################################### - # fail-safe test for dense data - from sklearn.datasets import load_svmlight_file - X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train') - X2 = X2.toarray() - dtrain2 = xgb.DMatrix(X2, label=y2) - - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'max_depth': max_depth, - 'n_gpus': n_gpus, - 'max_bin': max_bin, - 'debug_verbose': 0, - 'eval_metric': 'auc'} - res = {} - xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) - - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - - for j in range(X2.shape[1]): - for i in rng.choice(X2.shape[0], size=num_rounds, replace=False): - X2[i, j] = 2 - - dtrain3 = xgb.DMatrix(X2, label=y2) - res = {} - - xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res) - - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - - for j in range(X2.shape[1]): - for i in np.random.choice(X2.shape[0], size=num_rounds, replace=False): - X2[i, j] = 3 - - dtrain4 = xgb.DMatrix(X2, label=y2) - res = {} - xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res) - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - - ###################################################################### - # fail-safe test for max_bin - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'max_depth': max_depth, - 'n_gpus': n_gpus, - 'debug_verbose': 0, - 'eval_metric': 'auc', - 'max_bin': max_bin} - res = {} - xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - ###################################################################### - # subsampling - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'max_depth': max_depth, - 'n_gpus': n_gpus, - 'eval_metric': 'auc', - 'colsample_bytree': 0.5, - 'colsample_bylevel': 0.5, - 'subsample': 0.5, - 'debug_verbose': 0, - 'max_bin': max_bin} - res = {} - xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - ###################################################################### - # fail-safe test for max_bin=2 - param = {'objective': 'binary:logistic', - 'nthread': 0, - 'tree_method': 'gpu_hist', - 'max_depth': 2, - 'n_gpus': n_gpus, - 'debug_verbose': 0, - 'eval_metric': 'auc', - 'max_bin': 2} - res = {} - xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) - assert self.non_decreasing(res['train']['auc']) - if max_bin > 32: - assert res['train']['auc'][0] >= 0.85 - - def non_decreasing(self, L): - return all((x - y) < 0.001 for x, y in zip(L, L[1:])) + def test_gpu_hist_experimental(self): + variable_param = {'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]} + assert_updater_accuracy('gpu_hist_experimental', 'hist', variable_param, 0.01)