diff --git a/CMakeLists.txt b/CMakeLists.txt index f1318ddc6..bb4168616 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,7 +97,7 @@ if(PLUGIN_UPDATER_GPU) #Find cub set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory") include_directories(${CUB_DIRECTORY}) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;") if(NOT MSVC) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC") endif() @@ -105,8 +105,9 @@ if(PLUGIN_UPDATER_GPU) plugin/updater_gpu/src/updater_gpu.cc ) find_package(CUDA QUIET REQUIRED) + file(GLOB_RECURSE CUDA_SOURCES "plugin/updater_gpu/src/*") cuda_add_library(updater_gpu STATIC - plugin/updater_gpu/src/gpu_builder.cu + ${CUDA_SOURCES} ) endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ce2aba503..6102f0ec6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -63,3 +63,5 @@ List of Contributors * [Alex Bain](https://github.com/convexquad) * [Baltazar Bieniek](https://github.com/bbieniek) * [Adam Pocock](https://github.com/Craigacp) +* [Rory Mitchell](https://github.com/RAMitchell) + - Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration diff --git a/README.md b/README.md index 36e97f88e..c707d35e6 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ The same code runs on major distributed environment (Hadoop, SGE, MPI) and can s What's New ---------- +* [XGBoost GPU support with fast histogram algorithm](https://github.com/dmlc/xgboost/tree/master/plugin/updater_gpu) * [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages) * [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html) * [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 17dae0659..14e16eb63 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -48,7 +48,7 @@ #define XGBOOST_ALIGNAS(X) #endif -#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 +#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && !defined(__CUDACC__) #include #define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z)) diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index 1bc7d982c..3ae8ba10c 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -1,29 +1,25 @@ -# CUDA Accelerated Tree Construction Algorithm - -## Benchmarks - -[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks - +# CUDA Accelerated Tree Construction Algorithms +This plugin adds GPU accelerated tree construction algorithms to XGBoost. ## Usage -Specify the updater parameter as 'grow_gpu'. +Specify the 'updater' parameter as one of the following algorithms. +updater | Description +--- | --- +grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist' +grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. + +All algorithms currently use only a single GPU. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0. This plugin currently works with the CLI version and python version. Python example: ```python +param['gpu_id'] = 1 param['updater'] = 'grow_gpu' ``` +## Benchmarks -## Memory usage -Device memory usage can be calculated as approximately: -``` -bytes = (10 x n_rows) + (40 x n_rows x n_columns x column_density) + (64 x max_nodes) + (76 x max_nodes_level x n_columns) -``` -The maximum number of nodes needed for a given tree depth d is 2d+1 - 1. The maximum number of nodes on any given level is 2d. +[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks of the 'grow_gpu' updater. -Data is stored in a sparse format. For example, missing values produced by one hot encoding are not stored. If a one hot encoding separates a categorical variable into 5 columns the density of these columns is 1/5 = 0.2. - -A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset. ## Dependencies A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). @@ -58,6 +54,15 @@ On Windows cmake will generate an xgboost.sln solution file in the build directo The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm. +## Changelog +##### 2017/4/25 +* Add fast histogram algorithm +* Fix Linux build +* Add 'gpu_id' parameter + +## References +[Mitchell, Rory, and Eibe Frank. Accelerating the XGBoost algorithm using GPU computing. No. e2911v1. PeerJ Preprints, 2017.](https://peerj.com/preprints/2911/) + ## Author Rory Mitchell diff --git a/plugin/updater_gpu/benchmark/benchmark.py b/plugin/updater_gpu/benchmark/benchmark.py new file mode 100644 index 000000000..0c61210f1 --- /dev/null +++ b/plugin/updater_gpu/benchmark/benchmark.py @@ -0,0 +1,32 @@ +#pylint: skip-file +import xgboost as xgb +import numpy as np +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +import time + +n = 1000000 +num_rounds = 100 + +X,y = make_classification(n, n_features=50, random_state=7) +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) +dtrain = xgb.DMatrix(X_train, y_train) +dtest = xgb.DMatrix(X_test, y_test) + +param = {'objective': 'binary:logistic', + 'tree_method': 'exact', + 'updater': 'grow_gpu_hist', + 'max_depth': 8, + 'silent': 1, + 'eval_metric': 'auc'} +res = {} +tmp = time.time() +xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) +print ("GPU: %s seconds" % (str(time.time() - tmp))) + +tmp = time.time() +param['updater'] = 'grow_fast_histmaker' +xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], evals_result=res) +print ("CPU: %s seconds" % (str(time.time() - tmp))) + diff --git a/plugin/updater_gpu/speed_test.py b/plugin/updater_gpu/speed_test.py deleted file mode 100644 index fc76d98b2..000000000 --- a/plugin/updater_gpu/speed_test.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/pytho#!/usr/bin/python -#pylint: skip-file -# this is the example script to use xgboost to train -import numpy as np -import xgboost as xgb -import time - -# path to where the data lies -dpath = '../../demo/data' - -# load in training data, directly use numpy -dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } ) -dtrain = np.concatenate((dtrain, np.copy(dtrain))) -dtrain = np.concatenate((dtrain, np.copy(dtrain))) -dtrain = np.concatenate((dtrain, np.copy(dtrain))) -test_size = len(dtrain) - -print(len(dtrain)) -print ('finish loading from csv ') - -label = dtrain[:,32] -data = dtrain[:,1:31] -# rescale weight to make it same as test set -weight = dtrain[:,31] * float(test_size) / len(label) - -sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 ) -sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 ) - -# print weight statistics -print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos )) - -# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value -xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight ) - -# setup parameters for xgboost -param = {} -# use logistic regression loss -param['objective'] = 'binary:logitraw' -# scale weight of positive examples -param['scale_pos_weight'] = sum_wneg/sum_wpos -param['bst:eta'] = 0.1 -param['max_depth'] = 15 -param['eval_metric'] = 'auc' -param['nthread'] = 16 - -plst = param.items()+[('eval_metric', 'ams@0.15')] - -watchlist = [ (xgmat,'train') ] -num_round = 10 -print ("training xgboost") -threads = [16] -for i in threads: - param['nthread'] = i - tmp = time.time() - plst = param.items()+[('eval_metric', 'ams@0.15')] - bst = xgb.train( plst, xgmat, num_round, watchlist ); - print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp))) - -print ("training xgboost - gpu tree construction") -param['updater'] = 'grow_gpu' -tmp = time.time() -plst = param.items()+[('eval_metric', 'ams@0.15')] -bst = xgb.train( plst, xgmat, num_round, watchlist ); -print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp))) -print ('finish training') diff --git a/plugin/updater_gpu/src/common.cuh b/plugin/updater_gpu/src/common.cuh new file mode 100644 index 000000000..52e4c1854 --- /dev/null +++ b/plugin/updater_gpu/src/common.cuh @@ -0,0 +1,153 @@ +/*! + * Copyright 2016 Rory mitchell + */ +#pragma once +#include +#include "../../../src/common/random.h" +#include "../../../src/tree/param.h" +#include "device_helpers.cuh" +#include "types.cuh" + +namespace xgboost { +namespace tree { +// When we split on a value which has no left neighbour, define its left +// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS +// This produces a split value slightly lower than the current instance +#define FVALUE_EPS 0.0001 + +__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param, + const gpu_gpair& scan, + const gpu_gpair& missing, + const gpu_gpair& parent_sum, + const float& parent_gain, + bool missing_left) { + gpu_gpair left = scan; + + if (missing_left) { + left += missing; + } + + gpu_gpair right = parent_sum - left; + + float left_gain = CalcGain(param, left.grad(), left.hess()); + float right_gain = CalcGain(param, right.grad(), right.hess()); + return left_gain + right_gain - parent_gain; +} + +__device__ float inline loss_chg_missing(const gpu_gpair& scan, + const gpu_gpair& missing, + const gpu_gpair& parent_sum, + const float& parent_gain, + const GPUTrainingParam& param, + bool& missing_left_out) { // NOLINT + float missing_left_loss = + device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true); + float missing_right_loss = device_calc_loss_chg( + param, scan, missing, parent_sum, parent_gain, false); + + if (missing_left_loss >= missing_right_loss) { + missing_left_out = true; + return missing_left_loss; + } else { + missing_left_out = false; + return missing_right_loss; + } +} + +// Total number of nodes in tree, given depth +__host__ __device__ inline int n_nodes(int depth) { + return (1 << (depth + 1)) - 1; +} + +// Number of nodes at this level of the tree +__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; } + +enum NodeType { + NODE = 0, + LEAF = 1, + UNUSED = 2, +}; + +// Recursively label node types +inline void flag_nodes(const thrust::host_vector& nodes, + std::vector* node_flags, int nid, + NodeType type) { + if (nid >= nodes.size() || type == UNUSED) { + return; + } + + const Node& n = nodes[nid]; + + // Current node and all children are valid + if (n.split.loss_chg > rt_eps) { + (*node_flags)[nid] = NODE; + flag_nodes(nodes, node_flags, nid * 2 + 1, NODE); + flag_nodes(nodes, node_flags, nid * 2 + 2, NODE); + } else { + // Current node is leaf, therefore is valid but all children are invalid + (*node_flags)[nid] = LEAF; + flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED); + flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED); + } +} + +// Copy gpu dense representation of tree to xgboost sparse representation +inline void dense2sparse_tree(RegTree* p_tree, + thrust::device_ptr nodes_begin, + thrust::device_ptr nodes_end, + const TrainParam& param) { + RegTree & tree = *p_tree; + thrust::host_vector h_nodes(nodes_begin, nodes_end); + std::vector node_flags(h_nodes.size(), UNUSED); + flag_nodes(h_nodes, &node_flags, 0, NODE); + + int nid = 0; + for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { + NodeType flag = node_flags[gpu_nid]; + const Node& n = h_nodes[gpu_nid]; + if (flag == NODE) { + tree.AddChilds(nid); + tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left); + tree.stat(nid).loss_chg = n.split.loss_chg; + tree.stat(nid).base_weight = n.weight; + tree.stat(nid).sum_hess = n.sum_gradients.hess(); + tree[tree[nid].cleft()].set_leaf(0); + tree[tree[nid].cright()].set_leaf(0); + nid++; + } else if (flag == LEAF) { + tree[nid].set_leaf(n.weight * param.learning_rate); + tree.stat(nid).sum_hess = n.sum_gradients.hess(); + nid++; + } + } +} + +// Set gradient pair to 0 with p = 1 - subsample +inline void subsample_gpair(dh::dvec* p_gpair, float subsample) { + if (subsample == 1.0) { + return; + } + + dh::dvec& gpair = *p_gpair; + + auto d_gpair = gpair.data(); + dh::BernoulliRng rng(subsample, common::GlobalRandom()()); + + dh::launch_n(gpair.size(), [=] __device__(int i) { + if (!rng(i)) { + d_gpair[i] = gpu_gpair(); + } + }); +} + +inline std::vector col_sample(std::vector features, float colsample) { + int n = colsample * features.size(); + CHECK_GT(n, 0); + + std::shuffle(features.begin(), features.end(), common::GlobalRandom()); + features.resize(n); + + return features; +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index ee8e4a2fc..6191c4c39 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -1,18 +1,19 @@ /*! * Copyright 2016 Rory mitchell -*/ + */ #pragma once #include #include #include +#include #include #include -#include #include #include #include #include #include +#include "cusparse_v2.h" #ifdef _WIN32 #include @@ -30,7 +31,8 @@ namespace dh { #define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__) -cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) { +inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, + int line) { if (code != cudaSuccess) { std::stringstream ss; ss << file << "(" << line << ")"; @@ -41,16 +43,29 @@ cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) { return code; } +#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__) -#define gpuErrchk(ans) \ +inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status, + const char *file, int line) { + if (status != CUSPARSE_STATUS_SUCCESS) { + std::stringstream ss; + ss << "cusparse error: " << file << "(" << line << ")"; + std::string error_text; + ss >> error_text; + throw error_text; + } + + return status; +} + +#define gpuErrchk(ans) \ { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); - if (abort) - exit(code); + if (abort) exit(code); } } @@ -119,10 +134,10 @@ struct DeviceTimer { #endif #ifdef DEVICE_TIMER - __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT + __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT : GTimer(GTimer), start(clock()), slot(slot) {} #else - __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT + __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT #endif __device__ void End() { @@ -224,21 +239,22 @@ class range { iterator end_; }; -template __device__ range grid_stride_range(T begin, T end) { +template +__device__ range grid_stride_range(T begin, T end) { begin += blockDim.x * blockIdx.x + threadIdx.x; range r(begin, end); r.step(gridDim.x * blockDim.x); return r; } -template __device__ range block_stride_range(T begin, T end) { +template +__device__ range block_stride_range(T begin, T end) { begin += threadIdx.x; range r(begin, end); r.step(blockDim.x); return r; } - // Threadblock iterates over range, filling with value template __device__ void block_fill(IterT begin, size_t n, ValueT value) { @@ -253,7 +269,8 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) { class bulk_allocator; -template class dvec { +template +class dvec { friend bulk_allocator; private: @@ -302,7 +319,8 @@ template class dvec { return thrust::device_pointer_cast(_ptr + size()); } - template dvec &operator=(const std::vector &other) { + template + dvec &operator=(const std::vector &other) { if (other.size() != size()) { throw std::runtime_error( "Cannot copy assign vector to dvec, sizes are different"); @@ -331,7 +349,8 @@ class bulk_allocator { const size_t align = 256; - template size_t align_round_up(SizeT n) { + template + size_t align_round_up(SizeT n) { if (n % align == 0) { return n; } else { @@ -357,7 +376,7 @@ class bulk_allocator { template void allocate_dvec(char *ptr, dvec *first_vec, SizeT first_size, Args... args) { - first_vec->external_allocate(static_cast(ptr), first_size); + first_vec->external_allocate(static_cast(ptr), first_size); ptr += align_round_up(first_size * sizeof(T)); allocate_dvec(ptr, args...); } @@ -366,14 +385,15 @@ class bulk_allocator { bulk_allocator() : _size(0), d_ptr(NULL) {} ~bulk_allocator() { - if (!d_ptr == NULL) { + if (!(d_ptr == nullptr)) { safe_cuda(cudaFree(d_ptr)); } } size_t size() { return _size; } - template void allocate(Args... args) { + template + void allocate(Args... args) { if (d_ptr != NULL) { throw std::runtime_error("Bulk allocator already allocated"); } @@ -393,14 +413,19 @@ struct CubMemory { CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} - ~CubMemory() { + ~CubMemory() { Free(); } + void Free() { if (d_temp_storage != NULL) { safe_cuda(cudaFree(d_temp_storage)); } } - void Allocate() { - safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + void LazyAllocate(size_t n_bytes) { + if (n_bytes > temp_storage_bytes) { + Free(); + safe_cuda(cudaMalloc(&d_temp_storage, n_bytes)); + temp_storage_bytes = n_bytes; + } } bool IsAllocated() { return d_temp_storage != NULL; } @@ -453,47 +478,58 @@ void print(char *label, const thrust::device_vector &v, std::cout << "\n"; } -template T1 div_round_up(const T1 a, const T2 b) { +template +T1 div_round_up(const T1 a, const T2 b) { return static_cast(ceil(static_cast(a) / b)); } -template thrust::device_ptr dptr(T *d_ptr) { +template +thrust::device_ptr dptr(T *d_ptr) { return thrust::device_pointer_cast(d_ptr); } -template T *raw(thrust::device_vector &v) { // NOLINT +template +T *raw(thrust::device_vector &v) { // NOLINT return raw_pointer_cast(v.data()); } -template size_t size_bytes(const thrust::device_vector &v) { +template +const T *raw(const thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} + +template +size_t size_bytes(const thrust::device_vector &v) { return sizeof(T) * v.size(); } /* * Kernel launcher */ -template __global__ void launch_n_kernel(size_t n, L lambda) { +template +__global__ void launch_n_kernel(size_t n, L lambda) { for (auto i : grid_stride_range(static_cast(0), n)) { lambda(i); } } -template +template inline void launch_n(size_t n, L lambda) { const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); - +#if defined(__CUDACC__) launch_n_kernel<<>>(n, lambda); +#endif } /* - * Random + * Random */ struct BernoulliRng { float p; int seed; - __host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {} + __host__ __device__ BernoulliRng(float p, int seed) : p(p), seed(seed) {} __host__ __device__ bool operator()(const int i) const { thrust::default_random_engine rng(seed); @@ -504,5 +540,4 @@ struct BernoulliRng { } }; - } // namespace dh diff --git a/plugin/updater_gpu/src/find_split.cuh b/plugin/updater_gpu/src/find_split.cuh index 070ea2efd..2451387cf 100644 --- a/plugin/updater_gpu/src/find_split.cuh +++ b/plugin/updater_gpu/src/find_split.cuh @@ -9,7 +9,7 @@ #include "find_split_multiscan.cuh" #include "find_split_sorting.cuh" #include "gpu_data.cuh" -#include "types_functions.cuh" +#include "types.cuh" namespace xgboost { namespace tree { diff --git a/plugin/updater_gpu/src/find_split_multiscan.cuh b/plugin/updater_gpu/src/find_split_multiscan.cuh index 41af54c7e..7abda0195 100644 --- a/plugin/updater_gpu/src/find_split_multiscan.cuh +++ b/plugin/updater_gpu/src/find_split_multiscan.cuh @@ -6,7 +6,8 @@ #include #include "device_helpers.cuh" #include "gpu_data.cuh" -#include "types_functions.cuh" +#include "types.cuh" +#include "common.cuh" namespace xgboost { namespace tree { diff --git a/plugin/updater_gpu/src/find_split_sorting.cuh b/plugin/updater_gpu/src/find_split_sorting.cuh index dae597ce1..02fe1aa24 100644 --- a/plugin/updater_gpu/src/find_split_sorting.cuh +++ b/plugin/updater_gpu/src/find_split_sorting.cuh @@ -5,7 +5,8 @@ #include #include #include "device_helpers.cuh" -#include "types_functions.cuh" +#include "types.cuh" +#include "common.cuh" namespace xgboost { namespace tree { diff --git a/plugin/updater_gpu/src/functions.cuh b/plugin/updater_gpu/src/functions.cuh new file mode 100644 index 000000000..c9f2ff863 --- /dev/null +++ b/plugin/updater_gpu/src/functions.cuh @@ -0,0 +1,15 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include "types.cuh" +#include "../../../src/tree/param.h" +#include "../../../src/common/random.h" + + +namespace xgboost { +namespace tree { + + +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_builder.cu b/plugin/updater_gpu/src/gpu_builder.cu index a8c242443..f16dc9fb9 100644 --- a/plugin/updater_gpu/src/gpu_builder.cu +++ b/plugin/updater_gpu/src/gpu_builder.cu @@ -1,6 +1,7 @@ /*! * Copyright 2016 Rory mitchell -*/ + */ +#include "gpu_builder.cuh" #include #include #include @@ -11,31 +12,35 @@ #include #include #include -#include #include +#include #include #include "../../../src/common/random.h" +#include "common.cuh" #include "device_helpers.cuh" #include "find_split.cuh" -#include "gpu_builder.cuh" -#include "types_functions.cuh" #include "gpu_data.cuh" +#include "types.cuh" namespace xgboost { namespace tree { - GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); } -void GPUBuilder::Init(const TrainParam ¶m_in) { +void GPUBuilder::Init(const TrainParam& param_in) { param = param_in; CHECK(param.max_depth < 16) << "Tree depth too large."; + + dh::safe_cuda(cudaSetDevice(param.gpu_id)); + if (!param.silent) { + LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name(); + } } GPUBuilder::~GPUBuilder() { delete gpu_data; } void GPUBuilder::UpdateNodeId(int level) { - auto *d_node_id_instance = gpu_data->node_id_instance.data(); - Node *d_nodes = gpu_data->nodes.data(); + auto* d_node_id_instance = gpu_data->node_id_instance.data(); + Node* d_nodes = gpu_data->nodes.data(); dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) { NodeIdT item_node_id = d_node_id_instance[i]; @@ -57,10 +62,10 @@ void GPUBuilder::UpdateNodeId(int level) { dh::safe_cuda(cudaDeviceSynchronize()); - auto *d_fvalues = gpu_data->fvalues.data(); - auto *d_instance_id = gpu_data->instance_id.data(); - auto *d_node_id = gpu_data->node_id.data(); - auto *d_feature_id = gpu_data->feature_id.data(); + auto* d_fvalues = gpu_data->fvalues.data(); + auto* d_instance_id = gpu_data->instance_id.data(); + auto* d_node_id = gpu_data->node_id.data(); + auto* d_feature_id = gpu_data->feature_id.data(); // Update node based on fvalue where exists dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) { @@ -128,75 +133,52 @@ void GPUBuilder::Sort(int level) { } void GPUBuilder::ColsampleTree() { - unsigned n = static_cast( - param.colsample_bytree * gpu_data->n_features); + unsigned n = + static_cast(param.colsample_bytree * gpu_data->n_features); CHECK_GT(n, 0); feature_set_tree.resize(gpu_data->n_features); std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); std::shuffle(feature_set_tree.begin(), feature_set_tree.end(), - common::GlobalRandom()); + common::GlobalRandom()); } -void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, - RegTree *p_tree) { - try { - dh::Timer update; - dh::Timer t; - this->InitData(gpair, *p_fmat, *p_tree); - t.printElapsed("init data"); - this->InitFirstNode(); - this->ColsampleTree(); +void GPUBuilder::Update(const std::vector& gpair, DMatrix* p_fmat, + RegTree* p_tree) { + this->InitData(gpair, *p_fmat, *p_tree); + this->InitFirstNode(); + this->ColsampleTree(); - for (int level = 0; level < param.max_depth; level++) { - bool use_multiscan_algorithm = level < multiscan_levels; + for (int level = 0; level < param.max_depth; level++) { + bool use_multiscan_algorithm = level < multiscan_levels; - t.reset(); - if (level > 0) { - dh::Timer update_node; - this->UpdateNodeId(level); - update_node.printElapsed("node"); - } - - if (level > 0 && !use_multiscan_algorithm) { - dh::Timer s; - this->Sort(level); - s.printElapsed("sort"); - } - - dh::Timer split; - find_split(gpu_data, param, level, use_multiscan_algorithm, - feature_set_tree, &feature_set_level); - - split.printElapsed("split"); - - t.printElapsed("level"); + if (level > 0) { + this->UpdateNodeId(level); } - this->CopyTree(*p_tree); - update.printElapsed("update"); - } catch (thrust::system_error &e) { - std::cerr << "CUDA error: " << e.what() << std::endl; - exit(-1); - } catch (const std::exception &e) { - std::cerr << "Error: " << e.what() << std::endl; - exit(-1); - } catch (...) { - std::cerr << "Unknown exception." << std::endl; - exit(-1); + + if (level > 0 && !use_multiscan_algorithm) { + this->Sort(level); + } + + find_split(gpu_data, param, level, use_multiscan_algorithm, + feature_set_tree, &feature_set_level); } + + dense2sparse_tree(p_tree, gpu_data->nodes.tbegin(), gpu_data->nodes.tend(), + param); } -void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, - const RegTree &tree) { - CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block"; +void GPUBuilder::InitData(const std::vector& gpair, DMatrix& fmat, + const RegTree& tree) { + CHECK(fmat.SingleColBlock()) << "grow_gpu: must have single column block. " + "Try setting 'tree_method' parameter to " + "'exact'"; if (gpu_data->IsAllocated()) { gpu_data->Reset(gpair, param.subsample); return; } - dh::Timer t; - MetaInfo info = fmat.info(); std::vector foffsets; @@ -208,31 +190,27 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, instance_id.reserve(info.num_col * info.num_row); feature_id.reserve(info.num_col * info.num_row); - dmlc::DataIter *iter = fmat.ColIterator(); + dmlc::DataIter* iter = fmat.ColIterator(); while (iter->Next()) { - const ColBatch &batch = iter->Value(); + const ColBatch& batch = iter->Value(); for (int i = 0; i < batch.size; i++) { - const ColBatch::Inst &col = batch[i]; + const ColBatch::Inst& col = batch[i]; - for (const ColBatch::Entry *it = col.data; it != col.data + col.length; + for (const ColBatch::Entry* it = col.data; it != col.data + col.length; it++) { bst_uint inst_id = it->index; - fvalues.push_back(it->fvalue); - instance_id.push_back(inst_id); - feature_id.push_back(i); + fvalues.push_back(it->fvalue); + instance_id.push_back(inst_id); + feature_id.push_back(i); } foffsets.push_back(fvalues.size()); } } - t.printElapsed("dmatrix"); - t.reset(); gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair, info.num_row, info.num_col, param.max_depth, param); - - t.printElapsed("gpu init"); } void GPUBuilder::InitFirstNode() { @@ -248,60 +226,5 @@ void GPUBuilder::InitFirstNode() { thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin()); } - -enum NodeType { - NODE = 0, - LEAF = 1, - UNUSED = 2, -}; - -// Recursively label node types -void flag_nodes(const thrust::host_vector &nodes, - std::vector *node_flags, int nid, NodeType type) { - if (nid >= nodes.size() || type == UNUSED) { - return; - } - - const Node &n = nodes[nid]; - - // Current node and all children are valid - if (n.split.loss_chg > rt_eps) { - (*node_flags)[nid] = NODE; - flag_nodes(nodes, node_flags, nid * 2 + 1, NODE); - flag_nodes(nodes, node_flags, nid * 2 + 2, NODE); - } else { - // Current node is leaf, therefore is valid but all children are invalid - (*node_flags)[nid] = LEAF; - flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED); - flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED); - } -} - -// Copy gpu dense representation of tree to xgboost sparse representation -void GPUBuilder::CopyTree(RegTree &tree) { - std::vector h_nodes = gpu_data->nodes.as_vector(); - std::vector node_flags(h_nodes.size(), UNUSED); - flag_nodes(h_nodes, &node_flags, 0, NODE); - - int nid = 0; - for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { - NodeType flag = node_flags[gpu_nid]; - const Node &n = h_nodes[gpu_nid]; - if (flag == NODE) { - tree.AddChilds(nid); - tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left); - tree.stat(nid).loss_chg = n.split.loss_chg; - tree.stat(nid).base_weight = n.weight; - tree.stat(nid).sum_hess = n.sum_gradients.hess(); - tree[tree[nid].cleft()].set_leaf(0); - tree[tree[nid].cright()].set_leaf(0); - nid++; - } else if (flag == LEAF) { - tree[nid].set_leaf(n.weight * param.learning_rate); - tree.stat(nid).sum_hess = n.sum_gradients.hess(); - nid++; - } - } -} } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh index bfdfa6d38..452f5c148 100644 --- a/plugin/updater_gpu/src/gpu_builder.cuh +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -1,6 +1,6 @@ /*! * Copyright 2016 Rory mitchell -*/ + */ #pragma once #include #include @@ -19,10 +19,7 @@ class GPUBuilder { void Init(const TrainParam ¶m); ~GPUBuilder(); - void UpdateParam(const TrainParam ¶m) - { - this->param = param; - } + void UpdateParam(const TrainParam ¶m) { this->param = param; } void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree); @@ -30,13 +27,11 @@ class GPUBuilder { void UpdateNodeId(int level); private: - void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT + void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT const RegTree &tree); - float GetSubsamplingRate(MetaInfo info); void Sort(int level); void InitFirstNode(); - void CopyTree(RegTree &tree); // NOLINT void ColsampleTree(); TrainParam param; diff --git a/plugin/updater_gpu/src/gpu_data.cuh b/plugin/updater_gpu/src/gpu_data.cuh index f3bc675d7..2cf74612d 100644 --- a/plugin/updater_gpu/src/gpu_data.cuh +++ b/plugin/updater_gpu/src/gpu_data.cuh @@ -1,14 +1,15 @@ /*! * Copyright 2016 Rory mitchell -*/ + */ #pragma once -#include -#include #include +#include +#include #include -#include "device_helpers.cuh" #include "../../src/tree/param.h" -#include "types_functions.cuh" +#include "common.cuh" +#include "device_helpers.cuh" +#include "types.cuh" namespace xgboost { namespace tree { @@ -67,9 +68,8 @@ struct GPUData { cub::DoubleBuffer db_value; cub::DeviceSegmentedRadixSort::SortPairs( - cub_mem.data(), cub_mem_size, db_key, - db_value, in_fvalues.size(), n_features, - foffsets.data(), foffsets.data() + 1); + cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(), + n_features, foffsets.data(), foffsets.data() + 1); // Allocate memory size_t free_memory = dh::available_memory(); @@ -100,7 +100,6 @@ struct GPUData { param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, param_in.reg_alpha, param_in.max_delta_step); - allocated = true; this->Reset(in_gpair, param_in.subsample); @@ -109,33 +108,16 @@ struct GPUData { thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), fvalues.tbegin(), node_id.tbegin())); - dh::safe_cuda(cudaGetLastError()); } ~GPUData() {} - // Set gradient pair to 0 with p = 1 - subsample - void MarkSubsample(float subsample) { - if (subsample == 1.0) { - return; - } - - auto d_gpair = gpair.data(); - dh::BernoulliRng rng(subsample, common::GlobalRandom()()); - - dh::launch_n(n_instances, [=] __device__(int i) { - if (!rng(i)) { - d_gpair[i] = gpu_gpair(); - } - }); - } - // Reset memory for new boosting iteration void Reset(const std::vector &in_gpair, float subsample) { CHECK(allocated); gpair = in_gpair; - this->MarkSubsample(subsample); + subsample_gpair(&gpair, subsample); instance_id = instance_id_cached; fvalues = fvalues_cached; nodes.fill(Node()); @@ -153,7 +135,6 @@ struct GPUData { auto d_instance_id = instance_id.data(); dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { - // Item item = d_items[i]; d_node_id[i] = d_node_id_instance[d_instance_id[i]]; }); } diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu new file mode 100644 index 000000000..6b6eaf1ed --- /dev/null +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -0,0 +1,560 @@ +/*! + * Copyright 2017 Rory mitchell +*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include "common.cuh" +#include "device_helpers.cuh" +#include "gpu_hist_builder.cuh" + +namespace xgboost { +namespace tree { +void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) { + CHECK_EQ(gidx.size(), gmat.index.size()) + << "gidx must be externally allocated"; + CHECK_EQ(ridx.size(), gmat.index.size()) + << "ridx must be externally allocated"; + + gidx = gmat.index; + thrust::device_vector row_ptr = gmat.row_ptr; + + auto counting = thrust::make_counting_iterator(0); + thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting, + counting + gidx.size(), ridx.tbegin()); + thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(), + [=] __device__(int val) { return val - 1; }); +} + +void DeviceHist::Init(int n_bins_in) { + this->n_bins = n_bins_in; + CHECK(!hist.empty()) << "DeviceHist must be externally allocated"; +} + +void DeviceHist::Reset() { hist.fill(gpu_gpair()); } + +gpu_gpair* DeviceHist::GetLevelPtr(int depth) { + return hist.data() + n_nodes(depth - 1) * n_bins; +} + +int DeviceHist::LevelSize(int depth) { return n_bins * n_nodes_level(depth); } + +HistBuilder DeviceHist::GetBuilder() { + return HistBuilder(hist.data(), n_bins); +} + +HistBuilder::HistBuilder(gpu_gpair* ptr, int n_bins) + : d_hist(ptr), n_bins(n_bins) {} + +__device__ void HistBuilder::Add(gpu_gpair gpair, int gidx, int nidx) const { + int hist_idx = nidx * n_bins + gidx; + atomicAdd(&(d_hist[hist_idx]._grad), gpair._grad); + atomicAdd(&(d_hist[hist_idx]._hess), gpair._hess); +} + +__device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const { + return d_hist[nidx * n_bins + gidx]; +} + +GPUHistBuilder::GPUHistBuilder() {} + +GPUHistBuilder::~GPUHistBuilder() {} + +void GPUHistBuilder::Init(const TrainParam& param) { + CHECK(param.max_depth < 16) << "Tree depth too large."; + this->param = param; + initialised = false; + is_dense = false; + + dh::safe_cuda(cudaSetDevice(param.gpu_id)); + if (!param.silent) { + LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name(); + } +} + +template +struct ReduceBySegmentOp { + ReductionOpT op; + + __host__ __device__ __forceinline__ ReduceBySegmentOp() {} + + __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) + : op(op) {} + + template + __host__ __device__ __forceinline__ KeyValuePairT + operator()(const KeyValuePairT& first, const KeyValuePairT& second) { + KeyValuePairT retval; + retval.key = second.key; + retval.value = + first.key != second.key ? second.value : op(first.value, second.value); + return retval; + } +}; + +template +__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx, + NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins, + int depth, int n) { + typedef cub::KeyValuePair OffsetValuePairT; + typedef cub::BlockLoad + BlockLoadT; + typedef cub::BlockRadixSort + BlockRadixSortT; + typedef cub::BlockDiscontinuity BlockDiscontinuityKeysT; + typedef ReduceBySegmentOp ReduceBySegmentOpT; + typedef cub::BlockScan + BlockScanT; + + union TempStorage { + typename BlockLoadT::TempStorage load; + typename BlockRadixSortT::TempStorage sort; + typename BlockScanT::TempStorage scan; + typename BlockDiscontinuityKeysT::TempStorage disc; + }; + + __shared__ TempStorage temp_storage; + + int ridx[ITEMS_PER_THREAD]; + int gidx[ITEMS_PER_THREAD]; + + const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS; + int block_offset = blockIdx.x * TILE_SIZE; + + BlockLoadT(temp_storage.load) + .Load(d_ridx + block_offset, ridx, n - block_offset, -1); + BlockLoadT(temp_storage.load) + .Load(d_gidx + block_offset, gidx, n - block_offset, -1); + + int hist_idx[ITEMS_PER_THREAD]; + + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) { + if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) { + hist_idx[ITEM] = + (d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM]; + } else { + hist_idx[ITEM] = -1; + } + } + + __syncthreads(); + BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx); + + OffsetValuePairT kv[ITEMS_PER_THREAD]; + + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) { + kv[ITEM].key = hist_idx[ITEM]; + if (ridx[ITEM] > -1) { + kv[ITEM].value = d_gpair[ridx[ITEM]]; + } + } + + __syncthreads(); + // Scan + BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT()); + + __syncthreads(); + int flags[ITEMS_PER_THREAD]; + BlockDiscontinuityKeysT(temp_storage.disc) + .FlagTails(flags, hist_idx, cub::Inequality()); + + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) { + if (flags[ITEM]) { + if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) { + atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad); + atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess); + } + } + } +} + +void GPUHistBuilder::BuildHist(int depth) { + auto d_ridx = device_matrix.ridx.data(); + auto d_gidx = device_matrix.gidx.data(); + auto d_position = position.data(); + auto d_gpair = device_gpair.data(); + auto hist_builder = hist.GetBuilder(); + + dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) { + int ridx = d_ridx[idx]; + int pos = d_position[ridx]; + + // Only increment even nodes + if (pos < 0 || pos % 2 == 1) return; + + int gidx = d_gidx[idx]; + gpu_gpair gpair = d_gpair[ridx]; + + hist_builder.Add(gpair, gidx, pos); + }); + + dh::safe_cuda(cudaDeviceSynchronize()); + + // Subtraction trick + int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins; + if (n_sub_bins > 0) { + dh::launch_n(n_sub_bins, [=] __device__(int idx) { + int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2); + int gidx = idx % hist_builder.n_bins; + gpu_gpair parent = hist_builder.Get(gidx, nidx / 2); + gpu_gpair right = hist_builder.Get(gidx, nidx + 1); + hist_builder.Add(parent - right, gidx, nidx); + }); + } + dh::safe_cuda(cudaDeviceSynchronize()); +} + +void GPUHistBuilder::FindSplit(int depth) { + auto counting = thrust::make_counting_iterator(0); + auto d_gidx_feature_map = gidx_feature_map.data(); + int n_bins = hmat_.row_ptr.back(); + int n_features = hmat_.row_ptr.size() - 1; + + auto feature_boundary = [=] __device__(int idx_a, int idx_b) { + int gidx_a = idx_a % n_bins; + int gidx_b = idx_b % n_bins; + return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b]; + }; // NOLINT + + // Reduce node sums + { + size_t temp_storage_bytes; + cub::DeviceSegmentedReduce::Reduce( + nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(), + n_nodes_level(depth) * n_features, feature_segments.data(), + feature_segments.data() + 1, cub::Sum(), gpu_gpair()); + cub_mem.LazyAllocate(temp_storage_bytes); + cub::DeviceSegmentedReduce::Reduce( + cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, + hist.GetLevelPtr(depth), node_sums.data(), + n_nodes_level(depth) * n_features, feature_segments.data(), + feature_segments.data() + 1, cub::Sum(), gpu_gpair()); + } + + // Scan + thrust::exclusive_scan_by_key( + counting, counting + hist.LevelSize(depth), + thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(), + gpu_gpair(), feature_boundary); + + // Calculate gain + auto d_gain = gain.data(); + auto d_nodes = nodes.data(); + auto d_node_sums = node_sums.data(); + auto d_hist_scan = hist_scan.data(); + GPUTrainingParam gpu_param_alias = + gpu_param; // Must be local variable to be used in device lambda + bool colsample = + param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0; + auto d_feature_flags = feature_flags.data(); + + dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) { + int node_segment = idx / n_bins; + int node_idx = n_nodes(depth - 1) + node_segment; + gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients; + float parent_gain = d_nodes[node_idx].root_gain; + int gidx = idx % n_bins; + int findex = d_gidx_feature_map[gidx]; + + // colsample + if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) { + d_gain[idx] = 0; + } else { + gpu_gpair scan = d_hist_scan[idx]; + gpu_gpair sum = d_node_sums[node_segment * n_features + findex]; + gpu_gpair missing = parent_sum - sum; + + bool missing_left; + d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain, + gpu_param_alias, missing_left); + } + }); + dh::safe_cuda(cudaDeviceSynchronize()); + + // Find best gain + { + size_t temp_storage_bytes; + cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(), + argmax.data(), n_nodes_level(depth), + hist_node_segments.data(), + hist_node_segments.data() + 1); + cub_mem.LazyAllocate(temp_storage_bytes); + cub::DeviceSegmentedReduce::ArgMax( + cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(), + argmax.data(), n_nodes_level(depth), hist_node_segments.data(), + hist_node_segments.data() + 1); + } + + auto d_argmax = argmax.data(); + auto d_gidx_fvalue_map = gidx_fvalue_map.data(); + auto d_fidx_min_map = fidx_min_map.data(); + dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) { + int max_idx = n_bins * idx + d_argmax[idx].key; + int gidx = max_idx % n_bins; + int fidx = d_gidx_feature_map[gidx]; + int node_segment = max_idx / n_bins; + int node_idx = n_nodes(depth - 1) + node_segment; + gpu_gpair scan = d_hist_scan[max_idx]; + gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients; + float parent_gain = d_nodes[node_idx].root_gain; + gpu_gpair sum = d_node_sums[node_segment * n_features + fidx]; + gpu_gpair missing = parent_sum - sum; + + bool missing_left; + float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain, + gpu_param_alias, missing_left); + + float fvalue; + if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) { + fvalue = d_fidx_min_map[fidx]; + } else { + fvalue = d_gidx_fvalue_map[gidx - 1]; + } + + gpu_gpair left = missing_left ? scan + missing : scan; + gpu_gpair right = parent_sum - left; + d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left, + right, gpu_param_alias); + + int left_child_idx = n_nodes(depth) + idx * 2; + int right_child_idx = n_nodes(depth) + idx * 2 + 1; + d_nodes[left_child_idx] = + Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()), + CalcWeight(gpu_param_alias, left.grad(), left.hess())); + + d_nodes[right_child_idx] = + Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()), + CalcWeight(gpu_param_alias, right.grad(), right.hess())); + }); + dh::safe_cuda(cudaDeviceSynchronize()); +} + +void GPUHistBuilder::InitFirstNode() { + // Build the root node on the CPU and copy to device + gpu_gpair sum_gradients = + thrust::reduce(device_gpair.tbegin(), device_gpair.tend(), + gpu_gpair(0, 0), thrust::plus()); + + Node tmp = + Node(sum_gradients, + CalcGain(param, sum_gradients.grad(), sum_gradients.hess()), + CalcWeight(param, sum_gradients.grad(), sum_gradients.hess())); + + thrust::copy_n(&tmp, 1, nodes.tbegin()); +} + +void GPUHistBuilder::UpdatePosition() { + if (is_dense) { + this->UpdatePositionDense(); + } else { + this->UpdatePositionSparse(); + } +} + +void GPUHistBuilder::UpdatePositionDense() { + auto d_position = position.data(); + Node* d_nodes = nodes.data(); + auto d_gidx_fvalue_map = gidx_fvalue_map.data(); + auto d_gidx = device_matrix.gidx.data(); + int n_columns = info->num_col; + + dh::launch_n(position.size(), [=] __device__(int idx) { + NodeIdT pos = d_position[idx]; + if (pos < 0) { + return; + } + + Node node = d_nodes[pos]; + + if (node.IsLeaf()) { + d_position[idx] = -1; + return; + } + + int gidx = d_gidx[idx * n_columns + node.split.findex]; + + float fvalue = d_gidx_fvalue_map[gidx]; + + if (fvalue <= node.split.fvalue) { + d_position[idx] = pos * 2 + 1; + } else { + d_position[idx] = pos * 2 + 2; + } + }); +} + +void GPUHistBuilder::UpdatePositionSparse() { + auto d_position = position.data(); + auto d_position_tmp = position_tmp.data(); + Node* d_nodes = nodes.data(); + auto d_gidx_feature_map = gidx_feature_map.data(); + auto d_gidx_fvalue_map = gidx_fvalue_map.data(); + auto d_gidx = device_matrix.gidx.data(); + auto d_ridx = device_matrix.ridx.data(); + + // Update missing direction + dh::launch_n(position.size(), [=] __device__(int idx) { + NodeIdT pos = d_position[idx]; + if (pos < 0) { + return; + } + + Node node = d_nodes[pos]; + + if (node.IsLeaf()) { + d_position_tmp[idx] = -1; + } else if (node.split.missing_left) { + d_position_tmp[idx] = pos * 2 + 1; + } else { + d_position_tmp[idx] = pos * 2 + 2; + } + }); + + // Update node based on fvalue where exists + dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) { + int ridx = d_ridx[idx]; + NodeIdT pos = d_position[ridx]; + if (pos < 0) { + return; + } + + Node node = d_nodes[pos]; + + if (node.IsLeaf()) { + return; + } + + int gidx = d_gidx[idx]; + int findex = d_gidx_feature_map[gidx]; + + if (findex == node.split.findex) { + float fvalue = d_gidx_fvalue_map[gidx]; + + if (fvalue <= node.split.fvalue) { + d_position_tmp[ridx] = pos * 2 + 1; + } else { + d_position_tmp[ridx] = pos * 2 + 2; + } + } + }); + + position = position_tmp; +} + +void GPUHistBuilder::ColSampleTree() { + feature_set_tree.resize(info->num_col); + std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); + feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree); +} + +void GPUHistBuilder::ColSampleLevel() { + feature_set_level.resize(feature_set_tree.size()); + feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel); + std::vector h_feature_flags(info->num_col, 0); + for (auto fidx : feature_set_level) { + h_feature_flags[fidx] = 1; + } + feature_flags = h_feature_flags; +} + +void GPUHistBuilder::InitData(const std::vector& gpair, + DMatrix& fmat, // NOLINT + const RegTree& tree) { + if (!initialised) { + CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column " + "block. Try setting 'tree_method' " + "parameter to 'exact'"; + info = &fmat.info(); + is_dense = info->num_nonzero == info->num_col * info->num_row; + hmat_.Init(&fmat, param.max_bin); + gmat_.cut = &hmat_; + gmat_.Init(&fmat); + int n_bins = hmat_.row_ptr.back(); + int n_features = hmat_.row_ptr.size() - 1; + + // Build feature segments + std::vector h_feature_segments; + for (int node = 0; node < n_nodes_level(param.max_depth - 1); node++) { + for (int fidx = 0; fidx < hmat_.row_ptr.size() - 1; fidx++) { + h_feature_segments.push_back(hmat_.row_ptr[fidx] + node * n_bins); + } + } + h_feature_segments.push_back(n_nodes_level(param.max_depth - 1) * n_bins); + + int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins; + + size_t free_memory = dh::available_memory(); + ba.allocate( + &gidx_feature_map, n_bins, &hist_node_segments, + n_nodes_level(param.max_depth - 1) + 1, &feature_segments, + h_feature_segments.size(), &gain, level_max_bins, &position, + gpair.size(), &position_tmp, gpair.size(), &nodes, + n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(), + &fidx_min_map, hmat_.min_val.size(), &argmax, + n_nodes_level(param.max_depth - 1), &node_sums, + n_nodes_level(param.max_depth - 1) * n_features, &hist_scan, + level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx, + gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist, + n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features); + + if (!param.silent) { + const int mb_size = 1048576; + LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" + << free_memory / mb_size << " MB on " << dh::device_name(); + } + + // Construct feature map + std::vector h_gidx_feature_map(n_bins); + for (int row = 0; row < hmat_.row_ptr.size() - 1; row++) { + for (int i = hmat_.row_ptr[row]; i < hmat_.row_ptr[row + 1]; i++) { + h_gidx_feature_map[i] = row; + } + } + + gidx_feature_map = h_gidx_feature_map; + + // Construct device matrix + device_matrix.Init(gmat_); + + gidx_fvalue_map = hmat_.cut; + fidx_min_map = hmat_.min_val; + + thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0, + n_bins); + + feature_segments = h_feature_segments; + + hist.Init(n_bins); + + initialised = true; + } + nodes.fill(Node()); + position.fill(0); + device_gpair = gpair; + subsample_gpair(&device_gpair, param.subsample); + hist.Reset(); +} + +void GPUHistBuilder::Update(const std::vector& gpair, + DMatrix* p_fmat, RegTree* p_tree) { + this->InitData(gpair, *p_fmat, *p_tree); + this->InitFirstNode(); + this->ColSampleTree(); + + for (int depth = 0; depth < param.max_depth; depth++) { + this->ColSampleLevel(); + this->BuildHist(depth); + this->FindSplit(depth); + this->UpdatePosition(); + } + dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param); +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh new file mode 100644 index 000000000..810d4e2ba --- /dev/null +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -0,0 +1,104 @@ +/*! + * Copyright 2016 Rory mitchell + */ +#pragma once +#include +#include +#include +#include // Need key value pair definition +#include +#include "../../src/common/hist_util.h" +#include "../../src/tree/param.h" +#include "device_helpers.cuh" +#include "types.cuh" + +namespace xgboost { + +namespace tree { + +struct DeviceGMat { + dh::dvec gidx; + dh::dvec ridx; + void Init(const common::GHistIndexMatrix &gmat); +}; + +struct HistBuilder { + gpu_gpair *d_hist; + int n_bins; + __host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins); + __device__ void Add(gpu_gpair gpair, int gidx, int nidx) const; + __device__ gpu_gpair Get(int gidx, int nidx) const; +}; + +struct DeviceHist { + int n_bins; + dh::dvec hist; + + void Init(int max_depth); + + void Reset(); + + HistBuilder GetBuilder(); + + gpu_gpair *GetLevelPtr(int depth); + + int LevelSize(int depth); +}; + +class GPUHistBuilder { + public: + GPUHistBuilder(); + ~GPUHistBuilder(); + void Init(const TrainParam ¶m); + + void UpdateParam(const TrainParam ¶m) { + this->param = param; + this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda, + param.reg_alpha, param.max_delta_step); + } + + void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT + const RegTree &tree); + void Update(const std::vector &gpair, DMatrix *p_fmat, + RegTree *p_tree); + void BuildHist(int depth); + void FindSplit(int depth); + void InitFirstNode(); + void UpdatePosition(); + void UpdatePositionDense(); + void UpdatePositionSparse(); + void ColSampleTree(); + void ColSampleLevel(); + + TrainParam param; + GPUTrainingParam gpu_param; + common::HistCutMatrix hmat_; + common::GHistIndexMatrix gmat_; + MetaInfo *info; + bool initialised; + bool is_dense; + DeviceGMat device_matrix; + + dh::bulk_allocator ba; + dh::CubMemory cub_mem; + dh::dvec gidx_feature_map; + dh::dvec hist_node_segments; + dh::dvec feature_segments; + dh::dvec gain; + dh::dvec position; + dh::dvec position_tmp; + dh::dvec gidx_fvalue_map; + dh::dvec fidx_min_map; + DeviceHist hist; + dh::dvec> argmax; + dh::dvec node_sums; + dh::dvec hist_scan; + dh::dvec device_gpair; + dh::dvec nodes; + dh::dvec feature_flags; + + std::vector feature_set_tree; + std::vector feature_set_level; +}; +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/loss_functions.cuh b/plugin/updater_gpu/src/loss_functions.cuh deleted file mode 100644 index 00796d051..000000000 --- a/plugin/updater_gpu/src/loss_functions.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/*! - * Copyright 2016 Rory mitchell -*/ -#pragma once -#include "types.cuh" -#include "../../../src/tree/param.h" - -// When we split on a value which has no left neighbour, define its left -// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS -// This produces a split value slightly lower than the current instance -#define FVALUE_EPS 0.0001 - -namespace xgboost { -namespace tree { - - -__device__ __forceinline__ float -device_calc_loss_chg(const GPUTrainingParam ¶m, const gpu_gpair &scan, - const gpu_gpair &missing, const gpu_gpair &parent_sum, - const float &parent_gain, bool missing_left) { - gpu_gpair left = scan; - - if (missing_left) { - left += missing; - } - - gpu_gpair right = parent_sum - left; - - float left_gain = CalcGain(param, left.grad(), left.hess()); - float right_gain = CalcGain(param, right.grad(), right.hess()); - return left_gain + right_gain - parent_gain; -} - -__device__ __forceinline__ float -loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing, - const gpu_gpair &parent_sum, const float &parent_gain, - const GPUTrainingParam ¶m, bool &missing_left_out) { // NOLINT - float missing_left_loss = - device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true); - float missing_right_loss = device_calc_loss_chg( - param, scan, missing, parent_sum, parent_gain, false); - - if (missing_left_loss >= missing_right_loss) { - missing_left_out = true; - return missing_left_loss; - } else { - missing_left_out = false; - return missing_right_loss; - } -} -} // namespace tree -} // namespace xgboost diff --git a/plugin/updater_gpu/src/types.cuh b/plugin/updater_gpu/src/types.cuh index 544d1f0e1..0e1e37da3 100644 --- a/plugin/updater_gpu/src/types.cuh +++ b/plugin/updater_gpu/src/types.cuh @@ -1,8 +1,11 @@ /*! * Copyright 2016 Rory mitchell -*/ + */ #pragma once +#include #include +#include +#include #include // The linter is not very smart and thinks we need this namespace xgboost { @@ -104,8 +107,10 @@ struct GPUTrainingParam { __host__ __device__ GPUTrainingParam(float min_child_weight_in, float reg_lambda_in, float reg_alpha_in, float max_delta_step_in) - : min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in), - reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {} + : min_child_weight(min_child_weight_in), + reg_lambda(reg_lambda_in), + reg_alpha(reg_alpha_in), + max_delta_step(max_delta_step_in) {} }; struct Split { @@ -123,8 +128,9 @@ struct Split { float fvalue_in, int findex_in, gpu_gpair left_sum_in, gpu_gpair right_sum_in, const GPUTrainingParam ¶m) { - if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight && - right_sum_in.hess() > param.min_child_weight) { + if (loss_chg_in > loss_chg && + left_sum_in.hess() >= param.min_child_weight && + right_sum_in.hess() >= param.min_child_weight) { loss_chg = loss_chg_in; missing_left = missing_left_in; fvalue = fvalue_in; @@ -135,7 +141,7 @@ struct Split { } // Does not check minimum weight - __device__ void Update(Split &s) { // NOLINT + __device__ void Update(Split &s) { // NOLINT if (s.loss_chg > loss_chg) { loss_chg = s.loss_chg; missing_left = s.missing_left; @@ -146,7 +152,7 @@ struct Split { } } - __device__ void Print() { + __host__ __device__ void Print() { printf("Loss: %1.4f\n", loss_chg); printf("Missing left: %d\n", missing_left); printf("fvalue: %1.4f\n", fvalue); @@ -160,7 +166,7 @@ struct Split { struct split_reduce_op { template - __device__ __forceinline__ T operator()(T &a, T b) { // NOLINT + __device__ __forceinline__ T operator()(T &a, T b) { // NOLINT b.Update(a); return b; } @@ -184,5 +190,6 @@ struct Node { __host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; } }; + } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/types_functions.cuh b/plugin/updater_gpu/src/types_functions.cuh deleted file mode 100644 index f7bd8e65f..000000000 --- a/plugin/updater_gpu/src/types_functions.cuh +++ /dev/null @@ -1,6 +0,0 @@ -/*! - * Copyright 2016 Rory mitchell -*/ -#pragma once -#include "types.cuh" -#include "loss_functions.cuh" diff --git a/plugin/updater_gpu/src/updater_gpu.cc b/plugin/updater_gpu/src/updater_gpu.cc index 3c7badee2..78336df67 100644 --- a/plugin/updater_gpu/src/updater_gpu.cc +++ b/plugin/updater_gpu/src/updater_gpu.cc @@ -7,30 +7,37 @@ #include "../../src/common/sync.h" #include "../../src/tree/param.h" #include "gpu_builder.cuh" +#include "gpu_hist_builder.cuh" namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpumaker); /*! \brief column-wise update to construct a tree */ -template class GPUMaker : public TreeUpdater { +template +class GPUMaker : public TreeUpdater { public: - void - Init(const std::vector> &args) override { + void Init( + const std::vector>& args) override { param.InitAllowUnknown(args); builder.Init(param); } - void Update(const std::vector &gpair, DMatrix *dmat, - const std::vector &trees) override { + void Update(const std::vector& gpair, DMatrix* dmat, + const std::vector& trees) override { TStats::CheckInfo(dmat->info()); // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); builder.UpdateParam(param); - // build tree - for (size_t i = 0; i < trees.size(); ++i) { - builder.Update(gpair, dmat, trees[i]); + + try { + // build tree + for (size_t i = 0; i < trees.size(); ++i) { + builder.Update(gpair, dmat, trees[i]); + } + } catch (const std::exception& e) { + LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; } param.learning_rate = lr; } @@ -41,9 +48,45 @@ template class GPUMaker : public TreeUpdater { GPUBuilder builder; }; +template +class GPUHistMaker : public TreeUpdater { + public: + void Init( + const std::vector>& args) override { + param.InitAllowUnknown(args); + builder.Init(param); + } + + void Update(const std::vector& gpair, DMatrix* dmat, + const std::vector& trees) override { + TStats::CheckInfo(dmat->info()); + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + builder.UpdateParam(param); + // build tree + try { + for (size_t i = 0; i < trees.size(); ++i) { + builder.Update(gpair, dmat, trees[i]); + } + } catch (const std::exception& e) { + LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; + } + param.learning_rate = lr; + } + + protected: + // training parameter + TrainParam param; + GPUHistBuilder builder; +}; + XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu") .describe("Grow tree with GPU.") .set_body([]() { return new GPUMaker(); }); +XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") + .describe("Grow tree with GPU.") + .set_body([]() { return new GPUHistMaker(); }); } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/test.py b/plugin/updater_gpu/test.py deleted file mode 100644 index 4258b8d02..000000000 --- a/plugin/updater_gpu/test.py +++ /dev/null @@ -1,137 +0,0 @@ -#pylint: skip-file -import numpy as np -import xgboost as xgb -import os -import pandas as pd -import urllib2 - -class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - - -def get_last_eval_callback(result): - - def callback(env): - result.append(env.evaluation_result_list[-1][1]) - - callback.after_iteration = True - return callback - - -def load_adult(): - path = "../../demo/data/adult.data" - - if(not os.path.isfile(path)): - data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data') - with open(path,'wb') as output: - output.write(data.read()) - - train_set = pd.read_csv( path, header=None) - - train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', - 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', - 'wage_class'] - train_nomissing = train_set.replace(' ?', np.nan).dropna() - for feature in train_nomissing.columns: # Loop through all columns in the dataframe - if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings - train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer - - y_train = train_nomissing.pop('wage_class') - - return xgb.DMatrix( train_nomissing, label=y_train) - - -def load_higgs(): - higgs_path = '../../demo/data/training.csv' - dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } ) - - #dtrain = dtrain[0:200000,:] - label = dtrain[:,32] - data = dtrain[:,1:31] - weight = dtrain[:,31] - - return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight ) - -def load_dermatology(): - data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } ) - sz = data.shape - - X = data[:,0:33] - Y = data[:, 34] - - return xgb.DMatrix( X, label=Y) - -def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): - return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - -#Check GPU test evaluation is approximately equal to CPU test evaluation -def check_result(cpu_result, gpu_result): - for i in range(len(cpu_result)): - if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02): - return False - - return True - - -#Get data -data = [] -params = [] -data.append(load_higgs()) -params.append({}) - - -data.append( load_adult()) -params.append({}) - -data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test')) -params.append({'objective':'binary:logistic'}) - -#if(os.path.isfile("../../demo/data/dermatology.data")): -data.append(load_dermatology()) -params.append({'objective':'multi:softmax', 'num_class': 6}) - -num_round = 5 - -num_pass = 0 -num_fail = 0 - -test_depth = [ 1, 6, 9, 11, 15 ] -#test_depth = [ 1 ] - -for test in range(0, len(data)): - for depth in test_depth: - xgmat = data[test] - cpu_result = [] - param = params[test] - param['max_depth'] = depth - param['updater'] = 'grow_colmaker' - xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)]) - - #bst = xgb.train( param, xgmat, 1); - #bst.dump_model('reference_model.txt','', True) - - gpu_result = [] - param['updater'] = 'grow_gpu' - xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)]) - - #bst = xgb.train( param, xgmat, 1); - #bst.dump_model('dump.raw.txt','', True) - - if check_result(cpu_result, gpu_result): - print(bcolors.OKGREEN + "Pass" + bcolors.ENDC) - num_pass = num_pass + 1 - else: - print(bcolors.FAIL + "Fail" + bcolors.ENDC) - num_fail = num_fail + 1 - - print("cpu rmse: "+str(cpu_result)) - print("gpu rmse: "+str(gpu_result)) - -print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed") diff --git a/plugin/updater_gpu/test/test.py b/plugin/updater_gpu/test/test.py new file mode 100644 index 000000000..f3520bd11 --- /dev/null +++ b/plugin/updater_gpu/test/test.py @@ -0,0 +1,221 @@ +#pylint: skip-file +import sys +sys.path.append("../../tests/python") +import xgboost as xgb +import testing as tm +import numpy as np +import unittest + +rng = np.random.RandomState(1994) + +dpath = '../../demo/data/' +ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') +ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + + +class TestGPU(unittest.TestCase): + def test_grow_gpu(self): + tm._skip_if_no_sklearn() + from sklearn.datasets import load_digits + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split + + ag_param = {'max_depth': 2, + 'tree_method': 'exact', + 'nthread': 1, + 'eta': 1, + 'silent': 1, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_param2 = {'max_depth': 2, + 'updater': 'grow_gpu', + 'eta': 1, + 'silent': 1, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_res = {} + ag_res2 = {} + + num_rounds = 10 + xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res) + xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res2) + assert ag_res['train']['auc'] == ag_res2['train']['auc'] + assert ag_res['test']['auc'] == ag_res2['test']['auc'] + + digits = load_digits(2) + X = digits['data'] + y = digits['target'] + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu', + 'max_depth': 3, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert self.non_decreasing(res['test']['auc']) + + # fail-safe test for dense data + from sklearn.datasets import load_svmlight_file + X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train') + X2 = X2.toarray() + dtrain2 = xgb.DMatrix(X2, label=y2) + + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu', + 'max_depth': 2, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in rng.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 2 + + dtrain3 = xgb.DMatrix(X2, label=y2) + res = {} + + xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res) + + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in np.random.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 3 + + dtrain4 = xgb.DMatrix(X2, label=y2) + res = {} + xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + + def test_grow_gpu_hist(self): + tm._skip_if_no_sklearn() + from sklearn.datasets import load_digits + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split + + # regression test --- hist must be same as exact on all-categorial data + ag_param = {'max_depth': 2, + 'tree_method': 'exact', + 'nthread': 1, + 'eta': 1, + 'silent': 1, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_param2 = {'max_depth': 2, + 'updater': 'grow_gpu_hist', + 'eta': 1, + 'silent': 1, + 'objective': 'binary:logistic', + 'eval_metric': 'auc'} + ag_res = {} + ag_res2 = {} + + num_rounds = 10 + xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res) + xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')], + evals_result=ag_res2) + assert ag_res['train']['auc'] == ag_res2['train']['auc'] + assert ag_res['test']['auc'] == ag_res2['test']['auc'] + + digits = load_digits(2) + X = digits['data'] + y = digits['target'] + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu_hist', + 'max_depth': 3, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')], + evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert self.non_decreasing(res['test']['auc']) + + # fail-safe test for dense data + from sklearn.datasets import load_svmlight_file + X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train') + X2 = X2.toarray() + dtrain2 = xgb.DMatrix(X2, label=y2) + + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu_hist', + 'grow_policy': 'depthwise', + 'max_depth': 2, + 'eval_metric': 'auc'} + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in rng.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 2 + + dtrain3 = xgb.DMatrix(X2, label=y2) + res = {} + + xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res) + + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + for j in range(X2.shape[1]): + for i in np.random.choice(X2.shape[0], size=10, replace=False): + X2[i, j] = 3 + + dtrain4 = xgb.DMatrix(X2, label=y2) + res = {} + xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + # fail-safe test for max_bin=2 + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu_hist', + 'max_depth': 2, + 'eval_metric': 'auc', + 'max_bin': 2} + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + # subsampling + param = {'objective': 'binary:logistic', + 'updater': 'grow_gpu_hist', + 'max_depth': 3, + 'eval_metric': 'auc', + 'colsample_bytree': 0.5, + 'colsample_bylevel': 0.5, + 'subsample': 0.5 + } + res = {} + xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res) + assert self.non_decreasing(res['train']['auc']) + assert res['train']['auc'][0] >= 0.85 + + + def non_decreasing(self, L): + return all((x - y) < 0.001 for x, y in zip(L, L[1:])) diff --git a/src/tree/param.h b/src/tree/param.h index 590fb363e..23048cde3 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -79,6 +79,8 @@ struct TrainParam : public dmlc::Parameter { bool refresh_leaf; // auxiliary data structure std::vector monotone_constraints; + // gpu to use for single gpu algorithms + int gpu_id; // declare the parameters DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_FIELD(learning_rate) @@ -186,6 +188,10 @@ struct TrainParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(monotone_constraints) .set_default(std::vector()) .describe("Constraint of variable monotonicity"); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for single gpu algorithms"); // add alias of parameters DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_alpha, alpha);