[GPU-Plugin] (#2227)

* Add fast histogram algorithm
* Fix Linux build
* Add 'gpu_id' parameter
This commit is contained in:
Rory Mitchell 2017-04-26 11:37:10 +12:00 committed by Tianqi Chen
parent d281c6aafa
commit 8ab5d4611c
25 changed files with 1318 additions and 492 deletions

View File

@ -97,7 +97,7 @@ if(PLUGIN_UPDATER_GPU)
#Find cub
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
include_directories(${CUB_DIRECTORY})
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
endif()
@ -105,8 +105,9 @@ if(PLUGIN_UPDATER_GPU)
plugin/updater_gpu/src/updater_gpu.cc
)
find_package(CUDA QUIET REQUIRED)
file(GLOB_RECURSE CUDA_SOURCES "plugin/updater_gpu/src/*")
cuda_add_library(updater_gpu STATIC
plugin/updater_gpu/src/gpu_builder.cu
${CUDA_SOURCES}
)
endif()

View File

@ -63,3 +63,5 @@ List of Contributors
* [Alex Bain](https://github.com/convexquad)
* [Baltazar Bieniek](https://github.com/bbieniek)
* [Adam Pocock](https://github.com/Craigacp)
* [Rory Mitchell](https://github.com/RAMitchell)
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration

View File

@ -21,6 +21,7 @@ The same code runs on major distributed environment (Hadoop, SGE, MPI) and can s
What's New
----------
* [XGBoost GPU support with fast histogram algorithm](https://github.com/dmlc/xgboost/tree/master/plugin/updater_gpu)
* [XGBoost4J: Portable Distributed XGboost in Spark, Flink and Dataflow](http://dmlc.ml/2016/03/14/xgboost4j-portable-distributed-xgboost-in-spark-flink-and-dataflow.html), see [JVM-Package](https://github.com/dmlc/xgboost/tree/master/jvm-packages)
* [Story and Lessons Behind the Evolution of XGBoost](http://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html)
* [Tutorial: Distributed XGBoost on AWS with YARN](https://xgboost.readthedocs.io/en/latest/tutorials/aws_yarn.html)

View File

@ -48,7 +48,7 @@
#define XGBOOST_ALIGNAS(X)
#endif
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && !defined(__CUDACC__)
#include <parallel/algorithm>
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))

View File

@ -1,29 +1,25 @@
# CUDA Accelerated Tree Construction Algorithm
## Benchmarks
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks
# CUDA Accelerated Tree Construction Algorithms
This plugin adds GPU accelerated tree construction algorithms to XGBoost.
## Usage
Specify the updater parameter as 'grow_gpu'.
Specify the 'updater' parameter as one of the following algorithms.
updater | Description
--- | ---
grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist'
grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate.
All algorithms currently use only a single GPU. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
This plugin currently works with the CLI version and python version.
Python example:
```python
param['gpu_id'] = 1
param['updater'] = 'grow_gpu'
```
## Benchmarks
## Memory usage
Device memory usage can be calculated as approximately:
```
bytes = (10 x n_rows) + (40 x n_rows x n_columns x column_density) + (64 x max_nodes) + (76 x max_nodes_level x n_columns)
```
The maximum number of nodes needed for a given tree depth d is 2<sup>d+1</sup> - 1. The maximum number of nodes on any given level is 2<sup>d</sup>.
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks of the 'grow_gpu' updater.
Data is stored in a sparse format. For example, missing values produced by one hot encoding are not stored. If a one hot encoding separates a categorical variable into 5 columns the density of these columns is 1/5 = 0.2.
A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset.
## Dependencies
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
@ -58,6 +54,15 @@ On Windows cmake will generate an xgboost.sln solution file in the build directo
The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm.
## Changelog
##### 2017/4/25
* Add fast histogram algorithm
* Fix Linux build
* Add 'gpu_id' parameter
## References
[Mitchell, Rory, and Eibe Frank. Accelerating the XGBoost algorithm using GPU computing. No. e2911v1. PeerJ Preprints, 2017.](https://peerj.com/preprints/2911/)
## Author
Rory Mitchell

View File

@ -0,0 +1,32 @@
#pylint: skip-file
import xgboost as xgb
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import time
n = 1000000
num_rounds = 100
X,y = make_classification(n, n_features=50, random_state=7)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'tree_method': 'exact',
'updater': 'grow_gpu_hist',
'max_depth': 8,
'silent': 1,
'eval_metric': 'auc'}
res = {}
tmp = time.time()
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
print ("GPU: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
param['updater'] = 'grow_fast_histmaker'
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], evals_result=res)
print ("CPU: %s seconds" % (str(time.time() - tmp)))

View File

@ -1,65 +0,0 @@
#!/usr/bin/pytho#!/usr/bin/python
#pylint: skip-file
# this is the example script to use xgboost to train
import numpy as np
import xgboost as xgb
import time
# path to where the data lies
dpath = '../../demo/data'
# load in training data, directly use numpy
dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } )
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
test_size = len(dtrain)
print(len(dtrain))
print ('finish loading from csv ')
label = dtrain[:,32]
data = dtrain[:,1:31]
# rescale weight to make it same as test set
weight = dtrain[:,31] * float(test_size) / len(label)
sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 )
sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 )
# print weight statistics
print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos ))
# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value
xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
# setup parameters for xgboost
param = {}
# use logistic regression loss
param['objective'] = 'binary:logitraw'
# scale weight of positive examples
param['scale_pos_weight'] = sum_wneg/sum_wpos
param['bst:eta'] = 0.1
param['max_depth'] = 15
param['eval_metric'] = 'auc'
param['nthread'] = 16
plst = param.items()+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
num_round = 10
print ("training xgboost")
threads = [16]
for i in threads:
param['nthread'] = i
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp)))
print ("training xgboost - gpu tree construction")
param['updater'] = 'grow_gpu'
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp)))
print ('finish training')

View File

@ -0,0 +1,153 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
// When we split on a value which has no left neighbour, define its left
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
// This produces a split value slightly lower than the current instance
#define FVALUE_EPS 0.0001
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
const gpu_gpair& scan,
const gpu_gpair& missing,
const gpu_gpair& parent_sum,
const float& parent_gain,
bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;
}
gpu_gpair right = parent_sum - left;
float left_gain = CalcGain(param, left.grad(), left.hess());
float right_gain = CalcGain(param, right.grad(), right.hess());
return left_gain + right_gain - parent_gain;
}
__device__ float inline loss_chg_missing(const gpu_gpair& scan,
const gpu_gpair& missing,
const gpu_gpair& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// Total number of nodes in tree, given depth
__host__ __device__ inline int n_nodes(int depth) {
return (1 << (depth + 1)) - 1;
}
// Number of nodes at this level of the tree
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
inline void flag_nodes(const thrust::host_vector<Node>& nodes,
std::vector<NodeType>* node_flags, int nid,
NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node& n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
inline void dense2sparse_tree(RegTree* p_tree,
thrust::device_ptr<Node> nodes_begin,
thrust::device_ptr<Node> nodes_end,
const TrainParam& param) {
RegTree & tree = *p_tree;
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node& n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess();
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess();
nid++;
}
}
}
// Set gradient pair to 0 with p = 1 - subsample
inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
if (subsample == 1.0) {
return;
}
dh::dvec<gpu_gpair>& gpair = *p_gpair;
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(gpair.size(), [=] __device__(int i) {
if (!rng(i)) {
d_gpair[i] = gpu_gpair();
}
});
}
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
int n = colsample * features.size();
CHECK_GT(n, 0);
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
return features;
}
} // namespace tree
} // namespace xgboost

View File

@ -5,14 +5,15 @@
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <thrust/device_vector.h>
#include <thrust/random.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/random.h>
#include <algorithm>
#include <ctime>
#include <sstream>
#include <string>
#include <vector>
#include "cusparse_v2.h"
#ifdef _WIN32
#include <windows.h>
@ -30,7 +31,8 @@ namespace dh {
#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__)
cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
int line) {
if (code != cudaSuccess) {
std::stringstream ss;
ss << file << "(" << line << ")";
@ -41,6 +43,20 @@ cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
return code;
}
#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__)
inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status,
const char *file, int line) {
if (status != CUSPARSE_STATUS_SUCCESS) {
std::stringstream ss;
ss << "cusparse error: " << file << "(" << line << ")";
std::string error_text;
ss >> error_text;
throw error_text;
}
return status;
}
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
@ -49,8 +65,7 @@ inline void gpuAssert(cudaError_t code, const char *file, int line,
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort)
exit(code);
if (abort) exit(code);
}
}
@ -224,21 +239,22 @@ class range {
iterator end_;
};
template <typename T> __device__ range grid_stride_range(T begin, T end) {
template <typename T>
__device__ range grid_stride_range(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x;
range r(begin, end);
r.step(gridDim.x * blockDim.x);
return r;
}
template <typename T> __device__ range block_stride_range(T begin, T end) {
template <typename T>
__device__ range block_stride_range(T begin, T end) {
begin += threadIdx.x;
range r(begin, end);
r.step(blockDim.x);
return r;
}
// Threadblock iterates over range, filling with value
template <typename IterT, typename ValueT>
__device__ void block_fill(IterT begin, size_t n, ValueT value) {
@ -253,7 +269,8 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) {
class bulk_allocator;
template <typename T> class dvec {
template <typename T>
class dvec {
friend bulk_allocator;
private:
@ -302,7 +319,8 @@ template <typename T> class dvec {
return thrust::device_pointer_cast(_ptr + size());
}
template <typename T2> dvec &operator=(const std::vector<T2> &other) {
template <typename T2>
dvec &operator=(const std::vector<T2> &other) {
if (other.size() != size()) {
throw std::runtime_error(
"Cannot copy assign vector to dvec, sizes are different");
@ -331,7 +349,8 @@ class bulk_allocator {
const size_t align = 256;
template <typename SizeT> size_t align_round_up(SizeT n) {
template <typename SizeT>
size_t align_round_up(SizeT n) {
if (n % align == 0) {
return n;
} else {
@ -366,14 +385,15 @@ class bulk_allocator {
bulk_allocator() : _size(0), d_ptr(NULL) {}
~bulk_allocator() {
if (!d_ptr == NULL) {
if (!(d_ptr == nullptr)) {
safe_cuda(cudaFree(d_ptr));
}
}
size_t size() { return _size; }
template <typename... Args> void allocate(Args... args) {
template <typename... Args>
void allocate(Args... args) {
if (d_ptr != NULL) {
throw std::runtime_error("Bulk allocator already allocated");
}
@ -393,14 +413,19 @@ struct CubMemory {
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
~CubMemory() {
~CubMemory() { Free(); }
void Free() {
if (d_temp_storage != NULL) {
safe_cuda(cudaFree(d_temp_storage));
}
}
void Allocate() {
safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes));
void LazyAllocate(size_t n_bytes) {
if (n_bytes > temp_storage_bytes) {
Free();
safe_cuda(cudaMalloc(&d_temp_storage, n_bytes));
temp_storage_bytes = n_bytes;
}
}
bool IsAllocated() { return d_temp_storage != NULL; }
@ -453,36 +478,47 @@ void print(char *label, const thrust::device_vector<T> &v,
std::cout << "\n";
}
template <typename T1, typename T2> T1 div_round_up(const T1 a, const T2 b) {
template <typename T1, typename T2>
T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
template <typename T>
thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr);
}
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
template <typename T>
T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T> size_t size_bytes(const thrust::device_vector<T> &v) {
template <typename T>
const T *raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T>
size_t size_bytes(const thrust::device_vector<T> &v) {
return sizeof(T) * v.size();
}
/*
* Kernel launcher
*/
template <typename L> __global__ void launch_n_kernel(size_t n, L lambda) {
template <typename L>
__global__ void launch_n_kernel(size_t n, L lambda) {
for (auto i : grid_stride_range(static_cast<size_t>(0), n)) {
lambda(i);
}
}
template <typename L, int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256>
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void launch_n(size_t n, L lambda) {
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(n, lambda);
#endif
}
/*
@ -504,5 +540,4 @@ struct BernoulliRng {
}
};
} // namespace dh

View File

@ -9,7 +9,7 @@
#include "find_split_multiscan.cuh"
#include "find_split_sorting.cuh"
#include "gpu_data.cuh"
#include "types_functions.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {

View File

@ -6,7 +6,8 @@
#include <xgboost/base.h>
#include "device_helpers.cuh"
#include "gpu_data.cuh"
#include "types_functions.cuh"
#include "types.cuh"
#include "common.cuh"
namespace xgboost {
namespace tree {

View File

@ -5,7 +5,8 @@
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include "device_helpers.cuh"
#include "types_functions.cuh"
#include "types.cuh"
#include "common.cuh"
namespace xgboost {
namespace tree {

View File

@ -0,0 +1,15 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "../../../src/tree/param.h"
#include "../../../src/common/random.h"
namespace xgboost {
namespace tree {
} // namespace tree
} // namespace xgboost

View File

@ -1,6 +1,7 @@
/*!
* Copyright 2016 Rory mitchell
*/
#include "gpu_builder.cuh"
#include <cub/cub.cuh>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
@ -11,24 +12,28 @@
#include <thrust/host_vector.h>
#include <thrust/sequence.h>
#include <algorithm>
#include <random>
#include <numeric>
#include <random>
#include <vector>
#include "../../../src/common/random.h"
#include "common.cuh"
#include "device_helpers.cuh"
#include "find_split.cuh"
#include "gpu_builder.cuh"
#include "types_functions.cuh"
#include "gpu_data.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
void GPUBuilder::Init(const TrainParam& param_in) {
param = param_in;
CHECK(param.max_depth < 16) << "Tree depth too large.";
dh::safe_cuda(cudaSetDevice(param.gpu_id));
if (!param.silent) {
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
}
}
GPUBuilder::~GPUBuilder() { delete gpu_data; }
@ -128,8 +133,8 @@ void GPUBuilder::Sort(int level) {
}
void GPUBuilder::ColsampleTree() {
unsigned n = static_cast<unsigned>(
param.colsample_bytree * gpu_data->n_features);
unsigned n =
static_cast<unsigned>(param.colsample_bytree * gpu_data->n_features);
CHECK_GT(n, 0);
feature_set_tree.resize(gpu_data->n_features);
@ -140,63 +145,40 @@ void GPUBuilder::ColsampleTree() {
void GPUBuilder::Update(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree* p_tree) {
try {
dh::Timer update;
dh::Timer t;
this->InitData(gpair, *p_fmat, *p_tree);
t.printElapsed("init data");
this->InitFirstNode();
this->ColsampleTree();
for (int level = 0; level < param.max_depth; level++) {
bool use_multiscan_algorithm = level < multiscan_levels;
t.reset();
if (level > 0) {
dh::Timer update_node;
this->UpdateNodeId(level);
update_node.printElapsed("node");
}
if (level > 0 && !use_multiscan_algorithm) {
dh::Timer s;
this->Sort(level);
s.printElapsed("sort");
}
dh::Timer split;
find_split(gpu_data, param, level, use_multiscan_algorithm,
feature_set_tree, &feature_set_level);
split.printElapsed("split");
t.printElapsed("level");
}
this->CopyTree(*p_tree);
update.printElapsed("update");
} catch (thrust::system_error &e) {
std::cerr << "CUDA error: " << e.what() << std::endl;
exit(-1);
} catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
exit(-1);
} catch (...) {
std::cerr << "Unknown exception." << std::endl;
exit(-1);
}
dense2sparse_tree(p_tree, gpu_data->nodes.tbegin(), gpu_data->nodes.tend(),
param);
}
void GPUBuilder::InitData(const std::vector<bst_gpair>& gpair, DMatrix& fmat,
const RegTree& tree) {
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block";
CHECK(fmat.SingleColBlock()) << "grow_gpu: must have single column block. "
"Try setting 'tree_method' parameter to "
"'exact'";
if (gpu_data->IsAllocated()) {
gpu_data->Reset(gpair, param.subsample);
return;
}
dh::Timer t;
MetaInfo info = fmat.info();
std::vector<int> foffsets;
@ -227,12 +209,8 @@ void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
}
}
t.printElapsed("dmatrix");
t.reset();
gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair,
info.num_row, info.num_col, param.max_depth, param);
t.printElapsed("gpu init");
}
void GPUBuilder::InitFirstNode() {
@ -248,60 +226,5 @@ void GPUBuilder::InitFirstNode() {
thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin());
}
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
void flag_nodes(const thrust::host_vector<Node> &nodes,
std::vector<NodeType> *node_flags, int nid, NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node &n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
void GPUBuilder::CopyTree(RegTree &tree) {
std::vector<Node> h_nodes = gpu_data->nodes.as_vector();
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node &n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess();
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess();
nid++;
}
}
}
} // namespace tree
} // namespace xgboost

View File

@ -19,10 +19,7 @@ class GPUBuilder {
void Init(const TrainParam &param);
~GPUBuilder();
void UpdateParam(const TrainParam &param)
{
this->param = param;
}
void UpdateParam(const TrainParam &param) { this->param = param; }
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree);
@ -33,10 +30,8 @@ class GPUBuilder {
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
float GetSubsamplingRate(MetaInfo info);
void Sort(int level);
void InitFirstNode();
void CopyTree(RegTree &tree); // NOLINT
void ColsampleTree();
TrainParam param;

View File

@ -2,13 +2,14 @@
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cub/cub.cuh>
#include <xgboost/logging.h>
#include <thrust/sequence.h>
#include <xgboost/logging.h>
#include <cub/cub.cuh>
#include <vector>
#include "device_helpers.cuh"
#include "../../src/tree/param.h"
#include "types_functions.cuh"
#include "common.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
@ -67,9 +68,8 @@ struct GPUData {
cub::DoubleBuffer<int> db_value;
cub::DeviceSegmentedRadixSort::SortPairs(
cub_mem.data(), cub_mem_size, db_key,
db_value, in_fvalues.size(), n_features,
foffsets.data(), foffsets.data() + 1);
cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(),
n_features, foffsets.data(), foffsets.data() + 1);
// Allocate memory
size_t free_memory = dh::available_memory();
@ -100,7 +100,6 @@ struct GPUData {
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step);
allocated = true;
this->Reset(in_gpair, param_in.subsample);
@ -109,33 +108,16 @@ struct GPUData {
thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()),
fvalues.tbegin(), node_id.tbegin()));
dh::safe_cuda(cudaGetLastError());
}
~GPUData() {}
// Set gradient pair to 0 with p = 1 - subsample
void MarkSubsample(float subsample) {
if (subsample == 1.0) {
return;
}
auto d_gpair = gpair.data();
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
dh::launch_n(n_instances, [=] __device__(int i) {
if (!rng(i)) {
d_gpair[i] = gpu_gpair();
}
});
}
// Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair, float subsample) {
CHECK(allocated);
gpair = in_gpair;
this->MarkSubsample(subsample);
subsample_gpair(&gpair, subsample);
instance_id = instance_id_cached;
fvalues = fvalues_cached;
nodes.fill(Node());
@ -153,7 +135,6 @@ struct GPUData {
auto d_instance_id = instance_id.data();
dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) {
// Item item = d_items[i];
d_node_id[i] = d_node_id_instance[d_instance_id[i]];
});
}

View File

@ -0,0 +1,560 @@
/*!
* Copyright 2017 Rory mitchell
*/
#include <thrust/binary_search.h>
#include <thrust/count.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <cub/cub.cuh>
#include <numeric>
#include <algorithm>
#include <functional>
#include "common.cuh"
#include "device_helpers.cuh"
#include "gpu_hist_builder.cuh"
namespace xgboost {
namespace tree {
void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) {
CHECK_EQ(gidx.size(), gmat.index.size())
<< "gidx must be externally allocated";
CHECK_EQ(ridx.size(), gmat.index.size())
<< "ridx must be externally allocated";
gidx = gmat.index;
thrust::device_vector<int> row_ptr = gmat.row_ptr;
auto counting = thrust::make_counting_iterator(0);
thrust::upper_bound(row_ptr.begin(), row_ptr.end(), counting,
counting + gidx.size(), ridx.tbegin());
thrust::transform(ridx.tbegin(), ridx.tend(), ridx.tbegin(),
[=] __device__(int val) { return val - 1; });
}
void DeviceHist::Init(int n_bins_in) {
this->n_bins = n_bins_in;
CHECK(!hist.empty()) << "DeviceHist must be externally allocated";
}
void DeviceHist::Reset() { hist.fill(gpu_gpair()); }
gpu_gpair* DeviceHist::GetLevelPtr(int depth) {
return hist.data() + n_nodes(depth - 1) * n_bins;
}
int DeviceHist::LevelSize(int depth) { return n_bins * n_nodes_level(depth); }
HistBuilder DeviceHist::GetBuilder() {
return HistBuilder(hist.data(), n_bins);
}
HistBuilder::HistBuilder(gpu_gpair* ptr, int n_bins)
: d_hist(ptr), n_bins(n_bins) {}
__device__ void HistBuilder::Add(gpu_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx;
atomicAdd(&(d_hist[hist_idx]._grad), gpair._grad);
atomicAdd(&(d_hist[hist_idx]._hess), gpair._hess);
}
__device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx];
}
GPUHistBuilder::GPUHistBuilder() {}
GPUHistBuilder::~GPUHistBuilder() {}
void GPUHistBuilder::Init(const TrainParam& param) {
CHECK(param.max_depth < 16) << "Tree depth too large.";
this->param = param;
initialised = false;
is_dense = false;
dh::safe_cuda(cudaSetDevice(param.gpu_id));
if (!param.silent) {
LOG(CONSOLE) << "Device: [" << param.gpu_id << "] " << dh::device_name();
}
}
template <typename ReductionOpT>
struct ReduceBySegmentOp {
ReductionOpT op;
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op)
: op(op) {}
template <typename KeyValuePairT>
__host__ __device__ __forceinline__ KeyValuePairT
operator()(const KeyValuePairT& first, const KeyValuePairT& second) {
KeyValuePairT retval;
retval.key = second.key;
retval.value =
first.key != second.key ? second.value : op(first.value, second.value);
return retval;
}
};
template <int ITEMS_PER_THREAD, int BLOCK_THREADS>
__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx,
NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins,
int depth, int n) {
typedef cub::KeyValuePair<int, gpu_gpair> OffsetValuePairT;
typedef cub::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoadT;
typedef cub::BlockRadixSort<int, BLOCK_THREADS, ITEMS_PER_THREAD, int>
BlockRadixSortT;
typedef cub::BlockDiscontinuity<int, BLOCK_THREADS> BlockDiscontinuityKeysT;
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
typedef cub::BlockScan<OffsetValuePairT, BLOCK_THREADS,
cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
union TempStorage {
typename BlockLoadT::TempStorage load;
typename BlockRadixSortT::TempStorage sort;
typename BlockScanT::TempStorage scan;
typename BlockDiscontinuityKeysT::TempStorage disc;
};
__shared__ TempStorage temp_storage;
int ridx[ITEMS_PER_THREAD];
int gidx[ITEMS_PER_THREAD];
const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS;
int block_offset = blockIdx.x * TILE_SIZE;
BlockLoadT(temp_storage.load)
.Load(d_ridx + block_offset, ridx, n - block_offset, -1);
BlockLoadT(temp_storage.load)
.Load(d_gidx + block_offset, gidx, n - block_offset, -1);
int hist_idx[ITEMS_PER_THREAD];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
hist_idx[ITEM] =
(d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM];
} else {
hist_idx[ITEM] = -1;
}
}
__syncthreads();
BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx);
OffsetValuePairT kv[ITEMS_PER_THREAD];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
kv[ITEM].key = hist_idx[ITEM];
if (ridx[ITEM] > -1) {
kv[ITEM].value = d_gpair[ridx[ITEM]];
}
}
__syncthreads();
// Scan
BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT());
__syncthreads();
int flags[ITEMS_PER_THREAD];
BlockDiscontinuityKeysT(temp_storage.disc)
.FlagTails(flags, hist_idx, cub::Inequality());
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
if (flags[ITEM]) {
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad);
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess);
}
}
}
}
void GPUHistBuilder::BuildHist(int depth) {
auto d_ridx = device_matrix.ridx.data();
auto d_gidx = device_matrix.gidx.data();
auto d_position = position.data();
auto d_gpair = device_gpair.data();
auto hist_builder = hist.GetBuilder();
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
int ridx = d_ridx[idx];
int pos = d_position[ridx];
// Only increment even nodes
if (pos < 0 || pos % 2 == 1) return;
int gidx = d_gidx[idx];
gpu_gpair gpair = d_gpair[ridx];
hist_builder.Add(gpair, gidx, pos);
});
dh::safe_cuda(cudaDeviceSynchronize());
// Subtraction trick
int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins;
if (n_sub_bins > 0) {
dh::launch_n(n_sub_bins, [=] __device__(int idx) {
int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2);
int gidx = idx % hist_builder.n_bins;
gpu_gpair parent = hist_builder.Get(gidx, nidx / 2);
gpu_gpair right = hist_builder.Get(gidx, nidx + 1);
hist_builder.Add(parent - right, gidx, nidx);
});
}
dh::safe_cuda(cudaDeviceSynchronize());
}
void GPUHistBuilder::FindSplit(int depth) {
auto counting = thrust::make_counting_iterator(0);
auto d_gidx_feature_map = gidx_feature_map.data();
int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1;
auto feature_boundary = [=] __device__(int idx_a, int idx_b) {
int gidx_a = idx_a % n_bins;
int gidx_b = idx_b % n_bins;
return d_gidx_feature_map[gidx_a] == d_gidx_feature_map[gidx_b];
}; // NOLINT
// Reduce node sums
{
size_t temp_storage_bytes;
cub::DeviceSegmentedReduce::Reduce(
nullptr, temp_storage_bytes, hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::Reduce(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
hist.GetLevelPtr(depth), node_sums.data(),
n_nodes_level(depth) * n_features, feature_segments.data(),
feature_segments.data() + 1, cub::Sum(), gpu_gpair());
}
// Scan
thrust::exclusive_scan_by_key(
counting, counting + hist.LevelSize(depth),
thrust::device_pointer_cast(hist.GetLevelPtr(depth)), hist_scan.tbegin(),
gpu_gpair(), feature_boundary);
// Calculate gain
auto d_gain = gain.data();
auto d_nodes = nodes.data();
auto d_node_sums = node_sums.data();
auto d_hist_scan = hist_scan.data();
GPUTrainingParam gpu_param_alias =
gpu_param; // Must be local variable to be used in device lambda
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
auto d_feature_flags = feature_flags.data();
dh::launch_n(hist.LevelSize(depth), [=] __device__(int idx) {
int node_segment = idx / n_bins;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
int gidx = idx % n_bins;
int findex = d_gidx_feature_map[gidx];
// colsample
if (colsample && d_feature_flags[d_gidx_feature_map[gidx]] == 0) {
d_gain[idx] = 0;
} else {
gpu_gpair scan = d_hist_scan[idx];
gpu_gpair sum = d_node_sums[node_segment * n_features + findex];
gpu_gpair missing = parent_sum - sum;
bool missing_left;
d_gain[idx] = loss_chg_missing(scan, missing, parent_sum, parent_gain,
gpu_param_alias, missing_left);
}
});
dh::safe_cuda(cudaDeviceSynchronize());
// Find best gain
{
size_t temp_storage_bytes;
cub::DeviceSegmentedReduce::ArgMax(nullptr, temp_storage_bytes, gain.data(),
argmax.data(), n_nodes_level(depth),
hist_node_segments.data(),
hist_node_segments.data() + 1);
cub_mem.LazyAllocate(temp_storage_bytes);
cub::DeviceSegmentedReduce::ArgMax(
cub_mem.d_temp_storage, cub_mem.temp_storage_bytes, gain.data(),
argmax.data(), n_nodes_level(depth), hist_node_segments.data(),
hist_node_segments.data() + 1);
}
auto d_argmax = argmax.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_fidx_min_map = fidx_min_map.data();
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
int max_idx = n_bins * idx + d_argmax[idx].key;
int gidx = max_idx % n_bins;
int fidx = d_gidx_feature_map[gidx];
int node_segment = max_idx / n_bins;
int node_idx = n_nodes(depth - 1) + node_segment;
gpu_gpair scan = d_hist_scan[max_idx];
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
gpu_gpair sum = d_node_sums[node_segment * n_features + fidx];
gpu_gpair missing = parent_sum - sum;
bool missing_left;
float loss_chg = loss_chg_missing(scan, missing, parent_sum, parent_gain,
gpu_param_alias, missing_left);
float fvalue;
if (gidx == 0 || fidx != d_gidx_feature_map[gidx - 1]) {
fvalue = d_fidx_min_map[fidx];
} else {
fvalue = d_gidx_fvalue_map[gidx - 1];
}
gpu_gpair left = missing_left ? scan + missing : scan;
gpu_gpair right = parent_sum - left;
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
right, gpu_param_alias);
int left_child_idx = n_nodes(depth) + idx * 2;
int right_child_idx = n_nodes(depth) + idx * 2 + 1;
d_nodes[left_child_idx] =
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
d_nodes[right_child_idx] =
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
});
dh::safe_cuda(cudaDeviceSynchronize());
}
void GPUHistBuilder::InitFirstNode() {
// Build the root node on the CPU and copy to device
gpu_gpair sum_gradients =
thrust::reduce(device_gpair.tbegin(), device_gpair.tend(),
gpu_gpair(0, 0), thrust::plus<gpu_gpair>());
Node tmp =
Node(sum_gradients,
CalcGain(param, sum_gradients.grad(), sum_gradients.hess()),
CalcWeight(param, sum_gradients.grad(), sum_gradients.hess()));
thrust::copy_n(&tmp, 1, nodes.tbegin());
}
void GPUHistBuilder::UpdatePosition() {
if (is_dense) {
this->UpdatePositionDense();
} else {
this->UpdatePositionSparse();
}
}
void GPUHistBuilder::UpdatePositionDense() {
auto d_position = position.data();
Node* d_nodes = nodes.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_gidx = device_matrix.gidx.data();
int n_columns = info->num_col;
dh::launch_n(position.size(), [=] __device__(int idx) {
NodeIdT pos = d_position[idx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
d_position[idx] = -1;
return;
}
int gidx = d_gidx[idx * n_columns + node.split.findex];
float fvalue = d_gidx_fvalue_map[gidx];
if (fvalue <= node.split.fvalue) {
d_position[idx] = pos * 2 + 1;
} else {
d_position[idx] = pos * 2 + 2;
}
});
}
void GPUHistBuilder::UpdatePositionSparse() {
auto d_position = position.data();
auto d_position_tmp = position_tmp.data();
Node* d_nodes = nodes.data();
auto d_gidx_feature_map = gidx_feature_map.data();
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
auto d_gidx = device_matrix.gidx.data();
auto d_ridx = device_matrix.ridx.data();
// Update missing direction
dh::launch_n(position.size(), [=] __device__(int idx) {
NodeIdT pos = d_position[idx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
d_position_tmp[idx] = -1;
} else if (node.split.missing_left) {
d_position_tmp[idx] = pos * 2 + 1;
} else {
d_position_tmp[idx] = pos * 2 + 2;
}
});
// Update node based on fvalue where exists
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
int ridx = d_ridx[idx];
NodeIdT pos = d_position[ridx];
if (pos < 0) {
return;
}
Node node = d_nodes[pos];
if (node.IsLeaf()) {
return;
}
int gidx = d_gidx[idx];
int findex = d_gidx_feature_map[gidx];
if (findex == node.split.findex) {
float fvalue = d_gidx_fvalue_map[gidx];
if (fvalue <= node.split.fvalue) {
d_position_tmp[ridx] = pos * 2 + 1;
} else {
d_position_tmp[ridx] = pos * 2 + 2;
}
}
});
position = position_tmp;
}
void GPUHistBuilder::ColSampleTree() {
feature_set_tree.resize(info->num_col);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
}
void GPUHistBuilder::ColSampleLevel() {
feature_set_level.resize(feature_set_tree.size());
feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel);
std::vector<int> h_feature_flags(info->num_col, 0);
for (auto fidx : feature_set_level) {
h_feature_flags[fidx] = 1;
}
feature_flags = h_feature_flags;
}
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
DMatrix& fmat, // NOLINT
const RegTree& tree) {
if (!initialised) {
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
"block. Try setting 'tree_method' "
"parameter to 'exact'";
info = &fmat.info();
is_dense = info->num_nonzero == info->num_col * info->num_row;
hmat_.Init(&fmat, param.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(&fmat);
int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1;
// Build feature segments
std::vector<int> h_feature_segments;
for (int node = 0; node < n_nodes_level(param.max_depth - 1); node++) {
for (int fidx = 0; fidx < hmat_.row_ptr.size() - 1; fidx++) {
h_feature_segments.push_back(hmat_.row_ptr[fidx] + node * n_bins);
}
}
h_feature_segments.push_back(n_nodes_level(param.max_depth - 1) * n_bins);
int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins;
size_t free_memory = dh::available_memory();
ba.allocate(
&gidx_feature_map, n_bins, &hist_node_segments,
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
h_feature_segments.size(), &gain, level_max_bins, &position,
gpair.size(), &position_tmp, gpair.size(), &nodes,
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
&fidx_min_map, hmat_.min_val.size(), &argmax,
n_nodes_level(param.max_depth - 1), &node_sums,
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx,
gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist,
n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features);
if (!param.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on " << dh::device_name();
}
// Construct feature map
std::vector<int> h_gidx_feature_map(n_bins);
for (int row = 0; row < hmat_.row_ptr.size() - 1; row++) {
for (int i = hmat_.row_ptr[row]; i < hmat_.row_ptr[row + 1]; i++) {
h_gidx_feature_map[i] = row;
}
}
gidx_feature_map = h_gidx_feature_map;
// Construct device matrix
device_matrix.Init(gmat_);
gidx_fvalue_map = hmat_.cut;
fidx_min_map = hmat_.min_val;
thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0,
n_bins);
feature_segments = h_feature_segments;
hist.Init(n_bins);
initialised = true;
}
nodes.fill(Node());
position.fill(0);
device_gpair = gpair;
subsample_gpair(&device_gpair, param.subsample);
hist.Reset();
}
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
DMatrix* p_fmat, RegTree* p_tree) {
this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode();
this->ColSampleTree();
for (int depth = 0; depth < param.max_depth; depth++) {
this->ColSampleLevel();
this->BuildHist(depth);
this->FindSplit(depth);
this->UpdatePosition();
}
dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param);
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,104 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cusparse.h>
#include <thrust/device_vector.h>
#include <xgboost/tree_updater.h>
#include <cub/util_type.cuh> // Need key value pair definition
#include <vector>
#include "../../src/common/hist_util.h"
#include "../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
struct DeviceGMat {
dh::dvec<int> gidx;
dh::dvec<int> ridx;
void Init(const common::GHistIndexMatrix &gmat);
};
struct HistBuilder {
gpu_gpair *d_hist;
int n_bins;
__host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins);
__device__ void Add(gpu_gpair gpair, int gidx, int nidx) const;
__device__ gpu_gpair Get(int gidx, int nidx) const;
};
struct DeviceHist {
int n_bins;
dh::dvec<gpu_gpair> hist;
void Init(int max_depth);
void Reset();
HistBuilder GetBuilder();
gpu_gpair *GetLevelPtr(int depth);
int LevelSize(int depth);
};
class GPUHistBuilder {
public:
GPUHistBuilder();
~GPUHistBuilder();
void Init(const TrainParam &param);
void UpdateParam(const TrainParam &param) {
this->param = param;
this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda,
param.reg_alpha, param.max_delta_step);
}
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree);
void BuildHist(int depth);
void FindSplit(int depth);
void InitFirstNode();
void UpdatePosition();
void UpdatePositionDense();
void UpdatePositionSparse();
void ColSampleTree();
void ColSampleLevel();
TrainParam param;
GPUTrainingParam gpu_param;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo *info;
bool initialised;
bool is_dense;
DeviceGMat device_matrix;
dh::bulk_allocator ba;
dh::CubMemory cub_mem;
dh::dvec<int> gidx_feature_map;
dh::dvec<int> hist_node_segments;
dh::dvec<int> feature_segments;
dh::dvec<float> gain;
dh::dvec<NodeIdT> position;
dh::dvec<NodeIdT> position_tmp;
dh::dvec<float> gidx_fvalue_map;
dh::dvec<float> fidx_min_map;
DeviceHist hist;
dh::dvec<cub::KeyValuePair<int, float>> argmax;
dh::dvec<gpu_gpair> node_sums;
dh::dvec<gpu_gpair> hist_scan;
dh::dvec<gpu_gpair> device_gpair;
dh::dvec<Node> nodes;
dh::dvec<int> feature_flags;
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
};
} // namespace tree
} // namespace xgboost

View File

@ -1,52 +0,0 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "../../../src/tree/param.h"
// When we split on a value which has no left neighbour, define its left
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
// This produces a split value slightly lower than the current instance
#define FVALUE_EPS 0.0001
namespace xgboost {
namespace tree {
__device__ __forceinline__ float
device_calc_loss_chg(const GPUTrainingParam &param, const gpu_gpair &scan,
const gpu_gpair &missing, const gpu_gpair &parent_sum,
const float &parent_gain, bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;
}
gpu_gpair right = parent_sum - left;
float left_gain = CalcGain(param, left.grad(), left.hess());
float right_gain = CalcGain(param, right.grad(), right.hess());
return left_gain + right_gain - parent_gain;
}
__device__ __forceinline__ float
loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing,
const gpu_gpair &parent_sum, const float &parent_gain,
const GPUTrainingParam &param, bool &missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
} // namespace tree
} // namespace xgboost

View File

@ -2,7 +2,10 @@
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include <xgboost/tree_model.h>
#include <cfloat>
#include <tuple> // The linter is not very smart and thinks we need this
namespace xgboost {
@ -104,8 +107,10 @@ struct GPUTrainingParam {
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
float reg_lambda_in, float reg_alpha_in,
float max_delta_step_in)
: min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in),
reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {}
: min_child_weight(min_child_weight_in),
reg_lambda(reg_lambda_in),
reg_alpha(reg_alpha_in),
max_delta_step(max_delta_step_in) {}
};
struct Split {
@ -123,8 +128,9 @@ struct Split {
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
gpu_gpair right_sum_in,
const GPUTrainingParam &param) {
if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight &&
right_sum_in.hess() > param.min_child_weight) {
if (loss_chg_in > loss_chg &&
left_sum_in.hess() >= param.min_child_weight &&
right_sum_in.hess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
missing_left = missing_left_in;
fvalue = fvalue_in;
@ -146,7 +152,7 @@ struct Split {
}
}
__device__ void Print() {
__host__ __device__ void Print() {
printf("Loss: %1.4f\n", loss_chg);
printf("Missing left: %d\n", missing_left);
printf("fvalue: %1.4f\n", fvalue);
@ -184,5 +190,6 @@ struct Node {
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
};
} // namespace tree
} // namespace xgboost

View File

@ -1,6 +0,0 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "loss_functions.cuh"

View File

@ -7,16 +7,52 @@
#include "../../src/common/sync.h"
#include "../../src/tree/param.h"
#include "gpu_builder.cuh"
#include "gpu_hist_builder.cuh"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
/*! \brief column-wise update to construct a tree */
template <typename TStats> class GPUMaker : public TreeUpdater {
template <typename TStats>
class GPUMaker : public TreeUpdater {
public:
void
Init(const std::vector<std::pair<std::string, std::string>> &args) override {
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
try {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
GPUBuilder builder;
};
template <typename TStats>
class GPUHistMaker : public TreeUpdater {
public:
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
@ -29,21 +65,28 @@ template <typename TStats> class GPUMaker : public TreeUpdater {
param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
// build tree
try {
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
GPUBuilder builder;
GPUHistBuilder builder;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker<GradStats>(); });
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker<GradStats>(); });
} // namespace tree
} // namespace xgboost

View File

@ -1,137 +0,0 @@
#pylint: skip-file
import numpy as np
import xgboost as xgb
import os
import pandas as pd
import urllib2
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def get_last_eval_callback(result):
def callback(env):
result.append(env.evaluation_result_list[-1][1])
callback.after_iteration = True
return callback
def load_adult():
path = "../../demo/data/adult.data"
if(not os.path.isfile(path)):
data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data')
with open(path,'wb') as output:
output.write(data.read())
train_set = pd.read_csv( path, header=None)
train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'wage_class']
train_nomissing = train_set.replace(' ?', np.nan).dropna()
for feature in train_nomissing.columns: # Loop through all columns in the dataframe
if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings
train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer
y_train = train_nomissing.pop('wage_class')
return xgb.DMatrix( train_nomissing, label=y_train)
def load_higgs():
higgs_path = '../../demo/data/training.csv'
dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } )
#dtrain = dtrain[0:200000,:]
label = dtrain[:,32]
data = dtrain[:,1:31]
weight = dtrain[:,31]
return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
def load_dermatology():
data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } )
sz = data.shape
X = data[:,0:33]
Y = data[:, 34]
return xgb.DMatrix( X, label=Y)
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
#Check GPU test evaluation is approximately equal to CPU test evaluation
def check_result(cpu_result, gpu_result):
for i in range(len(cpu_result)):
if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02):
return False
return True
#Get data
data = []
params = []
data.append(load_higgs())
params.append({})
data.append( load_adult())
params.append({})
data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test'))
params.append({'objective':'binary:logistic'})
#if(os.path.isfile("../../demo/data/dermatology.data")):
data.append(load_dermatology())
params.append({'objective':'multi:softmax', 'num_class': 6})
num_round = 5
num_pass = 0
num_fail = 0
test_depth = [ 1, 6, 9, 11, 15 ]
#test_depth = [ 1 ]
for test in range(0, len(data)):
for depth in test_depth:
xgmat = data[test]
cpu_result = []
param = params[test]
param['max_depth'] = depth
param['updater'] = 'grow_colmaker'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('reference_model.txt','', True)
gpu_result = []
param['updater'] = 'grow_gpu'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('dump.raw.txt','', True)
if check_result(cpu_result, gpu_result):
print(bcolors.OKGREEN + "Pass" + bcolors.ENDC)
num_pass = num_pass + 1
else:
print(bcolors.FAIL + "Fail" + bcolors.ENDC)
num_fail = num_fail + 1
print("cpu rmse: "+str(cpu_result))
print("gpu rmse: "+str(gpu_result))
print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed")

View File

@ -0,0 +1,221 @@
#pylint: skip-file
import sys
sys.path.append("../../tests/python")
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
rng = np.random.RandomState(1994)
dpath = '../../demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
class TestGPU(unittest.TestCase):
def test_grow_gpu(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 1,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'updater': 'grow_gpu',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
num_rounds = 10
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def test_grow_gpu_hist(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 1,
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'updater': 'grow_gpu_hist',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
num_rounds = 10
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 2,
'eval_metric': 'auc',
'max_bin': 2}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# subsampling
param = {'objective': 'binary:logistic',
'updater': 'grow_gpu_hist',
'max_depth': 3,
'eval_metric': 'auc',
'colsample_bytree': 0.5,
'colsample_bylevel': 0.5,
'subsample': 0.5
}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))

View File

@ -79,6 +79,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
bool refresh_leaf;
// auxiliary data structure
std::vector<int> monotone_constraints;
// gpu to use for single gpu algorithms
int gpu_id;
// declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(learning_rate)
@ -186,6 +188,10 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<int>())
.describe("Constraint of variable monotonicity");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for single gpu algorithms");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);