[GPU-Plugin] (#2227)
* Add fast histogram algorithm * Fix Linux build * Add 'gpu_id' parameter
This commit is contained in:
parent
d281c6aafa
commit
8ab5d4611c
@ -97,7 +97,7 @@ if(PLUGIN_UPDATER_GPU)
|
|||||||
#Find cub
|
#Find cub
|
||||||
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
|
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
|
||||||
include_directories(${CUB_DIRECTORY})
|
include_directories(${CUB_DIRECTORY})
|
||||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35")
|
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;")
|
||||||
if(NOT MSVC)
|
if(NOT MSVC)
|
||||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
|
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
|
||||||
endif()
|
endif()
|
||||||
@ -105,8 +105,9 @@ if(PLUGIN_UPDATER_GPU)
|
|||||||
plugin/updater_gpu/src/updater_gpu.cc
|
plugin/updater_gpu/src/updater_gpu.cc
|
||||||
)
|
)
|
||||||
find_package(CUDA QUIET REQUIRED)
|
find_package(CUDA QUIET REQUIRED)
|
||||||
|
file(GLOB_RECURSE CUDA_SOURCES "plugin/updater_gpu/src/*")
|
||||||
cuda_add_library(updater_gpu STATIC
|
cuda_add_library(updater_gpu STATIC
|
||||||
plugin/updater_gpu/src/gpu_builder.cu
|
${CUDA_SOURCES}
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@ -63,3 +63,5 @@ List of Contributors
|
|||||||
* [Alex Bain](https://github.com/convexquad)
|
* [Alex Bain](https://github.com/convexquad)
|
||||||
* [Baltazar Bieniek](https://github.com/bbieniek)
|
* [Baltazar Bieniek](https://github.com/bbieniek)
|
||||||
* [Adam Pocock](https://github.com/Craigacp)
|
* [Adam Pocock](https://github.com/Craigacp)
|
||||||
|
* [Rory Mitchell](https://github.com/RAMitchell)
|
||||||
|
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration
|
||||||
|
|||||||
@ -21,6 +21,7 @@ The same code runs on major distributed environment (Hadoop, SGE, MPI) and can s
|
|||||||
|
|
||||||
What's New
|
What's New
|
||||||
----------
|
----------
|
||||||
|
* [XGBoost GPU support with fast histogram algorithm](https://github.com/dmlc/xgboost/tree/master/plugin/updater_gpu)
|
||||||
* [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages)
|
* [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages)
|
||||||
* [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html)
|
* [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html)
|
||||||
* [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html)
|
* [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html)
|
||||||
|
|||||||
@ -48,7 +48,7 @@
|
|||||||
#define XGBOOST_ALIGNAS(X)
|
#define XGBOOST_ALIGNAS(X)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8
|
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && !defined(__CUDACC__)
|
||||||
#include <parallel/algorithm>
|
#include <parallel/algorithm>
|
||||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
||||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))
|
||||||
|
|||||||
@ -1,29 +1,25 @@
|
|||||||
# CUDA Accelerated Tree Construction Algorithm
|
# CUDA Accelerated Tree Construction Algorithms
|
||||||
|
This plugin adds GPU accelerated tree construction algorithms to XGBoost.
|
||||||
## Benchmarks
|
|
||||||
|
|
||||||
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
Specify the updater parameter as 'grow_gpu'.
|
Specify the 'updater' parameter as one of the following algorithms.
|
||||||
|
updater | Description
|
||||||
|
--- | ---
|
||||||
|
grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist'
|
||||||
|
grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate.
|
||||||
|
|
||||||
|
All algorithms currently use only a single GPU. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
|
||||||
|
|
||||||
This plugin currently works with the CLI version and python version.
|
This plugin currently works with the CLI version and python version.
|
||||||
|
|
||||||
Python example:
|
Python example:
|
||||||
```python
|
```python
|
||||||
|
param['gpu_id'] = 1
|
||||||
param['updater'] = 'grow_gpu'
|
param['updater'] = 'grow_gpu'
|
||||||
```
|
```
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
## Memory usage
|
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks of the 'grow_gpu' updater.
|
||||||
Device memory usage can be calculated as approximately:
|
|
||||||
```
|
|
||||||
bytes = (10 x n_rows) + (40 x n_rows x n_columns x column_density) + (64 x max_nodes) + (76 x max_nodes_level x n_columns)
|
|
||||||
```
|
|
||||||
The maximum number of nodes needed for a given tree depth d is 2<sup>d+1</sup> - 1. The maximum number of nodes on any given level is 2<sup>d</sup>.
|
|
||||||
|
|
||||||
Data is stored in a sparse format. For example, missing values produced by one hot encoding are not stored. If a one hot encoding separates a categorical variable into 5 columns the density of these columns is 1/5 = 0.2.
|
|
||||||
|
|
||||||
A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset.
|
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
|
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
|
||||||
@ -58,6 +54,15 @@ On Windows cmake will generate an xgboost.sln solution file in the build directo
|
|||||||
|
|
||||||
The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm.
|
The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm.
|
||||||
|
|
||||||
|
## Changelog
|
||||||
|
##### 2017/4/25
|
||||||
|
* Add fast histogram algorithm
|
||||||
|
* Fix Linux build
|
||||||
|
* Add 'gpu_id' parameter
|
||||||
|
|
||||||
|
## References
|
||||||
|
[Mitchell, Rory, and Eibe Frank. Accelerating the XGBoost algorithm using GPU computing. No. e2911v1. PeerJ Preprints, 2017.](https://peerj.com/preprints/2911/)
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
Rory Mitchell
|
Rory Mitchell
|
||||||
|
|
||||||
|
|||||||
32
plugin/updater_gpu/benchmark/benchmark.py
Normal file
32
plugin/updater_gpu/benchmark/benchmark.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#pylint: skip-file
|
||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import time
|
||||||
|
|
||||||
|
n = 1000000
|
||||||
|
num_rounds = 100
|
||||||
|
|
||||||
|
X,y = make_classification(n, n_features=50, random_state=7)
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||||
|
dtrain = xgb.DMatrix(X_train, y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test, y_test)
|
||||||
|
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'tree_method': 'exact',
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'max_depth': 8,
|
||||||
|
'silent': 1,
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
res = {}
|
||||||
|
tmp = time.time()
|
||||||
|
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||||
|
evals_result=res)
|
||||||
|
print ("GPU: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
tmp = time.time()
|
||||||
|
param['updater'] = 'grow_fast_histmaker'
|
||||||
|
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], evals_result=res)
|
||||||
|
print ("CPU: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
#!/usr/bin/pytho#!/usr/bin/python
|
|
||||||
#pylint: skip-file
|
|
||||||
# this is the example script to use xgboost to train
|
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
|
||||||
import time
|
|
||||||
|
|
||||||
# path to where the data lies
|
|
||||||
dpath = '../../demo/data'
|
|
||||||
|
|
||||||
# load in training data, directly use numpy
|
|
||||||
dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } )
|
|
||||||
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
|
|
||||||
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
|
|
||||||
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
|
|
||||||
test_size = len(dtrain)
|
|
||||||
|
|
||||||
print(len(dtrain))
|
|
||||||
print ('finish loading from csv ')
|
|
||||||
|
|
||||||
label = dtrain[:,32]
|
|
||||||
data = dtrain[:,1:31]
|
|
||||||
# rescale weight to make it same as test set
|
|
||||||
weight = dtrain[:,31] * float(test_size) / len(label)
|
|
||||||
|
|
||||||
sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 )
|
|
||||||
sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 )
|
|
||||||
|
|
||||||
# print weight statistics
|
|
||||||
print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos ))
|
|
||||||
|
|
||||||
# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value
|
|
||||||
xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
|
|
||||||
|
|
||||||
# setup parameters for xgboost
|
|
||||||
param = {}
|
|
||||||
# use logistic regression loss
|
|
||||||
param['objective'] = 'binary:logitraw'
|
|
||||||
# scale weight of positive examples
|
|
||||||
param['scale_pos_weight'] = sum_wneg/sum_wpos
|
|
||||||
param['bst:eta'] = 0.1
|
|
||||||
param['max_depth'] = 15
|
|
||||||
param['eval_metric'] = 'auc'
|
|
||||||
param['nthread'] = 16
|
|
||||||
|
|
||||||
plst = param.items()+[('eval_metric', 'ams@0.15')]
|
|
||||||
|
|
||||||
watchlist = [ (xgmat,'train') ]
|
|
||||||
num_round = 10
|
|
||||||
print ("training xgboost")
|
|
||||||
threads = [16]
|
|
||||||
for i in threads:
|
|
||||||
param['nthread'] = i
|
|
||||||
tmp = time.time()
|
|
||||||
plst = param.items()+[('eval_metric', 'ams@0.15')]
|
|
||||||
bst = xgb.train( plst, xgmat, num_round, watchlist );
|
|
||||||
print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp)))
|
|
||||||
|
|
||||||
print ("training xgboost - gpu tree construction")
|
|
||||||
param['updater'] = 'grow_gpu'
|
|
||||||
tmp = time.time()
|
|
||||||
plst = param.items()+[('eval_metric', 'ams@0.15')]
|
|
||||||
bst = xgb.train( plst, xgmat, num_round, watchlist );
|
|
||||||
print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp)))
|
|
||||||
print ('finish training')
|
|
||||||
153
plugin/updater_gpu/src/common.cuh
Normal file
153
plugin/updater_gpu/src/common.cuh
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2016 Rory mitchell
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
#include "../../../src/common/random.h"
|
||||||
|
#include "../../../src/tree/param.h"
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
#include "types.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
// When we split on a value which has no left neighbour, define its left
|
||||||
|
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
|
||||||
|
// This produces a split value slightly lower than the current instance
|
||||||
|
#define FVALUE_EPS 0.0001
|
||||||
|
|
||||||
|
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
||||||
|
const gpu_gpair& scan,
|
||||||
|
const gpu_gpair& missing,
|
||||||
|
const gpu_gpair& parent_sum,
|
||||||
|
const float& parent_gain,
|
||||||
|
bool missing_left) {
|
||||||
|
gpu_gpair left = scan;
|
||||||
|
|
||||||
|
if (missing_left) {
|
||||||
|
left += missing;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_gpair right = parent_sum - left;
|
||||||
|
|
||||||
|
float left_gain = CalcGain(param, left.grad(), left.hess());
|
||||||
|
float right_gain = CalcGain(param, right.grad(), right.hess());
|
||||||
|
return left_gain + right_gain - parent_gain;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float inline loss_chg_missing(const gpu_gpair& scan,
|
||||||
|
const gpu_gpair& missing,
|
||||||
|
const gpu_gpair& parent_sum,
|
||||||
|
const float& parent_gain,
|
||||||
|
const GPUTrainingParam& param,
|
||||||
|
bool& missing_left_out) { // NOLINT
|
||||||
|
float missing_left_loss =
|
||||||
|
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
|
||||||
|
float missing_right_loss = device_calc_loss_chg(
|
||||||
|
param, scan, missing, parent_sum, parent_gain, false);
|
||||||
|
|
||||||
|
if (missing_left_loss >= missing_right_loss) {
|
||||||
|
missing_left_out = true;
|
||||||
|
return missing_left_loss;
|
||||||
|
} else {
|
||||||
|
missing_left_out = false;
|
||||||
|
return missing_right_loss;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Total number of nodes in tree, given depth
|
||||||
|
__host__ __device__ inline int n_nodes(int depth) {
|
||||||
|
return (1 << (depth + 1)) - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of nodes at this level of the tree
|
||||||
|
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
|
||||||
|
|
||||||
|
enum NodeType {
|
||||||
|
NODE = 0,
|
||||||
|
LEAF = 1,
|
||||||
|
UNUSED = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Recursively label node types
|
||||||
|
inline void flag_nodes(const thrust::host_vector<Node>& nodes,
|
||||||
|
std::vector<NodeType>* node_flags, int nid,
|
||||||
|
NodeType type) {
|
||||||
|
if (nid >= nodes.size() || type == UNUSED) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Node& n = nodes[nid];
|
||||||
|
|
||||||
|
// Current node and all children are valid
|
||||||
|
if (n.split.loss_chg > rt_eps) {
|
||||||
|
(*node_flags)[nid] = NODE;
|
||||||
|
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
|
||||||
|
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
|
||||||
|
} else {
|
||||||
|
// Current node is leaf, therefore is valid but all children are invalid
|
||||||
|
(*node_flags)[nid] = LEAF;
|
||||||
|
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
|
||||||
|
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy gpu dense representation of tree to xgboost sparse representation
|
||||||
|
inline void dense2sparse_tree(RegTree* p_tree,
|
||||||
|
thrust::device_ptr<Node> nodes_begin,
|
||||||
|
thrust::device_ptr<Node> nodes_end,
|
||||||
|
const TrainParam& param) {
|
||||||
|
RegTree & tree = *p_tree;
|
||||||
|
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
|
||||||
|
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
|
||||||
|
flag_nodes(h_nodes, &node_flags, 0, NODE);
|
||||||
|
|
||||||
|
int nid = 0;
|
||||||
|
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
||||||
|
NodeType flag = node_flags[gpu_nid];
|
||||||
|
const Node& n = h_nodes[gpu_nid];
|
||||||
|
if (flag == NODE) {
|
||||||
|
tree.AddChilds(nid);
|
||||||
|
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
|
||||||
|
tree.stat(nid).loss_chg = n.split.loss_chg;
|
||||||
|
tree.stat(nid).base_weight = n.weight;
|
||||||
|
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
||||||
|
tree[tree[nid].cleft()].set_leaf(0);
|
||||||
|
tree[tree[nid].cright()].set_leaf(0);
|
||||||
|
nid++;
|
||||||
|
} else if (flag == LEAF) {
|
||||||
|
tree[nid].set_leaf(n.weight * param.learning_rate);
|
||||||
|
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
||||||
|
nid++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
|
inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
|
||||||
|
if (subsample == 1.0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::dvec<gpu_gpair>& gpair = *p_gpair;
|
||||||
|
|
||||||
|
auto d_gpair = gpair.data();
|
||||||
|
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
|
||||||
|
|
||||||
|
dh::launch_n(gpair.size(), [=] __device__(int i) {
|
||||||
|
if (!rng(i)) {
|
||||||
|
d_gpair[i] = gpu_gpair();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
||||||
|
int n = colsample * features.size();
|
||||||
|
CHECK_GT(n, 0);
|
||||||
|
|
||||||
|
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
|
||||||
|
features.resize(n);
|
||||||
|
|
||||||
|
return features;
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,18 +1,19 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <device_launch_parameters.h>
|
#include <device_launch_parameters.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
|
#include <thrust/random.h>
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
#include <thrust/random.h>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "cusparse_v2.h"
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
@ -30,7 +31,8 @@ namespace dh {
|
|||||||
|
|
||||||
#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__)
|
#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__)
|
||||||
|
|
||||||
cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
|
inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
|
||||||
|
int line) {
|
||||||
if (code != cudaSuccess) {
|
if (code != cudaSuccess) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << file << "(" << line << ")";
|
ss << file << "(" << line << ")";
|
||||||
@ -41,16 +43,29 @@ cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
|
|||||||
|
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__)
|
||||||
|
|
||||||
#define gpuErrchk(ans) \
|
inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status,
|
||||||
|
const char *file, int line) {
|
||||||
|
if (status != CUSPARSE_STATUS_SUCCESS) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "cusparse error: " << file << "(" << line << ")";
|
||||||
|
std::string error_text;
|
||||||
|
ss >> error_text;
|
||||||
|
throw error_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define gpuErrchk(ans) \
|
||||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||||
inline void gpuAssert(cudaError_t code, const char *file, int line,
|
inline void gpuAssert(cudaError_t code, const char *file, int line,
|
||||||
bool abort = true) {
|
bool abort = true) {
|
||||||
if (code != cudaSuccess) {
|
if (code != cudaSuccess) {
|
||||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
||||||
line);
|
line);
|
||||||
if (abort)
|
if (abort) exit(code);
|
||||||
exit(code);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,10 +134,10 @@ struct DeviceTimer {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef DEVICE_TIMER
|
#ifdef DEVICE_TIMER
|
||||||
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT
|
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT
|
||||||
: GTimer(GTimer), start(clock()), slot(slot) {}
|
: GTimer(GTimer), start(clock()), slot(slot) {}
|
||||||
#else
|
#else
|
||||||
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT
|
__device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__device__ void End() {
|
__device__ void End() {
|
||||||
@ -224,21 +239,22 @@ class range {
|
|||||||
iterator end_;
|
iterator end_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> __device__ range grid_stride_range(T begin, T end) {
|
template <typename T>
|
||||||
|
__device__ range grid_stride_range(T begin, T end) {
|
||||||
begin += blockDim.x * blockIdx.x + threadIdx.x;
|
begin += blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
range r(begin, end);
|
range r(begin, end);
|
||||||
r.step(gridDim.x * blockDim.x);
|
r.step(gridDim.x * blockDim.x);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> __device__ range block_stride_range(T begin, T end) {
|
template <typename T>
|
||||||
|
__device__ range block_stride_range(T begin, T end) {
|
||||||
begin += threadIdx.x;
|
begin += threadIdx.x;
|
||||||
range r(begin, end);
|
range r(begin, end);
|
||||||
r.step(blockDim.x);
|
r.step(blockDim.x);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Threadblock iterates over range, filling with value
|
// Threadblock iterates over range, filling with value
|
||||||
template <typename IterT, typename ValueT>
|
template <typename IterT, typename ValueT>
|
||||||
__device__ void block_fill(IterT begin, size_t n, ValueT value) {
|
__device__ void block_fill(IterT begin, size_t n, ValueT value) {
|
||||||
@ -253,7 +269,8 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) {
|
|||||||
|
|
||||||
class bulk_allocator;
|
class bulk_allocator;
|
||||||
|
|
||||||
template <typename T> class dvec {
|
template <typename T>
|
||||||
|
class dvec {
|
||||||
friend bulk_allocator;
|
friend bulk_allocator;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -302,7 +319,8 @@ template <typename T> class dvec {
|
|||||||
return thrust::device_pointer_cast(_ptr + size());
|
return thrust::device_pointer_cast(_ptr + size());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T2> dvec &operator=(const std::vector<T2> &other) {
|
template <typename T2>
|
||||||
|
dvec &operator=(const std::vector<T2> &other) {
|
||||||
if (other.size() != size()) {
|
if (other.size() != size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Cannot copy assign vector to dvec, sizes are different");
|
"Cannot copy assign vector to dvec, sizes are different");
|
||||||
@ -331,7 +349,8 @@ class bulk_allocator {
|
|||||||
|
|
||||||
const size_t align = 256;
|
const size_t align = 256;
|
||||||
|
|
||||||
template <typename SizeT> size_t align_round_up(SizeT n) {
|
template <typename SizeT>
|
||||||
|
size_t align_round_up(SizeT n) {
|
||||||
if (n % align == 0) {
|
if (n % align == 0) {
|
||||||
return n;
|
return n;
|
||||||
} else {
|
} else {
|
||||||
@ -357,7 +376,7 @@ class bulk_allocator {
|
|||||||
template <typename T, typename SizeT, typename... Args>
|
template <typename T, typename SizeT, typename... Args>
|
||||||
void allocate_dvec(char *ptr, dvec<T> *first_vec, SizeT first_size,
|
void allocate_dvec(char *ptr, dvec<T> *first_vec, SizeT first_size,
|
||||||
Args... args) {
|
Args... args) {
|
||||||
first_vec->external_allocate(static_cast<void*>(ptr), first_size);
|
first_vec->external_allocate(static_cast<void *>(ptr), first_size);
|
||||||
ptr += align_round_up(first_size * sizeof(T));
|
ptr += align_round_up(first_size * sizeof(T));
|
||||||
allocate_dvec(ptr, args...);
|
allocate_dvec(ptr, args...);
|
||||||
}
|
}
|
||||||
@ -366,14 +385,15 @@ class bulk_allocator {
|
|||||||
bulk_allocator() : _size(0), d_ptr(NULL) {}
|
bulk_allocator() : _size(0), d_ptr(NULL) {}
|
||||||
|
|
||||||
~bulk_allocator() {
|
~bulk_allocator() {
|
||||||
if (!d_ptr == NULL) {
|
if (!(d_ptr == nullptr)) {
|
||||||
safe_cuda(cudaFree(d_ptr));
|
safe_cuda(cudaFree(d_ptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size() { return _size; }
|
size_t size() { return _size; }
|
||||||
|
|
||||||
template <typename... Args> void allocate(Args... args) {
|
template <typename... Args>
|
||||||
|
void allocate(Args... args) {
|
||||||
if (d_ptr != NULL) {
|
if (d_ptr != NULL) {
|
||||||
throw std::runtime_error("Bulk allocator already allocated");
|
throw std::runtime_error("Bulk allocator already allocated");
|
||||||
}
|
}
|
||||||
@ -393,14 +413,19 @@ struct CubMemory {
|
|||||||
|
|
||||||
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
|
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
|
||||||
|
|
||||||
~CubMemory() {
|
~CubMemory() { Free(); }
|
||||||
|
void Free() {
|
||||||
if (d_temp_storage != NULL) {
|
if (d_temp_storage != NULL) {
|
||||||
safe_cuda(cudaFree(d_temp_storage));
|
safe_cuda(cudaFree(d_temp_storage));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Allocate() {
|
void LazyAllocate(size_t n_bytes) {
|
||||||
safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes));
|
if (n_bytes > temp_storage_bytes) {
|
||||||
|
Free();
|
||||||
|
safe_cuda(cudaMalloc(&d_temp_storage, n_bytes));
|
||||||
|
temp_storage_bytes = n_bytes;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsAllocated() { return d_temp_storage != NULL; }
|
bool IsAllocated() { return d_temp_storage != NULL; }
|
||||||
@ -453,36 +478,47 @@ void print(char *label, const thrust::device_vector<T> &v,
|
|||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
|
template <typename T1, typename T2>
|
||||||
|
T1 div_round_up(const T1 a, const T2 b) {
|
||||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
|
template <typename T>
|
||||||
|
thrust::device_ptr<T> dptr(T *d_ptr) {
|
||||||
return thrust::device_pointer_cast(d_ptr);
|
return thrust::device_pointer_cast(d_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
|
template <typename T>
|
||||||
|
T *raw(thrust::device_vector<T> &v) { // NOLINT
|
||||||
return raw_pointer_cast(v.data());
|
return raw_pointer_cast(v.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
|
template <typename T>
|
||||||
|
const T *raw(const thrust::device_vector<T> &v) { // NOLINT
|
||||||
|
return raw_pointer_cast(v.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
size_t size_bytes(const thrust::device_vector<T> &v) {
|
||||||
return sizeof(T) * v.size();
|
return sizeof(T) * v.size();
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
* Kernel launcher
|
* Kernel launcher
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename L> __global__ void launch_n_kernel(size_t n, L lambda) {
|
template <typename L>
|
||||||
|
__global__ void launch_n_kernel(size_t n, L lambda) {
|
||||||
for (auto i : grid_stride_range(static_cast<size_t>(0), n)) {
|
for (auto i : grid_stride_range(static_cast<size_t>(0), n)) {
|
||||||
lambda(i);
|
lambda(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256>
|
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||||
inline void launch_n(size_t n, L lambda) {
|
inline void launch_n(size_t n, L lambda) {
|
||||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||||
|
#if defined(__CUDACC__)
|
||||||
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
|
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -493,7 +529,7 @@ struct BernoulliRng {
|
|||||||
float p;
|
float p;
|
||||||
int seed;
|
int seed;
|
||||||
|
|
||||||
__host__ __device__ BernoulliRng(float p, int seed):p(p), seed(seed) {}
|
__host__ __device__ BernoulliRng(float p, int seed) : p(p), seed(seed) {}
|
||||||
|
|
||||||
__host__ __device__ bool operator()(const int i) const {
|
__host__ __device__ bool operator()(const int i) const {
|
||||||
thrust::default_random_engine rng(seed);
|
thrust::default_random_engine rng(seed);
|
||||||
@ -504,5 +540,4 @@ struct BernoulliRng {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -9,7 +9,7 @@
|
|||||||
#include "find_split_multiscan.cuh"
|
#include "find_split_multiscan.cuh"
|
||||||
#include "find_split_sorting.cuh"
|
#include "find_split_sorting.cuh"
|
||||||
#include "gpu_data.cuh"
|
#include "gpu_data.cuh"
|
||||||
#include "types_functions.cuh"
|
#include "types.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|||||||
@ -6,7 +6,8 @@
|
|||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "gpu_data.cuh"
|
#include "gpu_data.cuh"
|
||||||
#include "types_functions.cuh"
|
#include "types.cuh"
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|||||||
@ -5,7 +5,8 @@
|
|||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "types_functions.cuh"
|
#include "types.cuh"
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|||||||
15
plugin/updater_gpu/src/functions.cuh
Normal file
15
plugin/updater_gpu/src/functions.cuh
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2016 Rory mitchell
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include "types.cuh"
|
||||||
|
#include "../../../src/tree/param.h"
|
||||||
|
#include "../../../src/common/random.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,6 +1,7 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
|
#include "gpu_builder.cuh"
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#include <cuda_profiler_api.h>
|
#include <cuda_profiler_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@ -11,31 +12,35 @@
|
|||||||
#include <thrust/host_vector.h>
|
#include <thrust/host_vector.h>
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <random>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../../src/common/random.h"
|
#include "../../../src/common/random.h"
|
||||||
|
#include "common.cuh"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "find_split.cuh"
|
#include "find_split.cuh"
|
||||||
#include "gpu_builder.cuh"
|
|
||||||
#include "types_functions.cuh"
|
|
||||||
#include "gpu_data.cuh"
|
#include "gpu_data.cuh"
|
||||||
|
#include "types.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
|
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
|
||||||
|
|
||||||
void GPUBuilder::Init(const TrainParam ¶m_in) {
|
void GPUBuilder::Init(const TrainParam& param_in) {
|
||||||
param = param_in;
|
param = param_in;
|
||||||
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||||
|
if (!param.silent) {
|
||||||
|
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUBuilder::~GPUBuilder() { delete gpu_data; }
|
GPUBuilder::~GPUBuilder() { delete gpu_data; }
|
||||||
|
|
||||||
void GPUBuilder::UpdateNodeId(int level) {
|
void GPUBuilder::UpdateNodeId(int level) {
|
||||||
auto *d_node_id_instance = gpu_data->node_id_instance.data();
|
auto* d_node_id_instance = gpu_data->node_id_instance.data();
|
||||||
Node *d_nodes = gpu_data->nodes.data();
|
Node* d_nodes = gpu_data->nodes.data();
|
||||||
|
|
||||||
dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) {
|
dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) {
|
||||||
NodeIdT item_node_id = d_node_id_instance[i];
|
NodeIdT item_node_id = d_node_id_instance[i];
|
||||||
@ -57,10 +62,10 @@ void GPUBuilder::UpdateNodeId(int level) {
|
|||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
|
||||||
auto *d_fvalues = gpu_data->fvalues.data();
|
auto* d_fvalues = gpu_data->fvalues.data();
|
||||||
auto *d_instance_id = gpu_data->instance_id.data();
|
auto* d_instance_id = gpu_data->instance_id.data();
|
||||||
auto *d_node_id = gpu_data->node_id.data();
|
auto* d_node_id = gpu_data->node_id.data();
|
||||||
auto *d_feature_id = gpu_data->feature_id.data();
|
auto* d_feature_id = gpu_data->feature_id.data();
|
||||||
|
|
||||||
// Update node based on fvalue where exists
|
// Update node based on fvalue where exists
|
||||||
dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) {
|
dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) {
|
||||||
@ -128,75 +133,52 @@ void GPUBuilder::Sort(int level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GPUBuilder::ColsampleTree() {
|
void GPUBuilder::ColsampleTree() {
|
||||||
unsigned n = static_cast<unsigned>(
|
unsigned n =
|
||||||
param.colsample_bytree * gpu_data->n_features);
|
static_cast<unsigned>(param.colsample_bytree * gpu_data->n_features);
|
||||||
CHECK_GT(n, 0);
|
CHECK_GT(n, 0);
|
||||||
|
|
||||||
feature_set_tree.resize(gpu_data->n_features);
|
feature_set_tree.resize(gpu_data->n_features);
|
||||||
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
||||||
std::shuffle(feature_set_tree.begin(), feature_set_tree.end(),
|
std::shuffle(feature_set_tree.begin(), feature_set_tree.end(),
|
||||||
common::GlobalRandom());
|
common::GlobalRandom());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
void GPUBuilder::Update(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
|
||||||
RegTree *p_tree) {
|
RegTree* p_tree) {
|
||||||
try {
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
dh::Timer update;
|
this->InitFirstNode();
|
||||||
dh::Timer t;
|
this->ColsampleTree();
|
||||||
this->InitData(gpair, *p_fmat, *p_tree);
|
|
||||||
t.printElapsed("init data");
|
|
||||||
this->InitFirstNode();
|
|
||||||
this->ColsampleTree();
|
|
||||||
|
|
||||||
for (int level = 0; level < param.max_depth; level++) {
|
for (int level = 0; level < param.max_depth; level++) {
|
||||||
bool use_multiscan_algorithm = level < multiscan_levels;
|
bool use_multiscan_algorithm = level < multiscan_levels;
|
||||||
|
|
||||||
t.reset();
|
if (level > 0) {
|
||||||
if (level > 0) {
|
this->UpdateNodeId(level);
|
||||||
dh::Timer update_node;
|
|
||||||
this->UpdateNodeId(level);
|
|
||||||
update_node.printElapsed("node");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (level > 0 && !use_multiscan_algorithm) {
|
|
||||||
dh::Timer s;
|
|
||||||
this->Sort(level);
|
|
||||||
s.printElapsed("sort");
|
|
||||||
}
|
|
||||||
|
|
||||||
dh::Timer split;
|
|
||||||
find_split(gpu_data, param, level, use_multiscan_algorithm,
|
|
||||||
feature_set_tree, &feature_set_level);
|
|
||||||
|
|
||||||
split.printElapsed("split");
|
|
||||||
|
|
||||||
t.printElapsed("level");
|
|
||||||
}
|
}
|
||||||
this->CopyTree(*p_tree);
|
|
||||||
update.printElapsed("update");
|
if (level > 0 && !use_multiscan_algorithm) {
|
||||||
} catch (thrust::system_error &e) {
|
this->Sort(level);
|
||||||
std::cerr << "CUDA error: " << e.what() << std::endl;
|
}
|
||||||
exit(-1);
|
|
||||||
} catch (const std::exception &e) {
|
find_split(gpu_data, param, level, use_multiscan_algorithm,
|
||||||
std::cerr << "Error: " << e.what() << std::endl;
|
feature_set_tree, &feature_set_level);
|
||||||
exit(-1);
|
|
||||||
} catch (...) {
|
|
||||||
std::cerr << "Unknown exception." << std::endl;
|
|
||||||
exit(-1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dense2sparse_tree(p_tree, gpu_data->nodes.tbegin(), gpu_data->nodes.tend(),
|
||||||
|
param);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
void GPUBuilder::InitData(const std::vector<bst_gpair>& gpair, DMatrix& fmat,
|
||||||
const RegTree &tree) {
|
const RegTree& tree) {
|
||||||
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block";
|
CHECK(fmat.SingleColBlock()) << "grow_gpu: must have single column block. "
|
||||||
|
"Try setting 'tree_method' parameter to "
|
||||||
|
"'exact'";
|
||||||
|
|
||||||
if (gpu_data->IsAllocated()) {
|
if (gpu_data->IsAllocated()) {
|
||||||
gpu_data->Reset(gpair, param.subsample);
|
gpu_data->Reset(gpair, param.subsample);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dh::Timer t;
|
|
||||||
|
|
||||||
MetaInfo info = fmat.info();
|
MetaInfo info = fmat.info();
|
||||||
|
|
||||||
std::vector<int> foffsets;
|
std::vector<int> foffsets;
|
||||||
@ -208,31 +190,27 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
|
|||||||
instance_id.reserve(info.num_col * info.num_row);
|
instance_id.reserve(info.num_col * info.num_row);
|
||||||
feature_id.reserve(info.num_col * info.num_row);
|
feature_id.reserve(info.num_col * info.num_row);
|
||||||
|
|
||||||
dmlc::DataIter<ColBatch> *iter = fmat.ColIterator();
|
dmlc::DataIter<ColBatch>* iter = fmat.ColIterator();
|
||||||
|
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch& batch = iter->Value();
|
||||||
|
|
||||||
for (int i = 0; i < batch.size; i++) {
|
for (int i = 0; i < batch.size; i++) {
|
||||||
const ColBatch::Inst &col = batch[i];
|
const ColBatch::Inst& col = batch[i];
|
||||||
|
|
||||||
for (const ColBatch::Entry *it = col.data; it != col.data + col.length;
|
for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
|
||||||
it++) {
|
it++) {
|
||||||
bst_uint inst_id = it->index;
|
bst_uint inst_id = it->index;
|
||||||
fvalues.push_back(it->fvalue);
|
fvalues.push_back(it->fvalue);
|
||||||
instance_id.push_back(inst_id);
|
instance_id.push_back(inst_id);
|
||||||
feature_id.push_back(i);
|
feature_id.push_back(i);
|
||||||
}
|
}
|
||||||
foffsets.push_back(fvalues.size());
|
foffsets.push_back(fvalues.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.printElapsed("dmatrix");
|
|
||||||
t.reset();
|
|
||||||
gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair,
|
gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair,
|
||||||
info.num_row, info.num_col, param.max_depth, param);
|
info.num_row, info.num_col, param.max_depth, param);
|
||||||
|
|
||||||
t.printElapsed("gpu init");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUBuilder::InitFirstNode() {
|
void GPUBuilder::InitFirstNode() {
|
||||||
@ -248,60 +226,5 @@ void GPUBuilder::InitFirstNode() {
|
|||||||
|
|
||||||
thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin());
|
thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin());
|
||||||
}
|
}
|
||||||
|
|
||||||
enum NodeType {
|
|
||||||
NODE = 0,
|
|
||||||
LEAF = 1,
|
|
||||||
UNUSED = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Recursively label node types
|
|
||||||
void flag_nodes(const thrust::host_vector<Node> &nodes,
|
|
||||||
std::vector<NodeType> *node_flags, int nid, NodeType type) {
|
|
||||||
if (nid >= nodes.size() || type == UNUSED) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Node &n = nodes[nid];
|
|
||||||
|
|
||||||
// Current node and all children are valid
|
|
||||||
if (n.split.loss_chg > rt_eps) {
|
|
||||||
(*node_flags)[nid] = NODE;
|
|
||||||
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
|
|
||||||
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
|
|
||||||
} else {
|
|
||||||
// Current node is leaf, therefore is valid but all children are invalid
|
|
||||||
(*node_flags)[nid] = LEAF;
|
|
||||||
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
|
|
||||||
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy gpu dense representation of tree to xgboost sparse representation
|
|
||||||
void GPUBuilder::CopyTree(RegTree &tree) {
|
|
||||||
std::vector<Node> h_nodes = gpu_data->nodes.as_vector();
|
|
||||||
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
|
|
||||||
flag_nodes(h_nodes, &node_flags, 0, NODE);
|
|
||||||
|
|
||||||
int nid = 0;
|
|
||||||
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
|
||||||
NodeType flag = node_flags[gpu_nid];
|
|
||||||
const Node &n = h_nodes[gpu_nid];
|
|
||||||
if (flag == NODE) {
|
|
||||||
tree.AddChilds(nid);
|
|
||||||
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
|
|
||||||
tree.stat(nid).loss_chg = n.split.loss_chg;
|
|
||||||
tree.stat(nid).base_weight = n.weight;
|
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
|
||||||
tree[tree[nid].cleft()].set_leaf(0);
|
|
||||||
tree[tree[nid].cright()].set_leaf(0);
|
|
||||||
nid++;
|
|
||||||
} else if (flag == LEAF) {
|
|
||||||
tree[nid].set_leaf(n.weight * param.learning_rate);
|
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
|
||||||
nid++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -19,10 +19,7 @@ class GPUBuilder {
|
|||||||
void Init(const TrainParam ¶m);
|
void Init(const TrainParam ¶m);
|
||||||
~GPUBuilder();
|
~GPUBuilder();
|
||||||
|
|
||||||
void UpdateParam(const TrainParam ¶m)
|
void UpdateParam(const TrainParam ¶m) { this->param = param; }
|
||||||
{
|
|
||||||
this->param = param;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||||
RegTree *p_tree);
|
RegTree *p_tree);
|
||||||
@ -30,13 +27,11 @@ class GPUBuilder {
|
|||||||
void UpdateNodeId(int level);
|
void UpdateNodeId(int level);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
||||||
const RegTree &tree);
|
const RegTree &tree);
|
||||||
|
|
||||||
float GetSubsamplingRate(MetaInfo info);
|
|
||||||
void Sort(int level);
|
void Sort(int level);
|
||||||
void InitFirstNode();
|
void InitFirstNode();
|
||||||
void CopyTree(RegTree &tree); // NOLINT
|
|
||||||
void ColsampleTree();
|
void ColsampleTree();
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cub/cub.cuh>
|
|
||||||
#include <xgboost/logging.h>
|
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "device_helpers.cuh"
|
|
||||||
#include "../../src/tree/param.h"
|
#include "../../src/tree/param.h"
|
||||||
#include "types_functions.cuh"
|
#include "common.cuh"
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
#include "types.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -67,9 +68,8 @@ struct GPUData {
|
|||||||
cub::DoubleBuffer<int> db_value;
|
cub::DoubleBuffer<int> db_value;
|
||||||
|
|
||||||
cub::DeviceSegmentedRadixSort::SortPairs(
|
cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
cub_mem.data(), cub_mem_size, db_key,
|
cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(),
|
||||||
db_value, in_fvalues.size(), n_features,
|
n_features, foffsets.data(), foffsets.data() + 1);
|
||||||
foffsets.data(), foffsets.data() + 1);
|
|
||||||
|
|
||||||
// Allocate memory
|
// Allocate memory
|
||||||
size_t free_memory = dh::available_memory();
|
size_t free_memory = dh::available_memory();
|
||||||
@ -100,7 +100,6 @@ struct GPUData {
|
|||||||
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
|
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
|
||||||
param_in.reg_alpha, param_in.max_delta_step);
|
param_in.reg_alpha, param_in.max_delta_step);
|
||||||
|
|
||||||
|
|
||||||
allocated = true;
|
allocated = true;
|
||||||
|
|
||||||
this->Reset(in_gpair, param_in.subsample);
|
this->Reset(in_gpair, param_in.subsample);
|
||||||
@ -109,33 +108,16 @@ struct GPUData {
|
|||||||
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
|
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
|
||||||
fvalues.tbegin(), node_id.tbegin()));
|
fvalues.tbegin(), node_id.tbegin()));
|
||||||
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
dh::safe_cuda(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
~GPUData() {}
|
~GPUData() {}
|
||||||
|
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
|
||||||
void MarkSubsample(float subsample) {
|
|
||||||
if (subsample == 1.0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto d_gpair = gpair.data();
|
|
||||||
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
|
|
||||||
|
|
||||||
dh::launch_n(n_instances, [=] __device__(int i) {
|
|
||||||
if (!rng(i)) {
|
|
||||||
d_gpair[i] = gpu_gpair();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset memory for new boosting iteration
|
// Reset memory for new boosting iteration
|
||||||
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
|
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
|
||||||
CHECK(allocated);
|
CHECK(allocated);
|
||||||
gpair = in_gpair;
|
gpair = in_gpair;
|
||||||
this->MarkSubsample(subsample);
|
subsample_gpair(&gpair, subsample);
|
||||||
instance_id = instance_id_cached;
|
instance_id = instance_id_cached;
|
||||||
fvalues = fvalues_cached;
|
fvalues = fvalues_cached;
|
||||||
nodes.fill(Node());
|
nodes.fill(Node());
|
||||||
@ -153,7 +135,6 @@ struct GPUData {
|
|||||||
auto d_instance_id = instance_id.data();
|
auto d_instance_id = instance_id.data();
|
||||||
|
|
||||||
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
|
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
|
||||||
// Item item = d_items[i];
|
|
||||||
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
|
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
560
plugin/updater_gpu/src/gpu_hist_builder.cu
Normal file
560
plugin/updater_gpu/src/gpu_hist_builder.cu
Normal file
@ -0,0 +1,560 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017 Rory mitchell
|
||||||
|
*/
|
||||||
|
#include <thrust/binary_search.h>
|
||||||
|
#include <thrust/count.h>
|
||||||
|
#include <thrust/sequence.h>
|
||||||
|
#include <thrust/sort.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <numeric>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
#include "gpu_hist_builder.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) {
|
||||||
|
CHECK_EQ(gidx.size(), gmat.index.size())
|
||||||
|
<< "gidx must be externally allocated";
|
||||||
|
CHECK_EQ(ridx.size(), gmat.index.size())
|
||||||
|
<< "ridx must be externally allocated";
|
||||||
|
|
||||||
|
gidx = gmat.index;
|
||||||
|
thrust::device_vector<int> row_ptr = gmat.row_ptr;
|
||||||
|
|
||||||
|
auto counting = thrust::make_counting_iterator(0);
|
||||||
|
thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting,
|
||||||
|
counting + gidx.size(), ridx.tbegin());
|
||||||
|
thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(),
|
||||||
|
[=] __device__(int val) { return val - 1; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceHist::Init(int n_bins_in) {
|
||||||
|
this->n_bins = n_bins_in;
|
||||||
|
CHECK(!hist.empty()) << "DeviceHist must be externally allocated";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceHist::Reset() { hist.fill(gpu_gpair()); }
|
||||||
|
|
||||||
|
gpu_gpair* DeviceHist::GetLevelPtr(int depth) {
|
||||||
|
return hist.data() + n_nodes(depth - 1) * n_bins;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DeviceHist::LevelSize(int depth) { return n_bins * n_nodes_level(depth); }
|
||||||
|
|
||||||
|
HistBuilder DeviceHist::GetBuilder() {
|
||||||
|
return HistBuilder(hist.data(), n_bins);
|
||||||
|
}
|
||||||
|
|
||||||
|
HistBuilder::HistBuilder(gpu_gpair* ptr, int n_bins)
|
||||||
|
: d_hist(ptr), n_bins(n_bins) {}
|
||||||
|
|
||||||
|
__device__ void HistBuilder::Add(gpu_gpair gpair, int gidx, int nidx) const {
|
||||||
|
int hist_idx = nidx * n_bins + gidx;
|
||||||
|
atomicAdd(&(d_hist[hist_idx]._grad), gpair._grad);
|
||||||
|
atomicAdd(&(d_hist[hist_idx]._hess), gpair._hess);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
|
||||||
|
return d_hist[nidx * n_bins + gidx];
|
||||||
|
}
|
||||||
|
|
||||||
|
GPUHistBuilder::GPUHistBuilder() {}
|
||||||
|
|
||||||
|
GPUHistBuilder::~GPUHistBuilder() {}
|
||||||
|
|
||||||
|
void GPUHistBuilder::Init(const TrainParam& param) {
|
||||||
|
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
||||||
|
this->param = param;
|
||||||
|
initialised = false;
|
||||||
|
is_dense = false;
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||||
|
if (!param.silent) {
|
||||||
|
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ReductionOpT>
|
||||||
|
struct ReduceBySegmentOp {
|
||||||
|
ReductionOpT op;
|
||||||
|
|
||||||
|
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
|
||||||
|
|
||||||
|
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op)
|
||||||
|
: op(op) {}
|
||||||
|
|
||||||
|
template <typename KeyValuePairT>
|
||||||
|
__host__ __device__ __forceinline__ KeyValuePairT
|
||||||
|
operator()(const KeyValuePairT& first, const KeyValuePairT& second) {
|
||||||
|
KeyValuePairT retval;
|
||||||
|
retval.key = second.key;
|
||||||
|
retval.value =
|
||||||
|
first.key != second.key ? second.value : op(first.value, second.value);
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int ITEMS_PER_THREAD, int BLOCK_THREADS>
|
||||||
|
__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx,
|
||||||
|
NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins,
|
||||||
|
int depth, int n) {
|
||||||
|
typedef cub::KeyValuePair<int, gpu_gpair> OffsetValuePairT;
|
||||||
|
typedef cub::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD,
|
||||||
|
cub::BLOCK_LOAD_VECTORIZE>
|
||||||
|
BlockLoadT;
|
||||||
|
typedef cub::BlockRadixSort<int, BLOCK_THREADS, ITEMS_PER_THREAD, int>
|
||||||
|
BlockRadixSortT;
|
||||||
|
typedef cub::BlockDiscontinuity<int, BLOCK_THREADS> BlockDiscontinuityKeysT;
|
||||||
|
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
|
||||||
|
typedef cub::BlockScan<OffsetValuePairT, BLOCK_THREADS,
|
||||||
|
cub::BLOCK_SCAN_WARP_SCANS>
|
||||||
|
BlockScanT;
|
||||||
|
|
||||||
|
union TempStorage {
|
||||||
|
typename BlockLoadT::TempStorage load;
|
||||||
|
typename BlockRadixSortT::TempStorage sort;
|
||||||
|
typename BlockScanT::TempStorage scan;
|
||||||
|
typename BlockDiscontinuityKeysT::TempStorage disc;
|
||||||
|
};
|
||||||
|
|
||||||
|
__shared__ TempStorage temp_storage;
|
||||||
|
|
||||||
|
int ridx[ITEMS_PER_THREAD];
|
||||||
|
int gidx[ITEMS_PER_THREAD];
|
||||||
|
|
||||||
|
const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS;
|
||||||
|
int block_offset = blockIdx.x * TILE_SIZE;
|
||||||
|
|
||||||
|
BlockLoadT(temp_storage.load)
|
||||||
|
.Load(d_ridx + block_offset, ridx, n - block_offset, -1);
|
||||||
|
BlockLoadT(temp_storage.load)
|
||||||
|
.Load(d_gidx + block_offset, gidx, n - block_offset, -1);
|
||||||
|
|
||||||
|
int hist_idx[ITEMS_PER_THREAD];
|
||||||
|
|
||||||
|
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||||
|
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||||
|
hist_idx[ITEM] =
|
||||||
|
(d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM];
|
||||||
|
} else {
|
||||||
|
hist_idx[ITEM] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx);
|
||||||
|
|
||||||
|
OffsetValuePairT kv[ITEMS_PER_THREAD];
|
||||||
|
|
||||||
|
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||||
|
kv[ITEM].key = hist_idx[ITEM];
|
||||||
|
if (ridx[ITEM] > -1) {
|
||||||
|
kv[ITEM].value = d_gpair[ridx[ITEM]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
// Scan
|
||||||
|
BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT());
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
int flags[ITEMS_PER_THREAD];
|
||||||
|
BlockDiscontinuityKeysT(temp_storage.disc)
|
||||||
|
.FlagTails(flags, hist_idx, cub::Inequality());
|
||||||
|
|
||||||
|
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||||
|
if (flags[ITEM]) {
|
||||||
|
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||||
|
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad);
|
||||||
|
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::BuildHist(int depth) {
|
||||||
|
auto d_ridx = device_matrix.ridx.data();
|
||||||
|
auto d_gidx = device_matrix.gidx.data();
|
||||||
|
auto d_position = position.data();
|
||||||
|
auto d_gpair = device_gpair.data();
|
||||||
|
auto hist_builder = hist.GetBuilder();
|
||||||
|
|
||||||
|
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||||
|
int ridx = d_ridx[idx];
|
||||||
|
int pos = d_position[ridx];
|
||||||
|
|
||||||
|
// Only increment even nodes
|
||||||
|
if (pos < 0 || pos % 2 == 1) return;
|
||||||
|
|
||||||
|
int gidx = d_gidx[idx];
|
||||||
|
gpu_gpair gpair = d_gpair[ridx];
|
||||||
|
|
||||||
|
hist_builder.Add(gpair, gidx, pos);
|
||||||
|
});
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
// Subtraction trick
|
||||||
|
int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins;
|
||||||
|
if (n_sub_bins > 0) {
|
||||||
|
dh::launch_n(n_sub_bins, [=] __device__(int idx) {
|
||||||
|
int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2);
|
||||||
|
int gidx = idx % hist_builder.n_bins;
|
||||||
|
gpu_gpair parent = hist_builder.Get(gidx, nidx / 2);
|
||||||
|
gpu_gpair right = hist_builder.Get(gidx, nidx + 1);
|
||||||
|
hist_builder.Add(parent - right, gidx, nidx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::FindSplit(int depth) {
|
||||||
|
auto counting = thrust::make_counting_iterator(0);
|
||||||
|
auto d_gidx_feature_map = gidx_feature_map.data();
|
||||||
|
int n_bins = hmat_.row_ptr.back();
|
||||||
|
int n_features = hmat_.row_ptr.size() - 1;
|
||||||
|
|
||||||
|
auto feature_boundary = [=] __device__(int idx_a, int idx_b) {
|
||||||
|
int gidx_a = idx_a % n_bins;
|
||||||
|
int gidx_b = idx_b % n_bins;
|
||||||
|
return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b];
|
||||||
|
}; // NOLINT
|
||||||
|
|
||||||
|
// Reduce node sums
|
||||||
|
{
|
||||||
|
size_t temp_storage_bytes;
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(
|
||||||
|
nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(),
|
||||||
|
n_nodes_level(depth) * n_features, feature_segments.data(),
|
||||||
|
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
|
||||||
|
cub_mem.LazyAllocate(temp_storage_bytes);
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(
|
||||||
|
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
|
||||||
|
hist.GetLevelPtr(depth), node_sums.data(),
|
||||||
|
n_nodes_level(depth) * n_features, feature_segments.data(),
|
||||||
|
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan
|
||||||
|
thrust::exclusive_scan_by_key(
|
||||||
|
counting, counting + hist.LevelSize(depth),
|
||||||
|
thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(),
|
||||||
|
gpu_gpair(), feature_boundary);
|
||||||
|
|
||||||
|
// Calculate gain
|
||||||
|
auto d_gain = gain.data();
|
||||||
|
auto d_nodes = nodes.data();
|
||||||
|
auto d_node_sums = node_sums.data();
|
||||||
|
auto d_hist_scan = hist_scan.data();
|
||||||
|
GPUTrainingParam gpu_param_alias =
|
||||||
|
gpu_param; // Must be local variable to be used in device lambda
|
||||||
|
bool colsample =
|
||||||
|
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||||
|
auto d_feature_flags = feature_flags.data();
|
||||||
|
|
||||||
|
dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) {
|
||||||
|
int node_segment = idx / n_bins;
|
||||||
|
int node_idx = n_nodes(depth - 1) + node_segment;
|
||||||
|
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||||
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
|
int gidx = idx % n_bins;
|
||||||
|
int findex = d_gidx_feature_map[gidx];
|
||||||
|
|
||||||
|
// colsample
|
||||||
|
if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) {
|
||||||
|
d_gain[idx] = 0;
|
||||||
|
} else {
|
||||||
|
gpu_gpair scan = d_hist_scan[idx];
|
||||||
|
gpu_gpair sum = d_node_sums[node_segment * n_features + findex];
|
||||||
|
gpu_gpair missing = parent_sum - sum;
|
||||||
|
|
||||||
|
bool missing_left;
|
||||||
|
d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain,
|
||||||
|
gpu_param_alias, missing_left);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
// Find best gain
|
||||||
|
{
|
||||||
|
size_t temp_storage_bytes;
|
||||||
|
cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(),
|
||||||
|
argmax.data(), n_nodes_level(depth),
|
||||||
|
hist_node_segments.data(),
|
||||||
|
hist_node_segments.data() + 1);
|
||||||
|
cub_mem.LazyAllocate(temp_storage_bytes);
|
||||||
|
cub::DeviceSegmentedReduce::ArgMax(
|
||||||
|
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(),
|
||||||
|
argmax.data(), n_nodes_level(depth), hist_node_segments.data(),
|
||||||
|
hist_node_segments.data() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto d_argmax = argmax.data();
|
||||||
|
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||||
|
auto d_fidx_min_map = fidx_min_map.data();
|
||||||
|
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
|
||||||
|
int max_idx = n_bins * idx + d_argmax[idx].key;
|
||||||
|
int gidx = max_idx % n_bins;
|
||||||
|
int fidx = d_gidx_feature_map[gidx];
|
||||||
|
int node_segment = max_idx / n_bins;
|
||||||
|
int node_idx = n_nodes(depth - 1) + node_segment;
|
||||||
|
gpu_gpair scan = d_hist_scan[max_idx];
|
||||||
|
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||||
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
|
gpu_gpair sum = d_node_sums[node_segment * n_features + fidx];
|
||||||
|
gpu_gpair missing = parent_sum - sum;
|
||||||
|
|
||||||
|
bool missing_left;
|
||||||
|
float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain,
|
||||||
|
gpu_param_alias, missing_left);
|
||||||
|
|
||||||
|
float fvalue;
|
||||||
|
if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) {
|
||||||
|
fvalue = d_fidx_min_map[fidx];
|
||||||
|
} else {
|
||||||
|
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_gpair left = missing_left ? scan + missing : scan;
|
||||||
|
gpu_gpair right = parent_sum - left;
|
||||||
|
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
|
||||||
|
right, gpu_param_alias);
|
||||||
|
|
||||||
|
int left_child_idx = n_nodes(depth) + idx * 2;
|
||||||
|
int right_child_idx = n_nodes(depth) + idx * 2 + 1;
|
||||||
|
d_nodes[left_child_idx] =
|
||||||
|
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
|
||||||
|
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
|
||||||
|
|
||||||
|
d_nodes[right_child_idx] =
|
||||||
|
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
|
||||||
|
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
|
||||||
|
});
|
||||||
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::InitFirstNode() {
|
||||||
|
// Build the root node on the CPU and copy to device
|
||||||
|
gpu_gpair sum_gradients =
|
||||||
|
thrust::reduce(device_gpair.tbegin(), device_gpair.tend(),
|
||||||
|
gpu_gpair(0, 0), thrust::plus<gpu_gpair>());
|
||||||
|
|
||||||
|
Node tmp =
|
||||||
|
Node(sum_gradients,
|
||||||
|
CalcGain(param, sum_gradients.grad(), sum_gradients.hess()),
|
||||||
|
CalcWeight(param, sum_gradients.grad(), sum_gradients.hess()));
|
||||||
|
|
||||||
|
thrust::copy_n(&tmp, 1, nodes.tbegin());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::UpdatePosition() {
|
||||||
|
if (is_dense) {
|
||||||
|
this->UpdatePositionDense();
|
||||||
|
} else {
|
||||||
|
this->UpdatePositionSparse();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::UpdatePositionDense() {
|
||||||
|
auto d_position = position.data();
|
||||||
|
Node* d_nodes = nodes.data();
|
||||||
|
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||||
|
auto d_gidx = device_matrix.gidx.data();
|
||||||
|
int n_columns = info->num_col;
|
||||||
|
|
||||||
|
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||||
|
NodeIdT pos = d_position[idx];
|
||||||
|
if (pos < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node node = d_nodes[pos];
|
||||||
|
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
d_position[idx] = -1;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int gidx = d_gidx[idx * n_columns + node.split.findex];
|
||||||
|
|
||||||
|
float fvalue = d_gidx_fvalue_map[gidx];
|
||||||
|
|
||||||
|
if (fvalue <= node.split.fvalue) {
|
||||||
|
d_position[idx] = pos * 2 + 1;
|
||||||
|
} else {
|
||||||
|
d_position[idx] = pos * 2 + 2;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::UpdatePositionSparse() {
|
||||||
|
auto d_position = position.data();
|
||||||
|
auto d_position_tmp = position_tmp.data();
|
||||||
|
Node* d_nodes = nodes.data();
|
||||||
|
auto d_gidx_feature_map = gidx_feature_map.data();
|
||||||
|
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||||
|
auto d_gidx = device_matrix.gidx.data();
|
||||||
|
auto d_ridx = device_matrix.ridx.data();
|
||||||
|
|
||||||
|
// Update missing direction
|
||||||
|
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||||
|
NodeIdT pos = d_position[idx];
|
||||||
|
if (pos < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node node = d_nodes[pos];
|
||||||
|
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
d_position_tmp[idx] = -1;
|
||||||
|
} else if (node.split.missing_left) {
|
||||||
|
d_position_tmp[idx] = pos * 2 + 1;
|
||||||
|
} else {
|
||||||
|
d_position_tmp[idx] = pos * 2 + 2;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update node based on fvalue where exists
|
||||||
|
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||||
|
int ridx = d_ridx[idx];
|
||||||
|
NodeIdT pos = d_position[ridx];
|
||||||
|
if (pos < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node node = d_nodes[pos];
|
||||||
|
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int gidx = d_gidx[idx];
|
||||||
|
int findex = d_gidx_feature_map[gidx];
|
||||||
|
|
||||||
|
if (findex == node.split.findex) {
|
||||||
|
float fvalue = d_gidx_fvalue_map[gidx];
|
||||||
|
|
||||||
|
if (fvalue <= node.split.fvalue) {
|
||||||
|
d_position_tmp[ridx] = pos * 2 + 1;
|
||||||
|
} else {
|
||||||
|
d_position_tmp[ridx] = pos * 2 + 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
position = position_tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::ColSampleTree() {
|
||||||
|
feature_set_tree.resize(info->num_col);
|
||||||
|
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
||||||
|
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::ColSampleLevel() {
|
||||||
|
feature_set_level.resize(feature_set_tree.size());
|
||||||
|
feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel);
|
||||||
|
std::vector<int> h_feature_flags(info->num_col, 0);
|
||||||
|
for (auto fidx : feature_set_level) {
|
||||||
|
h_feature_flags[fidx] = 1;
|
||||||
|
}
|
||||||
|
feature_flags = h_feature_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||||
|
DMatrix& fmat, // NOLINT
|
||||||
|
const RegTree& tree) {
|
||||||
|
if (!initialised) {
|
||||||
|
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
||||||
|
"block. Try setting 'tree_method' "
|
||||||
|
"parameter to 'exact'";
|
||||||
|
info = &fmat.info();
|
||||||
|
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
||||||
|
hmat_.Init(&fmat, param.max_bin);
|
||||||
|
gmat_.cut = &hmat_;
|
||||||
|
gmat_.Init(&fmat);
|
||||||
|
int n_bins = hmat_.row_ptr.back();
|
||||||
|
int n_features = hmat_.row_ptr.size() - 1;
|
||||||
|
|
||||||
|
// Build feature segments
|
||||||
|
std::vector<int> h_feature_segments;
|
||||||
|
for (int node = 0; node < n_nodes_level(param.max_depth - 1); node++) {
|
||||||
|
for (int fidx = 0; fidx < hmat_.row_ptr.size() - 1; fidx++) {
|
||||||
|
h_feature_segments.push_back(hmat_.row_ptr[fidx] + node * n_bins);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h_feature_segments.push_back(n_nodes_level(param.max_depth - 1) * n_bins);
|
||||||
|
|
||||||
|
int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins;
|
||||||
|
|
||||||
|
size_t free_memory = dh::available_memory();
|
||||||
|
ba.allocate(
|
||||||
|
&gidx_feature_map, n_bins, &hist_node_segments,
|
||||||
|
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
|
||||||
|
h_feature_segments.size(), &gain, level_max_bins, &position,
|
||||||
|
gpair.size(), &position_tmp, gpair.size(), &nodes,
|
||||||
|
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
|
||||||
|
&fidx_min_map, hmat_.min_val.size(), &argmax,
|
||||||
|
n_nodes_level(param.max_depth - 1), &node_sums,
|
||||||
|
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
|
||||||
|
level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx,
|
||||||
|
gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist,
|
||||||
|
n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features);
|
||||||
|
|
||||||
|
if (!param.silent) {
|
||||||
|
const int mb_size = 1048576;
|
||||||
|
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
|
||||||
|
<< free_memory / mb_size << " MB on " << dh::device_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct feature map
|
||||||
|
std::vector<int> h_gidx_feature_map(n_bins);
|
||||||
|
for (int row = 0; row < hmat_.row_ptr.size() - 1; row++) {
|
||||||
|
for (int i = hmat_.row_ptr[row]; i < hmat_.row_ptr[row + 1]; i++) {
|
||||||
|
h_gidx_feature_map[i] = row;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gidx_feature_map = h_gidx_feature_map;
|
||||||
|
|
||||||
|
// Construct device matrix
|
||||||
|
device_matrix.Init(gmat_);
|
||||||
|
|
||||||
|
gidx_fvalue_map = hmat_.cut;
|
||||||
|
fidx_min_map = hmat_.min_val;
|
||||||
|
|
||||||
|
thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0,
|
||||||
|
n_bins);
|
||||||
|
|
||||||
|
feature_segments = h_feature_segments;
|
||||||
|
|
||||||
|
hist.Init(n_bins);
|
||||||
|
|
||||||
|
initialised = true;
|
||||||
|
}
|
||||||
|
nodes.fill(Node());
|
||||||
|
position.fill(0);
|
||||||
|
device_gpair = gpair;
|
||||||
|
subsample_gpair(&device_gpair, param.subsample);
|
||||||
|
hist.Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||||
|
DMatrix* p_fmat, RegTree* p_tree) {
|
||||||
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
|
this->InitFirstNode();
|
||||||
|
this->ColSampleTree();
|
||||||
|
|
||||||
|
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||||
|
this->ColSampleLevel();
|
||||||
|
this->BuildHist(depth);
|
||||||
|
this->FindSplit(depth);
|
||||||
|
this->UpdatePosition();
|
||||||
|
}
|
||||||
|
dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param);
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
104
plugin/updater_gpu/src/gpu_hist_builder.cuh
Normal file
104
plugin/updater_gpu/src/gpu_hist_builder.cuh
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2016 Rory mitchell
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <cusparse.h>
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
#include <cub/util_type.cuh> // Need key value pair definition
|
||||||
|
#include <vector>
|
||||||
|
#include "../../src/common/hist_util.h"
|
||||||
|
#include "../../src/tree/param.h"
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
#include "types.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
struct DeviceGMat {
|
||||||
|
dh::dvec<int> gidx;
|
||||||
|
dh::dvec<int> ridx;
|
||||||
|
void Init(const common::GHistIndexMatrix &gmat);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HistBuilder {
|
||||||
|
gpu_gpair *d_hist;
|
||||||
|
int n_bins;
|
||||||
|
__host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins);
|
||||||
|
__device__ void Add(gpu_gpair gpair, int gidx, int nidx) const;
|
||||||
|
__device__ gpu_gpair Get(int gidx, int nidx) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DeviceHist {
|
||||||
|
int n_bins;
|
||||||
|
dh::dvec<gpu_gpair> hist;
|
||||||
|
|
||||||
|
void Init(int max_depth);
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
HistBuilder GetBuilder();
|
||||||
|
|
||||||
|
gpu_gpair *GetLevelPtr(int depth);
|
||||||
|
|
||||||
|
int LevelSize(int depth);
|
||||||
|
};
|
||||||
|
|
||||||
|
class GPUHistBuilder {
|
||||||
|
public:
|
||||||
|
GPUHistBuilder();
|
||||||
|
~GPUHistBuilder();
|
||||||
|
void Init(const TrainParam ¶m);
|
||||||
|
|
||||||
|
void UpdateParam(const TrainParam ¶m) {
|
||||||
|
this->param = param;
|
||||||
|
this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda,
|
||||||
|
param.reg_alpha, param.max_delta_step);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
||||||
|
const RegTree &tree);
|
||||||
|
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||||
|
RegTree *p_tree);
|
||||||
|
void BuildHist(int depth);
|
||||||
|
void FindSplit(int depth);
|
||||||
|
void InitFirstNode();
|
||||||
|
void UpdatePosition();
|
||||||
|
void UpdatePositionDense();
|
||||||
|
void UpdatePositionSparse();
|
||||||
|
void ColSampleTree();
|
||||||
|
void ColSampleLevel();
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
GPUTrainingParam gpu_param;
|
||||||
|
common::HistCutMatrix hmat_;
|
||||||
|
common::GHistIndexMatrix gmat_;
|
||||||
|
MetaInfo *info;
|
||||||
|
bool initialised;
|
||||||
|
bool is_dense;
|
||||||
|
DeviceGMat device_matrix;
|
||||||
|
|
||||||
|
dh::bulk_allocator ba;
|
||||||
|
dh::CubMemory cub_mem;
|
||||||
|
dh::dvec<int> gidx_feature_map;
|
||||||
|
dh::dvec<int> hist_node_segments;
|
||||||
|
dh::dvec<int> feature_segments;
|
||||||
|
dh::dvec<float> gain;
|
||||||
|
dh::dvec<NodeIdT> position;
|
||||||
|
dh::dvec<NodeIdT> position_tmp;
|
||||||
|
dh::dvec<float> gidx_fvalue_map;
|
||||||
|
dh::dvec<float> fidx_min_map;
|
||||||
|
DeviceHist hist;
|
||||||
|
dh::dvec<cub::KeyValuePair<int, float>> argmax;
|
||||||
|
dh::dvec<gpu_gpair> node_sums;
|
||||||
|
dh::dvec<gpu_gpair> hist_scan;
|
||||||
|
dh::dvec<gpu_gpair> device_gpair;
|
||||||
|
dh::dvec<Node> nodes;
|
||||||
|
dh::dvec<int> feature_flags;
|
||||||
|
|
||||||
|
std::vector<int> feature_set_tree;
|
||||||
|
std::vector<int> feature_set_level;
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,52 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2016 Rory mitchell
|
|
||||||
*/
|
|
||||||
#pragma once
|
|
||||||
#include "types.cuh"
|
|
||||||
#include "../../../src/tree/param.h"
|
|
||||||
|
|
||||||
// When we split on a value which has no left neighbour, define its left
|
|
||||||
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
|
|
||||||
// This produces a split value slightly lower than the current instance
|
|
||||||
#define FVALUE_EPS 0.0001
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
|
|
||||||
__device__ __forceinline__ float
|
|
||||||
device_calc_loss_chg(const GPUTrainingParam ¶m, const gpu_gpair &scan,
|
|
||||||
const gpu_gpair &missing, const gpu_gpair &parent_sum,
|
|
||||||
const float &parent_gain, bool missing_left) {
|
|
||||||
gpu_gpair left = scan;
|
|
||||||
|
|
||||||
if (missing_left) {
|
|
||||||
left += missing;
|
|
||||||
}
|
|
||||||
|
|
||||||
gpu_gpair right = parent_sum - left;
|
|
||||||
|
|
||||||
float left_gain = CalcGain(param, left.grad(), left.hess());
|
|
||||||
float right_gain = CalcGain(param, right.grad(), right.hess());
|
|
||||||
return left_gain + right_gain - parent_gain;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ float
|
|
||||||
loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing,
|
|
||||||
const gpu_gpair &parent_sum, const float &parent_gain,
|
|
||||||
const GPUTrainingParam ¶m, bool &missing_left_out) { // NOLINT
|
|
||||||
float missing_left_loss =
|
|
||||||
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
|
|
||||||
float missing_right_loss = device_calc_loss_chg(
|
|
||||||
param, scan, missing, parent_sum, parent_gain, false);
|
|
||||||
|
|
||||||
if (missing_left_loss >= missing_right_loss) {
|
|
||||||
missing_left_out = true;
|
|
||||||
return missing_left_loss;
|
|
||||||
} else {
|
|
||||||
missing_left_out = false;
|
|
||||||
return missing_right_loss;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -1,8 +1,11 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2016 Rory mitchell
|
* Copyright 2016 Rory mitchell
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <cfloat>
|
||||||
#include <tuple> // The linter is not very smart and thinks we need this
|
#include <tuple> // The linter is not very smart and thinks we need this
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -104,8 +107,10 @@ struct GPUTrainingParam {
|
|||||||
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
|
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
|
||||||
float reg_lambda_in, float reg_alpha_in,
|
float reg_lambda_in, float reg_alpha_in,
|
||||||
float max_delta_step_in)
|
float max_delta_step_in)
|
||||||
: min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in),
|
: min_child_weight(min_child_weight_in),
|
||||||
reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {}
|
reg_lambda(reg_lambda_in),
|
||||||
|
reg_alpha(reg_alpha_in),
|
||||||
|
max_delta_step(max_delta_step_in) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Split {
|
struct Split {
|
||||||
@ -123,8 +128,9 @@ struct Split {
|
|||||||
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
|
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
|
||||||
gpu_gpair right_sum_in,
|
gpu_gpair right_sum_in,
|
||||||
const GPUTrainingParam ¶m) {
|
const GPUTrainingParam ¶m) {
|
||||||
if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight &&
|
if (loss_chg_in > loss_chg &&
|
||||||
right_sum_in.hess() > param.min_child_weight) {
|
left_sum_in.hess() >= param.min_child_weight &&
|
||||||
|
right_sum_in.hess() >= param.min_child_weight) {
|
||||||
loss_chg = loss_chg_in;
|
loss_chg = loss_chg_in;
|
||||||
missing_left = missing_left_in;
|
missing_left = missing_left_in;
|
||||||
fvalue = fvalue_in;
|
fvalue = fvalue_in;
|
||||||
@ -135,7 +141,7 @@ struct Split {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Does not check minimum weight
|
// Does not check minimum weight
|
||||||
__device__ void Update(Split &s) { // NOLINT
|
__device__ void Update(Split &s) { // NOLINT
|
||||||
if (s.loss_chg > loss_chg) {
|
if (s.loss_chg > loss_chg) {
|
||||||
loss_chg = s.loss_chg;
|
loss_chg = s.loss_chg;
|
||||||
missing_left = s.missing_left;
|
missing_left = s.missing_left;
|
||||||
@ -146,7 +152,7 @@ struct Split {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void Print() {
|
__host__ __device__ void Print() {
|
||||||
printf("Loss: %1.4f\n", loss_chg);
|
printf("Loss: %1.4f\n", loss_chg);
|
||||||
printf("Missing left: %d\n", missing_left);
|
printf("Missing left: %d\n", missing_left);
|
||||||
printf("fvalue: %1.4f\n", fvalue);
|
printf("fvalue: %1.4f\n", fvalue);
|
||||||
@ -160,7 +166,7 @@ struct Split {
|
|||||||
|
|
||||||
struct split_reduce_op {
|
struct split_reduce_op {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T operator()(T &a, T b) { // NOLINT
|
__device__ __forceinline__ T operator()(T &a, T b) { // NOLINT
|
||||||
b.Update(a);
|
b.Update(a);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
@ -184,5 +190,6 @@ struct Node {
|
|||||||
|
|
||||||
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
|
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,6 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2016 Rory mitchell
|
|
||||||
*/
|
|
||||||
#pragma once
|
|
||||||
#include "types.cuh"
|
|
||||||
#include "loss_functions.cuh"
|
|
||||||
@ -7,30 +7,37 @@
|
|||||||
#include "../../src/common/sync.h"
|
#include "../../src/common/sync.h"
|
||||||
#include "../../src/tree/param.h"
|
#include "../../src/tree/param.h"
|
||||||
#include "gpu_builder.cuh"
|
#include "gpu_builder.cuh"
|
||||||
|
#include "gpu_hist_builder.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
|
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
|
||||||
|
|
||||||
/*! \brief column-wise update to construct a tree */
|
/*! \brief column-wise update to construct a tree */
|
||||||
template <typename TStats> class GPUMaker : public TreeUpdater {
|
template <typename TStats>
|
||||||
|
class GPUMaker : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void
|
void Init(
|
||||||
Init(const std::vector<std::pair<std::string, std::string>> &args) override {
|
const std::vector<std::pair<std::string, std::string>>& args) override {
|
||||||
param.InitAllowUnknown(args);
|
param.InitAllowUnknown(args);
|
||||||
builder.Init(param);
|
builder.Init(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(const std::vector<bst_gpair> &gpair, DMatrix *dmat,
|
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
TStats::CheckInfo(dmat->info());
|
TStats::CheckInfo(dmat->info());
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
builder.UpdateParam(param);
|
builder.UpdateParam(param);
|
||||||
// build tree
|
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
try {
|
||||||
builder.Update(gpair, dmat, trees[i]);
|
// build tree
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
builder.Update(gpair, dmat, trees[i]);
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
param.learning_rate = lr;
|
param.learning_rate = lr;
|
||||||
}
|
}
|
||||||
@ -41,9 +48,45 @@ template <typename TStats> class GPUMaker : public TreeUpdater {
|
|||||||
GPUBuilder builder;
|
GPUBuilder builder;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TStats>
|
||||||
|
class GPUHistMaker : public TreeUpdater {
|
||||||
|
public:
|
||||||
|
void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& args) override {
|
||||||
|
param.InitAllowUnknown(args);
|
||||||
|
builder.Init(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
|
const std::vector<RegTree*>& trees) override {
|
||||||
|
TStats::CheckInfo(dmat->info());
|
||||||
|
// rescale learning rate according to size of trees
|
||||||
|
float lr = param.learning_rate;
|
||||||
|
param.learning_rate = lr / trees.size();
|
||||||
|
builder.UpdateParam(param);
|
||||||
|
// build tree
|
||||||
|
try {
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
builder.Update(gpair, dmat, trees[i]);
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
param.learning_rate = lr;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// training parameter
|
||||||
|
TrainParam param;
|
||||||
|
GPUHistBuilder builder;
|
||||||
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
|
||||||
.describe("Grow tree with GPU.")
|
.describe("Grow tree with GPU.")
|
||||||
.set_body([]() { return new GPUMaker<GradStats>(); });
|
.set_body([]() { return new GPUMaker<GradStats>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
|
.describe("Grow tree with GPU.")
|
||||||
|
.set_body([]() { return new GPUHistMaker<GradStats>(); });
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,137 +0,0 @@
|
|||||||
#pylint: skip-file
|
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import urllib2
|
|
||||||
|
|
||||||
class bcolors:
|
|
||||||
HEADER = '\033[95m'
|
|
||||||
OKBLUE = '\033[94m'
|
|
||||||
OKGREEN = '\033[92m'
|
|
||||||
WARNING = '\033[93m'
|
|
||||||
FAIL = '\033[91m'
|
|
||||||
ENDC = '\033[0m'
|
|
||||||
BOLD = '\033[1m'
|
|
||||||
UNDERLINE = '\033[4m'
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_eval_callback(result):
|
|
||||||
|
|
||||||
def callback(env):
|
|
||||||
result.append(env.evaluation_result_list[-1][1])
|
|
||||||
|
|
||||||
callback.after_iteration = True
|
|
||||||
return callback
|
|
||||||
|
|
||||||
|
|
||||||
def load_adult():
|
|
||||||
path = "../../demo/data/adult.data"
|
|
||||||
|
|
||||||
if(not os.path.isfile(path)):
|
|
||||||
data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data')
|
|
||||||
with open(path,'wb') as output:
|
|
||||||
output.write(data.read())
|
|
||||||
|
|
||||||
train_set = pd.read_csv( path, header=None)
|
|
||||||
|
|
||||||
train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
|
|
||||||
'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
|
|
||||||
'wage_class']
|
|
||||||
train_nomissing = train_set.replace(' ?', np.nan).dropna()
|
|
||||||
for feature in train_nomissing.columns: # Loop through all columns in the dataframe
|
|
||||||
if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings
|
|
||||||
train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer
|
|
||||||
|
|
||||||
y_train = train_nomissing.pop('wage_class')
|
|
||||||
|
|
||||||
return xgb.DMatrix( train_nomissing, label=y_train)
|
|
||||||
|
|
||||||
|
|
||||||
def load_higgs():
|
|
||||||
higgs_path = '../../demo/data/training.csv'
|
|
||||||
dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } )
|
|
||||||
|
|
||||||
#dtrain = dtrain[0:200000,:]
|
|
||||||
label = dtrain[:,32]
|
|
||||||
data = dtrain[:,1:31]
|
|
||||||
weight = dtrain[:,31]
|
|
||||||
|
|
||||||
return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
|
|
||||||
|
|
||||||
def load_dermatology():
|
|
||||||
data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } )
|
|
||||||
sz = data.shape
|
|
||||||
|
|
||||||
X = data[:,0:33]
|
|
||||||
Y = data[:, 34]
|
|
||||||
|
|
||||||
return xgb.DMatrix( X, label=Y)
|
|
||||||
|
|
||||||
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
|
|
||||||
return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
|
|
||||||
|
|
||||||
#Check GPU test evaluation is approximately equal to CPU test evaluation
|
|
||||||
def check_result(cpu_result, gpu_result):
|
|
||||||
for i in range(len(cpu_result)):
|
|
||||||
if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
#Get data
|
|
||||||
data = []
|
|
||||||
params = []
|
|
||||||
data.append(load_higgs())
|
|
||||||
params.append({})
|
|
||||||
|
|
||||||
|
|
||||||
data.append( load_adult())
|
|
||||||
params.append({})
|
|
||||||
|
|
||||||
data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test'))
|
|
||||||
params.append({'objective':'binary:logistic'})
|
|
||||||
|
|
||||||
#if(os.path.isfile("../../demo/data/dermatology.data")):
|
|
||||||
data.append(load_dermatology())
|
|
||||||
params.append({'objective':'multi:softmax', 'num_class': 6})
|
|
||||||
|
|
||||||
num_round = 5
|
|
||||||
|
|
||||||
num_pass = 0
|
|
||||||
num_fail = 0
|
|
||||||
|
|
||||||
test_depth = [ 1, 6, 9, 11, 15 ]
|
|
||||||
#test_depth = [ 1 ]
|
|
||||||
|
|
||||||
for test in range(0, len(data)):
|
|
||||||
for depth in test_depth:
|
|
||||||
xgmat = data[test]
|
|
||||||
cpu_result = []
|
|
||||||
param = params[test]
|
|
||||||
param['max_depth'] = depth
|
|
||||||
param['updater'] = 'grow_colmaker'
|
|
||||||
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)])
|
|
||||||
|
|
||||||
#bst = xgb.train( param, xgmat, 1);
|
|
||||||
#bst.dump_model('reference_model.txt','', True)
|
|
||||||
|
|
||||||
gpu_result = []
|
|
||||||
param['updater'] = 'grow_gpu'
|
|
||||||
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)])
|
|
||||||
|
|
||||||
#bst = xgb.train( param, xgmat, 1);
|
|
||||||
#bst.dump_model('dump.raw.txt','', True)
|
|
||||||
|
|
||||||
if check_result(cpu_result, gpu_result):
|
|
||||||
print(bcolors.OKGREEN + "Pass" + bcolors.ENDC)
|
|
||||||
num_pass = num_pass + 1
|
|
||||||
else:
|
|
||||||
print(bcolors.FAIL + "Fail" + bcolors.ENDC)
|
|
||||||
num_fail = num_fail + 1
|
|
||||||
|
|
||||||
print("cpu rmse: "+str(cpu_result))
|
|
||||||
print("gpu rmse: "+str(gpu_result))
|
|
||||||
|
|
||||||
print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed")
|
|
||||||
221
plugin/updater_gpu/test/test.py
Normal file
221
plugin/updater_gpu/test/test.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
#pylint: skip-file
|
||||||
|
import sys
|
||||||
|
sys.path.append("../../tests/python")
|
||||||
|
import xgboost as xgb
|
||||||
|
import testing as tm
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
dpath = '../../demo/data/'
|
||||||
|
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
|
||||||
|
|
||||||
|
class TestGPU(unittest.TestCase):
|
||||||
|
def test_grow_gpu(self):
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
try:
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
except:
|
||||||
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
|
ag_param = {'max_depth': 2,
|
||||||
|
'tree_method': 'exact',
|
||||||
|
'nthread': 1,
|
||||||
|
'eta': 1,
|
||||||
|
'silent': 1,
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
ag_param2 = {'max_depth': 2,
|
||||||
|
'updater': 'grow_gpu',
|
||||||
|
'eta': 1,
|
||||||
|
'silent': 1,
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
ag_res = {}
|
||||||
|
ag_res2 = {}
|
||||||
|
|
||||||
|
num_rounds = 10
|
||||||
|
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||||
|
evals_result=ag_res)
|
||||||
|
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||||
|
evals_result=ag_res2)
|
||||||
|
assert ag_res['train']['auc'] == ag_res2['train']['auc']
|
||||||
|
assert ag_res['test']['auc'] == ag_res2['test']['auc']
|
||||||
|
|
||||||
|
digits = load_digits(2)
|
||||||
|
X = digits['data']
|
||||||
|
y = digits['target']
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||||
|
dtrain = xgb.DMatrix(X_train, y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test, y_test)
|
||||||
|
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu',
|
||||||
|
'max_depth': 3,
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
|
||||||
|
evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert self.non_decreasing(res['test']['auc'])
|
||||||
|
|
||||||
|
# fail-safe test for dense data
|
||||||
|
from sklearn.datasets import load_svmlight_file
|
||||||
|
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
|
||||||
|
X2 = X2.toarray()
|
||||||
|
dtrain2 = xgb.DMatrix(X2, label=y2)
|
||||||
|
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu',
|
||||||
|
'max_depth': 2,
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
|
||||||
|
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
for j in range(X2.shape[1]):
|
||||||
|
for i in rng.choice(X2.shape[0], size=10, replace=False):
|
||||||
|
X2[i, j] = 2
|
||||||
|
|
||||||
|
dtrain3 = xgb.DMatrix(X2, label=y2)
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
|
||||||
|
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
for j in range(X2.shape[1]):
|
||||||
|
for i in np.random.choice(X2.shape[0], size=10, replace=False):
|
||||||
|
X2[i, j] = 3
|
||||||
|
|
||||||
|
dtrain4 = xgb.DMatrix(X2, label=y2)
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def test_grow_gpu_hist(self):
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
try:
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
except:
|
||||||
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
|
# regression test --- hist must be same as exact on all-categorial data
|
||||||
|
ag_param = {'max_depth': 2,
|
||||||
|
'tree_method': 'exact',
|
||||||
|
'nthread': 1,
|
||||||
|
'eta': 1,
|
||||||
|
'silent': 1,
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
ag_param2 = {'max_depth': 2,
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'eta': 1,
|
||||||
|
'silent': 1,
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
ag_res = {}
|
||||||
|
ag_res2 = {}
|
||||||
|
|
||||||
|
num_rounds = 10
|
||||||
|
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||||
|
evals_result=ag_res)
|
||||||
|
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||||
|
evals_result=ag_res2)
|
||||||
|
assert ag_res['train']['auc'] == ag_res2['train']['auc']
|
||||||
|
assert ag_res['test']['auc'] == ag_res2['test']['auc']
|
||||||
|
|
||||||
|
digits = load_digits(2)
|
||||||
|
X = digits['data']
|
||||||
|
y = digits['target']
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||||
|
dtrain = xgb.DMatrix(X_train, y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test, y_test)
|
||||||
|
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'max_depth': 3,
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
|
||||||
|
evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert self.non_decreasing(res['test']['auc'])
|
||||||
|
|
||||||
|
# fail-safe test for dense data
|
||||||
|
from sklearn.datasets import load_svmlight_file
|
||||||
|
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
|
||||||
|
X2 = X2.toarray()
|
||||||
|
dtrain2 = xgb.DMatrix(X2, label=y2)
|
||||||
|
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'grow_policy': 'depthwise',
|
||||||
|
'max_depth': 2,
|
||||||
|
'eval_metric': 'auc'}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
|
||||||
|
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
for j in range(X2.shape[1]):
|
||||||
|
for i in rng.choice(X2.shape[0], size=10, replace=False):
|
||||||
|
X2[i, j] = 2
|
||||||
|
|
||||||
|
dtrain3 = xgb.DMatrix(X2, label=y2)
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
|
||||||
|
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
for j in range(X2.shape[1]):
|
||||||
|
for i in np.random.choice(X2.shape[0], size=10, replace=False):
|
||||||
|
X2[i, j] = 3
|
||||||
|
|
||||||
|
dtrain4 = xgb.DMatrix(X2, label=y2)
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
# fail-safe test for max_bin=2
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'max_depth': 2,
|
||||||
|
'eval_metric': 'auc',
|
||||||
|
'max_bin': 2}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
# subsampling
|
||||||
|
param = {'objective': 'binary:logistic',
|
||||||
|
'updater': 'grow_gpu_hist',
|
||||||
|
'max_depth': 3,
|
||||||
|
'eval_metric': 'auc',
|
||||||
|
'colsample_bytree': 0.5,
|
||||||
|
'colsample_bylevel': 0.5,
|
||||||
|
'subsample': 0.5
|
||||||
|
}
|
||||||
|
res = {}
|
||||||
|
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
|
||||||
|
assert self.non_decreasing(res['train']['auc'])
|
||||||
|
assert res['train']['auc'][0] >= 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def non_decreasing(self, L):
|
||||||
|
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||||
@ -79,6 +79,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
bool refresh_leaf;
|
bool refresh_leaf;
|
||||||
// auxiliary data structure
|
// auxiliary data structure
|
||||||
std::vector<int> monotone_constraints;
|
std::vector<int> monotone_constraints;
|
||||||
|
// gpu to use for single gpu algorithms
|
||||||
|
int gpu_id;
|
||||||
// declare the parameters
|
// declare the parameters
|
||||||
DMLC_DECLARE_PARAMETER(TrainParam) {
|
DMLC_DECLARE_PARAMETER(TrainParam) {
|
||||||
DMLC_DECLARE_FIELD(learning_rate)
|
DMLC_DECLARE_FIELD(learning_rate)
|
||||||
@ -186,6 +188,10 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(monotone_constraints)
|
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||||
.set_default(std::vector<int>())
|
.set_default(std::vector<int>())
|
||||||
.describe("Constraint of variable monotonicity");
|
.describe("Constraint of variable monotonicity");
|
||||||
|
DMLC_DECLARE_FIELD(gpu_id)
|
||||||
|
.set_lower_bound(0)
|
||||||
|
.set_default(0)
|
||||||
|
.describe("gpu to use for single gpu algorithms");
|
||||||
// add alias of parameters
|
// add alias of parameters
|
||||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user