[GPU-Plugin] (#2227)

* Add fast histogram algorithm
* Fix Linux build
* Add 'gpu_id' parameter
This commit is contained in:
Rory Mitchell 2017-04-26 11:37:10 +12:00 committed by Tianqi Chen
parent d281c6aafa
commit 8ab5d4611c
25 changed files with 1318 additions and 492 deletions

View File

@ -97,7 +97,7 @@ if(PLUGIN_UPDATER_GPU)
#Find cub #Find cub
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory") set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
include_directories(${CUB_DIRECTORY}) include_directories(${CUB_DIRECTORY})
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;")
if(NOT MSVC) if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
endif() endif()
@ -105,8 +105,9 @@ if(PLUGIN_UPDATER_GPU)
plugin/updater_gpu/src/updater_gpu.cc plugin/updater_gpu/src/updater_gpu.cc
) )
find_package(CUDA QUIET REQUIRED) find_package(CUDA QUIET REQUIRED)
file(GLOB_RECURSE CUDA_SOURCES "plugin/updater_gpu/src/*")
cuda_add_library(updater_gpu STATIC cuda_add_library(updater_gpu STATIC
plugin/updater_gpu/src/gpu_builder.cu ${CUDA_SOURCES}
) )
endif() endif()

View File

@ -63,3 +63,5 @@ List of Contributors
* [Alex Bain](https://github.com/convexquad) * [Alex Bain](https://github.com/convexquad)
* [Baltazar Bieniek](https://github.com/bbieniek) * [Baltazar Bieniek](https://github.com/bbieniek)
* [Adam Pocock](https://github.com/Craigacp) * [Adam Pocock](https://github.com/Craigacp)
* [Rory Mitchell](https://github.com/RAMitchell)
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration

View File

@ -21,6 +21,7 @@ The same code runs on major distributed environment (Hadoop, SGE, MPI) and can s
What's New What's New
---------- ----------
* [XGBoost GPU support with fast histogram algorithm](https://github.com/dmlc/xgboost/tree/master/plugin/updater_gpu)
* [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages) * [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages)
* [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html) * [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html)
* [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html) * [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html)

View File

@ -48,7 +48,7 @@
#define XGBOOST_ALIGNAS(X) #define XGBOOST_ALIGNAS(X)
#endif #endif
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 #if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && !defined(__CUDACC__)
#include <parallel/algorithm> #include <parallel/algorithm>
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))

View File

@ -1,29 +1,25 @@
# CUDA Accelerated Tree Construction Algorithm # CUDA Accelerated Tree Construction Algorithms
This plugin adds GPU accelerated tree construction algorithms to XGBoost.
## Benchmarks
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks
## Usage ## Usage
Specify the updater parameter as 'grow_gpu'. Specify the 'updater' parameter as one of the following algorithms.
updater | Description
--- | ---
grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist'
grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate.
All algorithms currently use only a single GPU. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
This plugin currently works with the CLI version and python version. This plugin currently works with the CLI version and python version.
Python example: Python example:
```python ```python
param['gpu_id'] = 1
param['updater'] = 'grow_gpu' param['updater'] = 'grow_gpu'
``` ```
## Benchmarks
## Memory usage [See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks of the 'grow_gpu' updater.
Device memory usage can be calculated as approximately:
```
bytes = (10 x n_rows) + (40 x n_rows x n_columns x column_density) + (64 x max_nodes) + (76 x max_nodes_level x n_columns)
```
The maximum number of nodes needed for a given tree depth d is 2<sup>d+1</sup> - 1. The maximum number of nodes on any given level is 2<sup>d</sup>.
Data is stored in a sparse format. For example, missing values produced by one hot encoding are not stored. If a one hot encoding separates a categorical variable into 5 columns the density of these columns is 1/5 = 0.2.
A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset.
## Dependencies ## Dependencies
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
@ -58,6 +54,15 @@ On Windows cmake will generate an xgboost.sln solution file in the build directo
The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm. The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm.
## Changelog
##### 2017/4/25
* Add fast histogram algorithm
* Fix Linux build
* Add 'gpu_id' parameter
## References
[Mitchell, Rory, and Eibe Frank. Accelerating the XGBoost algorithm using GPU computing. No. e2911v1. PeerJ Preprints, 2017.](https://peerj.com/preprints/2911/)
## Author ## Author
Rory Mitchell Rory Mitchell

View File

@ -0,0 +1,32 @@
#pylint: skip-file
import xgboost as xgb
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import time
n = 1000000
num_rounds = 100
X,y = make_classification(n, n_features=50, random_state=7)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'tree_method': 'exact',
'updater': 'grow_gpu_hist',
'max_depth': 8,
'silent': 1,
'eval_metric': 'auc'}
res = {}
tmp = time.time()
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
print ("GPU: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
param['updater'] = 'grow_fast_histmaker'
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], evals_result=res)
print ("CPU: %s seconds" % (str(time.time() - tmp)))

View File

@ -1,65 +0,0 @@
#!/usr/bin/pytho#!/usr/bin/python
#pylint: skip-file
# this is the example script to use xgboost to train
import numpy as np
import xgboost as xgb
import time
# path to where the data lies
dpath = '../../demo/data'
# load in training data, directly use numpy
dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } )
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
test_size = len(dtrain)
print(len(dtrain))
print ('finish loading from csv ')
label = dtrain[:,32]
data = dtrain[:,1:31]
# rescale weight to make it same as test set
weight = dtrain[:,31] * float(test_size) / len(label)
sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 )
sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 )
# print weight statistics
print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos ))
# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value
xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
# setup parameters for xgboost
param = {}
# use logistic regression loss
param['objective'] = 'binary:logitraw'
# scale weight of positive examples
param['scale_pos_weight'] = sum_wneg/sum_wpos
param['bst:eta'] = 0.1
param['max_depth'] = 15
param['eval_metric'] = 'auc'
param['nthread'] = 16
plst = param.items()+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
num_round = 10
print ("training xgboost")
threads = [16]
for i in threads:
param['nthread'] = i
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp)))
print ("training xgboost - gpu tree construction")
param['updater'] = 'grow_gpu'
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp)))
print ('finish training')

View File

@ -0,0 +1,153 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
// When we split on a value which has no left neighbour, define its left
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
// This produces a split value slightly lower than the current instance
#define FVALUE_EPS 0.0001
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
const gpu_gpair& scan,
const gpu_gpair& missing,
const gpu_gpair& parent_sum,
const float& parent_gain,
bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;
}
gpu_gpair right = parent_sum - left;
float left_gain = CalcGain(param, left.grad(), left.hess());
float right_gain = CalcGain(param, right.grad(), right.hess());
return left_gain + right_gain - parent_gain;
}
__device__ float inline loss_chg_missing(const gpu_gpair& scan,
const gpu_gpair& missing,
const gpu_gpair& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// Total number of nodes in tree, given depth
__host__ __device__ inline int n_nodes(int depth) {
return (1 << (depth + 1)) - 1;
}
// Number of nodes at this level of the tree
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
inline void flag_nodes(const thrust::host_vector<Node>& nodes,
std::vector<NodeType>* node_flags, int nid,
NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node& n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
inline void dense2sparse_tree(RegTree* p_tree,
thrust::device_ptr<Node> nodes_begin,
thrust::device_ptr<Node> nodes_end,
const TrainParam& param) {
RegTree & tree = *p_tree;
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node& n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess();
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess();
nid++;
}
}
}
// Set gradient pair to 0 with p = 1 - subsample
inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
if (subsample == 1.0) {
return;
}
dh::dvec<gpu_gpair>& gpair = *p_gpair;
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(gpair.size(), [=] __device__(int i) {
if (!rng(i)) {
d_gpair[i] = gpu_gpair();
}
});
}
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
int n = colsample * features.size();
CHECK_GT(n, 0);
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
return features;
}
} // namespace tree
} // namespace xgboost

View File

@ -1,18 +1,19 @@
/*! /*!
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#pragma once #pragma once
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <device_launch_parameters.h> #include <device_launch_parameters.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/random.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/random.h>
#include <algorithm> #include <algorithm>
#include <ctime> #include <ctime>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "cusparse_v2.h"
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
@ -30,7 +31,8 @@ namespace dh {
#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__) #define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__)
cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) { inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
int line) {
if (code != cudaSuccess) { if (code != cudaSuccess) {
std::stringstream ss; std::stringstream ss;
ss << file << "(" << line << ")"; ss << file << "(" << line << ")";
@ -41,16 +43,29 @@ cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
return code; return code;
} }
#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__)
#define gpuErrchk(ans) \ inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status,
const char *file, int line) {
if (status != CUSPARSE_STATUS_SUCCESS) {
std::stringstream ss;
ss << "cusparse error: " << file << "(" << line << ")";
std::string error_text;
ss >> error_text;
throw error_text;
}
return status;
}
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); } { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) { bool abort = true) {
if (code != cudaSuccess) { if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line); line);
if (abort) if (abort) exit(code);
exit(code);
} }
} }
@ -119,10 +134,10 @@ struct DeviceTimer {
#endif #endif
#ifdef DEVICE_TIMER #ifdef DEVICE_TIMER
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) // NOLINT __device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) // NOLINT
: GTimer(GTimer), start(clock()), slot(slot) {} : GTimer(GTimer), start(clock()), slot(slot) {}
#else #else
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) {} // NOLINT __device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) {} // NOLINT
#endif #endif
__device__ void End() { __device__ void End() {
@ -224,21 +239,22 @@ class range {
iterator end_; iterator end_;
}; };
template <typename T> __device__ range grid_stride_range(T begin, T end) { template <typename T>
__device__ range grid_stride_range(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x; begin += blockDim.x * blockIdx.x + threadIdx.x;
range r(begin, end); range r(begin, end);
r.step(gridDim.x * blockDim.x); r.step(gridDim.x * blockDim.x);
return r; return r;
} }
template <typename T> __device__ range block_stride_range(T begin, T end) { template <typename T>
__device__ range block_stride_range(T begin, T end) {
begin += threadIdx.x; begin += threadIdx.x;
range r(begin, end); range r(begin, end);
r.step(blockDim.x); r.step(blockDim.x);
return r; return r;
} }
// Threadblock iterates over range, filling with value // Threadblock iterates over range, filling with value
template <typename IterT, typename ValueT> template <typename IterT, typename ValueT>
__device__ void block_fill(IterT begin, size_t n, ValueT value) { __device__ void block_fill(IterT begin, size_t n, ValueT value) {
@ -253,7 +269,8 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) {
class bulk_allocator; class bulk_allocator;
template <typename T> class dvec { template <typename T>
class dvec {
friend bulk_allocator; friend bulk_allocator;
private: private:
@ -302,7 +319,8 @@ template <typename T> class dvec {
return thrust::device_pointer_cast(_ptr + size()); return thrust::device_pointer_cast(_ptr + size());
} }
template <typename T2> dvec &operator=(const std::vector<T2> &other) { template <typename T2>
dvec &operator=(const std::vector<T2> &other) {
if (other.size() != size()) { if (other.size() != size()) {
throw std::runtime_error( throw std::runtime_error(
"Cannot copy assign vector to dvec, sizes are different"); "Cannot copy assign vector to dvec, sizes are different");
@ -331,7 +349,8 @@ class bulk_allocator {
const size_t align = 256; const size_t align = 256;
template <typename SizeT> size_t align_round_up(SizeT n) { template <typename SizeT>
size_t align_round_up(SizeT n) {
if (n % align == 0) { if (n % align == 0) {
return n; return n;
} else { } else {
@ -357,7 +376,7 @@ class bulk_allocator {
template <typename T, typename SizeT, typename... Args> template <typename T, typename SizeT, typename... Args>
void allocate_dvec(char *ptr, dvec<T> *first_vec, SizeT first_size, void allocate_dvec(char *ptr, dvec<T> *first_vec, SizeT first_size,
Args... args) { Args... args) {
first_vec->external_allocate(static_cast<void*>(ptr), first_size); first_vec->external_allocate(static_cast<void *>(ptr), first_size);
ptr += align_round_up(first_size * sizeof(T)); ptr += align_round_up(first_size * sizeof(T));
allocate_dvec(ptr, args...); allocate_dvec(ptr, args...);
} }
@ -366,14 +385,15 @@ class bulk_allocator {
bulk_allocator() : _size(0), d_ptr(NULL) {} bulk_allocator() : _size(0), d_ptr(NULL) {}
~bulk_allocator() { ~bulk_allocator() {
if (!d_ptr == NULL) { if (!(d_ptr == nullptr)) {
safe_cuda(cudaFree(d_ptr)); safe_cuda(cudaFree(d_ptr));
} }
} }
size_t size() { return _size; } size_t size() { return _size; }
template <typename... Args> void allocate(Args... args) { template <typename... Args>
void allocate(Args... args) {
if (d_ptr != NULL) { if (d_ptr != NULL) {
throw std::runtime_error("Bulk allocator already allocated"); throw std::runtime_error("Bulk allocator already allocated");
} }
@ -393,14 +413,19 @@ struct CubMemory {
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
~CubMemory() { ~CubMemory() { Free(); }
void Free() {
if (d_temp_storage != NULL) { if (d_temp_storage != NULL) {
safe_cuda(cudaFree(d_temp_storage)); safe_cuda(cudaFree(d_temp_storage));
} }
} }
void Allocate() { void LazyAllocate(size_t n_bytes) {
safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes)); if (n_bytes > temp_storage_bytes) {
Free();
safe_cuda(cudaMalloc(&d_temp_storage, n_bytes));
temp_storage_bytes = n_bytes;
}
} }
bool IsAllocated() { return d_temp_storage != NULL; } bool IsAllocated() { return d_temp_storage != NULL; }
@ -453,36 +478,47 @@ void print(char *label, const thrust::device_vector<T> &v,
std::cout << "\n"; std::cout << "\n";
} }
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) { template <typename T1, typename T2>
T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b)); return static_cast<T1>(ceil(static_cast<double>(a) / b));
} }
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) { template <typename T>
thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr); return thrust::device_pointer_cast(d_ptr);
} }
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT template <typename T>
T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data()); return raw_pointer_cast(v.data());
} }
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) { template <typename T>
const T *raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T>
size_t size_bytes(const thrust::device_vector<T> &v) {
return sizeof(T) * v.size(); return sizeof(T) * v.size();
} }
/* /*
* Kernel launcher * Kernel launcher
*/ */
template <typename L> __global__ void launch_n_kernel(size_t n, L lambda) { template <typename L>
__global__ void launch_n_kernel(size_t n, L lambda) {
for (auto i : grid_stride_range(static_cast<size_t>(0), n)) { for (auto i : grid_stride_range(static_cast<size_t>(0), n)) {
lambda(i); lambda(i);
} }
} }
template <typename L, int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256> template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void launch_n(size_t n, L lambda) { inline void launch_n(size_t n, L lambda) {
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda); launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
#endif
} }
/* /*
@ -493,7 +529,7 @@ struct BernoulliRng {
float p; float p;
int seed; int seed;
__host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {} __host__ __device__ BernoulliRng(float p, int seed) : p(p), seed(seed) {}
__host__ __device__ bool operator()(const int i) const { __host__ __device__ bool operator()(const int i) const {
thrust::default_random_engine rng(seed); thrust::default_random_engine rng(seed);
@ -504,5 +540,4 @@ struct BernoulliRng {
} }
}; };
} // namespace dh } // namespace dh

View File

@ -9,7 +9,7 @@
#include "find_split_multiscan.cuh" #include "find_split_multiscan.cuh"
#include "find_split_sorting.cuh" #include "find_split_sorting.cuh"
#include "gpu_data.cuh" #include "gpu_data.cuh"
#include "types_functions.cuh" #include "types.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -6,7 +6,8 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "gpu_data.cuh" #include "gpu_data.cuh"
#include "types_functions.cuh" #include "types.cuh"
#include "common.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -5,7 +5,8 @@
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <xgboost/base.h> #include <xgboost/base.h>
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "types_functions.cuh" #include "types.cuh"
#include "common.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -0,0 +1,15 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "../../../src/tree/param.h"
#include "../../../src/common/random.h"
namespace xgboost {
namespace tree {
} // namespace tree
} // namespace xgboost

View File

@ -1,6 +1,7 @@
/*! /*!
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#include "gpu_builder.cuh"
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -11,31 +12,35 @@
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <algorithm> #include <algorithm>
#include <random>
#include <numeric> #include <numeric>
#include <random>
#include <vector> #include <vector>
#include "../../../src/common/random.h" #include "../../../src/common/random.h"
#include "common.cuh"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "find_split.cuh" #include "find_split.cuh"
#include "gpu_builder.cuh"
#include "types_functions.cuh"
#include "gpu_data.cuh" #include "gpu_data.cuh"
#include "types.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); } GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
void GPUBuilder::Init(const TrainParam &param_in) { void GPUBuilder::Init(const TrainParam& param_in) {
param = param_in; param = param_in;
CHECK(param.max_depth < 16) << "Tree depth too large."; CHECK(param.max_depth < 16) << "Tree depth too large.";
dh::safe_cuda(cudaSetDevice(param.gpu_id));
if (!param.silent) {
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
}
} }
GPUBuilder::~GPUBuilder() { delete gpu_data; } GPUBuilder::~GPUBuilder() { delete gpu_data; }
void GPUBuilder::UpdateNodeId(int level) { void GPUBuilder::UpdateNodeId(int level) {
auto *d_node_id_instance = gpu_data->node_id_instance.data(); auto* d_node_id_instance = gpu_data->node_id_instance.data();
Node *d_nodes = gpu_data->nodes.data(); Node* d_nodes = gpu_data->nodes.data();
dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) { dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) {
NodeIdT item_node_id = d_node_id_instance[i]; NodeIdT item_node_id = d_node_id_instance[i];
@ -57,10 +62,10 @@ void GPUBuilder::UpdateNodeId(int level) {
dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaDeviceSynchronize());
auto *d_fvalues = gpu_data->fvalues.data(); auto* d_fvalues = gpu_data->fvalues.data();
auto *d_instance_id = gpu_data->instance_id.data(); auto* d_instance_id = gpu_data->instance_id.data();
auto *d_node_id = gpu_data->node_id.data(); auto* d_node_id = gpu_data->node_id.data();
auto *d_feature_id = gpu_data->feature_id.data(); auto* d_feature_id = gpu_data->feature_id.data();
// Update node based on fvalue where exists // Update node based on fvalue where exists
dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) { dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) {
@ -128,75 +133,52 @@ void GPUBuilder::Sort(int level) {
} }
void GPUBuilder::ColsampleTree() { void GPUBuilder::ColsampleTree() {
unsigned n = static_cast<unsigned>( unsigned n =
param.colsample_bytree * gpu_data->n_features); static_cast<unsigned>(param.colsample_bytree * gpu_data->n_features);
CHECK_GT(n, 0); CHECK_GT(n, 0);
feature_set_tree.resize(gpu_data->n_features); feature_set_tree.resize(gpu_data->n_features);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
std::shuffle(feature_set_tree.begin(), feature_set_tree.end(), std::shuffle(feature_set_tree.begin(), feature_set_tree.end(),
common::GlobalRandom()); common::GlobalRandom());
} }
void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat, void GPUBuilder::Update(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree *p_tree) { RegTree* p_tree) {
try { this->InitData(gpair, *p_fmat, *p_tree);
dh::Timer update; this->InitFirstNode();
dh::Timer t; this->ColsampleTree();
this->InitData(gpair, *p_fmat, *p_tree);
t.printElapsed("init data");
this->InitFirstNode();
this->ColsampleTree();
for (int level = 0; level < param.max_depth; level++) { for (int level = 0; level < param.max_depth; level++) {
bool use_multiscan_algorithm = level < multiscan_levels; bool use_multiscan_algorithm = level < multiscan_levels;
t.reset(); if (level > 0) {
if (level > 0) { this->UpdateNodeId(level);
dh::Timer update_node;
this->UpdateNodeId(level);
update_node.printElapsed("node");
}
if (level > 0 && !use_multiscan_algorithm) {
dh::Timer s;
this->Sort(level);
s.printElapsed("sort");
}
dh::Timer split;
find_split(gpu_data, param, level, use_multiscan_algorithm,
feature_set_tree, &feature_set_level);
split.printElapsed("split");
t.printElapsed("level");
} }
this->CopyTree(*p_tree);
update.printElapsed("update"); if (level > 0 && !use_multiscan_algorithm) {
} catch (thrust::system_error &e) { this->Sort(level);
std::cerr << "CUDA error: " << e.what() << std::endl; }
exit(-1);
} catch (const std::exception &e) { find_split(gpu_data, param, level, use_multiscan_algorithm,
std::cerr << "Error: " << e.what() << std::endl; feature_set_tree, &feature_set_level);
exit(-1);
} catch (...) {
std::cerr << "Unknown exception." << std::endl;
exit(-1);
} }
dense2sparse_tree(p_tree, gpu_data->nodes.tbegin(), gpu_data->nodes.tend(),
param);
} }
void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, void GPUBuilder::InitData(const std::vector<bst_gpair>& gpair, DMatrix& fmat,
const RegTree &tree) { const RegTree& tree) {
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block"; CHECK(fmat.SingleColBlock()) << "grow_gpu: must have single column block. "
"Try setting 'tree_method' parameter to "
"'exact'";
if (gpu_data->IsAllocated()) { if (gpu_data->IsAllocated()) {
gpu_data->Reset(gpair, param.subsample); gpu_data->Reset(gpair, param.subsample);
return; return;
} }
dh::Timer t;
MetaInfo info = fmat.info(); MetaInfo info = fmat.info();
std::vector<int> foffsets; std::vector<int> foffsets;
@ -208,31 +190,27 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
instance_id.reserve(info.num_col * info.num_row); instance_id.reserve(info.num_col * info.num_row);
feature_id.reserve(info.num_col * info.num_row); feature_id.reserve(info.num_col * info.num_row);
dmlc::DataIter<ColBatch> *iter = fmat.ColIterator(); dmlc::DataIter<ColBatch>* iter = fmat.ColIterator();
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch& batch = iter->Value();
for (int i = 0; i < batch.size; i++) { for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst &col = batch[i]; const ColBatch::Inst& col = batch[i];
for (const ColBatch::Entry *it = col.data; it != col.data + col.length; for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
it++) { it++) {
bst_uint inst_id = it->index; bst_uint inst_id = it->index;
fvalues.push_back(it->fvalue); fvalues.push_back(it->fvalue);
instance_id.push_back(inst_id); instance_id.push_back(inst_id);
feature_id.push_back(i); feature_id.push_back(i);
} }
foffsets.push_back(fvalues.size()); foffsets.push_back(fvalues.size());
} }
} }
t.printElapsed("dmatrix");
t.reset();
gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair, gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair,
info.num_row, info.num_col, param.max_depth, param); info.num_row, info.num_col, param.max_depth, param);
t.printElapsed("gpu init");
} }
void GPUBuilder::InitFirstNode() { void GPUBuilder::InitFirstNode() {
@ -248,60 +226,5 @@ void GPUBuilder::InitFirstNode() {
thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin()); thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin());
} }
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
void flag_nodes(const thrust::host_vector<Node> &nodes,
std::vector<NodeType> *node_flags, int nid, NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node &n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
void GPUBuilder::CopyTree(RegTree &tree) {
std::vector<Node> h_nodes = gpu_data->nodes.as_vector();
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node &n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess();
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess();
nid++;
}
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,6 +1,6 @@
/*! /*!
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#pragma once #pragma once
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <vector> #include <vector>
@ -19,10 +19,7 @@ class GPUBuilder {
void Init(const TrainParam &param); void Init(const TrainParam &param);
~GPUBuilder(); ~GPUBuilder();
void UpdateParam(const TrainParam &param) void UpdateParam(const TrainParam &param) { this->param = param; }
{
this->param = param;
}
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat, void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree); RegTree *p_tree);
@ -30,13 +27,11 @@ class GPUBuilder {
void UpdateNodeId(int level); void UpdateNodeId(int level);
private: private:
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree); const RegTree &tree);
float GetSubsamplingRate(MetaInfo info);
void Sort(int level); void Sort(int level);
void InitFirstNode(); void InitFirstNode();
void CopyTree(RegTree &tree); // NOLINT
void ColsampleTree(); void ColsampleTree();
TrainParam param; TrainParam param;

View File

@ -1,14 +1,15 @@
/*! /*!
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#pragma once #pragma once
#include <cub/cub.cuh>
#include <xgboost/logging.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <xgboost/logging.h>
#include <cub/cub.cuh>
#include <vector> #include <vector>
#include "device_helpers.cuh"
#include "../../src/tree/param.h" #include "../../src/tree/param.h"
#include "types_functions.cuh" #include "common.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -67,9 +68,8 @@ struct GPUData {
cub::DoubleBuffer<int> db_value; cub::DoubleBuffer<int> db_value;
cub::DeviceSegmentedRadixSort::SortPairs( cub::DeviceSegmentedRadixSort::SortPairs(
cub_mem.data(), cub_mem_size, db_key, cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(),
db_value, in_fvalues.size(), n_features, n_features, foffsets.data(), foffsets.data() + 1);
foffsets.data(), foffsets.data() + 1);
// Allocate memory // Allocate memory
size_t free_memory = dh::available_memory(); size_t free_memory = dh::available_memory();
@ -100,7 +100,6 @@ struct GPUData {
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step); param_in.reg_alpha, param_in.max_delta_step);
allocated = true; allocated = true;
this->Reset(in_gpair, param_in.subsample); this->Reset(in_gpair, param_in.subsample);
@ -109,33 +108,16 @@ struct GPUData {
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
fvalues.tbegin(), node_id.tbegin())); fvalues.tbegin(), node_id.tbegin()));
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
} }
~GPUData() {} ~GPUData() {}
// Set gradient pair to 0 with p = 1 - subsample
void MarkSubsample(float subsample) {
if (subsample == 1.0) {
return;
}
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(n_instances, [=] __device__(int i) {
if (!rng(i)) {
d_gpair[i] = gpu_gpair();
}
});
}
// Reset memory for new boosting iteration // Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) { void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
CHECK(allocated); CHECK(allocated);
gpair = in_gpair; gpair = in_gpair;
this->MarkSubsample(subsample); subsample_gpair(&gpair, subsample);
instance_id = instance_id_cached; instance_id = instance_id_cached;
fvalues = fvalues_cached; fvalues = fvalues_cached;
nodes.fill(Node()); nodes.fill(Node());
@ -153,7 +135,6 @@ struct GPUData {
auto d_instance_id = instance_id.data(); auto d_instance_id = instance_id.data();
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
// Item item = d_items[i];
d_node_id[i] = d_node_id_instance[d_instance_id[i]]; d_node_id[i] = d_node_id_instance[d_instance_id[i]];
}); });
} }

View File

@ -0,0 +1,560 @@
/*!
* Copyright 2017 Rory mitchell
*/
#include <thrust/binary_search.h>
#include <thrust/count.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <cub/cub.cuh>
#include <numeric>
#include <algorithm>
#include <functional>
#include "common.cuh"
#include "device_helpers.cuh"
#include "gpu_hist_builder.cuh"
namespace xgboost {
namespace tree {
void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) {
CHECK_EQ(gidx.size(), gmat.index.size())
<< "gidx must be externally allocated";
CHECK_EQ(ridx.size(), gmat.index.size())
<< "ridx must be externally allocated";
gidx = gmat.index;
thrust::device_vector<int> row_ptr = gmat.row_ptr;
auto counting = thrust::make_counting_iterator(0);
thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting,
counting + gidx.size(), ridx.tbegin());
thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(),
[=] __device__(int val) { return val - 1; });
}
void DeviceHist::Init(int n_bins_in) {
this->n_bins = n_bins_in;
CHECK(!hist.empty()) << "DeviceHist must be externally allocated";
}
void DeviceHist::Reset() { hist.fill(gpu_gpair()); }
gpu_gpair* DeviceHist::GetLevelPtr(int depth) {
return hist.data() + n_nodes(depth - 1) * n_bins;
}
int DeviceHist::LevelSize(int depth) { return n_bins * n_nodes_level(depth); }
HistBuilder DeviceHist::GetBuilder() {
return HistBuilder(hist.data(), n_bins);
}
HistBuilder::HistBuilder(gpu_gpair* ptr, int n_bins)
: d_hist(ptr), n_bins(n_bins) {}
__device__ void HistBuilder::Add(gpu_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx;
atomicAdd(&(d_hist[hist_idx]._grad), gpair._grad);
atomicAdd(&(d_hist[hist_idx]._hess), gpair._hess);
}
__device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx];
}
GPUHistBuilder::GPUHistBuilder() {}
GPUHistBuilder::~GPUHistBuilder() {}
void GPUHistBuilder::Init(const TrainParam& param) {
CHECK(param.max_depth < 16) << "Tree depth too large.";
this->param = param;
initialised = false;
is_dense = false;
dh::safe_cuda(cudaSetDevice(param.gpu_id));
if (!param.silent) {
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
}
}
template <typename ReductionOpT>
struct ReduceBySegmentOp {
ReductionOpT op;
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op)
: op(op) {}
template <typename KeyValuePairT>
__host__ __device__ __forceinline__ KeyValuePairT
operator()(const KeyValuePairT& first, const KeyValuePairT& second) {
KeyValuePairT retval;
retval.key = second.key;
retval.value =
first.key != second.key ? second.value : op(first.value, second.value);
return retval;
}
};
template <int ITEMS_PER_THREAD, int BLOCK_THREADS>
__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx,
NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins,
int depth, int n) {
typedef cub::KeyValuePair<int, gpu_gpair> OffsetValuePairT;
typedef cub::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoadT;
typedef cub::BlockRadixSort<int, BLOCK_THREADS, ITEMS_PER_THREAD, int>
BlockRadixSortT;
typedef cub::BlockDiscontinuity<int, BLOCK_THREADS> BlockDiscontinuityKeysT;
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
typedef cub::BlockScan<OffsetValuePairT, BLOCK_THREADS,
cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
union TempStorage {
typename BlockLoadT::TempStorage load;
typename BlockRadixSortT::TempStorage sort;
typename BlockScanT::TempStorage scan;
typename BlockDiscontinuityKeysT::TempStorage disc;
};
__shared__ TempStorage temp_storage;
int ridx[ITEMS_PER_THREAD];
int gidx[ITEMS_PER_THREAD];
const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS;
int block_offset = blockIdx.x * TILE_SIZE;
BlockLoadT(temp_storage.load)
.Load(d_ridx + block_offset, ridx, n - block_offset, -1);
BlockLoadT(temp_storage.load)
.Load(d_gidx + block_offset, gidx, n - block_offset, -1);
int hist_idx[ITEMS_PER_THREAD];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
hist_idx[ITEM] =
(d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM];
} else {
hist_idx[ITEM] = -1;
}
}
__syncthreads();
BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx);
OffsetValuePairT kv[ITEMS_PER_THREAD];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
kv[ITEM].key = hist_idx[ITEM];
if (ridx[ITEM] > -1) {
kv[ITEM].value = d_gpair[ridx[ITEM]];
}
}
__syncthreads();
// Scan
BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT());
__syncthreads();
int flags[ITEMS_PER_THREAD];
BlockDiscontinuityKeysT(temp_storage.disc)
.FlagTails(flags, hist_idx, cub::Inequality());
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
if (flags[ITEM]) {
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad);
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess);
}
}
}
}
void GPUHistBuilder::BuildHist(int depth) {
auto d_ridx = device_matrix.ridx.data();
auto d_gidx = device_matrix.gidx.data();
auto d_position = position.data();
auto d_gpair = device_gpair.data();
auto hist_builder = hist.GetBuilder();
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
int ridx = d_ridx[idx];
int pos = d_position[ridx];
// Only increment even nodes
if (pos < 0 || pos % 2 == 1) return;
int gidx = d_gidx[idx];
gpu_gpair gpair = d_gpair[ridx];
hist_builder.Add(gpair, gidx, pos);
});
dh::safe_cuda(cudaDeviceSynchronize());
// Subtraction trick
int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins;
if (n_sub_bins > 0) {
dh::launch_n(n_sub_bins, [=] __device__(int idx) {
int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2);
int gidx = idx % hist_builder.n_bins;
gpu_gpair parent = hist_builder.Get(gidx, nidx / 2);
gpu_gpair right = hist_builder.Get(gidx, nidx + 1);
hist_builder.Add(parent - right, gidx, nidx);
});
}
dh::safe_cuda(cudaDeviceSynchronize());
}
void GPUHistBuilder::FindSplit(int depth) {
auto counting = thrust::make_counting_iterator(0);
auto d_gidx_feature_map = gidx_feature_map.data();
int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1;
auto feature_boundary = [=] __device__(int idx_a, int idx_b) {
int gidx_a = idx_a % n_bins;
int gidx_b = idx_b % n_bins;
return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b];
}; // NOLINT
// Reduce node sums
{
size_t temp_storage_bytes;
cub::DeviceSegmentedReduce::Reduce(
nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::Reduce(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
}
// Scan
thrust::exclusive_scan_by_key(
counting, counting + hist.LevelSize(depth),
thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(),
gpu_gpair(), feature_boundary);
// Calculate gain
auto d_gain = gain.data();
auto d_nodes = nodes.data();
auto d_node_sums = node_sums.data();
auto d_hist_scan = hist_scan.data();
GPUTrainingParam gpu_param_alias =
gpu_param; // Must be local variable to be used in device lambda
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
auto d_feature_flags = feature_flags.data();
dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) {
int node_segment = idx / n_bins;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
int gidx = idx % n_bins;
int findex = d_gidx_feature_map[gidx];
// colsample
if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) {
d_gain[idx] = 0;
} else {
gpu_gpair scan = d_hist_scan[idx];
gpu_gpair sum = d_node_sums[node_segment * n_features + findex];
gpu_gpair missing = parent_sum - sum;
bool missing_left;
d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain,
gpu_param_alias, missing_left);
}
});
dh::safe_cuda(cudaDeviceSynchronize());
// Find best gain
{
size_t temp_storage_bytes;
cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(),
argmax.data(), n_nodes_level(depth),
hist_node_segments.data(),
hist_node_segments.data() + 1);
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::ArgMax(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(),
argmax.data(), n_nodes_level(depth), hist_node_segments.data(),
hist_node_segments.data() + 1);
}
auto d_argmax = argmax.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_fidx_min_map = fidx_min_map.data();
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
int max_idx = n_bins * idx + d_argmax[idx].key;
int gidx = max_idx % n_bins;
int fidx = d_gidx_feature_map[gidx];
int node_segment = max_idx / n_bins;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair scan = d_hist_scan[max_idx];
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
gpu_gpair sum = d_node_sums[node_segment * n_features + fidx];
gpu_gpair missing = parent_sum - sum;
bool missing_left;
float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain,
gpu_param_alias, missing_left);
float fvalue;
if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) {
fvalue = d_fidx_min_map[fidx];
} else {
fvalue = d_gidx_fvalue_map[gidx - 1];
}
gpu_gpair left = missing_left ? scan + missing : scan;
gpu_gpair right = parent_sum - left;
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
right, gpu_param_alias);
int left_child_idx = n_nodes(depth) + idx * 2;
int right_child_idx = n_nodes(depth) + idx * 2 + 1;
d_nodes[left_child_idx] =
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
d_nodes[right_child_idx] =
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
});
dh::safe_cuda(cudaDeviceSynchronize());
}
void GPUHistBuilder::InitFirstNode() {
// Build the root node on the CPU and copy to device
gpu_gpair sum_gradients =
thrust::reduce(device_gpair.tbegin(), device_gpair.tend(),
gpu_gpair(0, 0), thrust::plus<gpu_gpair>());
Node tmp =
Node(sum_gradients,
CalcGain(param, sum_gradients.grad(), sum_gradients.hess()),
CalcWeight(param, sum_gradients.grad(), sum_gradients.hess()));
thrust::copy_n(&tmp, 1, nodes.tbegin());
}
void GPUHistBuilder::UpdatePosition() {
if (is_dense) {
this->UpdatePositionDense();
} else {
this->UpdatePositionSparse();
}
}
void GPUHistBuilder::UpdatePositionDense() {
auto d_position = position.data();
Node* d_nodes = nodes.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_gidx = device_matrix.gidx.data();
int n_columns = info->num_col;
dh::launch_n(position.size(), [=] __device__(int idx) {
NodeIdT pos = d_position[idx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
d_position[idx] = -1;
return;
}
int gidx = d_gidx[idx * n_columns + node.split.findex];
float fvalue = d_gidx_fvalue_map[gidx];
if (fvalue <= node.split.fvalue) {
d_position[idx] = pos * 2 + 1;
} else {
d_position[idx] = pos * 2 + 2;
}
});
}
void GPUHistBuilder::UpdatePositionSparse() {
auto d_position = position.data();
auto d_position_tmp = position_tmp.data();
Node* d_nodes = nodes.data();
auto d_gidx_feature_map = gidx_feature_map.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_gidx = device_matrix.gidx.data();
auto d_ridx = device_matrix.ridx.data();
// Update missing direction
dh::launch_n(position.size(), [=] __device__(int idx) {
NodeIdT pos = d_position[idx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
d_position_tmp[idx] = -1;
} else if (node.split.missing_left) {
d_position_tmp[idx] = pos * 2 + 1;
} else {
d_position_tmp[idx] = pos * 2 + 2;
}
});
// Update node based on fvalue where exists
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
int ridx = d_ridx[idx];
NodeIdT pos = d_position[ridx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
return;
}
int gidx = d_gidx[idx];
int findex = d_gidx_feature_map[gidx];
if (findex == node.split.findex) {
float fvalue = d_gidx_fvalue_map[gidx];
if (fvalue <= node.split.fvalue) {
d_position_tmp[ridx] = pos * 2 + 1;
} else {
d_position_tmp[ridx] = pos * 2 + 2;
}
}
});
position = position_tmp;
}
void GPUHistBuilder::ColSampleTree() {
feature_set_tree.resize(info->num_col);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
}
void GPUHistBuilder::ColSampleLevel() {
feature_set_level.resize(feature_set_tree.size());
feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel);
std::vector<int> h_feature_flags(info->num_col, 0);
for (auto fidx : feature_set_level) {
h_feature_flags[fidx] = 1;
}
feature_flags = h_feature_flags;
}
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
DMatrix& fmat, // NOLINT
const RegTree& tree) {
if (!initialised) {
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
"block. Try setting 'tree_method' "
"parameter to 'exact'";
info = &fmat.info();
is_dense = info->num_nonzero == info->num_col * info->num_row;
hmat_.Init(&fmat, param.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(&fmat);
int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1;
// Build feature segments
std::vector<int> h_feature_segments;
for (int node = 0; node < n_nodes_level(param.max_depth - 1); node++) {
for (int fidx = 0; fidx < hmat_.row_ptr.size() - 1; fidx++) {
h_feature_segments.push_back(hmat_.row_ptr[fidx] + node * n_bins);
}
}
h_feature_segments.push_back(n_nodes_level(param.max_depth - 1) * n_bins);
int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins;
size_t free_memory = dh::available_memory();
ba.allocate(
&gidx_feature_map, n_bins, &hist_node_segments,
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
h_feature_segments.size(), &gain, level_max_bins, &position,
gpair.size(), &position_tmp, gpair.size(), &nodes,
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
&fidx_min_map, hmat_.min_val.size(), &argmax,
n_nodes_level(param.max_depth - 1), &node_sums,
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx,
gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist,
n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features);
if (!param.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on " << dh::device_name();
}
// Construct feature map
std::vector<int> h_gidx_feature_map(n_bins);
for (int row = 0; row < hmat_.row_ptr.size() - 1; row++) {
for (int i = hmat_.row_ptr[row]; i < hmat_.row_ptr[row + 1]; i++) {
h_gidx_feature_map[i] = row;
}
}
gidx_feature_map = h_gidx_feature_map;
// Construct device matrix
device_matrix.Init(gmat_);
gidx_fvalue_map = hmat_.cut;
fidx_min_map = hmat_.min_val;
thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0,
n_bins);
feature_segments = h_feature_segments;
hist.Init(n_bins);
initialised = true;
}
nodes.fill(Node());
position.fill(0);
device_gpair = gpair;
subsample_gpair(&device_gpair, param.subsample);
hist.Reset();
}
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
DMatrix* p_fmat, RegTree* p_tree) {
this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode();
this->ColSampleTree();
for (int depth = 0; depth < param.max_depth; depth++) {
this->ColSampleLevel();
this->BuildHist(depth);
this->FindSplit(depth);
this->UpdatePosition();
}
dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param);
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,104 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cusparse.h>
#include <thrust/device_vector.h>
#include <xgboost/tree_updater.h>
#include <cub/util_type.cuh> // Need key value pair definition
#include <vector>
#include "../../src/common/hist_util.h"
#include "../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
struct DeviceGMat {
dh::dvec<int> gidx;
dh::dvec<int> ridx;
void Init(const common::GHistIndexMatrix &gmat);
};
struct HistBuilder {
gpu_gpair *d_hist;
int n_bins;
__host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins);
__device__ void Add(gpu_gpair gpair, int gidx, int nidx) const;
__device__ gpu_gpair Get(int gidx, int nidx) const;
};
struct DeviceHist {
int n_bins;
dh::dvec<gpu_gpair> hist;
void Init(int max_depth);
void Reset();
HistBuilder GetBuilder();
gpu_gpair *GetLevelPtr(int depth);
int LevelSize(int depth);
};
class GPUHistBuilder {
public:
GPUHistBuilder();
~GPUHistBuilder();
void Init(const TrainParam &param);
void UpdateParam(const TrainParam &param) {
this->param = param;
this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda,
param.reg_alpha, param.max_delta_step);
}
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree);
void BuildHist(int depth);
void FindSplit(int depth);
void InitFirstNode();
void UpdatePosition();
void UpdatePositionDense();
void UpdatePositionSparse();
void ColSampleTree();
void ColSampleLevel();
TrainParam param;
GPUTrainingParam gpu_param;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo *info;
bool initialised;
bool is_dense;
DeviceGMat device_matrix;
dh::bulk_allocator ba;
dh::CubMemory cub_mem;
dh::dvec<int> gidx_feature_map;
dh::dvec<int> hist_node_segments;
dh::dvec<int> feature_segments;
dh::dvec<float> gain;
dh::dvec<NodeIdT> position;
dh::dvec<NodeIdT> position_tmp;
dh::dvec<float> gidx_fvalue_map;
dh::dvec<float> fidx_min_map;
DeviceHist hist;
dh::dvec<cub::KeyValuePair<int, float>> argmax;
dh::dvec<gpu_gpair> node_sums;
dh::dvec<gpu_gpair> hist_scan;
dh::dvec<gpu_gpair> device_gpair;
dh::dvec<Node> nodes;
dh::dvec<int> feature_flags;
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
};
} // namespace tree
} // namespace xgboost

View File

@ -1,52 +0,0 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "../../../src/tree/param.h"
// When we split on a value which has no left neighbour, define its left
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
// This produces a split value slightly lower than the current instance
#define FVALUE_EPS 0.0001
namespace xgboost {
namespace tree {
__device__ __forceinline__ float
device_calc_loss_chg(const GPUTrainingParam &param, const gpu_gpair &scan,
const gpu_gpair &missing, const gpu_gpair &parent_sum,
const float &parent_gain, bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;
}
gpu_gpair right = parent_sum - left;
float left_gain = CalcGain(param, left.grad(), left.hess());
float right_gain = CalcGain(param, right.grad(), right.hess());
return left_gain + right_gain - parent_gain;
}
__device__ __forceinline__ float
loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing,
const gpu_gpair &parent_sum, const float &parent_gain,
const GPUTrainingParam &param, bool &missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
} // namespace tree
} // namespace xgboost

View File

@ -1,8 +1,11 @@
/*! /*!
* Copyright 2016 Rory mitchell * Copyright 2016 Rory mitchell
*/ */
#pragma once #pragma once
#include <thrust/device_vector.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/tree_model.h>
#include <cfloat>
#include <tuple> // The linter is not very smart and thinks we need this #include <tuple> // The linter is not very smart and thinks we need this
namespace xgboost { namespace xgboost {
@ -104,8 +107,10 @@ struct GPUTrainingParam {
__host__ __device__ GPUTrainingParam(float min_child_weight_in, __host__ __device__ GPUTrainingParam(float min_child_weight_in,
float reg_lambda_in, float reg_alpha_in, float reg_lambda_in, float reg_alpha_in,
float max_delta_step_in) float max_delta_step_in)
: min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in), : min_child_weight(min_child_weight_in),
reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {} reg_lambda(reg_lambda_in),
reg_alpha(reg_alpha_in),
max_delta_step(max_delta_step_in) {}
}; };
struct Split { struct Split {
@ -123,8 +128,9 @@ struct Split {
float fvalue_in, int findex_in, gpu_gpair left_sum_in, float fvalue_in, int findex_in, gpu_gpair left_sum_in,
gpu_gpair right_sum_in, gpu_gpair right_sum_in,
const GPUTrainingParam &param) { const GPUTrainingParam &param) {
if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight && if (loss_chg_in > loss_chg &&
right_sum_in.hess() > param.min_child_weight) { left_sum_in.hess() >= param.min_child_weight &&
right_sum_in.hess() >= param.min_child_weight) {
loss_chg = loss_chg_in; loss_chg = loss_chg_in;
missing_left = missing_left_in; missing_left = missing_left_in;
fvalue = fvalue_in; fvalue = fvalue_in;
@ -135,7 +141,7 @@ struct Split {
} }
// Does not check minimum weight // Does not check minimum weight
__device__ void Update(Split &s) { // NOLINT __device__ void Update(Split &s) { // NOLINT
if (s.loss_chg > loss_chg) { if (s.loss_chg > loss_chg) {
loss_chg = s.loss_chg; loss_chg = s.loss_chg;
missing_left = s.missing_left; missing_left = s.missing_left;
@ -146,7 +152,7 @@ struct Split {
} }
} }
__device__ void Print() { __host__ __device__ void Print() {
printf("Loss: %1.4f\n", loss_chg); printf("Loss: %1.4f\n", loss_chg);
printf("Missing left: %d\n", missing_left); printf("Missing left: %d\n", missing_left);
printf("fvalue: %1.4f\n", fvalue); printf("fvalue: %1.4f\n", fvalue);
@ -160,7 +166,7 @@ struct Split {
struct split_reduce_op { struct split_reduce_op {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T &a, T b) { // NOLINT __device__ __forceinline__ T operator()(T &a, T b) { // NOLINT
b.Update(a); b.Update(a);
return b; return b;
} }
@ -184,5 +190,6 @@ struct Node {
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; } __host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
}; };
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,6 +0,0 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "loss_functions.cuh"

View File

@ -7,30 +7,37 @@
#include "../../src/common/sync.h" #include "../../src/common/sync.h"
#include "../../src/tree/param.h" #include "../../src/tree/param.h"
#include "gpu_builder.cuh" #include "gpu_builder.cuh"
#include "gpu_hist_builder.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker); DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
/*! \brief column-wise update to construct a tree */ /*! \brief column-wise update to construct a tree */
template <typename TStats> class GPUMaker : public TreeUpdater { template <typename TStats>
class GPUMaker : public TreeUpdater {
public: public:
void void Init(
Init(const std::vector<std::pair<std::string, std::string>> &args) override { const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args); param.InitAllowUnknown(args);
builder.Init(param); builder.Init(param);
} }
void Update(const std::vector<bst_gpair> &gpair, DMatrix *dmat, void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info()); TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
param.learning_rate = lr / trees.size(); param.learning_rate = lr / trees.size();
builder.UpdateParam(param); builder.UpdateParam(param);
// build tree
for (size_t i = 0; i < trees.size(); ++i) { try {
builder.Update(gpair, dmat, trees[i]); // build tree
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
} }
param.learning_rate = lr; param.learning_rate = lr;
} }
@ -41,9 +48,45 @@ template <typename TStats> class GPUMaker : public TreeUpdater {
GPUBuilder builder; GPUBuilder builder;
}; };
template <typename TStats>
class GPUHistMaker : public TreeUpdater {
public:
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
// build tree
try {
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
GPUHistBuilder builder;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu") XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.") .describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker<GradStats>(); }); .set_body([]() { return new GPUMaker<GradStats>(); });
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker<GradStats>(); });
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,137 +0,0 @@
#pylint: skip-file
import numpy as np
import xgboost as xgb
import os
import pandas as pd
import urllib2
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def get_last_eval_callback(result):
def callback(env):
result.append(env.evaluation_result_list[-1][1])
callback.after_iteration = True
return callback
def load_adult():
path = "../../demo/data/adult.data"
if(not os.path.isfile(path)):
data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data')
with open(path,'wb') as output:
output.write(data.read())
train_set = pd.read_csv( path, header=None)
train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'wage_class']
train_nomissing = train_set.replace(' ?', np.nan).dropna()
for feature in train_nomissing.columns: # Loop through all columns in the dataframe
if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings
train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer
y_train = train_nomissing.pop('wage_class')
return xgb.DMatrix( train_nomissing, label=y_train)
def load_higgs():
higgs_path = '../../demo/data/training.csv'
dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } )
#dtrain = dtrain[0:200000,:]
label = dtrain[:,32]
data = dtrain[:,1:31]
weight = dtrain[:,31]
return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
def load_dermatology():
data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } )
sz = data.shape
X = data[:,0:33]
Y = data[:, 34]
return xgb.DMatrix( X, label=Y)
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
#Check GPU test evaluation is approximately equal to CPU test evaluation
def check_result(cpu_result, gpu_result):
for i in range(len(cpu_result)):
if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02):
return False
return True
#Get data
data = []
params = []
data.append(load_higgs())
params.append({})
data.append( load_adult())
params.append({})
data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test'))
params.append({'objective':'binary:logistic'})
#if(os.path.isfile("../../demo/data/dermatology.data")):
data.append(load_dermatology())
params.append({'objective':'multi:softmax', 'num_class': 6})
num_round = 5
num_pass = 0
num_fail = 0
test_depth = [ 1, 6, 9, 11, 15 ]
#test_depth = [ 1 ]
for test in range(0, len(data)):
for depth in test_depth:
xgmat = data[test]
cpu_result = []
param = params[test]
param['max_depth'] = depth
param['updater'] = 'grow_colmaker'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('reference_model.txt','', True)
gpu_result = []
param['updater'] = 'grow_gpu'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('dump.raw.txt','', True)
if check_result(cpu_result, gpu_result):
print(bcolors.OKGREEN + "Pass" + bcolors.ENDC)
num_pass = num_pass + 1
else:
print(bcolors.FAIL + "Fail" + bcolors.ENDC)
num_fail = num_fail + 1
print("cpu rmse: "+str(cpu_result))
print("gpu rmse: "+str(gpu_result))
print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed")

View File

@ -0,0 +1,221 @@
#pylint: skip-file
import sys
sys.path.append("../../tests/python")
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
rng = np.random.RandomState(1994)
dpath = '../../demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
class TestGPU(unittest.TestCase):
def test_grow_gpu(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 1,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'updater': 'grow_gpu',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
num_rounds = 10
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def test_grow_gpu_hist(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 1,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'updater': 'grow_gpu_hist',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
num_rounds = 10
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 2,
'eval_metric': 'auc',
'max_bin': 2}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# subsampling
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 3,
'eval_metric': 'auc',
'colsample_bytree': 0.5,
'colsample_bylevel': 0.5,
'subsample': 0.5
}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))

View File

@ -79,6 +79,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
bool refresh_leaf; bool refresh_leaf;
// auxiliary data structure // auxiliary data structure
std::vector<int> monotone_constraints; std::vector<int> monotone_constraints;
// gpu to use for single gpu algorithms
int gpu_id;
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(learning_rate) DMLC_DECLARE_FIELD(learning_rate)
@ -186,6 +188,10 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(monotone_constraints) DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<int>()) .set_default(std::vector<int>())
.describe("Constraint of variable monotonicity"); .describe("Constraint of variable monotonicity");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for single gpu algorithms");
// add alias of parameters // add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha); DMLC_DECLARE_ALIAS(reg_alpha, alpha);