[GPU Plugin] Fast histogram speed improvements. Updated benchmarks. (#2258)
This commit is contained in:
parent
98ea461532
commit
6bf968efe6
@ -5,7 +5,7 @@ find_package(OpenMP)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
|
||||
if(NOT MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O3 -funroll-loops -msse2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O3 -funroll-loops -msse2 -D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES")
|
||||
endif()
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ if(PLUGIN_UPDATER_GPU)
|
||||
#Find cub
|
||||
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
|
||||
include_directories(${CUB_DIRECTORY})
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;")
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35;-lineinfo;")
|
||||
if(NOT MSVC)
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
|
||||
endif()
|
||||
|
||||
@ -57,7 +57,7 @@ class TreeUpdater {
|
||||
* updated by the time this function returns.
|
||||
*/
|
||||
virtual bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) const {
|
||||
std::vector<bst_float>* out_preds) {
|
||||
return false;
|
||||
}
|
||||
/*!
|
||||
|
||||
@ -2,10 +2,21 @@
|
||||
This plugin adds GPU accelerated tree construction algorithms to XGBoost.
|
||||
## Usage
|
||||
Specify the 'updater' parameter as one of the following algorithms.
|
||||
updater | Description
|
||||
--- | ---
|
||||
grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist'
|
||||
grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate.
|
||||
|
||||
### Algorithms
|
||||
| updater | Description |
|
||||
| --- | --- |
|
||||
grow_gpu | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'grow_gpu_hist' |
|
||||
grow_gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. |
|
||||
|
||||
### Supported parameters
|
||||
| parameter | grow_gpu | grow_gpu_hist |
|
||||
| --- | --- | --- |
|
||||
subsample | ✔ | ✔ |
|
||||
colsample_bytree | ✔ | ✔|
|
||||
colsample_bylevel | ✔ | ✔ |
|
||||
max_bin | ✖ | ✔ |
|
||||
gpu_id | ✔ | ✔ |
|
||||
|
||||
All algorithms currently use only a single GPU. The device ordinal can be selected using the 'gpu_id' parameter, which defaults to 0.
|
||||
|
||||
@ -14,13 +25,32 @@ This plugin currently works with the CLI version and python version.
|
||||
Python example:
|
||||
```python
|
||||
param['gpu_id'] = 1
|
||||
param['updater'] = 'grow_gpu'
|
||||
param['max_bin'] = 16
|
||||
param['updater'] = 'grow_gpu_hist'
|
||||
```
|
||||
## Benchmarks
|
||||
To run benchmarks on synthetic data for binary classification:
|
||||
```bash
|
||||
$ python benchmark/benchmark.py
|
||||
```
|
||||
|
||||
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for performance benchmarks of the 'grow_gpu' updater.
|
||||
Training time time on 1000000 rows x 50 columns with 500 boosting iterations on i7-6700K CPU @ 4.00GHz and Pascal Titan X.
|
||||
|
||||
| Updater | Time (s) |
|
||||
| --- | --- |
|
||||
| grow_gpu_hist | 11.09 |
|
||||
| grow_fast_histmaker (histogram XGBoost - CPU) | 41.75 |
|
||||
| grow_gpu | 193.90 |
|
||||
| grow_colmaker (standard XGBoost - CPU) | 720.12 |
|
||||
|
||||
|
||||
[See here](http://dmlc.ml/2016/12/14/GPU-accelerated-xgboost.html) for additional performance benchmarks of the 'grow_gpu' updater.
|
||||
|
||||
## Test
|
||||
To run tests:
|
||||
```bash
|
||||
$ python -m nose test/
|
||||
```
|
||||
## Dependencies
|
||||
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
|
||||
|
||||
@ -48,13 +78,17 @@ $ cmake .. -G"Visual Studio 12 2013 Win64" -DPLUGIN_UPDATER_GPU=ON -DCUB_DIRECTO
|
||||
```
|
||||
You may also be able to use a later version of visual studio depending on whether the CUDA toolkit supports it.
|
||||
|
||||
On an linux cmake will generate a Makefile in the build directory. Invoking the command 'make' from this directory will build the project. If the build fails try invoking make again. There can sometimes be problems with the order items are built.
|
||||
On linux cmake will generate a Makefile in the build directory. Invoking the command 'make' from this directory will build the project. If the build fails try invoking make again. There can sometimes be problems with the order items are built.
|
||||
|
||||
On Windows cmake will generate an xgboost.sln solution file in the build directory. Build this solution in release mode. This is also a good time to check it is being built as x64. If not make sure the cmake generator is set correctly.
|
||||
|
||||
The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm.
|
||||
|
||||
## Changelog
|
||||
##### 2017/5/5
|
||||
* Histogram performance improvements
|
||||
* Fix gcc build issues
|
||||
|
||||
##### 2017/4/25
|
||||
* Add fast histogram algorithm
|
||||
* Fix Linux build
|
||||
|
||||
@ -1,32 +1,50 @@
|
||||
#pylint: skip-file
|
||||
# pylint: skip-file
|
||||
import sys, argparse
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.model_selection import train_test_split
|
||||
import time
|
||||
|
||||
n = 1000000
|
||||
num_rounds = 100
|
||||
num_rounds = 500
|
||||
|
||||
X,y = make_classification(n, n_features=50, random_state=7)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||
dtrain = xgb.DMatrix(X_train, y_train)
|
||||
dtest = xgb.DMatrix(X_test, y_test)
|
||||
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
|
||||
print("Generating dataset: {} rows * {} columns".format(args.rows,args.columns))
|
||||
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'tree_method': 'exact',
|
||||
'updater': 'grow_gpu_hist',
|
||||
'max_depth': 8,
|
||||
'silent': 1,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
tmp = time.time()
|
||||
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||
evals_result=res)
|
||||
print ("GPU: %s seconds" % (str(time.time() - tmp)))
|
||||
param = {'objective': 'binary:logistic',
|
||||
'tree_method': 'exact',
|
||||
'max_depth': 6,
|
||||
'silent': 1,
|
||||
'eval_metric': 'auc'}
|
||||
|
||||
tmp = time.time()
|
||||
param['updater'] = 'grow_fast_histmaker'
|
||||
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], evals_result=res)
|
||||
print ("CPU: %s seconds" % (str(time.time() - tmp)))
|
||||
param['updater'] = gpu_algorithm
|
||||
print("Training with '%s'" % param['updater'])
|
||||
tmp = time.time()
|
||||
xgb.train(param, dtrain, args.iterations)
|
||||
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
param['updater'] = cpu_algorithm
|
||||
print("Training with '%s'" % param['updater'])
|
||||
tmp = time.time()
|
||||
xgb.train(param, dtrain, args.iterations)
|
||||
print ("Time: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--algorithm', choices=['all', 'grow_gpu', 'grow_gpu_hist'], required=True)
|
||||
parser.add_argument('--rows',type=int,default=1000000)
|
||||
parser.add_argument('--columns',type=int,default=50)
|
||||
parser.add_argument('--iterations',type=int,default=500)
|
||||
args = parser.parse_args()
|
||||
|
||||
if 'grow_gpu_hist' in args.algorithm:
|
||||
run_benchmark(args, args.algorithm, 'grow_fast_histmaker')
|
||||
if 'grow_gpu ' in args.algorithm:
|
||||
run_benchmark(args, args.algorithm, 'grow_colmaker')
|
||||
if 'all' in args.algorithm:
|
||||
run_benchmark(args, 'grow_gpu', 'grow_colmaker')
|
||||
run_benchmark(args, 'grow_gpu_hist', 'grow_fast_histmaker')
|
||||
|
||||
|
||||
@ -62,6 +62,25 @@ __host__ __device__ inline int n_nodes(int depth) {
|
||||
// Number of nodes at this level of the tree
|
||||
__host__ __device__ inline int n_nodes_level(int depth) { return 1 << depth; }
|
||||
|
||||
// Whether a node is currently being processed at current depth
|
||||
__host__ __device__ inline bool is_active(int nidx, int depth) {
|
||||
return nidx >= n_nodes(depth - 1);
|
||||
}
|
||||
|
||||
__host__ __device__ inline int parent_nidx(int nidx) { return (nidx - 1) / 2; }
|
||||
|
||||
__host__ __device__ inline int left_child_nidx(int nidx) {
|
||||
return nidx * 2 + 1;
|
||||
}
|
||||
|
||||
__host__ __device__ inline int right_child_nidx(int nidx) {
|
||||
return nidx * 2 + 2;
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool is_left_child(int nidx) {
|
||||
return nidx % 2 == 1;
|
||||
}
|
||||
|
||||
enum NodeType {
|
||||
NODE = 0,
|
||||
LEAF = 1,
|
||||
@ -96,7 +115,7 @@ inline void dense2sparse_tree(RegTree* p_tree,
|
||||
thrust::device_ptr<Node> nodes_begin,
|
||||
thrust::device_ptr<Node> nodes_end,
|
||||
const TrainParam& param) {
|
||||
RegTree & tree = *p_tree;
|
||||
RegTree& tree = *p_tree;
|
||||
thrust::host_vector<Node> h_nodes(nodes_begin, nodes_end);
|
||||
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
|
||||
flag_nodes(h_nodes, &node_flags, 0, NODE);
|
||||
|
||||
@ -9,15 +9,11 @@
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cusparse_v2.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
// Uncomment to enable
|
||||
// #define DEVICE_TIMER
|
||||
@ -43,20 +39,6 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
|
||||
|
||||
return code;
|
||||
}
|
||||
#define safe_cusparse(ans) throw_on_cusparse_error((ans), __FILE__, __LINE__)
|
||||
|
||||
inline cusparseStatus_t throw_on_cusparse_error(cusparseStatus_t status,
|
||||
const char *file, int line) {
|
||||
if (status != CUSPARSE_STATUS_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << "cusparse error: " << file << "(" << line << ")";
|
||||
std::string error_text;
|
||||
ss >> error_text;
|
||||
throw error_text;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
#define gpuErrchk(ans) \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||
@ -153,40 +135,18 @@ struct DeviceTimer {
|
||||
};
|
||||
|
||||
struct Timer {
|
||||
volatile double start;
|
||||
typedef std::chrono::high_resolution_clock ClockT;
|
||||
|
||||
typedef std::chrono::high_resolution_clock::time_point TimePointT;
|
||||
TimePointT start;
|
||||
Timer() { reset(); }
|
||||
|
||||
double seconds_now() {
|
||||
#ifdef _WIN32
|
||||
static LARGE_INTEGER s_frequency;
|
||||
QueryPerformanceFrequency(&s_frequency);
|
||||
LARGE_INTEGER now;
|
||||
QueryPerformanceCounter(&now);
|
||||
return static_cast<double>(now.QuadPart) / s_frequency.QuadPart;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void reset() {
|
||||
#ifdef _WIN32
|
||||
_ReadWriteBarrier();
|
||||
start = seconds_now();
|
||||
#endif
|
||||
}
|
||||
double elapsed() {
|
||||
#ifdef _WIN32
|
||||
_ReadWriteBarrier();
|
||||
return seconds_now() - start;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
void reset() { start = ClockT::now(); }
|
||||
int64_t elapsed() const { return (ClockT::now() - start).count(); }
|
||||
void printElapsed(std::string label) {
|
||||
#ifdef TIMERS
|
||||
safe_cuda(cudaDeviceSynchronize());
|
||||
printf("%s:\t %1.4fs\n", label.c_str(), elapsed());
|
||||
#endif
|
||||
printf("%s:\t %lld\n", label.c_str(), elapsed());
|
||||
reset();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
/*!
|
||||
* Copyright 2017 Rory mitchell
|
||||
*/
|
||||
*/
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/count.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include "common.cuh"
|
||||
#include "device_helpers.cuh"
|
||||
#include "gpu_hist_builder.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
void DeviceGMat::Init(const common::GHistIndexMatrix& gmat) {
|
||||
CHECK_EQ(gidx.size(), gmat.index.size())
|
||||
<< "gidx must be externally allocated";
|
||||
@ -61,15 +62,19 @@ __device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
|
||||
return d_hist[nidx * n_bins + gidx];
|
||||
}
|
||||
|
||||
GPUHistBuilder::GPUHistBuilder() {}
|
||||
GPUHistBuilder::GPUHistBuilder()
|
||||
: initialised(false),
|
||||
is_dense(false),
|
||||
p_last_fmat_(nullptr),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
GPUHistBuilder::~GPUHistBuilder() {}
|
||||
|
||||
void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
||||
CHECK(param.grow_policy != TrainParam::kLossGuide)
|
||||
<< "Loss guided growth policy not supported. Use CPU algorithm.";
|
||||
this->param = param;
|
||||
initialised = false;
|
||||
is_dense = false;
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
if (!param.silent) {
|
||||
@ -77,117 +82,24 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ReductionOpT>
|
||||
struct ReduceBySegmentOp {
|
||||
ReductionOpT op;
|
||||
|
||||
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
|
||||
|
||||
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op)
|
||||
: op(op) {}
|
||||
|
||||
template <typename KeyValuePairT>
|
||||
__host__ __device__ __forceinline__ KeyValuePairT
|
||||
operator()(const KeyValuePairT& first, const KeyValuePairT& second) {
|
||||
KeyValuePairT retval;
|
||||
retval.key = second.key;
|
||||
retval.value =
|
||||
first.key != second.key ? second.value : op(first.value, second.value);
|
||||
return retval;
|
||||
}
|
||||
};
|
||||
|
||||
template <int ITEMS_PER_THREAD, int BLOCK_THREADS>
|
||||
__global__ void hist_kernel(gpu_gpair* d_dense_hist, int* d_ridx, int* d_gidx,
|
||||
NodeIdT* d_position, gpu_gpair* d_gpair, int n_bins,
|
||||
int depth, int n) {
|
||||
typedef cub::KeyValuePair<int, gpu_gpair> OffsetValuePairT;
|
||||
typedef cub::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoadT;
|
||||
typedef cub::BlockRadixSort<int, BLOCK_THREADS, ITEMS_PER_THREAD, int>
|
||||
BlockRadixSortT;
|
||||
typedef cub::BlockDiscontinuity<int, BLOCK_THREADS> BlockDiscontinuityKeysT;
|
||||
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
|
||||
typedef cub::BlockScan<OffsetValuePairT, BLOCK_THREADS,
|
||||
cub::BLOCK_SCAN_WARP_SCANS>
|
||||
BlockScanT;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockRadixSortT::TempStorage sort;
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename BlockDiscontinuityKeysT::TempStorage disc;
|
||||
};
|
||||
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
int ridx[ITEMS_PER_THREAD];
|
||||
int gidx[ITEMS_PER_THREAD];
|
||||
|
||||
const int TILE_SIZE = ITEMS_PER_THREAD * BLOCK_THREADS;
|
||||
int block_offset = blockIdx.x * TILE_SIZE;
|
||||
|
||||
BlockLoadT(temp_storage.load)
|
||||
.Load(d_ridx + block_offset, ridx, n - block_offset, -1);
|
||||
BlockLoadT(temp_storage.load)
|
||||
.Load(d_gidx + block_offset, gidx, n - block_offset, -1);
|
||||
|
||||
int hist_idx[ITEMS_PER_THREAD];
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||
hist_idx[ITEM] =
|
||||
(d_position[ridx[ITEM]] - n_nodes(depth - 1)) * n_bins + gidx[ITEM];
|
||||
} else {
|
||||
hist_idx[ITEM] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
BlockRadixSortT(temp_storage.sort).Sort(hist_idx, ridx);
|
||||
|
||||
OffsetValuePairT kv[ITEMS_PER_THREAD];
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
kv[ITEM].key = hist_idx[ITEM];
|
||||
if (ridx[ITEM] > -1) {
|
||||
kv[ITEM].value = d_gpair[ridx[ITEM]];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Scan
|
||||
BlockScanT(temp_storage.scan).InclusiveScan(kv, kv, ReduceBySegmentOpT());
|
||||
|
||||
__syncthreads();
|
||||
int flags[ITEMS_PER_THREAD];
|
||||
BlockDiscontinuityKeysT(temp_storage.disc)
|
||||
.FlagTails(flags, hist_idx, cub::Inequality());
|
||||
|
||||
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) {
|
||||
if (flags[ITEM]) {
|
||||
if (ridx[ITEM] > -1 && d_position[ridx[ITEM]] > -1) {
|
||||
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._grad), kv[ITEM].value._grad);
|
||||
atomicAdd(&(d_dense_hist[hist_idx[ITEM]]._hess), kv[ITEM].value._hess);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::BuildHist(int depth) {
|
||||
auto d_ridx = device_matrix.ridx.data();
|
||||
auto d_gidx = device_matrix.gidx.data();
|
||||
auto d_position = position.data();
|
||||
auto d_gpair = device_gpair.data();
|
||||
auto hist_builder = hist.GetBuilder();
|
||||
auto d_left_child_smallest = left_child_smallest.data();
|
||||
|
||||
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||
int ridx = d_ridx[idx];
|
||||
int pos = d_position[ridx];
|
||||
if (!is_active(pos, depth)) return;
|
||||
|
||||
// Only increment even nodes
|
||||
if (pos < 0 || pos % 2 == 1) return;
|
||||
// Only increment smallest node
|
||||
bool is_smallest =
|
||||
(d_left_child_smallest[parent_nidx(pos)] && is_left_child(pos)) ||
|
||||
(!d_left_child_smallest[parent_nidx(pos)] && !is_left_child(pos));
|
||||
if (!is_smallest && depth > 0) return;
|
||||
|
||||
int gidx = d_gidx[idx];
|
||||
gpu_gpair gpair = d_gpair[ridx];
|
||||
@ -199,19 +111,181 @@ void GPUHistBuilder::BuildHist(int depth) {
|
||||
|
||||
// Subtraction trick
|
||||
int n_sub_bins = (n_nodes_level(depth) / 2) * hist_builder.n_bins;
|
||||
if (n_sub_bins > 0) {
|
||||
if (depth > 0) {
|
||||
dh::launch_n(n_sub_bins, [=] __device__(int idx) {
|
||||
int nidx = n_nodes(depth - 1) + ((idx / hist_builder.n_bins) * 2);
|
||||
bool left_smallest = d_left_child_smallest[parent_nidx(nidx)];
|
||||
if (left_smallest) {
|
||||
nidx++; // If left is smallest switch to right child
|
||||
}
|
||||
|
||||
int gidx = idx % hist_builder.n_bins;
|
||||
gpu_gpair parent = hist_builder.Get(gidx, nidx / 2);
|
||||
gpu_gpair right = hist_builder.Get(gidx, nidx + 1);
|
||||
hist_builder.Add(parent - right, gidx, nidx);
|
||||
gpu_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||
gpu_gpair other = hist_builder.Get(gidx, other_nidx);
|
||||
hist_builder.Add(parent - other, gidx, nidx);
|
||||
});
|
||||
}
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void find_split_kernel(
|
||||
const gpu_gpair* d_level_hist, int* d_feature_segments, int depth,
|
||||
int n_features, int n_bins, Node* d_nodes, float* d_fidx_min_map,
|
||||
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
||||
bool* d_left_child_smallest, bool colsample, int* d_feature_flags) {
|
||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||
typedef cub::BlockScan<gpu_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||
BlockScanT;
|
||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||
typedef cub::BlockReduce<gpu_gpair, BLOCK_THREADS> SumReduceT;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
typename SumReduceT::TempStorage sum_reduce;
|
||||
};
|
||||
|
||||
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
||||
struct UninitializedGpair : cub::Uninitialized<gpu_gpair> {};
|
||||
|
||||
__shared__ UninitializedSplit uninitialized_split;
|
||||
Split& split = uninitialized_split.Alias();
|
||||
__shared__ ArgMaxT block_max;
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
split = Split();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int node_idx = n_nodes(depth - 1) + blockIdx.x;
|
||||
|
||||
for (int fidx = 0; fidx < n_features; fidx++) {
|
||||
if (colsample && d_feature_flags[fidx] == 0) continue;
|
||||
|
||||
int begin = d_feature_segments[blockIdx.x * n_features + fidx];
|
||||
int end = d_feature_segments[blockIdx.x * n_features + fidx + 1];
|
||||
int gidx = (begin - (blockIdx.x * n_bins)) + threadIdx.x;
|
||||
bool thread_active = threadIdx.x < end - begin;
|
||||
|
||||
// Scan histogram
|
||||
gpu_gpair bin =
|
||||
thread_active ? d_level_hist[begin + threadIdx.x] : gpu_gpair();
|
||||
|
||||
gpu_gpair feature_sum;
|
||||
BlockScanT(temp_storage.scan)
|
||||
.ExclusiveScan(bin, bin, gpu_gpair(), cub::Sum(), feature_sum);
|
||||
|
||||
// Calculate gain
|
||||
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||
float parent_gain = d_nodes[node_idx].root_gain;
|
||||
|
||||
gpu_gpair missing = parent_sum - feature_sum;
|
||||
|
||||
bool missing_left;
|
||||
float gain = thread_active
|
||||
? loss_chg_missing(bin, missing, parent_sum, parent_gain,
|
||||
gpu_param, missing_left)
|
||||
: -FLT_MAX;
|
||||
__syncthreads();
|
||||
|
||||
// Find thread with best gain
|
||||
ArgMaxT tuple(threadIdx.x, gain);
|
||||
ArgMaxT best = MaxReduceT(temp_storage.max_reduce)
|
||||
.Reduce(tuple, cub::ArgMax(), end - begin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
float fvalue;
|
||||
if (threadIdx.x == 0) {
|
||||
fvalue = d_fidx_min_map[fidx];
|
||||
} else {
|
||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
|
||||
gpu_gpair left = missing_left ? bin + missing : bin;
|
||||
gpu_gpair right = parent_sum - left;
|
||||
|
||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Create node
|
||||
if (threadIdx.x == 0) {
|
||||
d_nodes[node_idx].split = split;
|
||||
if (depth == 0) {
|
||||
// split.Print();
|
||||
}
|
||||
|
||||
d_nodes[left_child_nidx(node_idx)] = Node(
|
||||
split.left_sum,
|
||||
CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()),
|
||||
CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess()));
|
||||
|
||||
d_nodes[right_child_nidx(node_idx)] = Node(
|
||||
split.right_sum,
|
||||
CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()),
|
||||
CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess()));
|
||||
|
||||
// Record smallest node
|
||||
if (split.left_sum.hess() <= split.right_sum.hess()) {
|
||||
d_left_child_smallest[node_idx] = true;
|
||||
} else {
|
||||
d_left_child_smallest[node_idx] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::FindSplit(int depth) {
|
||||
// Specialised based on max_bins
|
||||
if (param.max_bin <= 256) {
|
||||
this->FindSplit256(depth);
|
||||
} else if (param.max_bin <= 1024) {
|
||||
this->FindSplit1024(depth);
|
||||
} else {
|
||||
this->FindSplitLarge(depth);
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::FindSplit256(int depth) {
|
||||
CHECK_LE(param.max_bin, 256);
|
||||
const int BLOCK_THREADS = 256;
|
||||
const int GRID_SIZE = n_nodes_level(depth);
|
||||
bool colsample =
|
||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
||||
feature_flags.data());
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
void GPUHistBuilder::FindSplit1024(int depth) {
|
||||
CHECK_LE(param.max_bin, 1024);
|
||||
const int BLOCK_THREADS = 1024;
|
||||
const int GRID_SIZE = n_nodes_level(depth);
|
||||
bool colsample =
|
||||
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
hist.GetLevelPtr(depth), feature_segments.data(), depth, info->num_col,
|
||||
hmat_.row_ptr.back(), nodes.data(), fidx_min_map.data(),
|
||||
gidx_fvalue_map.data(), gpu_param, left_child_smallest.data(), colsample,
|
||||
feature_flags.data());
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
void GPUHistBuilder::FindSplitLarge(int depth) {
|
||||
auto counting = thrust::make_counting_iterator(0);
|
||||
auto d_gidx_feature_map = gidx_feature_map.data();
|
||||
int n_bins = hmat_.row_ptr.back();
|
||||
@ -295,6 +369,8 @@ void GPUHistBuilder::FindSplit(int depth) {
|
||||
auto d_argmax = argmax.data();
|
||||
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||
auto d_fidx_min_map = fidx_min_map.data();
|
||||
auto d_left_child_smallest = left_child_smallest.data();
|
||||
|
||||
dh::launch_n(n_nodes_level(depth), [=] __device__(int idx) {
|
||||
int max_idx = n_bins * idx + d_argmax[idx].key;
|
||||
int gidx = max_idx % n_bins;
|
||||
@ -317,64 +393,78 @@ void GPUHistBuilder::FindSplit(int depth) {
|
||||
} else {
|
||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
|
||||
gpu_gpair left = missing_left ? scan + missing : scan;
|
||||
gpu_gpair right = parent_sum - left;
|
||||
d_nodes[node_idx].split.Update(loss_chg, missing_left, fvalue, fidx, left,
|
||||
right, gpu_param_alias);
|
||||
|
||||
int left_child_idx = n_nodes(depth) + idx * 2;
|
||||
int right_child_idx = n_nodes(depth) + idx * 2 + 1;
|
||||
d_nodes[left_child_idx] =
|
||||
d_nodes[left_child_nidx(node_idx)] =
|
||||
Node(left, CalcGain(gpu_param_alias, left.grad(), left.hess()),
|
||||
CalcWeight(gpu_param_alias, left.grad(), left.hess()));
|
||||
|
||||
d_nodes[right_child_idx] =
|
||||
d_nodes[right_child_nidx(node_idx)] =
|
||||
Node(right, CalcGain(gpu_param_alias, right.grad(), right.hess()),
|
||||
CalcWeight(gpu_param_alias, right.grad(), right.hess()));
|
||||
|
||||
// Record smallest node
|
||||
if (left.hess() <= right.hess()) {
|
||||
d_left_child_smallest[node_idx] = true;
|
||||
} else {
|
||||
d_left_child_smallest[node_idx] = false;
|
||||
}
|
||||
});
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void GPUHistBuilder::InitFirstNode() {
|
||||
// Build the root node on the CPU and copy to device
|
||||
gpu_gpair sum_gradients =
|
||||
thrust::reduce(device_gpair.tbegin(), device_gpair.tend(),
|
||||
gpu_gpair(0, 0), thrust::plus<gpu_gpair>());
|
||||
auto d_gpair = device_gpair.data();
|
||||
auto d_node_sums = node_sums.data();
|
||||
auto d_nodes = nodes.data();
|
||||
auto gpu_param_alias = gpu_param;
|
||||
|
||||
Node tmp =
|
||||
Node(sum_gradients,
|
||||
CalcGain(param, sum_gradients.grad(), sum_gradients.hess()),
|
||||
CalcWeight(param, sum_gradients.grad(), sum_gradients.hess()));
|
||||
size_t temp_storage_bytes;
|
||||
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, d_gpair, d_node_sums,
|
||||
device_gpair.size(), cub::Sum(), gpu_gpair());
|
||||
cub_mem.LazyAllocate(temp_storage_bytes);
|
||||
cub::DeviceReduce::Reduce(cub_mem.d_temp_storage, cub_mem.temp_storage_bytes,
|
||||
d_gpair, d_node_sums, device_gpair.size(),
|
||||
cub::Sum(), gpu_gpair());
|
||||
|
||||
thrust::copy_n(&tmp, 1, nodes.tbegin());
|
||||
dh::launch_n(1, [=] __device__(int idx) {
|
||||
gpu_gpair sum_gradients = d_node_sums[idx];
|
||||
d_nodes[idx] = Node(
|
||||
sum_gradients,
|
||||
CalcGain(gpu_param_alias, sum_gradients.grad(), sum_gradients.hess()),
|
||||
CalcWeight(gpu_param_alias, sum_gradients.grad(),
|
||||
sum_gradients.hess()));
|
||||
});
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePosition() {
|
||||
void GPUHistBuilder::UpdatePosition(int depth) {
|
||||
if (is_dense) {
|
||||
this->UpdatePositionDense();
|
||||
this->UpdatePositionDense(depth);
|
||||
} else {
|
||||
this->UpdatePositionSparse();
|
||||
this->UpdatePositionSparse(depth);
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePositionDense() {
|
||||
void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
auto d_position = position.data();
|
||||
Node* d_nodes = nodes.data();
|
||||
auto d_gidx_fvalue_map = gidx_fvalue_map.data();
|
||||
auto d_gidx = device_matrix.gidx.data();
|
||||
int n_columns = info->num_col;
|
||||
|
||||
int gidx_size = device_matrix.gidx.size();
|
||||
|
||||
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||
NodeIdT pos = d_position[idx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Node node = d_nodes[pos];
|
||||
|
||||
if (node.IsLeaf()) {
|
||||
d_position[idx] = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -383,14 +473,16 @@ void GPUHistBuilder::UpdatePositionDense() {
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
if (fvalue <= node.split.fvalue) {
|
||||
d_position[idx] = pos * 2 + 1;
|
||||
d_position[idx] = left_child_nidx(pos);
|
||||
} else {
|
||||
d_position[idx] = pos * 2 + 2;
|
||||
d_position[idx] = right_child_nidx(pos);
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void GPUHistBuilder::UpdatePositionSparse() {
|
||||
void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
||||
auto d_position = position.data();
|
||||
auto d_position_tmp = position_tmp.data();
|
||||
Node* d_nodes = nodes.data();
|
||||
@ -402,14 +494,16 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
// Update missing direction
|
||||
dh::launch_n(position.size(), [=] __device__(int idx) {
|
||||
NodeIdT pos = d_position[idx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
d_position_tmp[idx] = pos;
|
||||
return;
|
||||
}
|
||||
|
||||
Node node = d_nodes[pos];
|
||||
|
||||
if (node.IsLeaf()) {
|
||||
d_position_tmp[idx] = -1;
|
||||
d_position_tmp[idx] = pos;
|
||||
return;
|
||||
} else if (node.split.missing_left) {
|
||||
d_position_tmp[idx] = pos * 2 + 1;
|
||||
} else {
|
||||
@ -417,11 +511,13 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
// Update node based on fvalue where exists
|
||||
dh::launch_n(device_matrix.gidx.size(), [=] __device__(int idx) {
|
||||
int ridx = d_ridx[idx];
|
||||
NodeIdT pos = d_position[ridx];
|
||||
if (pos < 0) {
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -438,23 +534,29 @@ void GPUHistBuilder::UpdatePositionSparse() {
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
if (fvalue <= node.split.fvalue) {
|
||||
d_position_tmp[ridx] = pos * 2 + 1;
|
||||
d_position_tmp[ridx] = left_child_nidx(pos);
|
||||
} else {
|
||||
d_position_tmp[ridx] = pos * 2 + 2;
|
||||
d_position_tmp[ridx] = right_child_nidx(pos);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
position = position_tmp;
|
||||
}
|
||||
|
||||
void GPUHistBuilder::ColSampleTree() {
|
||||
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
|
||||
|
||||
feature_set_tree.resize(info->num_col);
|
||||
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
|
||||
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
|
||||
}
|
||||
|
||||
void GPUHistBuilder::ColSampleLevel() {
|
||||
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
|
||||
|
||||
feature_set_level.resize(feature_set_tree.size());
|
||||
feature_set_level = col_sample(feature_set_tree, param.colsample_bylevel);
|
||||
std::vector<int> h_feature_flags(info->num_col, 0);
|
||||
@ -491,18 +593,20 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
int level_max_bins = n_nodes_level(param.max_depth - 1) * n_bins;
|
||||
|
||||
size_t free_memory = dh::available_memory();
|
||||
ba.allocate(
|
||||
&gidx_feature_map, n_bins, &hist_node_segments,
|
||||
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
|
||||
h_feature_segments.size(), &gain, level_max_bins, &position,
|
||||
gpair.size(), &position_tmp, gpair.size(), &nodes,
|
||||
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
|
||||
&fidx_min_map, hmat_.min_val.size(), &argmax,
|
||||
n_nodes_level(param.max_depth - 1), &node_sums,
|
||||
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
|
||||
level_max_bins, &device_gpair, gpair.size(), &device_matrix.gidx,
|
||||
gmat_.index.size(), &device_matrix.ridx, gmat_.index.size(), &hist.hist,
|
||||
n_nodes(param.max_depth - 1) * n_bins, &feature_flags, n_features);
|
||||
ba.allocate(&gidx_feature_map, n_bins, &hist_node_segments,
|
||||
n_nodes_level(param.max_depth - 1) + 1, &feature_segments,
|
||||
h_feature_segments.size(), &gain, level_max_bins, &position,
|
||||
gpair.size(), &position_tmp, gpair.size(), &nodes,
|
||||
n_nodes(param.max_depth), &gidx_fvalue_map, hmat_.cut.size(),
|
||||
&fidx_min_map, hmat_.min_val.size(), &argmax,
|
||||
n_nodes_level(param.max_depth - 1), &node_sums,
|
||||
n_nodes_level(param.max_depth - 1) * n_features, &hist_scan,
|
||||
level_max_bins, &device_gpair, gpair.size(),
|
||||
&device_matrix.gidx, gmat_.index.size(), &device_matrix.ridx,
|
||||
gmat_.index.size(), &hist.hist,
|
||||
n_nodes(param.max_depth - 1) * n_bins, &feature_flags,
|
||||
n_features, &left_child_smallest, n_nodes(param.max_depth - 1),
|
||||
&prediction_cache, gpair.size());
|
||||
|
||||
if (!param.silent) {
|
||||
const int mb_size = 1048576;
|
||||
@ -529,10 +633,14 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
thrust::sequence(hist_node_segments.tbegin(), hist_node_segments.tend(), 0,
|
||||
n_bins);
|
||||
|
||||
feature_flags.fill(1);
|
||||
|
||||
feature_segments = h_feature_segments;
|
||||
|
||||
hist.Init(n_bins);
|
||||
|
||||
prediction_cache.fill(0);
|
||||
|
||||
initialised = true;
|
||||
}
|
||||
nodes.fill(Node());
|
||||
@ -540,6 +648,37 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
device_gpair = gpair;
|
||||
subsample_gpair(&device_gpair, param.subsample);
|
||||
hist.Reset();
|
||||
p_last_fmat_ = &fmat;
|
||||
}
|
||||
|
||||
bool GPUHistBuilder::UpdatePredictionCache(
|
||||
const DMatrix* data, std::vector<bst_float>* p_out_preds) {
|
||||
std::vector<bst_float>& out_preds = *p_out_preds;
|
||||
|
||||
if (nodes.empty() || !p_last_fmat_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
CHECK_EQ(prediction_cache.size(), out_preds.size());
|
||||
|
||||
if (!prediction_cache_initialised) {
|
||||
prediction_cache = out_preds;
|
||||
prediction_cache_initialised = true;
|
||||
}
|
||||
|
||||
auto d_nodes = nodes.data();
|
||||
auto d_position = position.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
float eps = param.learning_rate;
|
||||
|
||||
dh::launch_n(prediction_cache.size(), [=] __device__(int idx) {
|
||||
int pos = d_position[idx];
|
||||
d_prediction_cache[idx] += d_nodes[pos].weight * eps;
|
||||
});
|
||||
|
||||
thrust::copy(prediction_cache.tbegin(), prediction_cache.tend(),
|
||||
out_preds.data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
@ -547,12 +686,11 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode();
|
||||
this->ColSampleTree();
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||
this->ColSampleLevel();
|
||||
this->BuildHist(depth);
|
||||
this->FindSplit(depth);
|
||||
this->UpdatePosition();
|
||||
this->UpdatePosition(depth);
|
||||
}
|
||||
dense2sparse_tree(p_tree, nodes.tbegin(), nodes.tend(), param);
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
* Copyright 2016 Rory mitchell
|
||||
*/
|
||||
#pragma once
|
||||
#include <cusparse.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <cub/util_type.cuh> // Need key value pair definition
|
||||
@ -63,12 +62,17 @@ class GPUHistBuilder {
|
||||
RegTree *p_tree);
|
||||
void BuildHist(int depth);
|
||||
void FindSplit(int depth);
|
||||
void FindSplit256(int depth);
|
||||
void FindSplit1024(int depth);
|
||||
void FindSplitLarge(int depth);
|
||||
void InitFirstNode();
|
||||
void UpdatePosition();
|
||||
void UpdatePositionDense();
|
||||
void UpdatePositionSparse();
|
||||
void UpdatePosition(int depth);
|
||||
void UpdatePositionDense(int depth);
|
||||
void UpdatePositionSparse(int depth);
|
||||
void ColSampleTree();
|
||||
void ColSampleLevel();
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* p_out_preds);
|
||||
|
||||
TrainParam param;
|
||||
GPUTrainingParam gpu_param;
|
||||
@ -78,6 +82,7 @@ class GPUHistBuilder {
|
||||
bool initialised;
|
||||
bool is_dense;
|
||||
DeviceGMat device_matrix;
|
||||
const DMatrix* p_last_fmat_;
|
||||
|
||||
dh::bulk_allocator ba;
|
||||
dh::CubMemory cub_mem;
|
||||
@ -96,6 +101,9 @@ class GPUHistBuilder {
|
||||
dh::dvec<gpu_gpair> device_gpair;
|
||||
dh::dvec<Node> nodes;
|
||||
dh::dvec<int> feature_flags;
|
||||
dh::dvec<bool> left_child_smallest;
|
||||
dh::dvec<bst_float> prediction_cache;
|
||||
bool prediction_cache_initialised;
|
||||
|
||||
std::vector<int> feature_set_tree;
|
||||
std::vector<int> feature_set_level;
|
||||
|
||||
@ -122,7 +122,7 @@ struct Split {
|
||||
gpu_gpair right_sum;
|
||||
|
||||
__host__ __device__ Split()
|
||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0) {}
|
||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
||||
|
||||
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
||||
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
|
||||
|
||||
@ -75,6 +75,11 @@ class GPUHistMaker : public TreeUpdater {
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) override {
|
||||
return builder.UpdatePredictionCache(data, out_preds);
|
||||
}
|
||||
|
||||
protected:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
|
||||
@ -160,7 +160,6 @@ class TestGPU(unittest.TestCase):
|
||||
|
||||
param = {'objective': 'binary:logistic',
|
||||
'updater': 'grow_gpu_hist',
|
||||
'grow_policy': 'depthwise',
|
||||
'max_depth': 2,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
@ -216,6 +215,17 @@ class TestGPU(unittest.TestCase):
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
# max_bin = 2048
|
||||
param = {'objective': 'binary:logistic',
|
||||
'updater': 'grow_gpu_hist',
|
||||
'max_depth': 3,
|
||||
'eval_metric': 'auc',
|
||||
'max_bin': 2048
|
||||
}
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
|
||||
assert self.non_decreasing(res['train']['auc'])
|
||||
assert res['train']['auc'][0] >= 0.85
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
@ -80,7 +80,7 @@ class FastHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* out_preds) const override {
|
||||
std::vector<bst_float>* out_preds) override {
|
||||
if (!builder_ || param.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user