Add GPU accelerated tree construction plugin (#1679)

This commit is contained in:
RAMitchell 2016-10-21 16:14:47 +13:00 committed by Tianqi Chen
parent 9b2e41340b
commit ac41845d4b
16 changed files with 3040 additions and 275 deletions

View File

@ -2,7 +2,12 @@ cmake_minimum_required (VERSION 2.6)
project (xgboost) project (xgboost)
find_package(OpenMP) find_package(OpenMP)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
endif()
# Make sure we are using C++11 # Make sure we are using C++11
# Visual Studio 12.0 and newer supports enough c++11 to make this work # Visual Studio 12.0 and newer supports enough c++11 to make this work
@ -25,7 +30,6 @@ else()
endif() endif()
endif() endif()
#Make sure we are using the static runtime #Make sure we are using the static runtime
if(MSVC) if(MSVC)
set(variables set(variables
@ -69,22 +73,51 @@ set(RABIT_SOURCES
rabit/src/c_api.cc rabit/src/c_api.cc
) )
add_subdirectory(dmlc-core) add_subdirectory(dmlc-core)
add_library(rabit STATIC ${RABIT_SOURCES}) add_library(rabit STATIC ${RABIT_SOURCES})
#Set library output directories
if(MSVC) if(MSVC)
#With MSVC shared library is considered runtime
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/lib)
else()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR})
#Prevent shared library being called liblibxgboost.so on Linux
set(CMAKE_SHARED_LIBRARY_PREFIX "")
endif()
option(PLUGIN_UPDATER_GPU "Build GPU accelerated tree construction plugin")
if(PLUGIN_UPDATER_GPU)
#Find cub
set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory")
include_directories(${CUB_DIRECTORY})
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-std=c++11; -Xcompiler -fPIC")
endif()
set(SOURCES ${SOURCES}
plugin/updater_gpu/src/updater_gpu.cc
)
find_package(CUDA QUIET REQUIRED)
cuda_add_library(updater_gpu STATIC
plugin/updater_gpu/src/gpu_builder.cu
)
endif()
add_executable(xgboost ${SOURCES}) add_executable(xgboost ${SOURCES})
add_library(libxgboost SHARED ${SOURCES}) add_library(libxgboost SHARED ${SOURCES})
target_link_libraries(xgboost dmlccore rabit) target_link_libraries(xgboost dmlccore rabit)
target_link_libraries(libxgboost dmlccore rabit) target_link_libraries(libxgboost dmlccore rabit)
else()
add_executable(xgboost-bin ${SOURCES})
SET_TARGET_PROPERTIES(xgboost-bin PROPERTIES OUTPUT_NAME xgboost)
add_library(xgboost SHARED ${SOURCES})
target_link_libraries(xgboost-bin dmlccore rabit)
target_link_libraries(xgboost dmlccore rabit) if(PLUGIN_UPDATER_GPU)
target_link_libraries(xgboost updater_gpu)
target_link_libraries(libxgboost updater_gpu)
endif() endif()

View File

@ -0,0 +1,31 @@
# CUDA Accelerated Tree Construction Algorithm
## Usage
Specify the updater parameter as 'grow_gpu'.
Python example:
```python
param['updater'] = 'grow_gpu'
```
## Dependencies
A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler).
The plugin also depends on CUB 1.5.4 - http://nvlabs.github.io/cub/index.html.
CUB is a header only cuda library which provides sort/reduce/scan primitives.
## Build
The plugin can be built using cmake and specifying the option PLUGIN_UPDATER_GPU=ON.
Specify the location of the CUB library with the cmake variable CUB_DIRECTORY.
It is recommended to build with Cuda Toolkit 7.5 or greater.
## Author
Rory Mitchell
Report any bugs to r.a.mitchell.nz at google mail.

View File

@ -0,0 +1,64 @@
#!/usr/bin/pytho#!/usr/bin/python
#pylint: skip-file
# this is the example script to use xgboost to train
import numpy as np
import xgboost as xgb
import time
test_size = 550000
# path to where the data lies
dpath = '../../demo/data'
# load in training data, directly use numpy
dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } )
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
dtrain = np.concatenate((dtrain, np.copy(dtrain)))
print(len(dtrain))
print ('finish loading from csv ')
label = dtrain[:,32]
data = dtrain[:,1:31]
# rescale weight to make it same as test set
weight = dtrain[:,31] * float(test_size) / len(label)
sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 )
sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 )
# print weight statistics
print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos ))
# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value
xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
# setup parameters for xgboost
param = {}
# use logistic regression loss
param['objective'] = 'binary:logitraw'
# scale weight of positive examples
param['scale_pos_weight'] = sum_wneg/sum_wpos
param['bst:eta'] = 0.1
param['max_depth'] = 16
param['eval_metric'] = 'auc'
param['silent'] = 1
param['nthread'] = 4
plst = param.items()+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
num_round = 10
print ("training xgboost")
threads = [16]
for i in threads:
param['nthread'] = i
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp)))
print ("training xgboost - gpu tree construction")
param['updater'] = 'grow_gpu'
tmp = time.time()
plst = param.items()+[('eval_metric', 'ams@0.15')]
bst = xgb.train( plst, xgmat, num_round, watchlist );
print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp)))
print ('finish training')

View File

@ -0,0 +1,276 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <thrust/device_vector.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <ctime>
#include <algorithm>
#include <sstream>
#include <string>
#ifdef _WIN32
#include <windows.h>
#endif
#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__)
cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) {
if (code != cudaSuccess) {
std::cout << file;
std::cout << line;
std::stringstream ss;
ss << file << "(" << line << ")";
std::string file_and_line;
ss >> file_and_line;
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
}
return code;
}
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort)
exit(code);
}
}
// Keep track of cub library device allocation
struct CubMemory {
void *d_temp_storage;
size_t temp_storage_bytes;
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
~CubMemory() {
if (d_temp_storage != NULL) {
safe_cuda(cudaFree(d_temp_storage));
}
}
void Allocate() {
safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes));
}
bool IsAllocated() { return d_temp_storage != NULL; }
};
// Utility function: rounds up integer division.
template <typename T> T div_round_up(const T a, const T b) {
return static_cast<T>(ceil(static_cast<double>(a) / b));
}
template <typename T> thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr);
}
// #define DEVICE_TIMER
#define MAX_WARPS 32 // Maximum number of warps to time
#define MAX_SLOTS 10
#define TIMER_BLOCKID 0 // Block to time
struct DeviceTimerGlobal {
#ifdef DEVICE_TIMER
clock_t total_clocks[MAX_SLOTS][MAX_WARPS];
int64_t count[MAX_SLOTS][MAX_WARPS];
#endif
// Clear device memory. Call at start of kernel.
__device__ void Init() {
#ifdef DEVICE_TIMER
if (blockIdx.x == TIMER_BLOCKID && threadIdx.x < MAX_WARPS) {
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
total_clocks[SLOT][threadIdx.x] = 0;
count[SLOT][threadIdx.x] = 0;
}
}
#endif
}
void HostPrint() {
#ifdef DEVICE_TIMER
DeviceTimerGlobal h_timer;
safe_cuda(
cudaMemcpyFromSymbol(&h_timer, (*this), sizeof(DeviceTimerGlobal)));
for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) {
if (h_timer.count[SLOT][0] == 0) {
continue;
}
clock_t sum_clocks = 0;
int64_t sum_count = 0;
for (int WARP = 0; WARP < MAX_WARPS; WARP++) {
if (h_timer.count[SLOT][WARP] == 0) {
continue;
}
sum_clocks += h_timer.total_clocks[SLOT][WARP];
sum_count += h_timer.count[SLOT][WARP];
}
printf("Slot %d: %d clocks per call, called %d times.\n", SLOT,
sum_clocks / sum_count, h_timer.count[SLOT][0]);
}
#endif
}
};
struct DeviceTimer {
#ifdef DEVICE_TIMER
clock_t start;
int slot;
DeviceTimerGlobal &GTimer;
#endif
#ifdef DEVICE_TIMER
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) // NOLINT
:
GTimer(GTimer),
start(clock()), slot(slot) {}
#else
__device__ DeviceTimer(DeviceTimerGlobal &GTimer, int slot) {} // NOLINT
#endif
__device__ void End() {
#ifdef DEVICE_TIMER
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
if (blockIdx.x == TIMER_BLOCKID && lane_id == 0) {
GTimer.count[slot][warp_id] += 1;
GTimer.total_clocks[slot][warp_id] += clock() - start;
}
#endif
}
};
// #define TIMERS
struct Timer {
volatile double start;
Timer() { reset(); }
double seconds_now() {
#ifdef _WIN32
static LARGE_INTEGER s_frequency;
QueryPerformanceFrequency(&s_frequency);
LARGE_INTEGER now;
QueryPerformanceCounter(&now);
return static_cast<double>(now.QuadPart) / s_frequency.QuadPart;
#endif
}
void reset() {
#ifdef _WIN32
_ReadWriteBarrier();
start = seconds_now();
#endif
}
double elapsed() {
#ifdef _WIN32
_ReadWriteBarrier();
return seconds_now() - start;
#endif
}
void printElapsed(char *label) {
#ifdef TIMERS
safe_cuda(cudaDeviceSynchronize());
printf("%s:\t %1.4fs\n", label, elapsed());
#endif
}
};
template <typename T>
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
thrust::host_vector<T> h = v;
for (int i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
}
template <typename T>
void print(char *label, const thrust::device_vector<T> &v,
const char *format = "%d ", int max = 10) {
thrust::host_vector<T> h_v = v;
std::cout << label << ":\n";
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
printf(format, h_v[i]);
}
std::cout << "\n";
}
class range {
public:
class iterator {
friend class range;
public:
__host__ __device__ int64_t operator*() const { return i_; }
__host__ __device__ const iterator &operator++() {
i_ += step_;
return *this;
}
__host__ __device__ iterator operator++(int) {
iterator copy(*this);
i_ += step_;
return copy;
}
__host__ __device__ bool operator==(const iterator &other) const {
return i_ >= other.i_;
}
__host__ __device__ bool operator!=(const iterator &other) const {
return i_ < other.i_;
}
__host__ __device__ void step(int s) { step_ = s; }
protected:
__host__ __device__ explicit iterator(int64_t start) : i_(start) {}
public:
uint64_t i_;
int step_ = 1;
};
__host__ __device__ iterator begin() const { return begin_; }
__host__ __device__ iterator end() const { return end_; }
__host__ __device__ range(int64_t begin, int64_t end)
: begin_(begin), end_(end) {}
__host__ __device__ void step(int s) { begin_.step(s); }
private:
iterator begin_;
iterator end_;
};
template <typename T> __device__ range grid_stride_range(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x;
range r(begin, end);
r.step(gridDim.x * blockDim.x);
return r;
}
template <typename T> __device__ range block_stride_range(T begin, T end) {
begin += threadIdx.x;
range r(begin, end);
r.step(blockDim.x);
return r;
}
// Converts device_vector to raw pointer
template <typename T> T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}

View File

@ -0,0 +1,87 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include "cuda_helpers.cuh"
#include "find_split_multiscan.cuh"
#include "find_split_sorting.cuh"
#include "types_functions.cuh"
namespace xgboost {
namespace tree {
__global__ void
reduce_split_candidates_kernel(Split *d_split_candidates, Node *d_current_nodes,
Node *d_new_nodes, int n_current_nodes,
int n_features, const GPUTrainingParam param) {
int nid = blockIdx.x * blockDim.x + threadIdx.x;
if (nid >= n_current_nodes) {
return;
}
// Find best split for each node
Split best;
for (int i = 0; i < n_features; i++) {
best.Update(d_split_candidates[n_current_nodes * i + nid]);
}
// Update current node
d_current_nodes[nid].split = best;
// Generate new nodes
d_new_nodes[nid * 2] =
Node(best.left_sum,
CalcGain(param, best.left_sum.grad(), best.left_sum.hess()),
CalcWeight(param, best.left_sum.grad(), best.left_sum.hess()));
d_new_nodes[nid * 2 + 1] =
Node(best.right_sum,
CalcGain(param, best.right_sum.grad(), best.right_sum.hess()),
CalcWeight(param, best.right_sum.grad(), best.right_sum.hess()));
}
void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes,
int level, int n_features,
const GPUTrainingParam param) {
// Current level nodes (before split)
Node *d_current_nodes = d_nodes + (1 << (level)) - 1;
// Next level nodes (after split)
Node *d_new_nodes = d_nodes + (1 << (level + 1)) - 1;
// Number of existing nodes on this level
int n_current_nodes = 1 << level;
const int BLOCK_THREADS = 512;
const int GRID_SIZE = div_round_up(n_current_nodes, BLOCK_THREADS);
reduce_split_candidates_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(
d_split_candidates, d_current_nodes, d_new_nodes, n_current_nodes,
n_features, param);
safe_cuda(cudaDeviceSynchronize());
}
void find_split(const Item *d_items, Split *d_split_candidates,
const NodeIdT *d_node_id, Node *d_nodes, bst_uint num_items,
int num_features, const int *d_feature_offsets,
gpu_gpair *d_node_sums, int *d_node_offsets,
const GPUTrainingParam param, const int level,
bool multiscan_algorithm) {
if (multiscan_algorithm) {
find_split_candidates_multiscan(d_items, d_split_candidates, d_node_id,
d_nodes, num_items, num_features,
d_feature_offsets, param, level);
} else {
find_split_candidates_sorted(d_items, d_split_candidates, d_node_id,
d_nodes, num_items, num_features,
d_feature_offsets, d_node_sums, d_node_offsets,
param, level);
}
// Find the best split for each node
reduce_split_candidates(d_split_candidates, d_nodes, level, num_features,
param);
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,835 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include "cuda_helpers.cuh"
#include "types_functions.cuh"
namespace xgboost {
namespace tree {
typedef uint64_t BitFlagSet;
__device__ __inline__ void set_bit(BitFlagSet &bf, int index) { // NOLINT
bf |= 1 << index;
}
__device__ __inline__ bool check_bit(BitFlagSet bf, int index) {
return (bf >> index) & 1;
}
// Carryover prefix for scanning multiple tiles of bit flags
struct FlagPrefixCallbackOp {
BitFlagSet tile_carry;
__device__ FlagPrefixCallbackOp() : tile_carry(0) {}
__device__ BitFlagSet operator()(BitFlagSet block_aggregate) {
BitFlagSet old_prefix = tile_carry;
tile_carry |= block_aggregate;
return old_prefix;
}
};
// Scan op for bit flags that resets if the final bit is set
struct FlagScanOp {
__device__ __forceinline__ BitFlagSet operator()(const BitFlagSet &a,
const BitFlagSet &b) {
if (check_bit(b, 63)) {
return b;
} else {
return a | b;
}
}
};
template <int _BLOCK_THREADS, int _N_NODES, bool _DEBUG_VALIDATE>
struct FindSplitParamsMultiscan {
enum {
BLOCK_THREADS = _BLOCK_THREADS,
TILE_ITEMS = BLOCK_THREADS,
N_NODES = _N_NODES,
N_WARPS = _BLOCK_THREADS / 32,
DEBUG_VALIDATE = _DEBUG_VALIDATE,
ITEMS_PER_THREAD = 1
};
};
template <int _BLOCK_THREADS, int _N_NODES, bool _DEBUG_VALIDATE>
struct ReduceParamsMultiscan {
enum {
BLOCK_THREADS = _BLOCK_THREADS,
ITEMS_PER_THREAD = 1,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
N_NODES = _N_NODES,
N_WARPS = _BLOCK_THREADS / 32,
DEBUG_VALIDATE = _DEBUG_VALIDATE
};
};
template <typename ParamsT> struct ReduceEnactorMultiscan {
typedef cub::WarpReduce<gpu_gpair> WarpReduceT;
struct _TempStorage {
typename WarpReduceT::TempStorage warp_reduce[ParamsT::N_WARPS];
gpu_gpair partial_sums[ParamsT::N_NODES][ParamsT::N_WARPS];
};
struct TempStorage : cub::Uninitialized<_TempStorage> {};
struct _Reduction {
gpu_gpair node_sums[ParamsT::N_NODES];
};
struct Reduction : cub::Uninitialized<_Reduction> {};
// Thread local member variables
const Item *d_items;
const NodeIdT *d_node_id;
_TempStorage &temp_storage;
_Reduction &reduction;
gpu_gpair gpair;
NodeIdT node_id;
NodeIdT node_id_adjusted;
const int node_begin;
__device__ __forceinline__ ReduceEnactorMultiscan(
TempStorage &temp_storage, // NOLINT
Reduction &reduction, // NOLINT
const Item *d_items, const NodeIdT *d_node_id, const int node_begin)
: temp_storage(temp_storage.Alias()), reduction(reduction.Alias()),
d_items(d_items), d_node_id(d_node_id), node_begin(node_begin) {}
__device__ __forceinline__ void ResetPartials() {
if (threadIdx.x < ParamsT::N_WARPS) {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
temp_storage.partial_sums[NODE][threadIdx.x] = gpu_gpair();
}
}
}
__device__ __forceinline__ void ResetReduction() {
if (threadIdx.x < ParamsT::N_NODES) {
reduction.node_sums[threadIdx.x] = gpu_gpair();
}
}
__device__ __forceinline__ void LoadTile(const bst_uint &offset,
const bst_uint &num_remaining) {
if (threadIdx.x < num_remaining) {
gpair = d_items[offset + threadIdx.x].gpair;
node_id = d_node_id[offset + threadIdx.x];
node_id_adjusted = node_id - node_begin;
} else {
gpair = gpu_gpair();
node_id = -1;
node_id_adjusted = -1;
}
}
__device__ __forceinline__ void ProcessTile(const bst_uint &offset,
const bst_uint &num_remaining) {
LoadTile(offset, num_remaining);
// Warp synchronous reduction
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
bool active = node_id_adjusted == NODE;
unsigned int ballot = __ballot(active);
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
if (ballot == 0) {
continue;
} else if (__popc(ballot) == 1) {
if (active) {
temp_storage.partial_sums[NODE][warp_id] += gpair;
}
} else {
gpu_gpair sum = WarpReduceT(temp_storage.warp_reduce[warp_id])
.Sum(active ? gpair : gpu_gpair());
if (lane_id == 0) {
temp_storage.partial_sums[NODE][warp_id] += sum;
}
}
}
}
__device__ __forceinline__ void ReducePartials() {
// Use single warp to reduce partials
if (threadIdx.x < 32) {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
gpu_gpair sum =
WarpReduceT(temp_storage.warp_reduce[0])
.Sum(threadIdx.x < ParamsT::N_WARPS
? temp_storage.partial_sums[NODE][threadIdx.x]
: gpu_gpair());
if (threadIdx.x == 0) {
reduction.node_sums[NODE] = sum;
}
}
}
}
__device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin,
const bst_uint &segment_end) {
// Current position
bst_uint offset = segment_begin;
ResetReduction();
ResetPartials();
__syncthreads();
// Process full tiles
while (offset < segment_end) {
ProcessTile(offset, segment_end - offset);
offset += ParamsT::TILE_ITEMS;
}
__syncthreads();
ReducePartials();
__syncthreads();
}
};
template <typename ParamsT, typename ReductionT>
struct FindSplitEnactorMultiscan {
typedef cub::BlockScan<BitFlagSet, ParamsT::BLOCK_THREADS> FlagsBlockScanT;
typedef cub::WarpReduce<Split> WarpSplitReduceT;
typedef cub::WarpReduce<float> WarpReduceT;
typedef cub::WarpScan<gpu_gpair> WarpScanT;
struct _TempStorage {
union {
typename WarpSplitReduceT::TempStorage warp_split_reduce;
typename FlagsBlockScanT::TempStorage flags_scan;
typename WarpScanT::TempStorage warp_gpair_scan[ParamsT::N_WARPS];
typename WarpReduceT::TempStorage warp_reduce[ParamsT::N_WARPS];
};
Split warp_best_splits[ParamsT::N_NODES][ParamsT::N_WARPS];
gpu_gpair partial_sums[ParamsT::N_NODES][ParamsT::N_WARPS];
gpu_gpair top_level_sum[ParamsT::N_NODES]; // Sum of current partial sums
gpu_gpair tile_carry[ParamsT::N_NODES]; // Contains top-level sums from
// previous tiles
Split best_splits[ParamsT::N_NODES];
// Cache current level nodes into shared memory
float node_root_gain[ParamsT::N_NODES];
gpu_gpair node_parent_sum[ParamsT::N_NODES];
};
struct TempStorage : cub::Uninitialized<_TempStorage> {};
// Thread local member variables
const Item *d_items;
Split *d_split_candidates_out;
const NodeIdT *d_node_id;
const Node *d_nodes;
_TempStorage &temp_storage;
Item item;
NodeIdT node_id;
NodeIdT node_id_adjusted;
const NodeIdT node_begin;
const GPUTrainingParam &param;
const ReductionT &reduction;
const int level;
FlagPrefixCallbackOp flag_prefix_op;
__device__ __forceinline__ FindSplitEnactorMultiscan(
TempStorage &temp_storage, const Item *d_items, // NOLINT
Split *d_split_candidates_out, const NodeIdT *d_node_id,
const Node *d_nodes, const NodeIdT node_begin,
const GPUTrainingParam &param, const ReductionT reduction,
const int level)
: temp_storage(temp_storage.Alias()), d_items(d_items),
d_split_candidates_out(d_split_candidates_out), d_node_id(d_node_id),
d_nodes(d_nodes), node_begin(node_begin), param(param),
reduction(reduction), level(level), flag_prefix_op() {}
__device__ __forceinline__ void UpdateTileCarry() {
if (threadIdx.x < ParamsT::N_NODES) {
temp_storage.tile_carry[threadIdx.x] +=
temp_storage.top_level_sum[threadIdx.x];
}
}
__device__ __forceinline__ void ResetTileCarry() {
if (threadIdx.x < ParamsT::N_NODES) {
temp_storage.tile_carry[threadIdx.x] = gpu_gpair();
}
}
__device__ __forceinline__ void ResetPartials() {
if (threadIdx.x < ParamsT::N_WARPS) {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
temp_storage.partial_sums[NODE][threadIdx.x] = gpu_gpair();
}
}
if (threadIdx.x < ParamsT::N_NODES) {
temp_storage.top_level_sum[threadIdx.x] = gpu_gpair();
}
}
__device__ __forceinline__ void ResetSplits() {
if (threadIdx.x < ParamsT::N_WARPS) {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
temp_storage.warp_best_splits[NODE][threadIdx.x] = Split();
}
}
if (threadIdx.x < ParamsT::N_NODES) {
temp_storage.best_splits[threadIdx.x] = Split();
}
}
// Cache d_nodes array for this level into shared memory
__device__ __forceinline__ void CacheNodes() {
// Get pointer to nodes on the current level
const Node *d_nodes_level = d_nodes + node_begin;
if (threadIdx.x < ParamsT::N_NODES) {
temp_storage.node_root_gain[threadIdx.x] =
d_nodes_level[threadIdx.x].root_gain;
temp_storage.node_parent_sum[threadIdx.x] =
d_nodes_level[threadIdx.x].sum_gradients;
}
}
__device__ __forceinline__ void LoadTile(bst_uint offset,
bst_uint num_remaining) {
bst_uint index = offset + threadIdx.x;
if (threadIdx.x < num_remaining) {
item = d_items[index];
node_id = d_node_id[index];
node_id_adjusted = node_id - node_begin;
} else {
node_id = -1;
node_id_adjusted = -1;
item.fvalue = -FLT_MAX;
item.gpair = gpu_gpair();
}
}
// Is this node being processed by current kernel iteration?
__device__ __forceinline__ bool NodeActive() {
return node_id_adjusted < ParamsT::N_NODES && node_id_adjusted >= 0;
}
// Is current fvalue different from left fvalue
__device__ __forceinline__ bool
LeftMostFvalue(const bst_uint &segment_begin, const bst_uint &offset,
const bst_uint &num_remaining) {
int left_index = offset + threadIdx.x - 1;
float left_fvalue = left_index >= static_cast<int>(segment_begin) &&
threadIdx.x < num_remaining
? d_items[left_index].fvalue
: -FLT_MAX;
return left_fvalue != item.fvalue;
}
// Prevent splitting in the middle of same valued instances
__device__ __forceinline__ bool
CheckSplitValid(const bst_uint &segment_begin, const bst_uint &offset,
const bst_uint &num_remaining) {
BitFlagSet bit_flag = 0;
bool valid_split = false;
if (LeftMostFvalue(segment_begin, offset, num_remaining)) {
valid_split = true;
// Use MSB bit to flag if fvalue is leftmost
set_bit(bit_flag, 63);
}
// Flag nodeid
if (NodeActive()) {
set_bit(bit_flag, node_id_adjusted);
}
FlagsBlockScanT(temp_storage.flags_scan)
.ExclusiveScan(bit_flag, bit_flag, FlagScanOp(), flag_prefix_op);
__syncthreads();
if (!valid_split && NodeActive()) {
if (!check_bit(bit_flag, node_id_adjusted)) {
valid_split = true;
}
}
return valid_split;
}
// Perform warp reduction to find if this lane contains best loss_chg in warp
__device__ __forceinline__ bool QueryLaneBestLoss(const float &loss_chg) {
int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
// Possible source of bugs. Not all threads in warp are active here. Not
// sure if reduce function will behave correctly
float best = WarpReduceT(temp_storage.warp_reduce[warp_id])
.Reduce(loss_chg, cub::Max());
// Its possible for more than one lane to contain the best value, so make
// sure only one lane returns true
unsigned int ballot = __ballot(loss_chg == best);
if (lane_id == (__ffs(ballot) - 1)) {
return true;
} else {
return false;
}
}
// Which thread in this warp should update the current best split, if any
// Returns true for one thread or none
__device__ __forceinline__ bool
QueryUpdateWarpSplit(const float &loss_chg,
volatile const float &warp_best_loss) {
bool update = false;
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
bool active = node_id_adjusted == NODE;
unsigned int ballot = __ballot(loss_chg > warp_best_loss && active);
// No lane has improved loss_chg
if (__popc(ballot) == 0) {
continue;
} else if (__popc(ballot) == 1) {
// A single lane has improved loss_chg, set true for this lane
int lane_id = threadIdx.x % 32;
if (lane_id == __ffs(ballot) - 1) {
update = true;
}
} else {
// More than one lane has improved loss_chg, perform a reduction.
if (QueryLaneBestLoss(active ? loss_chg : -FLT_MAX)) {
update = true;
}
}
}
return update;
}
__device__ void PrintTileScan(int block_id, bool thread_active,
float loss_chg, gpu_gpair missing) {
if (blockIdx.x != block_id) {
return;
}
for (int warp = 0; warp < ParamsT::N_WARPS; warp++) {
if (threadIdx.x / 32 == warp) {
for (int lane = 0; lane < 32; lane++) {
gpu_gpair g = cub::ShuffleIndex(item.gpair, lane);
gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane);
float fvalue_broadcast = __shfl(item.fvalue, lane);
bool thread_active_broadcast = __shfl(thread_active, lane);
float loss_chg_broadcast = __shfl(loss_chg, lane);
NodeIdT node_id_broadcast = cub::ShuffleIndex(node_id, lane);
if (threadIdx.x == 32 * warp) {
printf("tid %d, nid %d, fvalue %1.2f, active %c, loss %1.2f, scan ",
threadIdx.x + lane, node_id_broadcast, fvalue_broadcast,
thread_active_broadcast ? 'y' : 'n',
loss_chg_broadcast < 0.0f ? 0 : loss_chg_broadcast);
g.print();
}
}
}
__syncthreads();
}
}
__device__ __forceinline__ void
EvaluateSplits(const bst_uint &segment_begin, const bst_uint &offset,
const bst_uint &num_remaining) {
bool valid_split = CheckSplitValid(segment_begin, offset, num_remaining);
bool thread_active =
NodeActive() && valid_split && threadIdx.x < num_remaining;
const int warp_id = threadIdx.x / 32;
gpu_gpair parent_sum = thread_active
? temp_storage.node_parent_sum[node_id_adjusted]
: gpu_gpair();
float parent_gain =
thread_active ? temp_storage.node_root_gain[node_id_adjusted] : 0.0f;
gpu_gpair missing = thread_active
? parent_sum - reduction.node_sums[node_id_adjusted]
: gpu_gpair();
bool missing_left;
float loss_chg = thread_active
? loss_chg_missing(item.gpair, missing, parent_sum,
parent_gain, param, missing_left)
: -FLT_MAX;
// PrintTileScan(64, thread_active, loss_chg, missing);
float warp_best_loss =
thread_active
? temp_storage.warp_best_splits[node_id_adjusted][warp_id].loss_chg
: 0.0f;
if (QueryUpdateWarpSplit(loss_chg, warp_best_loss)) {
float fvalue_split = item.fvalue - FVALUE_EPS;
if (missing_left) {
gpu_gpair left_sum = missing + item.gpair;
gpu_gpair right_sum = parent_sum - left_sum;
temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update(
loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum,
right_sum, param);
} else {
gpu_gpair left_sum = item.gpair;
gpu_gpair right_sum = parent_sum - left_sum;
temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update(
loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum,
right_sum, param);
}
}
}
/*
__device__ __forceinline__ void WarpExclusiveScan(bool active, gpu_gpair
input, gpu_gpair &output, gpu_gpair &sum)
{
output = input;
for (int offset = 1; offset < 32; offset <<= 1){
float tmp1 = __shfl_up(output.grad(), offset);
float tmp2 = __shfl_up(output.hess(), offset);
if (cub::LaneId() >= offset)
{
output.grad += tmp1;
output.hess += tmp2;
}
}
sum.grad = __shfl(output.grad, 31);
sum.hess = __shfl(output.hess, 31);
output -= input;
}
*/
__device__ __forceinline__ void BlockExclusiveScan() {
ResetPartials();
__syncthreads();
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
bool node_active = node_id_adjusted == NODE;
unsigned int ballot = __ballot(node_active);
gpu_gpair warp_sum = gpu_gpair();
gpu_gpair scan_result = gpu_gpair();
if (ballot > 0) {
WarpScanT(temp_storage.warp_gpair_scan[warp_id])
.InclusiveScan(node_active ? item.gpair : gpu_gpair(), scan_result,
cub::Sum(), warp_sum);
// WarpExclusiveScan( node_active, node_active ? item.gpair :
// gpu_gpair(), scan_result, warp_sum);
}
if (node_active) {
item.gpair = scan_result - item.gpair;
}
if (lane_id == 0) {
temp_storage.partial_sums[NODE][warp_id] = warp_sum;
}
}
__syncthreads();
if (threadIdx.x < 32) {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
gpu_gpair top_level_sum;
bool warp_active = threadIdx.x < ParamsT::N_WARPS;
gpu_gpair scan_result;
WarpScanT(temp_storage.warp_gpair_scan[warp_id])
.InclusiveScan(warp_active
? temp_storage.partial_sums[NODE][threadIdx.x]
: gpu_gpair(),
scan_result, cub::Sum(), top_level_sum);
if (warp_active) {
temp_storage.partial_sums[NODE][threadIdx.x] =
scan_result - temp_storage.partial_sums[NODE][threadIdx.x];
}
if (threadIdx.x == 0) {
temp_storage.top_level_sum[NODE] = top_level_sum;
}
}
}
__syncthreads();
if (NodeActive()) {
item.gpair += temp_storage.partial_sums[node_id_adjusted][warp_id] +
temp_storage.tile_carry[node_id_adjusted];
}
__syncthreads();
UpdateTileCarry();
__syncthreads();
}
__device__ __forceinline__ void ProcessTile(const bst_uint &segment_begin,
const bst_uint &offset,
const bst_uint &num_remaining) {
LoadTile(offset, num_remaining);
BlockExclusiveScan();
EvaluateSplits(segment_begin, offset, num_remaining);
}
__device__ __forceinline__ void ReduceSplits() {
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
if (threadIdx.x < 32) {
Split s = Split();
if (threadIdx.x < ParamsT::N_WARPS) {
s = temp_storage.warp_best_splits[NODE][threadIdx.x];
}
Split best = WarpSplitReduceT(temp_storage.warp_split_reduce)
.Reduce(s, split_reduce_op());
if (threadIdx.x == 0) {
temp_storage.best_splits[NODE] = best;
}
}
}
}
__device__ __forceinline__ void WriteBestSplits() {
const int nodes_level = 1 << level;
if (threadIdx.x < ParamsT::N_NODES) {
d_split_candidates_out[blockIdx.x * nodes_level + threadIdx.x] =
temp_storage.best_splits[threadIdx.x];
}
}
/*
__device__ void SequentialAlgorithm(bst_uint segment_begin,
bst_uint segment_end) {
if (threadIdx.x != 0) {
return;
}
__shared__ Split best_split[ParamsT::N_NODES];
__shared__ gpu_gpair scan[ParamsT::N_NODES];
__shared__ Node nodes[ParamsT::N_NODES];
__shared__ gpu_gpair missing[ParamsT::N_NODES];
float previous_fvalue[ParamsT::N_NODES];
// Initialise counts
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
best_split[NODE] = Split();
scan[NODE] = gpu_gpair();
nodes[NODE] = d_nodes[node_begin + NODE];
missing[NODE] = nodes[NODE].sum_gradients - reduction.node_sums[NODE];
previous_fvalue[NODE] = FLT_MAX;
}
for (bst_uint i = segment_begin; i < segment_end; i++) {
int8_t nodeid_adjusted = d_node_id[i] - node_begin;
float fvalue = d_items[i].fvalue;
if (NodeActive(nodeid_adjusted)) {
if (fvalue != previous_fvalue[nodeid_adjusted]) {
float f_split;
if (previous_fvalue[nodeid_adjusted] != FLT_MAX) {
f_split = (previous_fvalue[nodeid_adjusted] + fvalue) * 0.5;
} else {
f_split = fvalue;
}
best_split[nodeid_adjusted].UpdateCalcLoss(
f_split, scan[nodeid_adjusted], missing[nodeid_adjusted],
nodes[nodeid_adjusted], param);
}
scan[nodeid_adjusted] += d_items[i].gpair;
previous_fvalue[nodeid_adjusted] = fvalue;
}
}
for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) {
temp_storage.best_splits[NODE] = best_split[NODE];
}
}
*/
__device__ __forceinline__ void ResetSplitCandidates() {
const int max_nodes = 1 << level;
const int begin = blockIdx.x * max_nodes;
const int end = begin + max_nodes;
for (auto i : block_stride_range(begin, end)) {
d_split_candidates_out[i] = Split();
}
}
__device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin,
const bst_uint &segment_end) {
// Current position
bst_uint offset = segment_begin;
ResetSplitCandidates();
ResetTileCarry();
ResetSplits();
CacheNodes();
__syncthreads();
// Process full tiles
while (offset < segment_end) {
ProcessTile(segment_begin, offset, segment_end - offset);
__syncthreads();
offset += ParamsT::TILE_ITEMS;
}
__syncthreads();
ReduceSplits();
__syncthreads();
WriteBestSplits();
}
};
template <typename FindSplitParamsT, typename ReduceParamsT>
__global__ void
#if __CUDA_ARCH__ <= 530
__launch_bounds__(1024, 2)
#endif
find_split_candidates_multiscan_kernel(
const Item *d_items, Split *d_split_candidates_out,
const NodeIdT *d_node_id, const Node *d_nodes, const int node_begin,
bst_uint num_items, int num_features, const int *d_feature_offsets,
const GPUTrainingParam param, const int level) {
if (num_items <= 0) {
return;
}
int segment_begin = d_feature_offsets[blockIdx.x];
int segment_end = d_feature_offsets[blockIdx.x + 1];
typedef ReduceEnactorMultiscan<ReduceParamsT> ReduceT;
typedef FindSplitEnactorMultiscan<FindSplitParamsT,
typename ReduceT::_Reduction>
FindSplitT;
__shared__ union {
typename ReduceT::TempStorage reduce;
typename FindSplitT::TempStorage find_split;
} temp_storage;
__shared__ typename ReduceT::Reduction reduction;
ReduceT(temp_storage.reduce, reduction, d_items, d_node_id, node_begin)
.ProcessRegion(segment_begin, segment_end);
__syncthreads();
FindSplitT find_split(temp_storage.find_split, d_items,
d_split_candidates_out, d_node_id, d_nodes, node_begin,
param, reduction.Alias(), level);
find_split.ProcessRegion(segment_begin, segment_end);
}
template <int N_NODES>
void find_split_candidates_multiscan_variation(
const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id,
const Node *d_nodes, int node_begin, int node_end, bst_uint num_items,
int num_features, const int *d_feature_offsets,
const GPUTrainingParam param, const int level) {
const int BLOCK_THREADS = 512;
CHECK((node_end - node_begin) <= N_NODES) << "Multiscan: N_NODES template "
"parameter too small for given "
"node range.";
CHECK(BLOCK_THREADS / 32 < 32)
<< "Too many active warps. See FindSplitEnactor - ReduceSplits.";
typedef FindSplitParamsMultiscan<BLOCK_THREADS, N_NODES, false>
find_split_params;
typedef ReduceParamsMultiscan<BLOCK_THREADS, N_NODES, false> reduce_params;
int grid_size = num_features;
find_split_candidates_multiscan_kernel<
find_split_params,
reduce_params><<<grid_size, find_split_params::BLOCK_THREADS>>>(
d_items, d_split_candidates, d_node_id, d_nodes, node_begin, num_items,
num_features, d_feature_offsets, param, level);
safe_cuda(cudaDeviceSynchronize());
}
void find_split_candidates_multiscan(
const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id,
const Node *d_nodes, bst_uint num_items, int num_features,
const int *d_feature_offsets, const GPUTrainingParam param,
const int level) {
// Select templated variation of split finding algorithm
switch (level) {
case 0:
find_split_candidates_multiscan_variation<1>(
d_items, d_split_candidates, d_node_id, d_nodes, 0, 1, num_items,
num_features, d_feature_offsets, param, level);
break;
case 1:
find_split_candidates_multiscan_variation<2>(
d_items, d_split_candidates, d_node_id, d_nodes, 1, 3, num_items,
num_features, d_feature_offsets, param, level);
break;
case 2:
find_split_candidates_multiscan_variation<4>(
d_items, d_split_candidates, d_node_id, d_nodes, 3, 7, num_items,
num_features, d_feature_offsets, param, level);
break;
case 3:
find_split_candidates_multiscan_variation<8>(
d_items, d_split_candidates, d_node_id, d_nodes, 7, 15, num_items,
num_features, d_feature_offsets, param, level);
break;
case 4:
find_split_candidates_multiscan_variation<16>(
d_items, d_split_candidates, d_node_id, d_nodes, 15, 31, num_items,
num_features, d_feature_offsets, param, level);
break;
case 5:
find_split_candidates_multiscan_variation<32>(
d_items, d_split_candidates, d_node_id, d_nodes, 31, 63, num_items,
num_features, d_feature_offsets, param, level);
break;
}
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,474 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <cub/cub.cuh>
#include <xgboost/base.h>
#include "cuda_helpers.cuh"
#include "types_functions.cuh"
namespace xgboost {
namespace tree {
struct ScanTuple {
gpu_gpair gpair;
NodeIdT node_id;
__device__ ScanTuple() {}
__device__ ScanTuple(gpu_gpair gpair, NodeIdT node_id)
: gpair(gpair), node_id(node_id) {}
__device__ ScanTuple operator+=(const ScanTuple &rhs) {
if (node_id != rhs.node_id) {
*this = rhs;
return *this;
} else {
gpair += rhs.gpair;
return *this;
}
}
__device__ ScanTuple operator+(const ScanTuple &rhs) const {
ScanTuple t = *this;
return t += rhs;
}
};
struct GpairTupleCallbackOp {
// Running prefix
ScanTuple running_total;
// Constructor
__device__ GpairTupleCallbackOp()
: running_total(ScanTuple(gpu_gpair(), -1)) {}
__device__ ScanTuple operator()(ScanTuple block_aggregate) {
ScanTuple old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
struct GpairCallbackOp {
// Running prefix
gpu_gpair running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
gpu_gpair old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <int _BLOCK_THREADS, bool _DEBUG_VALIDATE>
struct FindSplitParamsSorting {
enum {
BLOCK_THREADS = _BLOCK_THREADS,
TILE_ITEMS = BLOCK_THREADS,
N_WARPS = _BLOCK_THREADS / 32,
DEBUG_VALIDATE = _DEBUG_VALIDATE,
ITEMS_PER_THREAD = 1
};
};
template <int _BLOCK_THREADS, bool _DEBUG_VALIDATE> struct ReduceParamsSorting {
enum {
BLOCK_THREADS = _BLOCK_THREADS,
ITEMS_PER_THREAD = 1,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
N_WARPS = _BLOCK_THREADS / 32,
DEBUG_VALIDATE = _DEBUG_VALIDATE
};
};
template <typename ParamsT> struct ReduceEnactorSorting {
typedef cub::BlockScan<ScanTuple, ParamsT::BLOCK_THREADS> GpairScanT;
struct _TempStorage {
typename GpairScanT::TempStorage gpair_scan;
};
struct TempStorage : cub::Uninitialized<_TempStorage> {};
// Thread local member variables
gpu_gpair *d_block_node_sums;
int *d_block_node_offsets;
const NodeIdT *d_node_id;
const Item *d_items;
_TempStorage &temp_storage;
Item item;
NodeIdT node_id;
NodeIdT right_node_id;
// Contains node_id relative to the current level only
NodeIdT node_id_adjusted;
GpairTupleCallbackOp callback_op;
const int level;
__device__ __forceinline__ ReduceEnactorSorting(
TempStorage &temp_storage, // NOLINT
gpu_gpair *d_block_node_sums, int *d_block_node_offsets,
const Item *d_items, const NodeIdT *d_node_id, const int level)
: temp_storage(temp_storage.Alias()),
d_block_node_sums(d_block_node_sums),
d_block_node_offsets(d_block_node_offsets), d_items(d_items),
d_node_id(d_node_id), callback_op(), level(level) {}
__device__ __forceinline__ void ResetSumsOffsets() {
const int max_nodes = 1 << level;
for (auto i : block_stride_range(0, max_nodes)) {
d_block_node_sums[i] = gpu_gpair();
d_block_node_offsets[i] = -1;
}
}
__device__ __forceinline__ void LoadTile(const bst_uint &offset,
const bst_uint &num_remaining) {
if (threadIdx.x < num_remaining) {
item = d_items[offset + threadIdx.x];
node_id = d_node_id[offset + threadIdx.x];
right_node_id = threadIdx.x == num_remaining - 1
? -1
: d_node_id[offset + threadIdx.x + 1];
// Prevent overflow
const int level_begin = (1 << level) - 1;
node_id_adjusted =
max(static_cast<int>(node_id) - level_begin, -1); // NOLINT
}
}
__device__ __forceinline__ void ProcessTile(const bst_uint &offset,
const bst_uint &num_remaining) {
LoadTile(offset, num_remaining);
ScanTuple t(item.gpair, node_id);
GpairScanT(temp_storage.gpair_scan).InclusiveSum(t, t, callback_op);
__syncthreads();
// If tail of segment
if (node_id != right_node_id && node_id_adjusted >= 0 &&
threadIdx.x < num_remaining) {
// Write sum
d_block_node_sums[node_id_adjusted] = t.gpair;
// Write offset
d_block_node_offsets[node_id_adjusted] = offset + threadIdx.x + 1;
}
}
__device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin,
const bst_uint &segment_end) {
// Current position
bst_uint offset = segment_begin;
ResetSumsOffsets();
__syncthreads();
// Process full tiles
while (offset < segment_end) {
ProcessTile(offset, segment_end - offset);
offset += ParamsT::TILE_ITEMS;
}
}
};
template <typename ParamsT> struct FindSplitEnactorSorting {
typedef cub::BlockScan<gpu_gpair, ParamsT::BLOCK_THREADS> GpairScanT;
typedef cub::BlockReduce<Split, ParamsT::BLOCK_THREADS> SplitReduceT;
typedef cub::WarpReduce<float> WarpLossReduceT;
struct _TempStorage {
union {
typename GpairScanT::TempStorage gpair_scan;
typename SplitReduceT::TempStorage split_reduce;
typename WarpLossReduceT::TempStorage loss_reduce[ParamsT::N_WARPS];
};
Split warp_best_splits[ParamsT::N_WARPS];
};
struct TempStorage : cub::Uninitialized<_TempStorage> {};
// Thread local member variables
_TempStorage &temp_storage;
gpu_gpair *d_block_node_sums;
int *d_block_node_offsets;
const Item *d_items;
const NodeIdT *d_node_id;
const Node *d_nodes;
Item item;
NodeIdT node_id;
float left_fvalue;
const GPUTrainingParam &param;
Split *d_split_candidates_out;
const int level;
__device__ __forceinline__ FindSplitEnactorSorting(
TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT
int *d_block_node_offsets, const Item *d_items, const NodeIdT *d_node_id,
const Node *d_nodes, const GPUTrainingParam &param,
Split *d_split_candidates_out, const int level)
: temp_storage(temp_storage.Alias()),
d_block_node_sums(d_block_node_sums),
d_block_node_offsets(d_block_node_offsets), d_items(d_items),
d_node_id(d_node_id), d_nodes(d_nodes),
d_split_candidates_out(d_split_candidates_out), level(level),
param(param) {}
__device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted,
const bst_uint &node_begin,
const bst_uint &offset,
const bst_uint &num_remaining) {
if (threadIdx.x < num_remaining) {
node_id = d_node_id[offset + threadIdx.x];
item = d_items[offset + threadIdx.x];
bool first_item = offset + threadIdx.x == node_begin;
left_fvalue = first_item ? item.fvalue - FVALUE_EPS
: d_items[offset + threadIdx.x - 1].fvalue;
}
}
__device__ void PrintTileScan(int block_id, bool thread_active,
float loss_chg, gpu_gpair missing) {
if (blockIdx.x != block_id) {
return;
}
for (int warp = 0; warp < ParamsT::N_WARPS; warp++) {
if (threadIdx.x / 32 == warp) {
for (int lane = 0; lane < 32; lane++) {
gpu_gpair g = cub::ShuffleIndex(item.gpair, lane);
gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane);
float fvalue_broadcast = __shfl(item.fvalue, lane);
bool thread_active_broadcast = __shfl(thread_active, lane);
float loss_chg_broadcast = __shfl(loss_chg, lane);
if (threadIdx.x == 32 * warp) {
printf("tid %d, fvalue %1.2f, active %c, loss %1.2f, scan ",
threadIdx.x + lane, fvalue_broadcast,
thread_active_broadcast ? 'y' : 'n',
loss_chg_broadcast < 0.0f ? 0 : loss_chg_broadcast);
g.print();
}
}
}
__syncthreads();
}
}
__device__ __forceinline__ bool QueryUpdateWarpSplit(float loss_chg,
float warp_best_loss,
bool thread_active) {
int warp_id = threadIdx.x / 32;
int ballot = __ballot(loss_chg > warp_best_loss && thread_active);
if (ballot == 0) {
return false;
} else {
// Warp reduce best loss
float best = WarpLossReduceT(temp_storage.loss_reduce[warp_id])
.Reduce(loss_chg, cub::Max());
// Broadcast
best = cub::ShuffleIndex(best, 0);
if (loss_chg == best) {
return true;
}
}
return false;
}
__device__ __forceinline__ bool LeftmostFvalue() {
return item.fvalue != left_fvalue;
}
__device__ __forceinline__ void
EvaluateSplits(const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
const bst_uint &offset, const bst_uint &num_remaining) {
bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining &&
node_id_adjusted >= 0 && node_id >= 0;
Node n = thread_active ? d_nodes[node_id] : Node();
gpu_gpair missing =
thread_active ? n.sum_gradients - d_block_node_sums[node_id_adjusted]
: gpu_gpair();
bool missing_left;
float loss_chg =
thread_active ? loss_chg_missing(item.gpair, missing, n.sum_gradients,
n.root_gain, param, missing_left)
: -FLT_MAX;
int warp_id = threadIdx.x / 32;
volatile float warp_best_loss =
temp_storage.warp_best_splits[warp_id].loss_chg;
if (QueryUpdateWarpSplit(loss_chg, warp_best_loss, thread_active)) {
float fvalue_split = (item.fvalue + left_fvalue) / 2.0f;
gpu_gpair left_sum = item.gpair;
if (missing_left) {
left_sum += missing;
}
gpu_gpair right_sum = n.sum_gradients - left_sum;
temp_storage.warp_best_splits[warp_id].Update(loss_chg, missing_left,
fvalue_split, blockIdx.x,
left_sum, right_sum, param);
}
}
__device__ __forceinline__ void
ProcessTile(const NodeIdT &node_id_adjusted, const bst_uint &node_begin,
const bst_uint &offset, const bst_uint &num_remaining,
GpairCallbackOp &callback_op) { // NOLINT
LoadTile(node_id_adjusted, node_begin, offset, num_remaining);
// Scan gpair
const bool thread_active = threadIdx.x < num_remaining && node_id >= 0;
GpairScanT(temp_storage.gpair_scan)
.ExclusiveSum(thread_active ? item.gpair : gpu_gpair(), item.gpair,
callback_op);
__syncthreads();
// Evaluate split
EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining);
}
__device__ __forceinline__ void ResetWarpSplits() {
if (threadIdx.x < ParamsT::N_WARPS) {
temp_storage.warp_best_splits[threadIdx.x] = Split();
}
}
__device__ __forceinline__ void
WriteBestSplit(const NodeIdT &node_id_adjusted) {
if (threadIdx.x < 32) {
bool active = threadIdx.x < ParamsT::N_WARPS;
float warp_loss =
active ? temp_storage.warp_best_splits[threadIdx.x].loss_chg
: -FLT_MAX;
if (QueryUpdateWarpSplit(warp_loss, 0, active)) {
const int max_nodes = 1 << level;
d_split_candidates_out[blockIdx.x * max_nodes + node_id_adjusted] =
temp_storage.warp_best_splits[threadIdx.x];
}
}
}
__device__ __forceinline__ void ProcessNode(const NodeIdT &node_id_adjusted,
const bst_uint &node_begin,
const bst_uint &node_end) {
ResetWarpSplits();
GpairCallbackOp callback_op = GpairCallbackOp();
bst_uint offset = node_begin;
while (offset < node_end) {
ProcessTile(node_id_adjusted, node_begin, offset, node_end - offset,
callback_op);
offset += ParamsT::TILE_ITEMS;
__syncthreads();
}
WriteBestSplit(node_id_adjusted);
}
__device__ __forceinline__ void ResetSplitCandidates() {
const int max_nodes = 1 << level;
const int begin = blockIdx.x * max_nodes;
const int end = begin + max_nodes;
for (auto i : block_stride_range(begin, end)) {
d_split_candidates_out[i] = Split();
}
}
__device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin,
const bst_uint &segment_end) {
ResetSplitCandidates();
int node_begin = segment_begin;
const int max_nodes = 1 << level;
// Iterate through nodes
int active_nodes = 0;
for (int i = 0; i < max_nodes; i++) {
int node_end = d_block_node_offsets[i];
if (node_end == -1) {
continue;
}
active_nodes++;
ProcessNode(i, node_begin, node_end);
__syncthreads();
node_begin = node_end;
}
}
};
template <typename ReduceParamsT, typename FindSplitParamsT>
__global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel(
const Item *d_items, Split *d_split_candidates_out,
const NodeIdT *d_node_id, const Node *d_nodes, bst_uint num_items,
const int num_features, const int *d_feature_offsets,
gpu_gpair *d_node_sums, int *d_node_offsets, const GPUTrainingParam param,
const int level) {
if (num_items <= 0) {
return;
}
bst_uint segment_begin = d_feature_offsets[blockIdx.x];
bst_uint segment_end = d_feature_offsets[blockIdx.x + 1];
typedef ReduceEnactorSorting<ReduceParamsT> ReduceT;
typedef FindSplitEnactorSorting<FindSplitParamsT> FindSplitT;
__shared__ union {
typename ReduceT::TempStorage reduce;
typename FindSplitT::TempStorage find_split;
} temp_storage;
const int max_modes_level = 1 << level;
gpu_gpair *d_block_node_sums = d_node_sums + blockIdx.x * max_modes_level;
int *d_block_node_offsets = d_node_offsets + blockIdx.x * max_modes_level;
ReduceT(temp_storage.reduce, d_block_node_sums, d_block_node_offsets, d_items,
d_node_id, level)
.ProcessRegion(segment_begin, segment_end);
__syncthreads();
FindSplitT(temp_storage.find_split, d_block_node_sums, d_block_node_offsets,
d_items, d_node_id, d_nodes, param, d_split_candidates_out, level)
.ProcessFeature(segment_begin, segment_end);
}
void find_split_candidates_sorted(
const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id,
Node *d_nodes, bst_uint num_items, int num_features,
const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets,
const GPUTrainingParam param, const int level) {
const int BLOCK_THREADS = 512;
CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps.";
typedef FindSplitParamsSorting<BLOCK_THREADS, false> find_split_params;
typedef ReduceParamsSorting<BLOCK_THREADS, false> reduce_params;
int grid_size = num_features;
find_split_candidates_sorted_kernel<
reduce_params, find_split_params><<<grid_size, BLOCK_THREADS>>>(
d_items, d_split_candidates, d_node_id, d_nodes, num_items, num_features,
d_feature_offsets, d_node_sums, d_node_offsets, param, level);
safe_cuda(cudaGetLastError());
safe_cuda(cudaDeviceSynchronize());
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,457 @@
/*!
* Copyright 2016 Rory mitchell
*/
#include "gpu_builder.cuh"
#include <stdio.h>
#include <thrust/count.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/sequence.h>
#include <cub/cub.cuh>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <vector>
#include "cuda_helpers.cuh"
#include "find_split.cuh"
#include "types_functions.cuh"
namespace xgboost {
namespace tree {
struct GPUData {
GPUData() : allocated(false), n_features(0), n_instances(0) {}
bool allocated;
int n_features;
int n_instances;
GPUTrainingParam param;
CubMemory cub_mem;
thrust::device_vector<float> fvalues;
thrust::device_vector<int> foffsets;
thrust::device_vector<bst_uint> instance_id;
thrust::device_vector<int> feature_id;
thrust::device_vector<NodeIdT> node_id;
thrust::device_vector<NodeIdT> node_id_temp;
thrust::device_vector<NodeIdT> node_id_instance;
thrust::device_vector<NodeIdT> node_id_instance_temp;
thrust::device_vector<gpu_gpair> gpair;
thrust::device_vector<Node> nodes;
thrust::device_vector<Split> split_candidates;
thrust::device_vector<Item> items;
thrust::device_vector<Item> items_temp;
thrust::device_vector<gpu_gpair> node_sums;
thrust::device_vector<int> node_offsets;
thrust::device_vector<int> sort_index_in;
thrust::device_vector<int> sort_index_out;
void Init(const std::vector<float> &in_fvalues,
const std::vector<int> &in_foffsets,
const std::vector<bst_uint> &in_instance_id,
const std::vector<int> &in_feature_id,
const std::vector<bst_gpair> &in_gpair, bst_uint n_instances_in,
bst_uint n_features_in, int max_depth, const TrainParam &param_in) {
Timer t;
n_features = n_features_in;
n_instances = n_instances_in;
fvalues = in_fvalues;
foffsets = in_foffsets;
instance_id = in_instance_id;
feature_id = in_feature_id;
param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda,
param_in.reg_alpha, param_in.max_delta_step);
gpair = thrust::device_vector<gpu_gpair>(in_gpair.begin(), in_gpair.end());
uint32_t max_nodes_level = 1 << max_depth;
node_sums = thrust::device_vector<gpu_gpair>(max_nodes_level * n_features);
node_offsets = thrust::device_vector<int>(max_nodes_level * n_features);
node_id_instance = thrust::device_vector<NodeIdT>(n_instances, 0);
node_id = thrust::device_vector<NodeIdT>(fvalues.size(), 0);
node_id_temp = thrust::device_vector<NodeIdT>(fvalues.size());
uint32_t max_nodes = (1 << (max_depth + 1)) - 1;
nodes = thrust::device_vector<Node>(max_nodes);
split_candidates =
thrust::device_vector<Split>(max_nodes_level * n_features);
allocated = true;
// Init items
items = thrust::device_vector<Item>(fvalues.size());
items_temp = thrust::device_vector<Item>(fvalues.size());
sort_index_in = thrust::device_vector<int>(fvalues.size());
sort_index_out = thrust::device_vector<int>(fvalues.size());
this->CreateItems();
}
~GPUData() {}
// Create items array using gpair, instaoce_id, fvalue
void CreateItems() {
auto d_items = items.data();
auto d_instance_id = instance_id.data();
auto d_gpair = gpair.data();
auto d_fvalue = fvalues.data();
auto counting = thrust::make_counting_iterator<bst_uint>(0);
thrust::for_each(counting, counting + fvalues.size(),
[=] __device__(bst_uint i) {
Item item;
item.instance_id = d_instance_id[i];
item.fvalue = d_fvalue[i];
item.gpair = d_gpair[item.instance_id];
d_items[i] = item;
});
}
// Reset memory for new boosting iteration
void Reset(const std::vector<bst_gpair> &in_gpair,
const std::vector<float> &in_fvalues,
const std::vector<bst_uint> &in_instance_id) {
CHECK(allocated);
thrust::copy(in_gpair.begin(), in_gpair.end(), gpair.begin());
thrust::fill(nodes.begin(), nodes.end(), Node());
thrust::fill(node_id_instance.begin(), node_id_instance.end(), 0);
thrust::fill(node_id.begin(), node_id.end(), 0);
this->CreateItems();
}
bool IsAllocated() { return allocated; }
// Gather from node_id_instance into node_id according to instance_id
void GatherNodeId() {
// Update node_id for each item
auto d_items = items.data();
auto d_node_id = node_id.data();
auto d_node_id_instance = node_id_instance.data();
auto counting = thrust::make_counting_iterator<bst_uint>(0);
thrust::for_each(counting, counting + fvalues.size(),
[=] __device__(bst_uint i) {
Item item = d_items[i];
d_node_id[i] = d_node_id_instance[item.instance_id];
});
}
};
GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); }
void GPUBuilder::Init(const TrainParam &param_in) { param = param_in; }
GPUBuilder::~GPUBuilder() { delete gpu_data; }
template <int ITEMS_PER_THREAD, typename OffsetT>
__global__ void update_nodeid_missing_kernel(NodeIdT *d_node_id_instance,
Node *d_nodes, const OffsetT n) {
for (auto i : grid_stride_range(OffsetT(0), n)) {
NodeIdT item_node_id = d_node_id_instance[i];
if (item_node_id < 0) {
continue;
}
Node node = d_nodes[item_node_id];
if (node.IsLeaf()) {
d_node_id_instance[i] = -1;
} else if (node.split.missing_left) {
d_node_id_instance[i] = item_node_id * 2 + 1;
} else {
d_node_id_instance[i] = item_node_id * 2 + 2;
}
}
}
__device__ void load_as_words(const int n_nodes, Node *d_nodes, Node *s_nodes) {
const int upper_range = n_nodes * (sizeof(Node) / sizeof(int));
for (auto i : block_stride_range(0, upper_range)) {
reinterpret_cast<int *>(s_nodes)[i] = reinterpret_cast<int *>(d_nodes)[i];
}
}
template <int ITEMS_PER_THREAD>
__global__ void
update_nodeid_fvalue_kernel(NodeIdT *d_node_id, NodeIdT *d_node_id_instance,
Item *d_items, Node *d_nodes, const int n_nodes,
const int *d_feature_id, const size_t n,
const int n_features, bool cache_nodes) {
// Load nodes into shared memory
extern __shared__ Node s_nodes[];
if (cache_nodes) {
load_as_words(n_nodes, d_nodes, s_nodes);
__syncthreads();
}
for (auto i : grid_stride_range(size_t(0), n)) {
Item item = d_items[i];
NodeIdT item_node_id = d_node_id[i];
if (item_node_id < 0) {
continue;
}
Node node = cache_nodes ? s_nodes[item_node_id] : d_nodes[item_node_id];
if (node.IsLeaf()) {
continue;
}
int feature_id = d_feature_id[i];
if (feature_id == node.split.findex) {
if (item.fvalue < node.split.fvalue) {
d_node_id_instance[item.instance_id] = item_node_id * 2 + 1;
} else {
d_node_id_instance[item.instance_id] = item_node_id * 2 + 2;
}
}
}
}
void GPUBuilder::UpdateNodeId(int level) {
// Update all nodes based on missing direction
{
const bst_uint n = gpu_data->node_id_instance.size();
const bst_uint ITEMS_PER_THREAD = 8;
const bst_uint BLOCK_THREADS = 256;
const bst_uint GRID_SIZE =
div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
update_nodeid_missing_kernel<
ITEMS_PER_THREAD><<<GRID_SIZE, BLOCK_THREADS>>>(
raw(gpu_data->node_id_instance), raw(gpu_data->nodes), n);
safe_cuda(cudaDeviceSynchronize());
}
// Update node based on fvalue where exists
{
const bst_uint n = gpu_data->fvalues.size();
const bst_uint ITEMS_PER_THREAD = 4;
const bst_uint BLOCK_THREADS = 256;
const bst_uint GRID_SIZE =
div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
// Use smem cache version if possible
const bool cache_nodes = level < 7;
int n_nodes = (1 << (level + 1)) - 1;
int smem_size = cache_nodes ? sizeof(Node) * n_nodes : 0;
update_nodeid_fvalue_kernel<
ITEMS_PER_THREAD><<<GRID_SIZE, BLOCK_THREADS, smem_size>>>(
raw(gpu_data->node_id), raw(gpu_data->node_id_instance),
raw(gpu_data->items), raw(gpu_data->nodes), n_nodes,
raw(gpu_data->feature_id), gpu_data->fvalues.size(),
gpu_data->n_features, cache_nodes);
safe_cuda(cudaGetLastError());
safe_cuda(cudaDeviceSynchronize());
}
gpu_data->GatherNodeId();
}
void GPUBuilder::Sort(int level) {
thrust::sequence(gpu_data->sort_index_in.begin(),
gpu_data->sort_index_in.end());
if (!gpu_data->cub_mem.IsAllocated()) {
cub::DeviceSegmentedRadixSort::SortPairs(
gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes,
raw(gpu_data->node_id), raw(gpu_data->node_id_temp),
raw(gpu_data->sort_index_in), raw(gpu_data->sort_index_out),
gpu_data->fvalues.size(), gpu_data->n_features, raw(gpu_data->foffsets),
raw(gpu_data->foffsets) + 1);
gpu_data->cub_mem.Allocate();
}
cub::DeviceSegmentedRadixSort::SortPairs(
gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes,
raw(gpu_data->node_id), raw(gpu_data->node_id_temp),
raw(gpu_data->sort_index_in), raw(gpu_data->sort_index_out),
gpu_data->fvalues.size(), gpu_data->n_features, raw(gpu_data->foffsets),
raw(gpu_data->foffsets) + 1);
thrust::gather(gpu_data->sort_index_out.begin(),
gpu_data->sort_index_out.end(), gpu_data->items.begin(),
gpu_data->items_temp.begin());
thrust::copy(gpu_data->items_temp.begin(), gpu_data->items_temp.end(),
gpu_data->items.begin());
thrust::copy(gpu_data->node_id_temp.begin(), gpu_data->node_id_temp.end(),
gpu_data->node_id.begin());
}
void GPUBuilder::Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree) {
try {
Timer update;
Timer t;
this->InitData(gpair, *p_fmat, *p_tree);
t.printElapsed("init data");
this->InitFirstNode();
for (int level = 0; level < param.max_depth; level++) {
bool use_multiscan_algorithm = level < multiscan_levels;
t.reset();
if (level > 0) {
Timer update_node;
this->UpdateNodeId(level);
update_node.printElapsed("node");
}
if (level > 0 && !use_multiscan_algorithm) {
Timer s;
this->Sort(level);
s.printElapsed("sort");
}
Timer split;
find_split(raw(gpu_data->items), raw(gpu_data->split_candidates),
raw(gpu_data->node_id), raw(gpu_data->nodes),
(bst_uint)gpu_data->fvalues.size(), gpu_data->n_features,
raw(gpu_data->foffsets), raw(gpu_data->node_sums),
raw(gpu_data->node_offsets), gpu_data->param, level,
use_multiscan_algorithm);
split.printElapsed("split");
t.printElapsed("level");
}
this->CopyTree(*p_tree);
update.printElapsed("update");
} catch (thrust::system_error &e) {
std::cerr << "CUDA error: " << e.what() << std::endl;
exit(-1);
}
}
void GPUBuilder::InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat,
const RegTree &tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMaker: can only grow new tree";
CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block";
if (gpu_data->IsAllocated()) {
gpu_data->Reset(gpair, fvalues, instance_id);
return;
}
Timer t;
MetaInfo info = fmat.info();
dmlc::DataIter<ColBatch> *iter = fmat.ColIterator();
std::vector<int> foffsets;
foffsets.push_back(0);
std::vector<int> feature_id;
fvalues.reserve(info.num_col * info.num_row);
instance_id.reserve(info.num_col * info.num_row);
feature_id.reserve(info.num_col * info.num_row);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst &col = batch[i];
for (const ColBatch::Entry *it = col.data; it != col.data + col.length;
it++) {
fvalues.push_back(it->fvalue);
instance_id.push_back(it->index);
feature_id.push_back(i);
}
foffsets.push_back(fvalues.size());
}
}
t.printElapsed("dmatrix");
t.reset();
gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair,
info.num_row, info.num_col, param.max_depth, param);
t.printElapsed("gpu init");
}
void GPUBuilder::InitFirstNode() {
// Build the root node on the CPU and copy to device
gpu_gpair sum_gradients =
thrust::reduce(gpu_data->gpair.begin(), gpu_data->gpair.end(),
gpu_gpair(0, 0), cub::Sum());
gpu_data->nodes[0] = Node(
sum_gradients,
CalcGain(gpu_data->param, sum_gradients.grad(), sum_gradients.hess()),
CalcWeight(gpu_data->param, sum_gradients.grad(), sum_gradients.hess()));
}
enum NodeType {
NODE = 0,
LEAF = 1,
UNUSED = 2,
};
// Recursively label node types
void flag_nodes(const thrust::host_vector<Node> &nodes,
std::vector<NodeType> *node_flags, int nid, NodeType type) {
if (nid >= nodes.size() || type == UNUSED) {
return;
}
const Node &n = nodes[nid];
// Current node and all children are valid
if (n.split.loss_chg > rt_eps) {
(*node_flags)[nid] = NODE;
flag_nodes(nodes, node_flags, nid * 2 + 1, NODE);
flag_nodes(nodes, node_flags, nid * 2 + 2, NODE);
} else {
// Current node is leaf, therefore is valid but all children are invalid
(*node_flags)[nid] = LEAF;
flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED);
flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED);
}
}
// Copy gpu dense representation of tree to xgboost sparse representation
void GPUBuilder::CopyTree(RegTree &tree) {
thrust::host_vector<Node> h_nodes = gpu_data->nodes;
std::vector<NodeType> node_flags(h_nodes.size(), UNUSED);
flag_nodes(h_nodes, &node_flags, 0, NODE);
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
NodeType flag = node_flags[gpu_nid];
const Node &n = h_nodes[gpu_nid];
if (flag == NODE) {
tree.AddChilds(nid);
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
tree.stat(nid).loss_chg = n.split.loss_chg;
tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess();
tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0);
nid++;
} else if (flag == LEAF) {
tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess();
nid++;
}
}
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,46 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <xgboost/tree_updater.h>
#include <vector>
#include "../../src/tree/param.h"
namespace xgboost {
namespace tree {
struct gpu_gpair;
struct GPUData;
class GPUBuilder {
public:
GPUBuilder();
void Init(const TrainParam &param);
~GPUBuilder();
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree);
private:
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
const RegTree &tree);
void UpdateNodeId(int level);
void Sort(int level);
void InitFirstNode();
void CopyTree(RegTree &tree); // NOLINT
TrainParam param;
GPUData *gpu_data;
// Keep host copies of these arrays as the device versions change between
// boosting iterations
std::vector<float> fvalues;
std::vector<bst_uint> instance_id;
int multiscan_levels =
5; // Number of levels before switching to sorting algorithm
};
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,52 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "../../../src/tree/param.h"
// When we split on a value which has no left neighbour, define its left
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
// This produces a split value slightly lower than the current instance
#define FVALUE_EPS 0.0001
namespace xgboost {
namespace tree {
__device__ __forceinline__ float
device_calc_loss_chg(const GPUTrainingParam &param, const gpu_gpair &scan,
const gpu_gpair &missing, const gpu_gpair &parent_sum,
const float &parent_gain, bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;
}
gpu_gpair right = parent_sum - left;
float left_gain = CalcGain(param, left.grad(), left.hess());
float right_gain = CalcGain(param, right.grad(), right.hess());
return left_gain + right_gain - parent_gain;
}
__device__ __forceinline__ float
loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing,
const gpu_gpair &parent_sum, const float &parent_gain,
const GPUTrainingParam &param, bool &missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,185 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include <xgboost/base.h>
namespace xgboost {
namespace tree {
typedef int32_t NodeIdT;
// gpair type defined with device accessible functions
struct gpu_gpair {
float _grad;
float _hess;
__host__ __device__ __forceinline__ float grad() const { return _grad; }
__host__ __device__ __forceinline__ float hess() const { return _hess; }
__host__ __device__ gpu_gpair() : _grad(0), _hess(0) {}
__host__ __device__ gpu_gpair(float g, float h) : _grad(g), _hess(h) {}
__host__ __device__ gpu_gpair(bst_gpair gpair)
: _grad(gpair.grad), _hess(gpair.hess) {}
__host__ __device__ bool operator==(const gpu_gpair &rhs) const {
return (_grad == rhs._grad) && (_hess == rhs._hess);
}
__host__ __device__ bool operator!=(const gpu_gpair &rhs) const {
return !(*this == rhs);
}
__host__ __device__ gpu_gpair &operator+=(const gpu_gpair &rhs) {
_grad += rhs._grad;
_hess += rhs._hess;
return *this;
}
__host__ __device__ gpu_gpair operator+(const gpu_gpair &rhs) const {
gpu_gpair g;
g._grad = _grad + rhs._grad;
g._hess = _hess + rhs._hess;
return g;
}
__host__ __device__ gpu_gpair &operator-=(const gpu_gpair &rhs) {
_grad -= rhs._grad;
_hess -= rhs._hess;
return *this;
}
__host__ __device__ gpu_gpair operator-(const gpu_gpair &rhs) const {
gpu_gpair g;
g._grad = _grad - rhs._grad;
g._hess = _hess - rhs._hess;
return g;
}
friend std::ostream &operator<<(std::ostream &os, const gpu_gpair &g) {
os << g.grad() << "/" << g.hess();
return os;
}
__host__ __device__ void print() const {
printf("%1.4f/%1.4f\n", grad(), hess());
}
__host__ __device__ bool approximate_compare(const gpu_gpair &b,
float g_eps = 0.1,
float h_eps = 0.1) const {
float gdiff = abs(this->grad() - b.grad());
float hdiff = abs(this->hess() - b.hess());
return (gdiff <= g_eps) && (hdiff <= h_eps);
}
};
struct Item {
bst_uint instance_id;
float fvalue;
gpu_gpair gpair;
};
struct GPUTrainingParam {
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// L2 regularization factor
float reg_lambda;
// L1 regularization factor
float reg_alpha;
// maximum delta update we can add in weight estimation
// this parameter can be used to stabilize update
// default=0 means no constraint on weight delta
float max_delta_step;
__host__ __device__ GPUTrainingParam() {}
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
float reg_lambda_in, float reg_alpha_in,
float max_delta_step_in)
: min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in),
reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {}
};
struct Split {
float loss_chg;
bool missing_left;
float fvalue;
int findex;
gpu_gpair left_sum;
gpu_gpair right_sum;
__host__ __device__ Split()
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0) {}
__device__ void Update(float loss_chg_in, bool missing_left_in,
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
gpu_gpair right_sum_in,
const GPUTrainingParam &param) {
if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight &&
right_sum_in.hess() > param.min_child_weight) {
loss_chg = loss_chg_in;
missing_left = missing_left_in;
fvalue = fvalue_in;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}
// Does not check minimum weight
__device__ void Update(Split &s) { // NOLINT
if (s.loss_chg > loss_chg) {
loss_chg = s.loss_chg;
missing_left = s.missing_left;
fvalue = s.fvalue;
findex = s.findex;
left_sum = s.left_sum;
right_sum = s.right_sum;
}
}
__device__ void Print() {
printf("Loss: %1.4f\n", loss_chg);
printf("Missing left: %d\n", missing_left);
printf("fvalue: %1.4f\n", fvalue);
printf("Left sum: ");
left_sum.print();
printf("Right sum: ");
right_sum.print();
}
};
struct split_reduce_op {
template <typename T>
__device__ __forceinline__ T operator()(T &a, T b) { // NOLINT
b.Update(a);
return b;
}
};
struct Node {
gpu_gpair sum_gradients;
float root_gain;
float weight;
Split split;
__host__ __device__ Node() : weight(0), root_gain(0) {}
__host__ __device__ Node(gpu_gpair sum_gradients_in, float root_gain_in,
float weight_in) {
sum_gradients = sum_gradients_in;
root_gain = root_gain_in;
weight = weight_in;
}
__host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; }
};
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,6 @@
/*!
* Copyright 2016 Rory mitchell
*/
#pragma once
#include "types.cuh"
#include "loss_functions.cuh"

View File

@ -0,0 +1,48 @@
/*!
* Copyright 2016 Rory Mitchell
*/
#include <xgboost/tree_updater.h>
#include <vector>
#include "../../src/common/random.h"
#include "../../src/common/sync.h"
#include "../../src/tree/param.h"
#include "gpu_builder.cuh"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
/*! \brief column-wise update to construct a tree */
template <typename TStats> class GPUMaker : public TreeUpdater {
public:
void
Init(const std::vector<std::pair<std::string, std::string>> &args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
void Update(const std::vector<bst_gpair> &gpair, DMatrix *dmat,
const std::vector<RegTree *> &trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
GPUBuilder builder;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker<GradStats>(); });
} // namespace tree
} // namespace xgboost

137
plugin/updater_gpu/test.py Normal file
View File

@ -0,0 +1,137 @@
#pylint: skip-file
import numpy as np
import xgboost as xgb
import os
import pandas as pd
import urllib2
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def get_last_eval_callback(result):
def callback(env):
result.append(env.evaluation_result_list[-1][1])
callback.after_iteration = True
return callback
def load_adult():
path = "../../demo/data/adult.data"
if(not os.path.isfile(path)):
data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data')
with open(path,'wb') as output:
output.write(data.read())
train_set = pd.read_csv( path, header=None)
train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation',
'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'wage_class']
train_nomissing = train_set.replace(' ?', np.nan).dropna()
for feature in train_nomissing.columns: # Loop through all columns in the dataframe
if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings
train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer
y_train = train_nomissing.pop('wage_class')
return xgb.DMatrix( train_nomissing, label=y_train)
def load_higgs():
higgs_path = '../../demo/data/training.csv'
dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } )
#dtrain = dtrain[0:200000,:]
label = dtrain[:,32]
data = dtrain[:,1:31]
weight = dtrain[:,31]
return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
def load_dermatology():
data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } )
sz = data.shape
X = data[:,0:33]
Y = data[:, 34]
return xgb.DMatrix( X, label=Y)
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
#Check GPU test evaluation is approximately equal to CPU test evaluation
def check_result(cpu_result, gpu_result):
for i in range(len(cpu_result)):
if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02):
return False
return True
#Get data
data = []
params = []
data.append(load_higgs())
params.append({})
data.append( load_adult())
params.append({})
data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test'))
params.append({'objective':'binary:logistic'})
#if(os.path.isfile("../../demo/data/dermatology.data")):
data.append(load_dermatology())
params.append({'objective':'multi:softmax', 'num_class': 6})
num_round = 5
num_pass = 0
num_fail = 0
test_depth = [ 1, 6, 9, 11, 15 ]
#test_depth = [ 1 ]
for test in range(0, len(data)):
for depth in test_depth:
xgmat = data[test]
cpu_result = []
param = params[test]
param['max_depth'] = depth
param['updater'] = 'grow_colmaker'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('reference_model.txt','', True)
gpu_result = []
param['updater'] = 'grow_gpu'
xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)])
#bst = xgb.train( param, xgmat, 1);
#bst.dump_model('dump.raw.txt','', True)
if check_result(cpu_result, gpu_result):
print(bcolors.OKGREEN + "Pass" + bcolors.ENDC)
num_pass = num_pass + 1
else:
print(bcolors.FAIL + "Fail" + bcolors.ENDC)
num_fail = num_fail + 1
print("cpu rmse: "+str(cpu_result))
print("gpu rmse: "+str(gpu_result))
print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed")

View File

@ -7,10 +7,16 @@
#ifndef XGBOOST_TREE_PARAM_H_ #ifndef XGBOOST_TREE_PARAM_H_
#define XGBOOST_TREE_PARAM_H_ #define XGBOOST_TREE_PARAM_H_
#include <vector> #include <cmath>
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <cmath> #include <vector>
#ifdef __NVCC__
#define XGB_DEVICE __host__ __device__
#else
#define XGB_DEVICE
#endif
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -60,47 +66,80 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
std::vector<int> monotone_constraints; std::vector<int> monotone_constraints;
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f) DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(0.3f)
.describe("Learning rate(step size) of update."); .describe("Learning rate(step size) of update.");
DMLC_DECLARE_FIELD(min_split_loss).set_lower_bound(0.0f).set_default(0.0f) DMLC_DECLARE_FIELD(min_split_loss)
.describe("Minimum loss reduction required to make a further partition."); .set_lower_bound(0.0f)
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6) .set_default(0.0f)
.describe("Maximum depth of the tree."); .describe(
DMLC_DECLARE_FIELD(min_child_weight).set_lower_bound(0.0f).set_default(1.0f) "Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6).describe(
"Maximum depth of the tree.");
DMLC_DECLARE_FIELD(min_child_weight)
.set_lower_bound(0.0f)
.set_default(1.0f)
.describe("Minimum sum of instance weight(hessian) needed in a child."); .describe("Minimum sum of instance weight(hessian) needed in a child.");
DMLC_DECLARE_FIELD(reg_lambda).set_lower_bound(0.0f).set_default(1.0f) DMLC_DECLARE_FIELD(reg_lambda)
.set_lower_bound(0.0f)
.set_default(1.0f)
.describe("L2 regularization on leaf weight"); .describe("L2 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_alpha).set_lower_bound(0.0f).set_default(0.0f) DMLC_DECLARE_FIELD(reg_alpha)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("L1 regularization on leaf weight"); .describe("L1 regularization on leaf weight");
DMLC_DECLARE_FIELD(default_direction).set_default(0) DMLC_DECLARE_FIELD(default_direction)
.set_default(0)
.add_enum("learn", 0) .add_enum("learn", 0)
.add_enum("left", 1) .add_enum("left", 1)
.add_enum("right", 2) .add_enum("right", 2)
.describe("Default direction choice when encountering a missing value"); .describe("Default direction choice when encountering a missing value");
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.0f) DMLC_DECLARE_FIELD(max_delta_step)
.describe("Maximum delta step we allow each tree's weight estimate to be. "\ .set_lower_bound(0.0f)
.set_default(0.0f)
.describe(
"Maximum delta step we allow each tree's weight estimate to be. "
"If the value is set to 0, it means there is no constraint"); "If the value is set to 0, it means there is no constraint");
DMLC_DECLARE_FIELD(subsample).set_range(0.0f, 1.0f).set_default(1.0f) DMLC_DECLARE_FIELD(subsample)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Row subsample ratio of training instance."); .describe("Row subsample ratio of training instance.");
DMLC_DECLARE_FIELD(colsample_bylevel).set_range(0.0f, 1.0f).set_default(1.0f) DMLC_DECLARE_FIELD(colsample_bylevel)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Subsample ratio of columns, resample on each level."); .describe("Subsample ratio of columns, resample on each level.");
DMLC_DECLARE_FIELD(colsample_bytree).set_range(0.0f, 1.0f).set_default(1.0f) DMLC_DECLARE_FIELD(colsample_bytree)
.describe("Subsample ratio of columns, resample on each tree construction."); .set_range(0.0f, 1.0f)
DMLC_DECLARE_FIELD(opt_dense_col).set_range(0.0f, 1.0f).set_default(1.0f) .set_default(1.0f)
.describe(
"Subsample ratio of columns, resample on each tree construction.");
DMLC_DECLARE_FIELD(opt_dense_col)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("EXP Param: speed optimization for dense column."); .describe("EXP Param: speed optimization for dense column.");
DMLC_DECLARE_FIELD(sketch_eps).set_range(0.0f, 1.0f).set_default(0.03f) DMLC_DECLARE_FIELD(sketch_eps)
.set_range(0.0f, 1.0f)
.set_default(0.03f)
.describe("EXP Param: Sketch accuracy of approximate algorithm."); .describe("EXP Param: Sketch accuracy of approximate algorithm.");
DMLC_DECLARE_FIELD(sketch_ratio).set_lower_bound(0.0f).set_default(2.0f) DMLC_DECLARE_FIELD(sketch_ratio)
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm."); .set_lower_bound(0.0f)
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) .set_default(2.0f)
.describe("EXP Param: Sketch accuracy related parameter of approximate "
"algorithm.");
DMLC_DECLARE_FIELD(size_leaf_vector)
.set_lower_bound(0)
.set_default(0)
.describe("Size of leaf vectors, reserved for vector trees"); .describe("Size of leaf vectors, reserved for vector trees");
DMLC_DECLARE_FIELD(parallel_option).set_default(0) DMLC_DECLARE_FIELD(parallel_option)
.set_default(0)
.describe("Different types of parallelization algorithm."); .describe("Different types of parallelization algorithm.");
DMLC_DECLARE_FIELD(cache_opt).set_default(true) DMLC_DECLARE_FIELD(cache_opt).set_default(true).describe(
.describe("EXP Param: Cache aware optimization."); "EXP Param: Cache aware optimization.");
DMLC_DECLARE_FIELD(silent).set_default(false) DMLC_DECLARE_FIELD(silent).set_default(false).describe(
.describe("Do not print information during trainig."); "Do not print information during trainig.");
DMLC_DECLARE_FIELD(monotone_constraints).set_default(std::vector<int>()) DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<int>())
.describe("Constraint of variable monotinicity"); .describe("Constraint of variable monotinicity");
// add alias of parameters // add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda); DMLC_DECLARE_ALIAS(reg_lambda, lambda);
@ -108,61 +147,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_ALIAS(min_split_loss, gamma); DMLC_DECLARE_ALIAS(min_split_loss, gamma);
DMLC_DECLARE_ALIAS(learning_rate, eta); DMLC_DECLARE_ALIAS(learning_rate, eta);
} }
// calculate the cost of loss function
inline double CalcGainGivenWeight(double sum_grad,
double sum_hess,
double w) const {
return -(2.0 * sum_grad * w + (sum_hess + reg_lambda) * Sqr(w));
}
// calculate the cost of loss function
inline double CalcGain(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) return 0.0;
if (max_delta_step == 0.0f) {
if (reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + reg_lambda);
} else {
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
}
} else {
double w = CalcWeight(sum_grad, sum_hess);
double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w);
if (reg_alpha == 0.0f) {
return - 2.0 * ret;
} else {
return - 2.0 * (ret + reg_alpha * std::abs(w));
}
}
}
// calculate cost of loss function with four statistics
inline double CalcGain(double sum_grad, double sum_hess,
double test_grad, double test_hess) const {
double w = CalcWeight(sum_grad, sum_hess);
double ret = test_grad * w + 0.5 * (test_hess + reg_lambda) * Sqr(w);
if (reg_alpha == 0.0f) {
return - 2.0 * ret;
} else {
return - 2.0 * (ret + reg_alpha * std::abs(w));
}
}
// calculate weight given the statistics
inline double CalcWeight(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) return 0.0;
double dw;
if (reg_alpha == 0.0f) {
dw = -sum_grad / (sum_hess + reg_lambda);
} else {
dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
}
if (max_delta_step != 0.0f) {
if (dw > max_delta_step) dw = max_delta_step;
if (dw < -max_delta_step) dw = -max_delta_step;
}
return dw;
}
/*! \brief whether need forward small to big search: default right */ /*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density, bool indicator) const { inline bool need_forward_search(float col_density, bool indicator) const {
return this->default_direction == 2 || return this->default_direction == 2 ||
(default_direction == 0 && (col_density < opt_dense_col) && !indicator); (default_direction == 0 && (col_density < opt_dense_col) &&
!indicator);
} }
/*! \brief whether need backward big to small search: default left */ /*! \brief whether need backward big to small search: default left */
inline bool need_backward_search(float col_density, bool indicator) const { inline bool need_backward_search(float col_density, bool indicator) const {
@ -182,18 +171,84 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
CHECK_GT(ret, 0); CHECK_GT(ret, 0);
return ret; return ret;
} }
};
/*! \brief Loss functions */
protected:
// functions for L1 cost // functions for L1 cost
inline static double ThresholdL1(double w, double lambda) { template <typename T1, typename T2>
if (w > +lambda) return w - lambda; XGB_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) {
if (w < -lambda) return w + lambda; if (w > +lambda)
return w - lambda;
if (w < -lambda)
return w + lambda;
return 0.0; return 0.0;
} }
inline static double Sqr(double a) {
return a * a; template <typename T>
XGB_DEVICE inline static T Sqr(T a) { return a * a; }
// calculate the cost of loss function
template <typename TrainingParams, typename T>
XGB_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad,
T sum_hess, T w) {
return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
}
// calculate the cost of loss function
template <typename TrainingParams, typename T>
XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
if (sum_hess < p.min_child_weight)
return 0.0;
if (p.max_delta_step == 0.0f) {
if (p.reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + p.reg_lambda);
} else {
return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) /
(sum_hess + p.reg_lambda);
}
} else {
T w = CalcWeight(p, sum_grad, sum_hess);
T ret = sum_grad * w + 0.5 * (sum_hess + p.reg_lambda) * Sqr(w);
if (p.reg_alpha == 0.0f) {
return -2.0 * ret;
} else {
return -2.0 * (ret + p.reg_alpha * std::abs(w));
}
}
}
// calculate cost of loss function with four statistics
template <typename TrainingParams, typename T>
XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
T test_grad, T test_hess) {
T w = CalcWeight(sum_grad, sum_hess);
T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w);
if (p.reg_alpha == 0.0f) {
return -2.0 * ret;
} else {
return -2.0 * (ret + p.reg_alpha * std::abs(w));
}
}
// calculate weight given the statistics
template <typename TrainingParams, typename T>
XGB_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
T sum_hess) {
if (sum_hess < p.min_child_weight)
return 0.0;
T dw;
if (p.reg_alpha == 0.0f) {
dw = -sum_grad / (sum_hess + p.reg_lambda);
} else {
dw = -ThresholdL1(sum_grad, p.reg_alpha) / (sum_hess + p.reg_lambda);
}
if (p.max_delta_step != 0.0f) {
if (dw > p.max_delta_step)
dw = p.max_delta_step;
if (dw < -p.max_delta_step)
dw = -p.max_delta_step;
}
return dw;
} }
};
/*! \brief core statistics used for tree construction */ /*! \brief core statistics used for tree construction */
struct GradStats { struct GradStats {
@ -207,47 +262,37 @@ struct GradStats {
*/ */
static const int kSimpleStats = 1; static const int kSimpleStats = 1;
/*! \brief constructor, the object must be cleared during construction */ /*! \brief constructor, the object must be cleared during construction */
explicit GradStats(const TrainParam& param) { explicit GradStats(const TrainParam &param) { this->Clear(); }
this->Clear();
}
/*! \brief clear the statistics */ /*! \brief clear the statistics */
inline void Clear() { inline void Clear() { sum_grad = sum_hess = 0.0f; }
sum_grad = sum_hess = 0.0f;
}
/*! \brief check if necessary information is ready */ /*! \brief check if necessary information is ready */
inline static void CheckInfo(const MetaInfo& info) { inline static void CheckInfo(const MetaInfo &info) {}
}
/*! /*!
* \brief accumulate statistics * \brief accumulate statistics
* \param p the gradient pair * \param p the gradient pair
*/ */
inline void Add(bst_gpair p) { inline void Add(bst_gpair p) { this->Add(p.grad, p.hess); }
this->Add(p.grad, p.hess);
}
/*! /*!
* \brief accumulate statistics, more complicated version * \brief accumulate statistics, more complicated version
* \param gpair the vector storing the gradient statistics * \param gpair the vector storing the gradient statistics
* \param info the additional information * \param info the additional information
* \param ridx instance index of this instance * \param ridx instance index of this instance
*/ */
inline void Add(const std::vector<bst_gpair>& gpair, inline void Add(const std::vector<bst_gpair> &gpair, const MetaInfo &info,
const MetaInfo& info,
bst_uint ridx) { bst_uint ridx) {
const bst_gpair &b = gpair[ridx]; const bst_gpair &b = gpair[ridx];
this->Add(b.grad, b.hess); this->Add(b.grad, b.hess);
} }
/*! \brief calculate leaf weight */ /*! \brief calculate leaf weight */
inline double CalcWeight(const TrainParam &param) const { inline double CalcWeight(const TrainParam &param) const {
return param.CalcWeight(sum_grad, sum_hess); return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
} }
/*! \brief calculate gain of the solution */ /*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const { inline double CalcGain(const TrainParam &param) const {
return param.CalcGain(sum_grad, sum_hess); return xgboost::tree::CalcGain(param, sum_grad, sum_hess);
} }
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(const GradStats& b) { inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); }
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief same as add, reduce is used in All Reduce */ /*! \brief same as add, reduce is used in All Reduce */
inline static void Reduce(GradStats &a, const GradStats &b) { // NOLINT(*) inline static void Reduce(GradStats &a, const GradStats &b) { // NOLINT(*)
a.Add(b); a.Add(b);
@ -258,57 +303,45 @@ struct GradStats {
sum_hess = a.sum_hess - b.sum_hess; sum_hess = a.sum_hess - b.sum_hess;
} }
/*! \return whether the statistics is not used yet */ /*! \return whether the statistics is not used yet */
inline bool Empty() const { inline bool Empty() const { return sum_hess == 0.0; }
return sum_hess == 0.0;
}
/*! \brief set leaf vector value based on statistics */ /*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam& param, bst_float *vec) const { inline void SetLeafVec(const TrainParam &param, bst_float *vec) const {}
}
// constructor to allow inheritance // constructor to allow inheritance
GradStats() {} GradStats() {}
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(double grad, double hess) { inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess; sum_grad += grad;
sum_hess += hess;
} }
}; };
struct NoConstraint { struct NoConstraint {
inline static void Init(TrainParam* param, unsigned num_feature) { inline static void Init(TrainParam *param, unsigned num_feature) {}
} inline double CalcSplitGain(const TrainParam &param, bst_uint split_index,
inline double CalcSplitGain(
const TrainParam& param, bst_uint split_index,
GradStats left, GradStats right) const { GradStats left, GradStats right) const {
return left.CalcGain(param) + right.CalcGain(param); return left.CalcGain(param) + right.CalcGain(param);
} }
inline double CalcWeight( inline double CalcWeight(const TrainParam &param, GradStats stats) const {
const TrainParam& param,
GradStats stats) const {
return stats.CalcWeight(param); return stats.CalcWeight(param);
} }
inline double CalcGain(const TrainParam& param, inline double CalcGain(const TrainParam &param, GradStats stats) const {
GradStats stats) const {
return stats.CalcGain(param); return stats.CalcGain(param);
} }
inline void SetChild( inline void SetChild(const TrainParam &param, bst_uint split_index,
const TrainParam& param, bst_uint split_index, GradStats left, GradStats right, NoConstraint *cleft,
GradStats left, GradStats right, NoConstraint *cright) {}
NoConstraint* cleft, NoConstraint* cright) {
}
}; };
struct ValueConstraint { struct ValueConstraint {
double lower_bound; double lower_bound;
double upper_bound; double upper_bound;
ValueConstraint() : ValueConstraint()
lower_bound(-std::numeric_limits<double>::max()), : lower_bound(-std::numeric_limits<double>::max()),
upper_bound(std::numeric_limits<double>::max()) { upper_bound(std::numeric_limits<double>::max()) {}
}
inline static void Init(TrainParam *param, unsigned num_feature) { inline static void Init(TrainParam *param, unsigned num_feature) {
param->monotone_constraints.resize(num_feature, 1); param->monotone_constraints.resize(num_feature, 1);
} }
inline double CalcWeight( inline double CalcWeight(const TrainParam &param, GradStats stats) const {
const TrainParam& param,
GradStats stats) const {
double w = stats.CalcWeight(param); double w = stats.CalcWeight(param);
if (w < lower_bound) { if (w < lower_bound) {
return lower_bound; return lower_bound;
@ -319,23 +352,19 @@ struct ValueConstraint {
return w; return w;
} }
inline double CalcGain(const TrainParam& param, inline double CalcGain(const TrainParam &param, GradStats stats) const {
GradStats stats) const { return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess,
return param.CalcGainGivenWeight(
stats.sum_grad, stats.sum_hess,
CalcWeight(param, stats)); CalcWeight(param, stats));
} }
inline double CalcSplitGain( inline double CalcSplitGain(const TrainParam &param, bst_uint split_index,
const TrainParam& param,
bst_uint split_index,
GradStats left, GradStats right) const { GradStats left, GradStats right) const {
double wleft = CalcWeight(param, left); double wleft = CalcWeight(param, left);
double wright = CalcWeight(param, right); double wright = CalcWeight(param, right);
int c = param.monotone_constraints[split_index]; int c = param.monotone_constraints[split_index];
double gain = double gain =
param.CalcGainGivenWeight(left.sum_grad, left.sum_hess, wleft) + CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) +
param.CalcGainGivenWeight(right.sum_grad, right.sum_hess, wright); CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright);
if (c == 0) { if (c == 0) {
return gain; return gain;
} else if (c > 0) { } else if (c > 0) {
@ -345,15 +374,14 @@ struct ValueConstraint {
} }
} }
inline void SetChild( inline void SetChild(const TrainParam &param, bst_uint split_index,
const TrainParam& param, GradStats left, GradStats right, ValueConstraint *cleft,
bst_uint split_index, ValueConstraint *cright) {
GradStats left, GradStats right,
ValueConstraint* cleft, ValueConstraint *cright) {
int c = param.monotone_constraints.at(split_index); int c = param.monotone_constraints.at(split_index);
*cleft = *this; *cleft = *this;
*cright = *this; *cright = *this;
if (c == 0) return; if (c == 0)
return;
double wleft = CalcWeight(param, left); double wleft = CalcWeight(param, left);
double wright = CalcWeight(param, right); double wright = CalcWeight(param, right);
double mid = (wleft + wright) / 2; double mid = (wleft + wright) / 2;
@ -382,9 +410,12 @@ struct SplitEntry {
/*! \brief constructor */ /*! \brief constructor */
SplitEntry() : loss_chg(0.0f), sindex(0), split_value(0.0f) {} SplitEntry() : loss_chg(0.0f), sindex(0), split_value(0.0f) {}
/*! /*!
* \brief decides whether we can replace current entry with the given statistics * \brief decides whether we can replace current entry with the given
* This function gives better priority to lower index when loss_chg == new_loss_chg. * statistics
* Not the best way, but helps to give consistent result during multi-thread execution. * This function gives better priority to lower index when loss_chg ==
* new_loss_chg.
* Not the best way, but helps to give consistent result during multi-thread
* execution.
* \param new_loss_chg the loss reduction get through the split * \param new_loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on * \param split_index the feature index where the split is on
*/ */
@ -422,7 +453,8 @@ struct SplitEntry {
float new_split_value, bool default_left) { float new_split_value, bool default_left) {
if (this->NeedReplace(new_loss_chg, split_index)) { if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg; this->loss_chg = new_loss_chg;
if (default_left) split_index |= (1U << 31); if (default_left)
split_index |= (1U << 31);
this->sindex = split_index; this->sindex = split_index;
this->split_value = new_split_value; this->split_value = new_split_value;
return true; return true;
@ -431,17 +463,14 @@ struct SplitEntry {
} }
} }
/*! \brief same as update, used by AllReduce*/ /*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntry& dst, const SplitEntry& src) { // NOLINT(*) inline static void Reduce(SplitEntry &dst, // NOLINT(*)
const SplitEntry &src) { // NOLINT(*)
dst.Update(src); dst.Update(src);
} }
/*!\return feature index to split on */ /*!\return feature index to split on */
inline unsigned split_index() const { inline unsigned split_index() const { return sindex & ((1U << 31) - 1U); }
return sindex & ((1U << 31) - 1U);
}
/*!\return whether missing value goes to left branch */ /*!\return whether missing value goes to left branch */
inline bool default_left() const { inline bool default_left() const { return (sindex >> 31) != 0; }
return (sindex >> 31) != 0;
}
}; };
} // namespace tree } // namespace tree
@ -451,13 +480,14 @@ struct SplitEntry {
namespace std { namespace std {
inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) { inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) {
os << '('; os << '(';
for (std::vector<int>::const_iterator for (std::vector<int>::const_iterator it = t.begin(); it != t.end(); ++it) {
it = t.begin(); it != t.end(); ++it) { if (it != t.begin())
if (it != t.begin()) os << ','; os << ',';
os << *it; os << *it;
} }
// python style tuple // python style tuple
if (t.size() == 1) os << ','; if (t.size() == 1)
os << ',';
os << ')'; os << ')';
return os; return os;
} }
@ -474,7 +504,8 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
return is; return is;
} }
is.get(); is.get();
if (ch == '(') break; if (ch == '(')
break;
if (!isspace(ch)) { if (!isspace(ch)) {
is.setstate(std::ios::failbit); is.setstate(std::ios::failbit);
return is; return is;
@ -495,14 +526,17 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
while (true) { while (true) {
ch = is.peek(); ch = is.peek();
if (isspace(ch)) { if (isspace(ch)) {
is.get(); continue; is.get();
continue;
} }
if (ch == ')') { if (ch == ')') {
is.get(); break; is.get();
break;
} }
break; break;
} }
if (ch == ')') break; if (ch == ')')
break;
} else if (ch == ')') { } else if (ch == ')') {
break; break;
} else { } else {

View File

@ -107,7 +107,7 @@ class SketchMaker: public BaseMaker {
} }
/*! \brief calculate gain of the solution */ /*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const { inline double CalcGain(const TrainParam &param) const {
return param.CalcGain(pos_grad - neg_grad, sum_hess); return xgboost::tree::CalcGain(param, pos_grad - neg_grad, sum_hess);
} }
/*! \brief set current value to a - b */ /*! \brief set current value to a - b */
inline void SetSubstract(const SKStats &a, const SKStats &b) { inline void SetSubstract(const SKStats &a, const SKStats &b) {
@ -117,7 +117,7 @@ class SketchMaker: public BaseMaker {
} }
// calculate leaf weight // calculate leaf weight
inline double CalcWeight(const TrainParam &param) const { inline double CalcWeight(const TrainParam &param) const {
return param.CalcWeight(pos_grad - neg_grad, sum_hess); return xgboost::tree::CalcWeight(param, pos_grad - neg_grad, sum_hess);
} }
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(const SKStats &b) { inline void Add(const SKStats &b) {