From ac41845d4bff1d57533334029463171131debfc0 Mon Sep 17 00:00:00 2001 From: RAMitchell Date: Fri, 21 Oct 2016 16:14:47 +1300 Subject: [PATCH] Add GPU accelerated tree construction plugin (#1679) --- CMakeLists.txt | 213 +++-- plugin/updater_gpu/README.md | 31 + plugin/updater_gpu/speed_test.py | 64 ++ plugin/updater_gpu/src/cuda_helpers.cuh | 276 ++++++ plugin/updater_gpu/src/find_split.cuh | 87 ++ .../updater_gpu/src/find_split_multiscan.cuh | 835 ++++++++++++++++++ plugin/updater_gpu/src/find_split_sorting.cuh | 474 ++++++++++ plugin/updater_gpu/src/gpu_builder.cu | 457 ++++++++++ plugin/updater_gpu/src/gpu_builder.cuh | 46 + plugin/updater_gpu/src/loss_functions.cuh | 52 ++ plugin/updater_gpu/src/types.cuh | 185 ++++ plugin/updater_gpu/src/types_functions.cuh | 6 + plugin/updater_gpu/src/updater_gpu.cc | 48 + plugin/updater_gpu/test.py | 137 +++ src/tree/param.h | 400 +++++---- src/tree/updater_skmaker.cc | 4 +- 16 files changed, 3040 insertions(+), 275 deletions(-) create mode 100644 plugin/updater_gpu/README.md create mode 100644 plugin/updater_gpu/speed_test.py create mode 100644 plugin/updater_gpu/src/cuda_helpers.cuh create mode 100644 plugin/updater_gpu/src/find_split.cuh create mode 100644 plugin/updater_gpu/src/find_split_multiscan.cuh create mode 100644 plugin/updater_gpu/src/find_split_sorting.cuh create mode 100644 plugin/updater_gpu/src/gpu_builder.cu create mode 100644 plugin/updater_gpu/src/gpu_builder.cuh create mode 100644 plugin/updater_gpu/src/loss_functions.cuh create mode 100644 plugin/updater_gpu/src/types.cuh create mode 100644 plugin/updater_gpu/src/types_functions.cuh create mode 100644 plugin/updater_gpu/src/updater_gpu.cc create mode 100644 plugin/updater_gpu/test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index baf66aad5..fec99aed0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,90 +1,123 @@ -cmake_minimum_required (VERSION 2.6) -project (xgboost) -find_package(OpenMP) - -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -fPIC") - -# Make sure we are using C++11 -# Visual Studio 12.0 and newer supports enough c++11 to make this work -if(MSVC) - if(MSVC_VERSION LESS 1800) - message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.") - endif() -else() - # GCC 4.6 with c++0x supports enough to make this work - include(CheckCXXCompilerFlag) - CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11) - CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X) - - if(COMPILER_SUPPORTS_CXX11) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") - elseif(COMPILER_SUPPORTS_CXX0X) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") - else() - message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.") - endif() -endif() - - -#Make sure we are using the static runtime -if(MSVC) - set(variables - CMAKE_C_FLAGS_DEBUG - CMAKE_C_FLAGS_MINSIZEREL - CMAKE_C_FLAGS_RELEASE - CMAKE_C_FLAGS_RELWITHDEBINFO - CMAKE_CXX_FLAGS_DEBUG - CMAKE_CXX_FLAGS_MINSIZEREL - CMAKE_CXX_FLAGS_RELEASE - CMAKE_CXX_FLAGS_RELWITHDEBINFO - ) - foreach(variable ${variables}) - if(${variable} MATCHES "/MD") - string(REGEX REPLACE "/MD" "/MT" ${variable} "${${variable}}") - endif() - endforeach() -endif() - -include_directories ( - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/dmlc-core/include - ${PROJECT_SOURCE_DIR}/rabit/include - ) - -file(GLOB SOURCES - src/c_api/*.cc - src/common/*.cc - src/data/*.cc - src/gbm/*.cc - src/metric/*.cc - src/objective/*.cc - src/tree/*.cc - src/*.cc -) - -set(RABIT_SOURCES - rabit/src/allreduce_base.cc - rabit/src/allreduce_robust.cc - rabit/src/engine.cc - rabit/src/c_api.cc -) - - -add_subdirectory(dmlc-core) - -add_library(rabit STATIC ${RABIT_SOURCES}) - -if(MSVC) - add_executable(xgboost ${SOURCES}) - add_library(libxgboost SHARED ${SOURCES}) - - target_link_libraries(xgboost dmlccore rabit) - target_link_libraries(libxgboost dmlccore rabit) -else() - add_executable(xgboost-bin ${SOURCES}) - SET_TARGET_PROPERTIES(xgboost-bin PROPERTIES OUTPUT_NAME xgboost) - add_library(xgboost SHARED ${SOURCES}) - - target_link_libraries(xgboost-bin dmlccore rabit) - target_link_libraries(xgboost dmlccore rabit) -endif() +cmake_minimum_required (VERSION 2.6) +project (xgboost) +find_package(OpenMP) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +endif() + + +# Make sure we are using C++11 +# Visual Studio 12.0 and newer supports enough c++11 to make this work +if(MSVC) + if(MSVC_VERSION LESS 1800) + message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.") + endif() +else() + # GCC 4.6 with c++0x supports enough to make this work + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11) + CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X) + + if(COMPILER_SUPPORTS_CXX11) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + elseif(COMPILER_SUPPORTS_CXX0X) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") + else() + message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.") + endif() +endif() + +#Make sure we are using the static runtime +if(MSVC) + set(variables + CMAKE_C_FLAGS_DEBUG + CMAKE_C_FLAGS_MINSIZEREL + CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_RELWITHDEBINFO + ) + foreach(variable ${variables}) + if(${variable} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${variable} "${${variable}}") + endif() + endforeach() +endif() + +include_directories ( + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dmlc-core/include + ${PROJECT_SOURCE_DIR}/rabit/include + ) + +file(GLOB SOURCES + src/c_api/*.cc + src/common/*.cc + src/data/*.cc + src/gbm/*.cc + src/metric/*.cc + src/objective/*.cc + src/tree/*.cc + src/*.cc +) + +set(RABIT_SOURCES + rabit/src/allreduce_base.cc + rabit/src/allreduce_robust.cc + rabit/src/engine.cc + rabit/src/c_api.cc +) + +add_subdirectory(dmlc-core) + +add_library(rabit STATIC ${RABIT_SOURCES}) + +#Set library output directories +if(MSVC) + #With MSVC shared library is considered runtime + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/lib) +else() + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}) + #Prevent shared library being called liblibxgboost.so on Linux + set(CMAKE_SHARED_LIBRARY_PREFIX "") +endif() + +option(PLUGIN_UPDATER_GPU "Build GPU accelerated tree construction plugin") + +if(PLUGIN_UPDATER_GPU) + #Find cub + set(CUB_DIRECTORY "" CACHE PATH "CUB 1.5.4 directory") + include_directories(${CUB_DIRECTORY}) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;-arch=compute_35") + if(NOT MSVC) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-std=c++11; -Xcompiler -fPIC") + endif() + set(SOURCES ${SOURCES} + plugin/updater_gpu/src/updater_gpu.cc + ) + find_package(CUDA QUIET REQUIRED) + cuda_add_library(updater_gpu STATIC + plugin/updater_gpu/src/gpu_builder.cu + ) +endif() + +add_executable(xgboost ${SOURCES}) +add_library(libxgboost SHARED ${SOURCES}) +target_link_libraries(xgboost dmlccore rabit) +target_link_libraries(libxgboost dmlccore rabit) + + +if(PLUGIN_UPDATER_GPU) + target_link_libraries(xgboost updater_gpu) + target_link_libraries(libxgboost updater_gpu) +endif() + diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md new file mode 100644 index 000000000..e801fc354 --- /dev/null +++ b/plugin/updater_gpu/README.md @@ -0,0 +1,31 @@ +# CUDA Accelerated Tree Construction Algorithm + +## Usage +Specify the updater parameter as 'grow_gpu'. + +Python example: +```python +param['updater'] = 'grow_gpu' +``` + +## Dependencies +A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). + +The plugin also depends on CUB 1.5.4 - http://nvlabs.github.io/cub/index.html. + +CUB is a header only cuda library which provides sort/reduce/scan primitives. + + +## Build +The plugin can be built using cmake and specifying the option PLUGIN_UPDATER_GPU=ON. + +Specify the location of the CUB library with the cmake variable CUB_DIRECTORY. + +It is recommended to build with Cuda Toolkit 7.5 or greater. + +## Author +Rory Mitchell + +Report any bugs to r.a.mitchell.nz at google mail. + + diff --git a/plugin/updater_gpu/speed_test.py b/plugin/updater_gpu/speed_test.py new file mode 100644 index 000000000..eaf8111b5 --- /dev/null +++ b/plugin/updater_gpu/speed_test.py @@ -0,0 +1,64 @@ +#!/usr/bin/pytho#!/usr/bin/python +#pylint: skip-file +# this is the example script to use xgboost to train +import numpy as np +import xgboost as xgb +import time +test_size = 550000 + +# path to where the data lies +dpath = '../../demo/data' + +# load in training data, directly use numpy +dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } ) +dtrain = np.concatenate((dtrain, np.copy(dtrain))) +dtrain = np.concatenate((dtrain, np.copy(dtrain))) +print(len(dtrain)) +print ('finish loading from csv ') + +label = dtrain[:,32] +data = dtrain[:,1:31] +# rescale weight to make it same as test set +weight = dtrain[:,31] * float(test_size) / len(label) + +sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 ) +sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 ) + +# print weight statistics +print ('weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos )) + +# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value +xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight ) + +# setup parameters for xgboost +param = {} +# use logistic regression loss +param['objective'] = 'binary:logitraw' +# scale weight of positive examples +param['scale_pos_weight'] = sum_wneg/sum_wpos +param['bst:eta'] = 0.1 +param['max_depth'] = 16 +param['eval_metric'] = 'auc' +param['silent'] = 1 +param['nthread'] = 4 + +plst = param.items()+[('eval_metric', 'ams@0.15')] + +watchlist = [ (xgmat,'train') ] +num_round = 10 +print ("training xgboost") +threads = [16] +for i in threads: + param['nthread'] = i + tmp = time.time() + plst = param.items()+[('eval_metric', 'ams@0.15')] + bst = xgb.train( plst, xgmat, num_round, watchlist ); + print ("XGBoost with %d thread costs: %s seconds" % (i, str(time.time() - tmp))) + +print ("training xgboost - gpu tree construction") +param['updater'] = 'grow_gpu' +tmp = time.time() +plst = param.items()+[('eval_metric', 'ams@0.15')] +bst = xgb.train( plst, xgmat, num_round, watchlist ); +print ("XGBoost GPU: %s seconds" % (str(time.time() - tmp))) +print ('finish training') diff --git a/plugin/updater_gpu/src/cuda_helpers.cuh b/plugin/updater_gpu/src/cuda_helpers.cuh new file mode 100644 index 000000000..d46814e93 --- /dev/null +++ b/plugin/updater_gpu/src/cuda_helpers.cuh @@ -0,0 +1,276 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#endif + +#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__) + +cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) { + if (code != cudaSuccess) { + std::cout << file; + std::cout << line; + std::stringstream ss; + ss << file << "(" << line << ")"; + std::string file_and_line; + ss >> file_and_line; + throw thrust::system_error(code, thrust::cuda_category(), file_and_line); + } + + return code; +} + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) + exit(code); + } +} + +// Keep track of cub library device allocation +struct CubMemory { + void *d_temp_storage; + size_t temp_storage_bytes; + + CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} + + ~CubMemory() { + if (d_temp_storage != NULL) { + safe_cuda(cudaFree(d_temp_storage)); + } + } + + void Allocate() { + safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + } + + bool IsAllocated() { return d_temp_storage != NULL; } +}; + +// Utility function: rounds up integer division. +template T div_round_up(const T a, const T b) { + return static_cast(ceil(static_cast(a) / b)); +} + +template thrust::device_ptr dptr(T *d_ptr) { + return thrust::device_pointer_cast(d_ptr); +} + +// #define DEVICE_TIMER +#define MAX_WARPS 32 // Maximum number of warps to time +#define MAX_SLOTS 10 +#define TIMER_BLOCKID 0 // Block to time +struct DeviceTimerGlobal { +#ifdef DEVICE_TIMER + + clock_t total_clocks[MAX_SLOTS][MAX_WARPS]; + int64_t count[MAX_SLOTS][MAX_WARPS]; + +#endif + + // Clear device memory. Call at start of kernel. + __device__ void Init() { +#ifdef DEVICE_TIMER + if (blockIdx.x == TIMER_BLOCKID && threadIdx.x < MAX_WARPS) { + for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) { + total_clocks[SLOT][threadIdx.x] = 0; + count[SLOT][threadIdx.x] = 0; + } + } +#endif + } + + void HostPrint() { +#ifdef DEVICE_TIMER + DeviceTimerGlobal h_timer; + safe_cuda( + cudaMemcpyFromSymbol(&h_timer, (*this), sizeof(DeviceTimerGlobal))); + + for (int SLOT = 0; SLOT < MAX_SLOTS; SLOT++) { + if (h_timer.count[SLOT][0] == 0) { + continue; + } + + clock_t sum_clocks = 0; + int64_t sum_count = 0; + + for (int WARP = 0; WARP < MAX_WARPS; WARP++) { + if (h_timer.count[SLOT][WARP] == 0) { + continue; + } + + sum_clocks += h_timer.total_clocks[SLOT][WARP]; + sum_count += h_timer.count[SLOT][WARP]; + } + + printf("Slot %d: %d clocks per call, called %d times.\n", SLOT, + sum_clocks / sum_count, h_timer.count[SLOT][0]); + } +#endif + } +}; + +struct DeviceTimer { +#ifdef DEVICE_TIMER + clock_t start; + int slot; + DeviceTimerGlobal >imer; +#endif + +#ifdef DEVICE_TIMER + __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT + : + GTimer(GTimer), + start(clock()), slot(slot) {} +#else + __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT +#endif + + __device__ void End() { +#ifdef DEVICE_TIMER + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + if (blockIdx.x == TIMER_BLOCKID && lane_id == 0) { + GTimer.count[slot][warp_id] += 1; + GTimer.total_clocks[slot][warp_id] += clock() - start; + } +#endif + } +}; + +// #define TIMERS +struct Timer { + volatile double start; + Timer() { reset(); } + + double seconds_now() { +#ifdef _WIN32 + static LARGE_INTEGER s_frequency; + QueryPerformanceFrequency(&s_frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return static_cast(now.QuadPart) / s_frequency.QuadPart; +#endif + } + + void reset() { +#ifdef _WIN32 + _ReadWriteBarrier(); + start = seconds_now(); +#endif + } + double elapsed() { +#ifdef _WIN32 + _ReadWriteBarrier(); + return seconds_now() - start; +#endif + } + void printElapsed(char *label) { +#ifdef TIMERS + safe_cuda(cudaDeviceSynchronize()); + printf("%s:\t %1.4fs\n", label, elapsed()); +#endif + } +}; + +template +void print(const thrust::device_vector &v, size_t max_items = 10) { + thrust::host_vector h = v; + for (int i = 0; i < std::min(max_items, h.size()); i++) { + std::cout << " " << h[i]; + } + std::cout << "\n"; +} + +template +void print(char *label, const thrust::device_vector &v, + const char *format = "%d ", int max = 10) { + thrust::host_vector h_v = v; + + std::cout << label << ":\n"; + for (int i = 0; i < std::min(static_cast(h_v.size()), max); i++) { + printf(format, h_v[i]); + } + std::cout << "\n"; +} + +class range { + public: + class iterator { + friend class range; + + public: + __host__ __device__ int64_t operator*() const { return i_; } + __host__ __device__ const iterator &operator++() { + i_ += step_; + return *this; + } + __host__ __device__ iterator operator++(int) { + iterator copy(*this); + i_ += step_; + return copy; + } + + __host__ __device__ bool operator==(const iterator &other) const { + return i_ >= other.i_; + } + __host__ __device__ bool operator!=(const iterator &other) const { + return i_ < other.i_; + } + + __host__ __device__ void step(int s) { step_ = s; } + + protected: + __host__ __device__ explicit iterator(int64_t start) : i_(start) {} + + public: + uint64_t i_; + int step_ = 1; + }; + + __host__ __device__ iterator begin() const { return begin_; } + __host__ __device__ iterator end() const { return end_; } + __host__ __device__ range(int64_t begin, int64_t end) + : begin_(begin), end_(end) {} + __host__ __device__ void step(int s) { begin_.step(s); } + + private: + iterator begin_; + iterator end_; +}; + +template __device__ range grid_stride_range(T begin, T end) { + begin += blockDim.x * blockIdx.x + threadIdx.x; + range r(begin, end); + r.step(gridDim.x * blockDim.x); + return r; +} + +template __device__ range block_stride_range(T begin, T end) { + begin += threadIdx.x; + range r(begin, end); + r.step(blockDim.x); + return r; +} + +// Converts device_vector to raw pointer +template T *raw(thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} diff --git a/plugin/updater_gpu/src/find_split.cuh b/plugin/updater_gpu/src/find_split.cuh new file mode 100644 index 000000000..d1b7958d8 --- /dev/null +++ b/plugin/updater_gpu/src/find_split.cuh @@ -0,0 +1,87 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include "cuda_helpers.cuh" +#include "find_split_multiscan.cuh" +#include "find_split_sorting.cuh" +#include "types_functions.cuh" + +namespace xgboost { +namespace tree { + +__global__ void +reduce_split_candidates_kernel(Split *d_split_candidates, Node *d_current_nodes, + Node *d_new_nodes, int n_current_nodes, + int n_features, const GPUTrainingParam param) { + int nid = blockIdx.x * blockDim.x + threadIdx.x; + + if (nid >= n_current_nodes) { + return; + } + + // Find best split for each node + Split best; + + for (int i = 0; i < n_features; i++) { + best.Update(d_split_candidates[n_current_nodes * i + nid]); + } + + // Update current node + d_current_nodes[nid].split = best; + + // Generate new nodes + d_new_nodes[nid * 2] = + Node(best.left_sum, + CalcGain(param, best.left_sum.grad(), best.left_sum.hess()), + CalcWeight(param, best.left_sum.grad(), best.left_sum.hess())); + d_new_nodes[nid * 2 + 1] = + Node(best.right_sum, + CalcGain(param, best.right_sum.grad(), best.right_sum.hess()), + CalcWeight(param, best.right_sum.grad(), best.right_sum.hess())); +} + +void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes, + int level, int n_features, + const GPUTrainingParam param) { + // Current level nodes (before split) + Node *d_current_nodes = d_nodes + (1 << (level)) - 1; + // Next level nodes (after split) + Node *d_new_nodes = d_nodes + (1 << (level + 1)) - 1; + // Number of existing nodes on this level + int n_current_nodes = 1 << level; + + const int BLOCK_THREADS = 512; + const int GRID_SIZE = div_round_up(n_current_nodes, BLOCK_THREADS); + + reduce_split_candidates_kernel<<>>( + d_split_candidates, d_current_nodes, d_new_nodes, n_current_nodes, + n_features, param); + safe_cuda(cudaDeviceSynchronize()); +} + +void find_split(const Item *d_items, Split *d_split_candidates, + const NodeIdT *d_node_id, Node *d_nodes, bst_uint num_items, + int num_features, const int *d_feature_offsets, + gpu_gpair *d_node_sums, int *d_node_offsets, + const GPUTrainingParam param, const int level, + bool multiscan_algorithm) { + if (multiscan_algorithm) { + find_split_candidates_multiscan(d_items, d_split_candidates, d_node_id, + d_nodes, num_items, num_features, + d_feature_offsets, param, level); + } else { + find_split_candidates_sorted(d_items, d_split_candidates, d_node_id, + d_nodes, num_items, num_features, + d_feature_offsets, d_node_sums, d_node_offsets, + param, level); + } + + // Find the best split for each node + reduce_split_candidates(d_split_candidates, d_nodes, level, num_features, + param); +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/find_split_multiscan.cuh b/plugin/updater_gpu/src/find_split_multiscan.cuh new file mode 100644 index 000000000..fc49946d2 --- /dev/null +++ b/plugin/updater_gpu/src/find_split_multiscan.cuh @@ -0,0 +1,835 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include "cuda_helpers.cuh" +#include "types_functions.cuh" + +namespace xgboost { +namespace tree { + +typedef uint64_t BitFlagSet; + +__device__ __inline__ void set_bit(BitFlagSet &bf, int index) { // NOLINT + bf |= 1 << index; +} + +__device__ __inline__ bool check_bit(BitFlagSet bf, int index) { + return (bf >> index) & 1; +} + +// Carryover prefix for scanning multiple tiles of bit flags +struct FlagPrefixCallbackOp { + BitFlagSet tile_carry; + + __device__ FlagPrefixCallbackOp() : tile_carry(0) {} + + __device__ BitFlagSet operator()(BitFlagSet block_aggregate) { + BitFlagSet old_prefix = tile_carry; + tile_carry |= block_aggregate; + return old_prefix; + } +}; + +// Scan op for bit flags that resets if the final bit is set +struct FlagScanOp { + __device__ __forceinline__ BitFlagSet operator()(const BitFlagSet &a, + const BitFlagSet &b) { + if (check_bit(b, 63)) { + return b; + } else { + return a | b; + } + } +}; + +template +struct FindSplitParamsMultiscan { + enum { + BLOCK_THREADS = _BLOCK_THREADS, + TILE_ITEMS = BLOCK_THREADS, + N_NODES = _N_NODES, + N_WARPS = _BLOCK_THREADS / 32, + DEBUG_VALIDATE = _DEBUG_VALIDATE, + ITEMS_PER_THREAD = 1 + }; +}; + +template +struct ReduceParamsMultiscan { + enum { + BLOCK_THREADS = _BLOCK_THREADS, + ITEMS_PER_THREAD = 1, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + N_NODES = _N_NODES, + N_WARPS = _BLOCK_THREADS / 32, + DEBUG_VALIDATE = _DEBUG_VALIDATE + }; +}; + +template struct ReduceEnactorMultiscan { + typedef cub::WarpReduce WarpReduceT; + + struct _TempStorage { + typename WarpReduceT::TempStorage warp_reduce[ParamsT::N_WARPS]; + gpu_gpair partial_sums[ParamsT::N_NODES][ParamsT::N_WARPS]; + }; + + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + struct _Reduction { + gpu_gpair node_sums[ParamsT::N_NODES]; + }; + + struct Reduction : cub::Uninitialized<_Reduction> {}; + + // Thread local member variables + const Item *d_items; + const NodeIdT *d_node_id; + _TempStorage &temp_storage; + _Reduction &reduction; + gpu_gpair gpair; + NodeIdT node_id; + NodeIdT node_id_adjusted; + const int node_begin; + + __device__ __forceinline__ ReduceEnactorMultiscan( + TempStorage &temp_storage, // NOLINT + Reduction &reduction, // NOLINT + const Item *d_items, const NodeIdT *d_node_id, const int node_begin) + : temp_storage(temp_storage.Alias()), reduction(reduction.Alias()), + d_items(d_items), d_node_id(d_node_id), node_begin(node_begin) {} + + __device__ __forceinline__ void ResetPartials() { + if (threadIdx.x < ParamsT::N_WARPS) { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + temp_storage.partial_sums[NODE][threadIdx.x] = gpu_gpair(); + } + } + } + + __device__ __forceinline__ void ResetReduction() { + if (threadIdx.x < ParamsT::N_NODES) { + reduction.node_sums[threadIdx.x] = gpu_gpair(); + } + } + + __device__ __forceinline__ void LoadTile(const bst_uint &offset, + const bst_uint &num_remaining) { + if (threadIdx.x < num_remaining) { + gpair = d_items[offset + threadIdx.x].gpair; + node_id = d_node_id[offset + threadIdx.x]; + node_id_adjusted = node_id - node_begin; + } else { + gpair = gpu_gpair(); + node_id = -1; + node_id_adjusted = -1; + } + } + + __device__ __forceinline__ void ProcessTile(const bst_uint &offset, + const bst_uint &num_remaining) { + LoadTile(offset, num_remaining); + + // Warp synchronous reduction + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + bool active = node_id_adjusted == NODE; + + unsigned int ballot = __ballot(active); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + if (ballot == 0) { + continue; + } else if (__popc(ballot) == 1) { + if (active) { + temp_storage.partial_sums[NODE][warp_id] += gpair; + } + } else { + gpu_gpair sum = WarpReduceT(temp_storage.warp_reduce[warp_id]) + .Sum(active ? gpair : gpu_gpair()); + if (lane_id == 0) { + temp_storage.partial_sums[NODE][warp_id] += sum; + } + } + } + } + + __device__ __forceinline__ void ReducePartials() { + // Use single warp to reduce partials + if (threadIdx.x < 32) { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + gpu_gpair sum = + WarpReduceT(temp_storage.warp_reduce[0]) + .Sum(threadIdx.x < ParamsT::N_WARPS + ? temp_storage.partial_sums[NODE][threadIdx.x] + : gpu_gpair()); + + if (threadIdx.x == 0) { + reduction.node_sums[NODE] = sum; + } + } + } + } + + __device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin, + const bst_uint &segment_end) { + // Current position + bst_uint offset = segment_begin; + + ResetReduction(); + ResetPartials(); + + __syncthreads(); + + // Process full tiles + while (offset < segment_end) { + ProcessTile(offset, segment_end - offset); + offset += ParamsT::TILE_ITEMS; + } + + __syncthreads(); + + ReducePartials(); + + __syncthreads(); + } +}; + +template +struct FindSplitEnactorMultiscan { + typedef cub::BlockScan FlagsBlockScanT; + + typedef cub::WarpReduce WarpSplitReduceT; + + typedef cub::WarpReduce WarpReduceT; + + typedef cub::WarpScan WarpScanT; + + struct _TempStorage { + union { + typename WarpSplitReduceT::TempStorage warp_split_reduce; + typename FlagsBlockScanT::TempStorage flags_scan; + typename WarpScanT::TempStorage warp_gpair_scan[ParamsT::N_WARPS]; + typename WarpReduceT::TempStorage warp_reduce[ParamsT::N_WARPS]; + }; + + Split warp_best_splits[ParamsT::N_NODES][ParamsT::N_WARPS]; + gpu_gpair partial_sums[ParamsT::N_NODES][ParamsT::N_WARPS]; + gpu_gpair top_level_sum[ParamsT::N_NODES]; // Sum of current partial sums + gpu_gpair tile_carry[ParamsT::N_NODES]; // Contains top-level sums from + // previous tiles + Split best_splits[ParamsT::N_NODES]; + // Cache current level nodes into shared memory + float node_root_gain[ParamsT::N_NODES]; + gpu_gpair node_parent_sum[ParamsT::N_NODES]; + }; + + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + // Thread local member variables + const Item *d_items; + Split *d_split_candidates_out; + const NodeIdT *d_node_id; + const Node *d_nodes; + _TempStorage &temp_storage; + Item item; + NodeIdT node_id; + NodeIdT node_id_adjusted; + const NodeIdT node_begin; + const GPUTrainingParam ¶m; + const ReductionT &reduction; + const int level; + FlagPrefixCallbackOp flag_prefix_op; + + __device__ __forceinline__ FindSplitEnactorMultiscan( + TempStorage &temp_storage, const Item *d_items, // NOLINT + Split *d_split_candidates_out, const NodeIdT *d_node_id, + const Node *d_nodes, const NodeIdT node_begin, + const GPUTrainingParam ¶m, const ReductionT reduction, + const int level) + : temp_storage(temp_storage.Alias()), d_items(d_items), + d_split_candidates_out(d_split_candidates_out), d_node_id(d_node_id), + d_nodes(d_nodes), node_begin(node_begin), param(param), + reduction(reduction), level(level), flag_prefix_op() {} + + __device__ __forceinline__ void UpdateTileCarry() { + if (threadIdx.x < ParamsT::N_NODES) { + temp_storage.tile_carry[threadIdx.x] += + temp_storage.top_level_sum[threadIdx.x]; + } + } + + __device__ __forceinline__ void ResetTileCarry() { + if (threadIdx.x < ParamsT::N_NODES) { + temp_storage.tile_carry[threadIdx.x] = gpu_gpair(); + } + } + + __device__ __forceinline__ void ResetPartials() { + if (threadIdx.x < ParamsT::N_WARPS) { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + temp_storage.partial_sums[NODE][threadIdx.x] = gpu_gpair(); + } + } + + if (threadIdx.x < ParamsT::N_NODES) { + temp_storage.top_level_sum[threadIdx.x] = gpu_gpair(); + } + } + + __device__ __forceinline__ void ResetSplits() { + if (threadIdx.x < ParamsT::N_WARPS) { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + temp_storage.warp_best_splits[NODE][threadIdx.x] = Split(); + } + } + + if (threadIdx.x < ParamsT::N_NODES) { + temp_storage.best_splits[threadIdx.x] = Split(); + } + } + + // Cache d_nodes array for this level into shared memory + __device__ __forceinline__ void CacheNodes() { + // Get pointer to nodes on the current level + const Node *d_nodes_level = d_nodes + node_begin; + + if (threadIdx.x < ParamsT::N_NODES) { + temp_storage.node_root_gain[threadIdx.x] = + d_nodes_level[threadIdx.x].root_gain; + temp_storage.node_parent_sum[threadIdx.x] = + d_nodes_level[threadIdx.x].sum_gradients; + } + } + + __device__ __forceinline__ void LoadTile(bst_uint offset, + bst_uint num_remaining) { + bst_uint index = offset + threadIdx.x; + if (threadIdx.x < num_remaining) { + item = d_items[index]; + node_id = d_node_id[index]; + node_id_adjusted = node_id - node_begin; + } else { + node_id = -1; + node_id_adjusted = -1; + item.fvalue = -FLT_MAX; + item.gpair = gpu_gpair(); + } + } + + // Is this node being processed by current kernel iteration? + __device__ __forceinline__ bool NodeActive() { + return node_id_adjusted < ParamsT::N_NODES && node_id_adjusted >= 0; + } + + // Is current fvalue different from left fvalue + __device__ __forceinline__ bool + LeftMostFvalue(const bst_uint &segment_begin, const bst_uint &offset, + const bst_uint &num_remaining) { + int left_index = offset + threadIdx.x - 1; + float left_fvalue = left_index >= static_cast(segment_begin) && + threadIdx.x < num_remaining + ? d_items[left_index].fvalue + : -FLT_MAX; + + return left_fvalue != item.fvalue; + } + + // Prevent splitting in the middle of same valued instances + __device__ __forceinline__ bool + CheckSplitValid(const bst_uint &segment_begin, const bst_uint &offset, + const bst_uint &num_remaining) { + BitFlagSet bit_flag = 0; + + bool valid_split = false; + + if (LeftMostFvalue(segment_begin, offset, num_remaining)) { + valid_split = true; + // Use MSB bit to flag if fvalue is leftmost + set_bit(bit_flag, 63); + } + + // Flag nodeid + if (NodeActive()) { + set_bit(bit_flag, node_id_adjusted); + } + + FlagsBlockScanT(temp_storage.flags_scan) + .ExclusiveScan(bit_flag, bit_flag, FlagScanOp(), flag_prefix_op); + __syncthreads(); + + if (!valid_split && NodeActive()) { + if (!check_bit(bit_flag, node_id_adjusted)) { + valid_split = true; + } + } + + return valid_split; + } + + // Perform warp reduction to find if this lane contains best loss_chg in warp + __device__ __forceinline__ bool QueryLaneBestLoss(const float &loss_chg) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + + // Possible source of bugs. Not all threads in warp are active here. Not + // sure if reduce function will behave correctly + float best = WarpReduceT(temp_storage.warp_reduce[warp_id]) + .Reduce(loss_chg, cub::Max()); + + // Its possible for more than one lane to contain the best value, so make + // sure only one lane returns true + unsigned int ballot = __ballot(loss_chg == best); + + if (lane_id == (__ffs(ballot) - 1)) { + return true; + } else { + return false; + } + } + + // Which thread in this warp should update the current best split, if any + // Returns true for one thread or none + __device__ __forceinline__ bool + QueryUpdateWarpSplit(const float &loss_chg, + volatile const float &warp_best_loss) { + bool update = false; + + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + bool active = node_id_adjusted == NODE; + + unsigned int ballot = __ballot(loss_chg > warp_best_loss && active); + + // No lane has improved loss_chg + if (__popc(ballot) == 0) { + continue; + } else if (__popc(ballot) == 1) { + // A single lane has improved loss_chg, set true for this lane + int lane_id = threadIdx.x % 32; + + if (lane_id == __ffs(ballot) - 1) { + update = true; + } + } else { + // More than one lane has improved loss_chg, perform a reduction. + if (QueryLaneBestLoss(active ? loss_chg : -FLT_MAX)) { + update = true; + } + } + } + + return update; + } + + __device__ void PrintTileScan(int block_id, bool thread_active, + float loss_chg, gpu_gpair missing) { + if (blockIdx.x != block_id) { + return; + } + + for (int warp = 0; warp < ParamsT::N_WARPS; warp++) { + if (threadIdx.x / 32 == warp) { + for (int lane = 0; lane < 32; lane++) { + gpu_gpair g = cub::ShuffleIndex(item.gpair, lane); + gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane); + float fvalue_broadcast = __shfl(item.fvalue, lane); + bool thread_active_broadcast = __shfl(thread_active, lane); + float loss_chg_broadcast = __shfl(loss_chg, lane); + NodeIdT node_id_broadcast = cub::ShuffleIndex(node_id, lane); + if (threadIdx.x == 32 * warp) { + printf("tid %d, nid %d, fvalue %1.2f, active %c, loss %1.2f, scan ", + threadIdx.x + lane, node_id_broadcast, fvalue_broadcast, + thread_active_broadcast ? 'y' : 'n', + loss_chg_broadcast < 0.0f ? 0 : loss_chg_broadcast); + g.print(); + } + } + } + + __syncthreads(); + } + } + + __device__ __forceinline__ void + EvaluateSplits(const bst_uint &segment_begin, const bst_uint &offset, + const bst_uint &num_remaining) { + bool valid_split = CheckSplitValid(segment_begin, offset, num_remaining); + + bool thread_active = + NodeActive() && valid_split && threadIdx.x < num_remaining; + + const int warp_id = threadIdx.x / 32; + + gpu_gpair parent_sum = thread_active + ? temp_storage.node_parent_sum[node_id_adjusted] + : gpu_gpair(); + float parent_gain = + thread_active ? temp_storage.node_root_gain[node_id_adjusted] : 0.0f; + gpu_gpair missing = thread_active + ? parent_sum - reduction.node_sums[node_id_adjusted] + : gpu_gpair(); + + bool missing_left; + + float loss_chg = thread_active + ? loss_chg_missing(item.gpair, missing, parent_sum, + parent_gain, param, missing_left) + : -FLT_MAX; + + // PrintTileScan(64, thread_active, loss_chg, missing); + + float warp_best_loss = + thread_active + ? temp_storage.warp_best_splits[node_id_adjusted][warp_id].loss_chg + : 0.0f; + + if (QueryUpdateWarpSplit(loss_chg, warp_best_loss)) { + float fvalue_split = item.fvalue - FVALUE_EPS; + + if (missing_left) { + gpu_gpair left_sum = missing + item.gpair; + gpu_gpair right_sum = parent_sum - left_sum; + temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update( + loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum, + right_sum, param); + } else { + gpu_gpair left_sum = item.gpair; + gpu_gpair right_sum = parent_sum - left_sum; + temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update( + loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum, + right_sum, param); + } + } + } + + /* + __device__ __forceinline__ void WarpExclusiveScan(bool active, gpu_gpair + input, gpu_gpair &output, gpu_gpair &sum) + { + + output = input; + + for (int offset = 1; offset < 32; offset <<= 1){ + float tmp1 = __shfl_up(output.grad(), offset); + + float tmp2 = __shfl_up(output.hess(), offset); + if (cub::LaneId() >= offset) + { + output.grad += tmp1; + output.hess += tmp2; + } + } + + sum.grad = __shfl(output.grad, 31); + sum.hess = __shfl(output.hess, 31); + + output -= input; + } + */ + __device__ __forceinline__ void BlockExclusiveScan() { + ResetPartials(); + + __syncthreads(); + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + bool node_active = node_id_adjusted == NODE; + + unsigned int ballot = __ballot(node_active); + + gpu_gpair warp_sum = gpu_gpair(); + gpu_gpair scan_result = gpu_gpair(); + + if (ballot > 0) { + WarpScanT(temp_storage.warp_gpair_scan[warp_id]) + .InclusiveScan(node_active ? item.gpair : gpu_gpair(), scan_result, + cub::Sum(), warp_sum); + // WarpExclusiveScan( node_active, node_active ? item.gpair : + // gpu_gpair(), scan_result, warp_sum); + } + + if (node_active) { + item.gpair = scan_result - item.gpair; + } + + if (lane_id == 0) { + temp_storage.partial_sums[NODE][warp_id] = warp_sum; + } + } + + __syncthreads(); + + if (threadIdx.x < 32) { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + gpu_gpair top_level_sum; + bool warp_active = threadIdx.x < ParamsT::N_WARPS; + gpu_gpair scan_result; + WarpScanT(temp_storage.warp_gpair_scan[warp_id]) + .InclusiveScan(warp_active + ? temp_storage.partial_sums[NODE][threadIdx.x] + : gpu_gpair(), + scan_result, cub::Sum(), top_level_sum); + + if (warp_active) { + temp_storage.partial_sums[NODE][threadIdx.x] = + scan_result - temp_storage.partial_sums[NODE][threadIdx.x]; + } + + if (threadIdx.x == 0) { + temp_storage.top_level_sum[NODE] = top_level_sum; + } + } + } + + __syncthreads(); + + if (NodeActive()) { + item.gpair += temp_storage.partial_sums[node_id_adjusted][warp_id] + + temp_storage.tile_carry[node_id_adjusted]; + } + + __syncthreads(); + + UpdateTileCarry(); + + __syncthreads(); + } + + __device__ __forceinline__ void ProcessTile(const bst_uint &segment_begin, + const bst_uint &offset, + const bst_uint &num_remaining) { + LoadTile(offset, num_remaining); + BlockExclusiveScan(); + EvaluateSplits(segment_begin, offset, num_remaining); + } + + __device__ __forceinline__ void ReduceSplits() { + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + if (threadIdx.x < 32) { + Split s = Split(); + if (threadIdx.x < ParamsT::N_WARPS) { + s = temp_storage.warp_best_splits[NODE][threadIdx.x]; + } + Split best = WarpSplitReduceT(temp_storage.warp_split_reduce) + .Reduce(s, split_reduce_op()); + if (threadIdx.x == 0) { + temp_storage.best_splits[NODE] = best; + } + } + } + } + + __device__ __forceinline__ void WriteBestSplits() { + const int nodes_level = 1 << level; + + if (threadIdx.x < ParamsT::N_NODES) { + d_split_candidates_out[blockIdx.x * nodes_level + threadIdx.x] = + temp_storage.best_splits[threadIdx.x]; + } + } + + /* + __device__ void SequentialAlgorithm(bst_uint segment_begin, + bst_uint segment_end) { + if (threadIdx.x != 0) { + return; + } + + __shared__ Split best_split[ParamsT::N_NODES]; + + __shared__ gpu_gpair scan[ParamsT::N_NODES]; + + __shared__ Node nodes[ParamsT::N_NODES]; + + __shared__ gpu_gpair missing[ParamsT::N_NODES]; + + float previous_fvalue[ParamsT::N_NODES]; + + // Initialise counts + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + best_split[NODE] = Split(); + scan[NODE] = gpu_gpair(); + nodes[NODE] = d_nodes[node_begin + NODE]; + missing[NODE] = nodes[NODE].sum_gradients - reduction.node_sums[NODE]; + previous_fvalue[NODE] = FLT_MAX; + } + + for (bst_uint i = segment_begin; i < segment_end; i++) { + int8_t nodeid_adjusted = d_node_id[i] - node_begin; + float fvalue = d_items[i].fvalue; + + if (NodeActive(nodeid_adjusted)) { + if (fvalue != previous_fvalue[nodeid_adjusted]) { + float f_split; + if (previous_fvalue[nodeid_adjusted] != FLT_MAX) { + f_split = (previous_fvalue[nodeid_adjusted] + fvalue) * 0.5; + } else { + f_split = fvalue; + } + + best_split[nodeid_adjusted].UpdateCalcLoss( + f_split, scan[nodeid_adjusted], missing[nodeid_adjusted], + nodes[nodeid_adjusted], param); + } + + scan[nodeid_adjusted] += d_items[i].gpair; + previous_fvalue[nodeid_adjusted] = fvalue; + } + } + + for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { + temp_storage.best_splits[NODE] = best_split[NODE]; + } + } + */ + + __device__ __forceinline__ void ResetSplitCandidates() { + const int max_nodes = 1 << level; + const int begin = blockIdx.x * max_nodes; + const int end = begin + max_nodes; + + for (auto i : block_stride_range(begin, end)) { + d_split_candidates_out[i] = Split(); + } + } + + __device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin, + const bst_uint &segment_end) { + // Current position + bst_uint offset = segment_begin; + + ResetSplitCandidates(); + ResetTileCarry(); + ResetSplits(); + CacheNodes(); + __syncthreads(); + + // Process full tiles + while (offset < segment_end) { + ProcessTile(segment_begin, offset, segment_end - offset); + __syncthreads(); + offset += ParamsT::TILE_ITEMS; + } + + __syncthreads(); + ReduceSplits(); + + __syncthreads(); + WriteBestSplits(); + } +}; + +template +__global__ void +#if __CUDA_ARCH__ <= 530 +__launch_bounds__(1024, 2) +#endif + find_split_candidates_multiscan_kernel( + const Item *d_items, Split *d_split_candidates_out, + const NodeIdT *d_node_id, const Node *d_nodes, const int node_begin, + bst_uint num_items, int num_features, const int *d_feature_offsets, + const GPUTrainingParam param, const int level) { + if (num_items <= 0) { + return; + } + + int segment_begin = d_feature_offsets[blockIdx.x]; + int segment_end = d_feature_offsets[blockIdx.x + 1]; + + typedef ReduceEnactorMultiscan ReduceT; + typedef FindSplitEnactorMultiscan + FindSplitT; + + __shared__ union { + typename ReduceT::TempStorage reduce; + typename FindSplitT::TempStorage find_split; + } temp_storage; + + __shared__ typename ReduceT::Reduction reduction; + + ReduceT(temp_storage.reduce, reduction, d_items, d_node_id, node_begin) + .ProcessRegion(segment_begin, segment_end); + __syncthreads(); + + FindSplitT find_split(temp_storage.find_split, d_items, + d_split_candidates_out, d_node_id, d_nodes, node_begin, + param, reduction.Alias(), level); + find_split.ProcessRegion(segment_begin, segment_end); +} + +template +void find_split_candidates_multiscan_variation( + const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, + const Node *d_nodes, int node_begin, int node_end, bst_uint num_items, + int num_features, const int *d_feature_offsets, + const GPUTrainingParam param, const int level) { + + const int BLOCK_THREADS = 512; + + CHECK((node_end - node_begin) <= N_NODES) << "Multiscan: N_NODES template " + "parameter too small for given " + "node range."; + CHECK(BLOCK_THREADS / 32 < 32) + << "Too many active warps. See FindSplitEnactor - ReduceSplits."; + + typedef FindSplitParamsMultiscan + find_split_params; + typedef ReduceParamsMultiscan reduce_params; + int grid_size = num_features; + + find_split_candidates_multiscan_kernel< + find_split_params, + reduce_params><<>>( + d_items, d_split_candidates, d_node_id, d_nodes, node_begin, num_items, + num_features, d_feature_offsets, param, level); + + safe_cuda(cudaDeviceSynchronize()); +} + +void find_split_candidates_multiscan( + const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, + const Node *d_nodes, bst_uint num_items, int num_features, + const int *d_feature_offsets, const GPUTrainingParam param, + const int level) { + // Select templated variation of split finding algorithm + switch (level) { + case 0: + find_split_candidates_multiscan_variation<1>( + d_items, d_split_candidates, d_node_id, d_nodes, 0, 1, num_items, + num_features, d_feature_offsets, param, level); + break; + case 1: + find_split_candidates_multiscan_variation<2>( + d_items, d_split_candidates, d_node_id, d_nodes, 1, 3, num_items, + num_features, d_feature_offsets, param, level); + break; + case 2: + find_split_candidates_multiscan_variation<4>( + d_items, d_split_candidates, d_node_id, d_nodes, 3, 7, num_items, + num_features, d_feature_offsets, param, level); + break; + case 3: + find_split_candidates_multiscan_variation<8>( + d_items, d_split_candidates, d_node_id, d_nodes, 7, 15, num_items, + num_features, d_feature_offsets, param, level); + break; + case 4: + find_split_candidates_multiscan_variation<16>( + d_items, d_split_candidates, d_node_id, d_nodes, 15, 31, num_items, + num_features, d_feature_offsets, param, level); + break; + case 5: + find_split_candidates_multiscan_variation<32>( + d_items, d_split_candidates, d_node_id, d_nodes, 31, 63, num_items, + num_features, d_feature_offsets, param, level); + break; + } +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/find_split_sorting.cuh b/plugin/updater_gpu/src/find_split_sorting.cuh new file mode 100644 index 000000000..0cf422e45 --- /dev/null +++ b/plugin/updater_gpu/src/find_split_sorting.cuh @@ -0,0 +1,474 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include "cuda_helpers.cuh" +#include "types_functions.cuh" + +namespace xgboost { +namespace tree { + +struct ScanTuple { + gpu_gpair gpair; + NodeIdT node_id; + + __device__ ScanTuple() {} + + __device__ ScanTuple(gpu_gpair gpair, NodeIdT node_id) + : gpair(gpair), node_id(node_id) {} + + __device__ ScanTuple operator+=(const ScanTuple &rhs) { + if (node_id != rhs.node_id) { + *this = rhs; + return *this; + } else { + gpair += rhs.gpair; + return *this; + } + } + __device__ ScanTuple operator+(const ScanTuple &rhs) const { + ScanTuple t = *this; + return t += rhs; + } +}; + +struct GpairTupleCallbackOp { + // Running prefix + ScanTuple running_total; + // Constructor + __device__ GpairTupleCallbackOp() + : running_total(ScanTuple(gpu_gpair(), -1)) {} + __device__ ScanTuple operator()(ScanTuple block_aggregate) { + ScanTuple old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +struct GpairCallbackOp { + // Running prefix + gpu_gpair running_total; + // Constructor + __device__ GpairCallbackOp() : running_total(gpu_gpair()) {} + __device__ gpu_gpair operator()(gpu_gpair block_aggregate) { + gpu_gpair old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +struct FindSplitParamsSorting { + enum { + BLOCK_THREADS = _BLOCK_THREADS, + TILE_ITEMS = BLOCK_THREADS, + N_WARPS = _BLOCK_THREADS / 32, + DEBUG_VALIDATE = _DEBUG_VALIDATE, + ITEMS_PER_THREAD = 1 + }; +}; + +template struct ReduceParamsSorting { + enum { + BLOCK_THREADS = _BLOCK_THREADS, + ITEMS_PER_THREAD = 1, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + N_WARPS = _BLOCK_THREADS / 32, + DEBUG_VALIDATE = _DEBUG_VALIDATE + }; +}; + +template struct ReduceEnactorSorting { + typedef cub::BlockScan GpairScanT; + + struct _TempStorage { + typename GpairScanT::TempStorage gpair_scan; + }; + + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + // Thread local member variables + gpu_gpair *d_block_node_sums; + int *d_block_node_offsets; + const NodeIdT *d_node_id; + const Item *d_items; + _TempStorage &temp_storage; + Item item; + NodeIdT node_id; + NodeIdT right_node_id; + // Contains node_id relative to the current level only + NodeIdT node_id_adjusted; + GpairTupleCallbackOp callback_op; + const int level; + + __device__ __forceinline__ ReduceEnactorSorting( + TempStorage &temp_storage, // NOLINT + gpu_gpair *d_block_node_sums, int *d_block_node_offsets, + const Item *d_items, const NodeIdT *d_node_id, const int level) + : temp_storage(temp_storage.Alias()), + d_block_node_sums(d_block_node_sums), + d_block_node_offsets(d_block_node_offsets), d_items(d_items), + d_node_id(d_node_id), callback_op(), level(level) {} + + __device__ __forceinline__ void ResetSumsOffsets() { + const int max_nodes = 1 << level; + + for (auto i : block_stride_range(0, max_nodes)) { + d_block_node_sums[i] = gpu_gpair(); + d_block_node_offsets[i] = -1; + } + } + + __device__ __forceinline__ void LoadTile(const bst_uint &offset, + const bst_uint &num_remaining) { + if (threadIdx.x < num_remaining) { + item = d_items[offset + threadIdx.x]; + node_id = d_node_id[offset + threadIdx.x]; + right_node_id = threadIdx.x == num_remaining - 1 + ? -1 + : d_node_id[offset + threadIdx.x + 1]; + // Prevent overflow + const int level_begin = (1 << level) - 1; + node_id_adjusted = + max(static_cast(node_id) - level_begin, -1); // NOLINT + } + } + + __device__ __forceinline__ void ProcessTile(const bst_uint &offset, + const bst_uint &num_remaining) { + LoadTile(offset, num_remaining); + + ScanTuple t(item.gpair, node_id); + GpairScanT(temp_storage.gpair_scan).InclusiveSum(t, t, callback_op); + __syncthreads(); + + // If tail of segment + if (node_id != right_node_id && node_id_adjusted >= 0 && + threadIdx.x < num_remaining) { + // Write sum + d_block_node_sums[node_id_adjusted] = t.gpair; + // Write offset + d_block_node_offsets[node_id_adjusted] = offset + threadIdx.x + 1; + } + } + + __device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin, + const bst_uint &segment_end) { + // Current position + bst_uint offset = segment_begin; + + ResetSumsOffsets(); + + __syncthreads(); + + // Process full tiles + while (offset < segment_end) { + ProcessTile(offset, segment_end - offset); + offset += ParamsT::TILE_ITEMS; + } + } +}; + +template struct FindSplitEnactorSorting { + typedef cub::BlockScan GpairScanT; + typedef cub::BlockReduce SplitReduceT; + typedef cub::WarpReduce WarpLossReduceT; + + struct _TempStorage { + union { + typename GpairScanT::TempStorage gpair_scan; + typename SplitReduceT::TempStorage split_reduce; + typename WarpLossReduceT::TempStorage loss_reduce[ParamsT::N_WARPS]; + }; + Split warp_best_splits[ParamsT::N_WARPS]; + }; + + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + // Thread local member variables + _TempStorage &temp_storage; + gpu_gpair *d_block_node_sums; + int *d_block_node_offsets; + const Item *d_items; + const NodeIdT *d_node_id; + const Node *d_nodes; + Item item; + NodeIdT node_id; + float left_fvalue; + const GPUTrainingParam ¶m; + Split *d_split_candidates_out; + const int level; + + __device__ __forceinline__ FindSplitEnactorSorting( + TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT + int *d_block_node_offsets, const Item *d_items, const NodeIdT *d_node_id, + const Node *d_nodes, const GPUTrainingParam ¶m, + Split *d_split_candidates_out, const int level) + : temp_storage(temp_storage.Alias()), + d_block_node_sums(d_block_node_sums), + d_block_node_offsets(d_block_node_offsets), d_items(d_items), + d_node_id(d_node_id), d_nodes(d_nodes), + d_split_candidates_out(d_split_candidates_out), level(level), + param(param) {} + + __device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted, + const bst_uint &node_begin, + const bst_uint &offset, + const bst_uint &num_remaining) { + if (threadIdx.x < num_remaining) { + node_id = d_node_id[offset + threadIdx.x]; + + item = d_items[offset + threadIdx.x]; + bool first_item = offset + threadIdx.x == node_begin; + left_fvalue = first_item ? item.fvalue - FVALUE_EPS + : d_items[offset + threadIdx.x - 1].fvalue; + } + } + + __device__ void PrintTileScan(int block_id, bool thread_active, + float loss_chg, gpu_gpair missing) { + if (blockIdx.x != block_id) { + return; + } + + for (int warp = 0; warp < ParamsT::N_WARPS; warp++) { + if (threadIdx.x / 32 == warp) { + for (int lane = 0; lane < 32; lane++) { + gpu_gpair g = cub::ShuffleIndex(item.gpair, lane); + gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane); + float fvalue_broadcast = __shfl(item.fvalue, lane); + bool thread_active_broadcast = __shfl(thread_active, lane); + float loss_chg_broadcast = __shfl(loss_chg, lane); + if (threadIdx.x == 32 * warp) { + printf("tid %d, fvalue %1.2f, active %c, loss %1.2f, scan ", + threadIdx.x + lane, fvalue_broadcast, + thread_active_broadcast ? 'y' : 'n', + loss_chg_broadcast < 0.0f ? 0 : loss_chg_broadcast); + g.print(); + } + } + } + + __syncthreads(); + } + } + + __device__ __forceinline__ bool QueryUpdateWarpSplit(float loss_chg, + float warp_best_loss, + bool thread_active) { + int warp_id = threadIdx.x / 32; + int ballot = __ballot(loss_chg > warp_best_loss && thread_active); + if (ballot == 0) { + return false; + } else { + // Warp reduce best loss + float best = WarpLossReduceT(temp_storage.loss_reduce[warp_id]) + .Reduce(loss_chg, cub::Max()); + // Broadcast + best = cub::ShuffleIndex(best, 0); + + if (loss_chg == best) { + return true; + } + } + + return false; + } + + __device__ __forceinline__ bool LeftmostFvalue() { + return item.fvalue != left_fvalue; + } + + __device__ __forceinline__ void + EvaluateSplits(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, + const bst_uint &offset, const bst_uint &num_remaining) { + bool thread_active = LeftmostFvalue() && threadIdx.x < num_remaining && + node_id_adjusted >= 0 && node_id >= 0; + + Node n = thread_active ? d_nodes[node_id] : Node(); + gpu_gpair missing = + thread_active ? n.sum_gradients - d_block_node_sums[node_id_adjusted] + : gpu_gpair(); + + bool missing_left; + float loss_chg = + thread_active ? loss_chg_missing(item.gpair, missing, n.sum_gradients, + n.root_gain, param, missing_left) + : -FLT_MAX; + + int warp_id = threadIdx.x / 32; + volatile float warp_best_loss = + temp_storage.warp_best_splits[warp_id].loss_chg; + + if (QueryUpdateWarpSplit(loss_chg, warp_best_loss, thread_active)) { + float fvalue_split = (item.fvalue + left_fvalue) / 2.0f; + + gpu_gpair left_sum = item.gpair; + if (missing_left) { + left_sum += missing; + } + gpu_gpair right_sum = n.sum_gradients - left_sum; + temp_storage.warp_best_splits[warp_id].Update(loss_chg, missing_left, + fvalue_split, blockIdx.x, + left_sum, right_sum, param); + } + } + + __device__ __forceinline__ void + ProcessTile(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, + const bst_uint &offset, const bst_uint &num_remaining, + GpairCallbackOp &callback_op) { // NOLINT + LoadTile(node_id_adjusted, node_begin, offset, num_remaining); + + // Scan gpair + const bool thread_active = threadIdx.x < num_remaining && node_id >= 0; + GpairScanT(temp_storage.gpair_scan) + .ExclusiveSum(thread_active ? item.gpair : gpu_gpair(), item.gpair, + callback_op); + __syncthreads(); + // Evaluate split + EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining); + } + + __device__ __forceinline__ void ResetWarpSplits() { + if (threadIdx.x < ParamsT::N_WARPS) { + temp_storage.warp_best_splits[threadIdx.x] = Split(); + } + } + + __device__ __forceinline__ void + WriteBestSplit(const NodeIdT &node_id_adjusted) { + if (threadIdx.x < 32) { + bool active = threadIdx.x < ParamsT::N_WARPS; + float warp_loss = + active ? temp_storage.warp_best_splits[threadIdx.x].loss_chg + : -FLT_MAX; + if (QueryUpdateWarpSplit(warp_loss, 0, active)) { + const int max_nodes = 1 << level; + d_split_candidates_out[blockIdx.x * max_nodes + node_id_adjusted] = + temp_storage.warp_best_splits[threadIdx.x]; + } + } + } + + __device__ __forceinline__ void ProcessNode(const NodeIdT &node_id_adjusted, + const bst_uint &node_begin, + const bst_uint &node_end) { + ResetWarpSplits(); + + GpairCallbackOp callback_op = GpairCallbackOp(); + + bst_uint offset = node_begin; + + while (offset < node_end) { + ProcessTile(node_id_adjusted, node_begin, offset, node_end - offset, + callback_op); + offset += ParamsT::TILE_ITEMS; + __syncthreads(); + } + + WriteBestSplit(node_id_adjusted); + } + + __device__ __forceinline__ void ResetSplitCandidates() { + const int max_nodes = 1 << level; + const int begin = blockIdx.x * max_nodes; + const int end = begin + max_nodes; + + for (auto i : block_stride_range(begin, end)) { + d_split_candidates_out[i] = Split(); + } + } + + __device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin, + const bst_uint &segment_end) { + ResetSplitCandidates(); + + int node_begin = segment_begin; + + const int max_nodes = 1 << level; + + // Iterate through nodes + int active_nodes = 0; + for (int i = 0; i < max_nodes; i++) { + int node_end = d_block_node_offsets[i]; + + if (node_end == -1) { + continue; + } + + active_nodes++; + + ProcessNode(i, node_begin, node_end); + + __syncthreads(); + + node_begin = node_end; + } + } +}; + +template +__global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( + const Item *d_items, Split *d_split_candidates_out, + const NodeIdT *d_node_id, const Node *d_nodes, bst_uint num_items, + const int num_features, const int *d_feature_offsets, + gpu_gpair *d_node_sums, int *d_node_offsets, const GPUTrainingParam param, + const int level) { + + if (num_items <= 0) { + return; + } + + bst_uint segment_begin = d_feature_offsets[blockIdx.x]; + bst_uint segment_end = d_feature_offsets[blockIdx.x + 1]; + + typedef ReduceEnactorSorting ReduceT; + typedef FindSplitEnactorSorting FindSplitT; + + __shared__ union { + typename ReduceT::TempStorage reduce; + typename FindSplitT::TempStorage find_split; + } temp_storage; + + + const int max_modes_level = 1 << level; + gpu_gpair *d_block_node_sums = d_node_sums + blockIdx.x * max_modes_level; + int *d_block_node_offsets = d_node_offsets + blockIdx.x * max_modes_level; + + ReduceT(temp_storage.reduce, d_block_node_sums, d_block_node_offsets, d_items, + d_node_id, level) + .ProcessRegion(segment_begin, segment_end); + __syncthreads(); + + FindSplitT(temp_storage.find_split, d_block_node_sums, d_block_node_offsets, + d_items, d_node_id, d_nodes, param, d_split_candidates_out, level) + .ProcessFeature(segment_begin, segment_end); +} + +void find_split_candidates_sorted( + const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, + Node *d_nodes, bst_uint num_items, int num_features, + const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, + const GPUTrainingParam param, const int level) { + + const int BLOCK_THREADS = 512; + + CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps."; + + typedef FindSplitParamsSorting find_split_params; + typedef ReduceParamsSorting reduce_params; + int grid_size = num_features; + + find_split_candidates_sorted_kernel< + reduce_params, find_split_params><<>>( + d_items, d_split_candidates, d_node_id, d_nodes, num_items, num_features, + d_feature_offsets, d_node_sums, d_node_offsets, param, level); + + safe_cuda(cudaGetLastError()); + safe_cuda(cudaDeviceSynchronize()); +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_builder.cu b/plugin/updater_gpu/src/gpu_builder.cu new file mode 100644 index 000000000..3e8583942 --- /dev/null +++ b/plugin/updater_gpu/src/gpu_builder.cu @@ -0,0 +1,457 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#include "gpu_builder.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cuda_helpers.cuh" +#include "find_split.cuh" +#include "types_functions.cuh" + +namespace xgboost { +namespace tree { +struct GPUData { + GPUData() : allocated(false), n_features(0), n_instances(0) {} + + bool allocated; + int n_features; + int n_instances; + + GPUTrainingParam param; + + CubMemory cub_mem; + + thrust::device_vector fvalues; + thrust::device_vector foffsets; + thrust::device_vector instance_id; + thrust::device_vector feature_id; + thrust::device_vector node_id; + thrust::device_vector node_id_temp; + thrust::device_vector node_id_instance; + thrust::device_vector node_id_instance_temp; + thrust::device_vector gpair; + thrust::device_vector nodes; + thrust::device_vector split_candidates; + + thrust::device_vector items; + thrust::device_vector items_temp; + + thrust::device_vector node_sums; + thrust::device_vector node_offsets; + thrust::device_vector sort_index_in; + thrust::device_vector sort_index_out; + + void Init(const std::vector &in_fvalues, + const std::vector &in_foffsets, + const std::vector &in_instance_id, + const std::vector &in_feature_id, + const std::vector &in_gpair, bst_uint n_instances_in, + bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) { + Timer t; + n_features = n_features_in; + n_instances = n_instances_in; + + fvalues = in_fvalues; + foffsets = in_foffsets; + instance_id = in_instance_id; + feature_id = in_feature_id; + + param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, + param_in.reg_alpha, param_in.max_delta_step); + + gpair = thrust::device_vector(in_gpair.begin(), in_gpair.end()); + + uint32_t max_nodes_level = 1 << max_depth; + + node_sums = thrust::device_vector(max_nodes_level * n_features); + node_offsets = thrust::device_vector(max_nodes_level * n_features); + + node_id_instance = thrust::device_vector(n_instances, 0); + + node_id = thrust::device_vector(fvalues.size(), 0); + node_id_temp = thrust::device_vector(fvalues.size()); + + uint32_t max_nodes = (1 << (max_depth + 1)) - 1; + nodes = thrust::device_vector(max_nodes); + + split_candidates = + thrust::device_vector(max_nodes_level * n_features); + allocated = true; + + // Init items + items = thrust::device_vector(fvalues.size()); + items_temp = thrust::device_vector(fvalues.size()); + + sort_index_in = thrust::device_vector(fvalues.size()); + sort_index_out = thrust::device_vector(fvalues.size()); + + this->CreateItems(); + } + + ~GPUData() {} + + // Create items array using gpair, instaoce_id, fvalue + void CreateItems() { + auto d_items = items.data(); + auto d_instance_id = instance_id.data(); + auto d_gpair = gpair.data(); + auto d_fvalue = fvalues.data(); + + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(counting, counting + fvalues.size(), + [=] __device__(bst_uint i) { + Item item; + item.instance_id = d_instance_id[i]; + item.fvalue = d_fvalue[i]; + item.gpair = d_gpair[item.instance_id]; + d_items[i] = item; + }); + } + + // Reset memory for new boosting iteration + void Reset(const std::vector &in_gpair, + const std::vector &in_fvalues, + const std::vector &in_instance_id) { + CHECK(allocated); + thrust::copy(in_gpair.begin(), in_gpair.end(), gpair.begin()); + thrust::fill(nodes.begin(), nodes.end(), Node()); + thrust::fill(node_id_instance.begin(), node_id_instance.end(), 0); + thrust::fill(node_id.begin(), node_id.end(), 0); + + this->CreateItems(); + } + + bool IsAllocated() { return allocated; } + + // Gather from node_id_instance into node_id according to instance_id + void GatherNodeId() { + // Update node_id for each item + auto d_items = items.data(); + auto d_node_id = node_id.data(); + auto d_node_id_instance = node_id_instance.data(); + + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(counting, counting + fvalues.size(), + [=] __device__(bst_uint i) { + Item item = d_items[i]; + d_node_id[i] = d_node_id_instance[item.instance_id]; + }); + } +}; + +GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); } + +void GPUBuilder::Init(const TrainParam ¶m_in) { param = param_in; } + +GPUBuilder::~GPUBuilder() { delete gpu_data; } + +template +__global__ void update_nodeid_missing_kernel(NodeIdT *d_node_id_instance, + Node *d_nodes, const OffsetT n) { + for (auto i : grid_stride_range(OffsetT(0), n)) { + NodeIdT item_node_id = d_node_id_instance[i]; + + if (item_node_id < 0) { + continue; + } + Node node = d_nodes[item_node_id]; + + if (node.IsLeaf()) { + d_node_id_instance[i] = -1; + } else if (node.split.missing_left) { + d_node_id_instance[i] = item_node_id * 2 + 1; + } else { + d_node_id_instance[i] = item_node_id * 2 + 2; + } + } +} + +__device__ void load_as_words(const int n_nodes, Node *d_nodes, Node *s_nodes) { + const int upper_range = n_nodes * (sizeof(Node) / sizeof(int)); + for (auto i : block_stride_range(0, upper_range)) { + reinterpret_cast(s_nodes)[i] = reinterpret_cast(d_nodes)[i]; + } +} + +template +__global__ void +update_nodeid_fvalue_kernel(NodeIdT *d_node_id, NodeIdT *d_node_id_instance, + Item *d_items, Node *d_nodes, const int n_nodes, + const int *d_feature_id, const size_t n, + const int n_features, bool cache_nodes) { + // Load nodes into shared memory + extern __shared__ Node s_nodes[]; + + if (cache_nodes) { + load_as_words(n_nodes, d_nodes, s_nodes); + __syncthreads(); + } + + for (auto i : grid_stride_range(size_t(0), n)) { + Item item = d_items[i]; + NodeIdT item_node_id = d_node_id[i]; + + if (item_node_id < 0) { + continue; + } + + Node node = cache_nodes ? s_nodes[item_node_id] : d_nodes[item_node_id]; + + if (node.IsLeaf()) { + continue; + } + + int feature_id = d_feature_id[i]; + + if (feature_id == node.split.findex) { + if (item.fvalue < node.split.fvalue) { + d_node_id_instance[item.instance_id] = item_node_id * 2 + 1; + } else { + d_node_id_instance[item.instance_id] = item_node_id * 2 + 2; + } + } + } +} + +void GPUBuilder::UpdateNodeId(int level) { + // Update all nodes based on missing direction + { + const bst_uint n = gpu_data->node_id_instance.size(); + const bst_uint ITEMS_PER_THREAD = 8; + const bst_uint BLOCK_THREADS = 256; + const bst_uint GRID_SIZE = + div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); + + update_nodeid_missing_kernel< + ITEMS_PER_THREAD><<>>( + raw(gpu_data->node_id_instance), raw(gpu_data->nodes), n); + + safe_cuda(cudaDeviceSynchronize()); + } + + // Update node based on fvalue where exists + { + const bst_uint n = gpu_data->fvalues.size(); + const bst_uint ITEMS_PER_THREAD = 4; + const bst_uint BLOCK_THREADS = 256; + const bst_uint GRID_SIZE = + div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); + + // Use smem cache version if possible + const bool cache_nodes = level < 7; + int n_nodes = (1 << (level + 1)) - 1; + int smem_size = cache_nodes ? sizeof(Node) * n_nodes : 0; + update_nodeid_fvalue_kernel< + ITEMS_PER_THREAD><<>>( + raw(gpu_data->node_id), raw(gpu_data->node_id_instance), + raw(gpu_data->items), raw(gpu_data->nodes), n_nodes, + raw(gpu_data->feature_id), gpu_data->fvalues.size(), + gpu_data->n_features, cache_nodes); + + safe_cuda(cudaGetLastError()); + safe_cuda(cudaDeviceSynchronize()); + } + + gpu_data->GatherNodeId(); +} + +void GPUBuilder::Sort(int level) { + thrust::sequence(gpu_data->sort_index_in.begin(), + gpu_data->sort_index_in.end()); + if (!gpu_data->cub_mem.IsAllocated()) { + cub::DeviceSegmentedRadixSort::SortPairs( + gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes, + raw(gpu_data->node_id), raw(gpu_data->node_id_temp), + raw(gpu_data->sort_index_in), raw(gpu_data->sort_index_out), + gpu_data->fvalues.size(), gpu_data->n_features, raw(gpu_data->foffsets), + raw(gpu_data->foffsets) + 1); + gpu_data->cub_mem.Allocate(); + } + + cub::DeviceSegmentedRadixSort::SortPairs( + gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes, + raw(gpu_data->node_id), raw(gpu_data->node_id_temp), + raw(gpu_data->sort_index_in), raw(gpu_data->sort_index_out), + gpu_data->fvalues.size(), gpu_data->n_features, raw(gpu_data->foffsets), + raw(gpu_data->foffsets) + 1); + + thrust::gather(gpu_data->sort_index_out.begin(), + gpu_data->sort_index_out.end(), gpu_data->items.begin(), + gpu_data->items_temp.begin()); + + thrust::copy(gpu_data->items_temp.begin(), gpu_data->items_temp.end(), + gpu_data->items.begin()); + thrust::copy(gpu_data->node_id_temp.begin(), gpu_data->node_id_temp.end(), + gpu_data->node_id.begin()); +} + +void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, + RegTree *p_tree) { + try { + Timer update; + Timer t; + this->InitData(gpair, *p_fmat, *p_tree); + t.printElapsed("init data"); + this->InitFirstNode(); + + for (int level = 0; level < param.max_depth; level++) { + bool use_multiscan_algorithm = level < multiscan_levels; + + t.reset(); + if (level > 0) { + Timer update_node; + this->UpdateNodeId(level); + update_node.printElapsed("node"); + } + + if (level > 0 && !use_multiscan_algorithm) { + Timer s; + this->Sort(level); + s.printElapsed("sort"); + } + + Timer split; + find_split(raw(gpu_data->items), raw(gpu_data->split_candidates), + raw(gpu_data->node_id), raw(gpu_data->nodes), + (bst_uint)gpu_data->fvalues.size(), gpu_data->n_features, + raw(gpu_data->foffsets), raw(gpu_data->node_sums), + raw(gpu_data->node_offsets), gpu_data->param, level, + use_multiscan_algorithm); + + split.printElapsed("split"); + + t.printElapsed("level"); + } + this->CopyTree(*p_tree); + update.printElapsed("update"); + } catch (thrust::system_error &e) { + std::cerr << "CUDA error: " << e.what() << std::endl; + exit(-1); + } +} + +void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, + const RegTree &tree) { + CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) + << "ColMaker: can only grow new tree"; + + CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block"; + + if (gpu_data->IsAllocated()) { + gpu_data->Reset(gpair, fvalues, instance_id); + return; + } + + Timer t; + + MetaInfo info = fmat.info(); + dmlc::DataIter *iter = fmat.ColIterator(); + + std::vector foffsets; + foffsets.push_back(0); + std::vector feature_id; + fvalues.reserve(info.num_col * info.num_row); + instance_id.reserve(info.num_col * info.num_row); + feature_id.reserve(info.num_col * info.num_row); + + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + + for (int i = 0; i < batch.size; i++) { + const ColBatch::Inst &col = batch[i]; + + for (const ColBatch::Entry *it = col.data; it != col.data + col.length; + it++) { + fvalues.push_back(it->fvalue); + instance_id.push_back(it->index); + feature_id.push_back(i); + } + foffsets.push_back(fvalues.size()); + } + } + + t.printElapsed("dmatrix"); + t.reset(); + gpu_data->Init(fvalues, foffsets, instance_id, feature_id, gpair, + info.num_row, info.num_col, param.max_depth, param); + + t.printElapsed("gpu init"); +} + +void GPUBuilder::InitFirstNode() { + // Build the root node on the CPU and copy to device + gpu_gpair sum_gradients = + thrust::reduce(gpu_data->gpair.begin(), gpu_data->gpair.end(), + gpu_gpair(0, 0), cub::Sum()); + + gpu_data->nodes[0] = Node( + sum_gradients, + CalcGain(gpu_data->param, sum_gradients.grad(), sum_gradients.hess()), + CalcWeight(gpu_data->param, sum_gradients.grad(), sum_gradients.hess())); +} + +enum NodeType { + NODE = 0, + LEAF = 1, + UNUSED = 2, +}; + +// Recursively label node types +void flag_nodes(const thrust::host_vector &nodes, + std::vector *node_flags, int nid, NodeType type) { + if (nid >= nodes.size() || type == UNUSED) { + return; + } + + const Node &n = nodes[nid]; + + // Current node and all children are valid + if (n.split.loss_chg > rt_eps) { + (*node_flags)[nid] = NODE; + flag_nodes(nodes, node_flags, nid * 2 + 1, NODE); + flag_nodes(nodes, node_flags, nid * 2 + 2, NODE); + } else { + // Current node is leaf, therefore is valid but all children are invalid + (*node_flags)[nid] = LEAF; + flag_nodes(nodes, node_flags, nid * 2 + 1, UNUSED); + flag_nodes(nodes, node_flags, nid * 2 + 2, UNUSED); + } +} + +// Copy gpu dense representation of tree to xgboost sparse representation +void GPUBuilder::CopyTree(RegTree &tree) { + thrust::host_vector h_nodes = gpu_data->nodes; + std::vector node_flags(h_nodes.size(), UNUSED); + flag_nodes(h_nodes, &node_flags, 0, NODE); + + int nid = 0; + for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { + NodeType flag = node_flags[gpu_nid]; + const Node &n = h_nodes[gpu_nid]; + if (flag == NODE) { + tree.AddChilds(nid); + tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left); + tree.stat(nid).loss_chg = n.split.loss_chg; + tree.stat(nid).base_weight = n.weight; + tree.stat(nid).sum_hess = n.sum_gradients.hess(); + tree[tree[nid].cleft()].set_leaf(0); + tree[tree[nid].cright()].set_leaf(0); + nid++; + } else if (flag == LEAF) { + tree[nid].set_leaf(n.weight * param.learning_rate); + tree.stat(nid).sum_hess = n.sum_gradients.hess(); + nid++; + } + } +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh new file mode 100644 index 000000000..abfecefcd --- /dev/null +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -0,0 +1,46 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include +#include +#include "../../src/tree/param.h" + +namespace xgboost { + +namespace tree { + +struct gpu_gpair; +struct GPUData; + +class GPUBuilder { + public: + GPUBuilder(); + void Init(const TrainParam ¶m); + ~GPUBuilder(); + + void Update(const std::vector &gpair, DMatrix *p_fmat, + RegTree *p_tree); + + private: + void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT + const RegTree &tree); + + void UpdateNodeId(int level); + void Sort(int level); + void InitFirstNode(); + void CopyTree(RegTree &tree); // NOLINT + + TrainParam param; + GPUData *gpu_data; + + // Keep host copies of these arrays as the device versions change between + // boosting iterations + std::vector fvalues; + std::vector instance_id; + + int multiscan_levels = + 5; // Number of levels before switching to sorting algorithm +}; +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/loss_functions.cuh b/plugin/updater_gpu/src/loss_functions.cuh new file mode 100644 index 000000000..00796d051 --- /dev/null +++ b/plugin/updater_gpu/src/loss_functions.cuh @@ -0,0 +1,52 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include "types.cuh" +#include "../../../src/tree/param.h" + +// When we split on a value which has no left neighbour, define its left +// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS +// This produces a split value slightly lower than the current instance +#define FVALUE_EPS 0.0001 + +namespace xgboost { +namespace tree { + + +__device__ __forceinline__ float +device_calc_loss_chg(const GPUTrainingParam ¶m, const gpu_gpair &scan, + const gpu_gpair &missing, const gpu_gpair &parent_sum, + const float &parent_gain, bool missing_left) { + gpu_gpair left = scan; + + if (missing_left) { + left += missing; + } + + gpu_gpair right = parent_sum - left; + + float left_gain = CalcGain(param, left.grad(), left.hess()); + float right_gain = CalcGain(param, right.grad(), right.hess()); + return left_gain + right_gain - parent_gain; +} + +__device__ __forceinline__ float +loss_chg_missing(const gpu_gpair &scan, const gpu_gpair &missing, + const gpu_gpair &parent_sum, const float &parent_gain, + const GPUTrainingParam ¶m, bool &missing_left_out) { // NOLINT + float missing_left_loss = + device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true); + float missing_right_loss = device_calc_loss_chg( + param, scan, missing, parent_sum, parent_gain, false); + + if (missing_left_loss >= missing_right_loss) { + missing_left_out = true; + return missing_left_loss; + } else { + missing_left_out = false; + return missing_right_loss; + } +} +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/types.cuh b/plugin/updater_gpu/src/types.cuh new file mode 100644 index 000000000..8a9984416 --- /dev/null +++ b/plugin/updater_gpu/src/types.cuh @@ -0,0 +1,185 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include + +namespace xgboost { +namespace tree { + +typedef int32_t NodeIdT; + +// gpair type defined with device accessible functions +struct gpu_gpair { + float _grad; + float _hess; + + __host__ __device__ __forceinline__ float grad() const { return _grad; } + + __host__ __device__ __forceinline__ float hess() const { return _hess; } + + __host__ __device__ gpu_gpair() : _grad(0), _hess(0) {} + + __host__ __device__ gpu_gpair(float g, float h) : _grad(g), _hess(h) {} + + __host__ __device__ gpu_gpair(bst_gpair gpair) + : _grad(gpair.grad), _hess(gpair.hess) {} + + __host__ __device__ bool operator==(const gpu_gpair &rhs) const { + return (_grad == rhs._grad) && (_hess == rhs._hess); + } + + __host__ __device__ bool operator!=(const gpu_gpair &rhs) const { + return !(*this == rhs); + } + + __host__ __device__ gpu_gpair &operator+=(const gpu_gpair &rhs) { + _grad += rhs._grad; + _hess += rhs._hess; + return *this; + } + + __host__ __device__ gpu_gpair operator+(const gpu_gpair &rhs) const { + gpu_gpair g; + g._grad = _grad + rhs._grad; + g._hess = _hess + rhs._hess; + return g; + } + + __host__ __device__ gpu_gpair &operator-=(const gpu_gpair &rhs) { + _grad -= rhs._grad; + _hess -= rhs._hess; + return *this; + } + + __host__ __device__ gpu_gpair operator-(const gpu_gpair &rhs) const { + gpu_gpair g; + g._grad = _grad - rhs._grad; + g._hess = _hess - rhs._hess; + return g; + } + + friend std::ostream &operator<<(std::ostream &os, const gpu_gpair &g) { + os << g.grad() << "/" << g.hess(); + return os; + } + + __host__ __device__ void print() const { + printf("%1.4f/%1.4f\n", grad(), hess()); + } + + __host__ __device__ bool approximate_compare(const gpu_gpair &b, + float g_eps = 0.1, + float h_eps = 0.1) const { + float gdiff = abs(this->grad() - b.grad()); + float hdiff = abs(this->hess() - b.hess()); + + return (gdiff <= g_eps) && (hdiff <= h_eps); + } +}; + +struct Item { + bst_uint instance_id; + float fvalue; + gpu_gpair gpair; +}; + +struct GPUTrainingParam { + // minimum amount of hessian(weight) allowed in a child + float min_child_weight; + // L2 regularization factor + float reg_lambda; + // L1 regularization factor + float reg_alpha; + // maximum delta update we can add in weight estimation + // this parameter can be used to stabilize update + // default=0 means no constraint on weight delta + float max_delta_step; + + __host__ __device__ GPUTrainingParam() {} + + __host__ __device__ GPUTrainingParam(float min_child_weight_in, + float reg_lambda_in, float reg_alpha_in, + float max_delta_step_in) + : min_child_weight(min_child_weight_in), reg_lambda(reg_lambda_in), + reg_alpha(reg_alpha_in), max_delta_step(max_delta_step_in) {} +}; + +struct Split { + float loss_chg; + bool missing_left; + float fvalue; + int findex; + gpu_gpair left_sum; + gpu_gpair right_sum; + + __host__ __device__ Split() + : loss_chg(-FLT_MAX), missing_left(true), fvalue(0) {} + + __device__ void Update(float loss_chg_in, bool missing_left_in, + float fvalue_in, int findex_in, gpu_gpair left_sum_in, + gpu_gpair right_sum_in, + const GPUTrainingParam ¶m) { + if (loss_chg_in > loss_chg && left_sum_in.hess() > param.min_child_weight && + right_sum_in.hess() > param.min_child_weight) { + loss_chg = loss_chg_in; + missing_left = missing_left_in; + fvalue = fvalue_in; + left_sum = left_sum_in; + right_sum = right_sum_in; + findex = findex_in; + } + } + + // Does not check minimum weight + __device__ void Update(Split &s) { // NOLINT + if (s.loss_chg > loss_chg) { + loss_chg = s.loss_chg; + missing_left = s.missing_left; + fvalue = s.fvalue; + findex = s.findex; + left_sum = s.left_sum; + right_sum = s.right_sum; + } + } + + __device__ void Print() { + printf("Loss: %1.4f\n", loss_chg); + printf("Missing left: %d\n", missing_left); + printf("fvalue: %1.4f\n", fvalue); + printf("Left sum: "); + left_sum.print(); + + printf("Right sum: "); + right_sum.print(); + } +}; + +struct split_reduce_op { + template + __device__ __forceinline__ T operator()(T &a, T b) { // NOLINT + b.Update(a); + return b; + } +}; + +struct Node { + gpu_gpair sum_gradients; + float root_gain; + float weight; + + Split split; + + __host__ __device__ Node() : weight(0), root_gain(0) {} + + __host__ __device__ Node(gpu_gpair sum_gradients_in, float root_gain_in, + float weight_in) { + sum_gradients = sum_gradients_in; + root_gain = root_gain_in; + weight = weight_in; + } + + __host__ __device__ bool IsLeaf() { return split.loss_chg == -FLT_MAX; } +}; +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/src/types_functions.cuh b/plugin/updater_gpu/src/types_functions.cuh new file mode 100644 index 000000000..f7bd8e65f --- /dev/null +++ b/plugin/updater_gpu/src/types_functions.cuh @@ -0,0 +1,6 @@ +/*! + * Copyright 2016 Rory mitchell +*/ +#pragma once +#include "types.cuh" +#include "loss_functions.cuh" diff --git a/plugin/updater_gpu/src/updater_gpu.cc b/plugin/updater_gpu/src/updater_gpu.cc new file mode 100644 index 000000000..4083b8bd5 --- /dev/null +++ b/plugin/updater_gpu/src/updater_gpu.cc @@ -0,0 +1,48 @@ +/*! + * Copyright 2016 Rory Mitchell + */ +#include +#include +#include "../../src/common/random.h" +#include "../../src/common/sync.h" +#include "../../src/tree/param.h" +#include "gpu_builder.cuh" + +namespace xgboost { +namespace tree { +DMLC_REGISTRY_FILE_TAG(updater_gpumaker); + +/*! \brief column-wise update to construct a tree */ +template class GPUMaker : public TreeUpdater { + public: + void + Init(const std::vector> &args) override { + param.InitAllowUnknown(args); + builder.Init(param); + } + + void Update(const std::vector &gpair, DMatrix *dmat, + const std::vector &trees) override { + TStats::CheckInfo(dmat->info()); + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + // build tree + for (size_t i = 0; i < trees.size(); ++i) { + builder.Update(gpair, dmat, trees[i]); + } + param.learning_rate = lr; + } + + protected: + // training parameter + TrainParam param; + GPUBuilder builder; +}; + +XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu") + .describe("Grow tree with GPU.") + .set_body([]() { return new GPUMaker(); }); + +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_gpu/test.py b/plugin/updater_gpu/test.py new file mode 100644 index 000000000..4258b8d02 --- /dev/null +++ b/plugin/updater_gpu/test.py @@ -0,0 +1,137 @@ +#pylint: skip-file +import numpy as np +import xgboost as xgb +import os +import pandas as pd +import urllib2 + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +def get_last_eval_callback(result): + + def callback(env): + result.append(env.evaluation_result_list[-1][1]) + + callback.after_iteration = True + return callback + + +def load_adult(): + path = "../../demo/data/adult.data" + + if(not os.path.isfile(path)): + data = urllib2.urlopen('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data') + with open(path,'wb') as output: + output.write(data.read()) + + train_set = pd.read_csv( path, header=None) + + train_set.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', + 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', + 'wage_class'] + train_nomissing = train_set.replace(' ?', np.nan).dropna() + for feature in train_nomissing.columns: # Loop through all columns in the dataframe + if train_nomissing[feature].dtype == 'object': # Only apply for columns with categorical strings + train_nomissing[feature] = pd.Categorical(train_nomissing[feature]).codes # Replace strings with an integer + + y_train = train_nomissing.pop('wage_class') + + return xgb.DMatrix( train_nomissing, label=y_train) + + +def load_higgs(): + higgs_path = '../../demo/data/training.csv' + dtrain = np.loadtxt(higgs_path, delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s'.encode('utf-8')) } ) + + #dtrain = dtrain[0:200000,:] + label = dtrain[:,32] + data = dtrain[:,1:31] + weight = dtrain[:,31] + + return xgb.DMatrix( data, label=label, missing = -999.0, weight=weight ) + +def load_dermatology(): + data = np.loadtxt('../../demo/data/dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } ) + sz = data.shape + + X = data[:,0:33] + Y = data[:, 34] + + return xgb.DMatrix( X, label=Y) + +def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): + return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) + +#Check GPU test evaluation is approximately equal to CPU test evaluation +def check_result(cpu_result, gpu_result): + for i in range(len(cpu_result)): + if not isclose(cpu_result[i], gpu_result[i], 0.1, 0.02): + return False + + return True + + +#Get data +data = [] +params = [] +data.append(load_higgs()) +params.append({}) + + +data.append( load_adult()) +params.append({}) + +data.append(xgb.DMatrix('../../demo/data/agaricus.txt.test')) +params.append({'objective':'binary:logistic'}) + +#if(os.path.isfile("../../demo/data/dermatology.data")): +data.append(load_dermatology()) +params.append({'objective':'multi:softmax', 'num_class': 6}) + +num_round = 5 + +num_pass = 0 +num_fail = 0 + +test_depth = [ 1, 6, 9, 11, 15 ] +#test_depth = [ 1 ] + +for test in range(0, len(data)): + for depth in test_depth: + xgmat = data[test] + cpu_result = [] + param = params[test] + param['max_depth'] = depth + param['updater'] = 'grow_colmaker' + xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(cpu_result)]) + + #bst = xgb.train( param, xgmat, 1); + #bst.dump_model('reference_model.txt','', True) + + gpu_result = [] + param['updater'] = 'grow_gpu' + xgb.cv(param, xgmat, num_round, verbose_eval=False, nfold=5, callbacks=[get_last_eval_callback(gpu_result)]) + + #bst = xgb.train( param, xgmat, 1); + #bst.dump_model('dump.raw.txt','', True) + + if check_result(cpu_result, gpu_result): + print(bcolors.OKGREEN + "Pass" + bcolors.ENDC) + num_pass = num_pass + 1 + else: + print(bcolors.FAIL + "Fail" + bcolors.ENDC) + num_fail = num_fail + 1 + + print("cpu rmse: "+str(cpu_result)) + print("gpu rmse: "+str(gpu_result)) + +print(str(num_pass)+"/"+str(num_pass + num_fail)+" passed") diff --git a/src/tree/param.h b/src/tree/param.h index d4254d84f..ea61bd14f 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -7,10 +7,16 @@ #ifndef XGBOOST_TREE_PARAM_H_ #define XGBOOST_TREE_PARAM_H_ -#include +#include #include #include -#include +#include + +#ifdef __NVCC__ +#define XGB_DEVICE __host__ __device__ +#else +#define XGB_DEVICE +#endif namespace xgboost { namespace tree { @@ -60,47 +66,80 @@ struct TrainParam : public dmlc::Parameter { std::vector monotone_constraints; // declare the parameters DMLC_DECLARE_PARAMETER(TrainParam) { - DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f) + DMLC_DECLARE_FIELD(learning_rate) + .set_lower_bound(0.0f) + .set_default(0.3f) .describe("Learning rate(step size) of update."); - DMLC_DECLARE_FIELD(min_split_loss).set_lower_bound(0.0f).set_default(0.0f) - .describe("Minimum loss reduction required to make a further partition."); - DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6) - .describe("Maximum depth of the tree."); - DMLC_DECLARE_FIELD(min_child_weight).set_lower_bound(0.0f).set_default(1.0f) + DMLC_DECLARE_FIELD(min_split_loss) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe( + "Minimum loss reduction required to make a further partition."); + DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6).describe( + "Maximum depth of the tree."); + DMLC_DECLARE_FIELD(min_child_weight) + .set_lower_bound(0.0f) + .set_default(1.0f) .describe("Minimum sum of instance weight(hessian) needed in a child."); - DMLC_DECLARE_FIELD(reg_lambda).set_lower_bound(0.0f).set_default(1.0f) + DMLC_DECLARE_FIELD(reg_lambda) + .set_lower_bound(0.0f) + .set_default(1.0f) .describe("L2 regularization on leaf weight"); - DMLC_DECLARE_FIELD(reg_alpha).set_lower_bound(0.0f).set_default(0.0f) + DMLC_DECLARE_FIELD(reg_alpha) + .set_lower_bound(0.0f) + .set_default(0.0f) .describe("L1 regularization on leaf weight"); - DMLC_DECLARE_FIELD(default_direction).set_default(0) + DMLC_DECLARE_FIELD(default_direction) + .set_default(0) .add_enum("learn", 0) .add_enum("left", 1) .add_enum("right", 2) .describe("Default direction choice when encountering a missing value"); - DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.0f) - .describe("Maximum delta step we allow each tree's weight estimate to be. "\ - "If the value is set to 0, it means there is no constraint"); - DMLC_DECLARE_FIELD(subsample).set_range(0.0f, 1.0f).set_default(1.0f) + DMLC_DECLARE_FIELD(max_delta_step) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe( + "Maximum delta step we allow each tree's weight estimate to be. " + "If the value is set to 0, it means there is no constraint"); + DMLC_DECLARE_FIELD(subsample) + .set_range(0.0f, 1.0f) + .set_default(1.0f) .describe("Row subsample ratio of training instance."); - DMLC_DECLARE_FIELD(colsample_bylevel).set_range(0.0f, 1.0f).set_default(1.0f) + DMLC_DECLARE_FIELD(colsample_bylevel) + .set_range(0.0f, 1.0f) + .set_default(1.0f) .describe("Subsample ratio of columns, resample on each level."); - DMLC_DECLARE_FIELD(colsample_bytree).set_range(0.0f, 1.0f).set_default(1.0f) - .describe("Subsample ratio of columns, resample on each tree construction."); - DMLC_DECLARE_FIELD(opt_dense_col).set_range(0.0f, 1.0f).set_default(1.0f) + DMLC_DECLARE_FIELD(colsample_bytree) + .set_range(0.0f, 1.0f) + .set_default(1.0f) + .describe( + "Subsample ratio of columns, resample on each tree construction."); + DMLC_DECLARE_FIELD(opt_dense_col) + .set_range(0.0f, 1.0f) + .set_default(1.0f) .describe("EXP Param: speed optimization for dense column."); - DMLC_DECLARE_FIELD(sketch_eps).set_range(0.0f, 1.0f).set_default(0.03f) + DMLC_DECLARE_FIELD(sketch_eps) + .set_range(0.0f, 1.0f) + .set_default(0.03f) .describe("EXP Param: Sketch accuracy of approximate algorithm."); - DMLC_DECLARE_FIELD(sketch_ratio).set_lower_bound(0.0f).set_default(2.0f) - .describe("EXP Param: Sketch accuracy related parameter of approximate algorithm."); - DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) + DMLC_DECLARE_FIELD(sketch_ratio) + .set_lower_bound(0.0f) + .set_default(2.0f) + .describe("EXP Param: Sketch accuracy related parameter of approximate " + "algorithm."); + DMLC_DECLARE_FIELD(size_leaf_vector) + .set_lower_bound(0) + .set_default(0) .describe("Size of leaf vectors, reserved for vector trees"); - DMLC_DECLARE_FIELD(parallel_option).set_default(0) + DMLC_DECLARE_FIELD(parallel_option) + .set_default(0) .describe("Different types of parallelization algorithm."); - DMLC_DECLARE_FIELD(cache_opt).set_default(true) - .describe("EXP Param: Cache aware optimization."); - DMLC_DECLARE_FIELD(silent).set_default(false) - .describe("Do not print information during trainig."); - DMLC_DECLARE_FIELD(monotone_constraints).set_default(std::vector()) + DMLC_DECLARE_FIELD(cache_opt).set_default(true).describe( + "EXP Param: Cache aware optimization."); + DMLC_DECLARE_FIELD(silent).set_default(false).describe( + "Do not print information during trainig."); + DMLC_DECLARE_FIELD(monotone_constraints) + .set_default(std::vector()) .describe("Constraint of variable monotinicity"); // add alias of parameters DMLC_DECLARE_ALIAS(reg_lambda, lambda); @@ -108,61 +147,11 @@ struct TrainParam : public dmlc::Parameter { DMLC_DECLARE_ALIAS(min_split_loss, gamma); DMLC_DECLARE_ALIAS(learning_rate, eta); } - // calculate the cost of loss function - inline double CalcGainGivenWeight(double sum_grad, - double sum_hess, - double w) const { - return -(2.0 * sum_grad * w + (sum_hess + reg_lambda) * Sqr(w)); - } - // calculate the cost of loss function - inline double CalcGain(double sum_grad, double sum_hess) const { - if (sum_hess < min_child_weight) return 0.0; - if (max_delta_step == 0.0f) { - if (reg_alpha == 0.0f) { - return Sqr(sum_grad) / (sum_hess + reg_lambda); - } else { - return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda); - } - } else { - double w = CalcWeight(sum_grad, sum_hess); - double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w); - if (reg_alpha == 0.0f) { - return - 2.0 * ret; - } else { - return - 2.0 * (ret + reg_alpha * std::abs(w)); - } - } - } - // calculate cost of loss function with four statistics - inline double CalcGain(double sum_grad, double sum_hess, - double test_grad, double test_hess) const { - double w = CalcWeight(sum_grad, sum_hess); - double ret = test_grad * w + 0.5 * (test_hess + reg_lambda) * Sqr(w); - if (reg_alpha == 0.0f) { - return - 2.0 * ret; - } else { - return - 2.0 * (ret + reg_alpha * std::abs(w)); - } - } - // calculate weight given the statistics - inline double CalcWeight(double sum_grad, double sum_hess) const { - if (sum_hess < min_child_weight) return 0.0; - double dw; - if (reg_alpha == 0.0f) { - dw = -sum_grad / (sum_hess + reg_lambda); - } else { - dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda); - } - if (max_delta_step != 0.0f) { - if (dw > max_delta_step) dw = max_delta_step; - if (dw < -max_delta_step) dw = -max_delta_step; - } - return dw; - } /*! \brief whether need forward small to big search: default right */ inline bool need_forward_search(float col_density, bool indicator) const { return this->default_direction == 2 || - (default_direction == 0 && (col_density < opt_dense_col) && !indicator); + (default_direction == 0 && (col_density < opt_dense_col) && + !indicator); } /*! \brief whether need backward big to small search: default left */ inline bool need_backward_search(float col_density, bool indicator) const { @@ -182,19 +171,85 @@ struct TrainParam : public dmlc::Parameter { CHECK_GT(ret, 0); return ret; } - - protected: - // functions for L1 cost - inline static double ThresholdL1(double w, double lambda) { - if (w > +lambda) return w - lambda; - if (w < -lambda) return w + lambda; - return 0.0; - } - inline static double Sqr(double a) { - return a * a; - } }; +/*! \brief Loss functions */ + +// functions for L1 cost +template +XGB_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) { + if (w > +lambda) + return w - lambda; + if (w < -lambda) + return w + lambda; + return 0.0; +} + +template +XGB_DEVICE inline static T Sqr(T a) { return a * a; } + +// calculate the cost of loss function +template +XGB_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad, + T sum_hess, T w) { + return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); +} + +// calculate the cost of loss function +template +XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) { + if (sum_hess < p.min_child_weight) + return 0.0; + if (p.max_delta_step == 0.0f) { + if (p.reg_alpha == 0.0f) { + return Sqr(sum_grad) / (sum_hess + p.reg_lambda); + } else { + return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / + (sum_hess + p.reg_lambda); + } + } else { + T w = CalcWeight(p, sum_grad, sum_hess); + T ret = sum_grad * w + 0.5 * (sum_hess + p.reg_lambda) * Sqr(w); + if (p.reg_alpha == 0.0f) { + return -2.0 * ret; + } else { + return -2.0 * (ret + p.reg_alpha * std::abs(w)); + } + } +} +// calculate cost of loss function with four statistics +template +XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, + T test_grad, T test_hess) { + T w = CalcWeight(sum_grad, sum_hess); + T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w); + if (p.reg_alpha == 0.0f) { + return -2.0 * ret; + } else { + return -2.0 * (ret + p.reg_alpha * std::abs(w)); + } +} +// calculate weight given the statistics +template +XGB_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, + T sum_hess) { + if (sum_hess < p.min_child_weight) + return 0.0; + T dw; + if (p.reg_alpha == 0.0f) { + dw = -sum_grad / (sum_hess + p.reg_lambda); + } else { + dw = -ThresholdL1(sum_grad, p.reg_alpha) / (sum_hess + p.reg_lambda); + } + if (p.max_delta_step != 0.0f) { + if (dw > p.max_delta_step) + dw = p.max_delta_step; + if (dw < -p.max_delta_step) + dw = -p.max_delta_step; + } + return dw; +} + /*! \brief core statistics used for tree construction */ struct GradStats { /*! \brief sum gradient statistics */ @@ -207,109 +262,87 @@ struct GradStats { */ static const int kSimpleStats = 1; /*! \brief constructor, the object must be cleared during construction */ - explicit GradStats(const TrainParam& param) { - this->Clear(); - } + explicit GradStats(const TrainParam ¶m) { this->Clear(); } /*! \brief clear the statistics */ - inline void Clear() { - sum_grad = sum_hess = 0.0f; - } + inline void Clear() { sum_grad = sum_hess = 0.0f; } /*! \brief check if necessary information is ready */ - inline static void CheckInfo(const MetaInfo& info) { - } + inline static void CheckInfo(const MetaInfo &info) {} /*! * \brief accumulate statistics * \param p the gradient pair */ - inline void Add(bst_gpair p) { - this->Add(p.grad, p.hess); - } + inline void Add(bst_gpair p) { this->Add(p.grad, p.hess); } /*! * \brief accumulate statistics, more complicated version * \param gpair the vector storing the gradient statistics * \param info the additional information * \param ridx instance index of this instance */ - inline void Add(const std::vector& gpair, - const MetaInfo& info, + inline void Add(const std::vector &gpair, const MetaInfo &info, bst_uint ridx) { - const bst_gpair& b = gpair[ridx]; + const bst_gpair &b = gpair[ridx]; this->Add(b.grad, b.hess); } /*! \brief calculate leaf weight */ - inline double CalcWeight(const TrainParam& param) const { - return param.CalcWeight(sum_grad, sum_hess); + inline double CalcWeight(const TrainParam ¶m) const { + return xgboost::tree::CalcWeight(param, sum_grad, sum_hess); } /*! \brief calculate gain of the solution */ - inline double CalcGain(const TrainParam& param) const { - return param.CalcGain(sum_grad, sum_hess); + inline double CalcGain(const TrainParam ¶m) const { + return xgboost::tree::CalcGain(param, sum_grad, sum_hess); } /*! \brief add statistics to the data */ - inline void Add(const GradStats& b) { - this->Add(b.sum_grad, b.sum_hess); - } + inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); } /*! \brief same as add, reduce is used in All Reduce */ - inline static void Reduce(GradStats& a, const GradStats& b) { // NOLINT(*) + inline static void Reduce(GradStats &a, const GradStats &b) { // NOLINT(*) a.Add(b); } /*! \brief set current value to a - b */ - inline void SetSubstract(const GradStats& a, const GradStats& b) { + inline void SetSubstract(const GradStats &a, const GradStats &b) { sum_grad = a.sum_grad - b.sum_grad; sum_hess = a.sum_hess - b.sum_hess; } /*! \return whether the statistics is not used yet */ - inline bool Empty() const { - return sum_hess == 0.0; - } + inline bool Empty() const { return sum_hess == 0.0; } /*! \brief set leaf vector value based on statistics */ - inline void SetLeafVec(const TrainParam& param, bst_float *vec) const { - } + inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const {} // constructor to allow inheritance GradStats() {} /*! \brief add statistics to the data */ inline void Add(double grad, double hess) { - sum_grad += grad; sum_hess += hess; + sum_grad += grad; + sum_hess += hess; } }; struct NoConstraint { - inline static void Init(TrainParam* param, unsigned num_feature) { - } - inline double CalcSplitGain( - const TrainParam& param, bst_uint split_index, - GradStats left, GradStats right) const { + inline static void Init(TrainParam *param, unsigned num_feature) {} + inline double CalcSplitGain(const TrainParam ¶m, bst_uint split_index, + GradStats left, GradStats right) const { return left.CalcGain(param) + right.CalcGain(param); } - inline double CalcWeight( - const TrainParam& param, - GradStats stats) const { + inline double CalcWeight(const TrainParam ¶m, GradStats stats) const { return stats.CalcWeight(param); } - inline double CalcGain(const TrainParam& param, - GradStats stats) const { + inline double CalcGain(const TrainParam ¶m, GradStats stats) const { return stats.CalcGain(param); } - inline void SetChild( - const TrainParam& param, bst_uint split_index, - GradStats left, GradStats right, - NoConstraint* cleft, NoConstraint* cright) { - } + inline void SetChild(const TrainParam ¶m, bst_uint split_index, + GradStats left, GradStats right, NoConstraint *cleft, + NoConstraint *cright) {} }; struct ValueConstraint { double lower_bound; double upper_bound; - ValueConstraint() : - lower_bound(-std::numeric_limits::max()), - upper_bound(std::numeric_limits::max()) { - } - inline static void Init(TrainParam* param, unsigned num_feature) { + ValueConstraint() + : lower_bound(-std::numeric_limits::max()), + upper_bound(std::numeric_limits::max()) {} + inline static void Init(TrainParam *param, unsigned num_feature) { param->monotone_constraints.resize(num_feature, 1); } - inline double CalcWeight( - const TrainParam& param, - GradStats stats) const { - double w = stats.CalcWeight(param); + inline double CalcWeight(const TrainParam ¶m, GradStats stats) const { + double w = stats.CalcWeight(param); if (w < lower_bound) { return lower_bound; } @@ -319,41 +352,36 @@ struct ValueConstraint { return w; } - inline double CalcGain(const TrainParam& param, - GradStats stats) const { - return param.CalcGainGivenWeight( - stats.sum_grad, stats.sum_hess, - CalcWeight(param, stats)); + inline double CalcGain(const TrainParam ¶m, GradStats stats) const { + return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, + CalcWeight(param, stats)); } - inline double CalcSplitGain( - const TrainParam& param, - bst_uint split_index, - GradStats left, GradStats right) const { + inline double CalcSplitGain(const TrainParam ¶m, bst_uint split_index, + GradStats left, GradStats right) const { double wleft = CalcWeight(param, left); double wright = CalcWeight(param, right); int c = param.monotone_constraints[split_index]; double gain = - param.CalcGainGivenWeight(left.sum_grad, left.sum_hess, wleft) + - param.CalcGainGivenWeight(right.sum_grad, right.sum_hess, wright); + CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) + + CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright); if (c == 0) { return gain; - } else if (c > 0) { + } else if (c > 0) { return wleft < wright ? gain : 0.0; } else { return wleft > wright ? gain : 0.0; } } - inline void SetChild( - const TrainParam& param, - bst_uint split_index, - GradStats left, GradStats right, - ValueConstraint* cleft, ValueConstraint *cright) { + inline void SetChild(const TrainParam ¶m, bst_uint split_index, + GradStats left, GradStats right, ValueConstraint *cleft, + ValueConstraint *cright) { int c = param.monotone_constraints.at(split_index); *cleft = *this; *cright = *this; - if (c == 0) return; + if (c == 0) + return; double wleft = CalcWeight(param, left); double wright = CalcWeight(param, right); double mid = (wleft + wright) / 2; @@ -382,9 +410,12 @@ struct SplitEntry { /*! \brief constructor */ SplitEntry() : loss_chg(0.0f), sindex(0), split_value(0.0f) {} /*! - * \brief decides whether we can replace current entry with the given statistics - * This function gives better priority to lower index when loss_chg == new_loss_chg. - * Not the best way, but helps to give consistent result during multi-thread execution. + * \brief decides whether we can replace current entry with the given + * statistics + * This function gives better priority to lower index when loss_chg == + * new_loss_chg. + * Not the best way, but helps to give consistent result during multi-thread + * execution. * \param new_loss_chg the loss reduction get through the split * \param split_index the feature index where the split is on */ @@ -400,7 +431,7 @@ struct SplitEntry { * \param e candidate split solution * \return whether the proposed split is better and can replace current split */ - inline bool Update(const SplitEntry& e) { + inline bool Update(const SplitEntry &e) { if (this->NeedReplace(e.loss_chg, e.split_index())) { this->loss_chg = e.loss_chg; this->sindex = e.sindex; @@ -422,7 +453,8 @@ struct SplitEntry { float new_split_value, bool default_left) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; - if (default_left) split_index |= (1U << 31); + if (default_left) + split_index |= (1U << 31); this->sindex = split_index; this->split_value = new_split_value; return true; @@ -431,17 +463,14 @@ struct SplitEntry { } } /*! \brief same as update, used by AllReduce*/ - inline static void Reduce(SplitEntry& dst, const SplitEntry& src) { // NOLINT(*) + inline static void Reduce(SplitEntry &dst, // NOLINT(*) + const SplitEntry &src) { // NOLINT(*) dst.Update(src); } /*!\return feature index to split on */ - inline unsigned split_index() const { - return sindex & ((1U << 31) - 1U); - } + inline unsigned split_index() const { return sindex & ((1U << 31) - 1U); } /*!\return whether missing value goes to left branch */ - inline bool default_left() const { - return (sindex >> 31) != 0; - } + inline bool default_left() const { return (sindex >> 31) != 0; } }; } // namespace tree @@ -451,13 +480,14 @@ struct SplitEntry { namespace std { inline std::ostream &operator<<(std::ostream &os, const std::vector &t) { os << '('; - for (std::vector::const_iterator - it = t.begin(); it != t.end(); ++it) { - if (it != t.begin()) os << ','; + for (std::vector::const_iterator it = t.begin(); it != t.end(); ++it) { + if (it != t.begin()) + os << ','; os << *it; } // python style tuple - if (t.size() == 1) os << ','; + if (t.size() == 1) + os << ','; os << ')'; return os; } @@ -474,7 +504,8 @@ inline std::istream &operator>>(std::istream &is, std::vector &t) { return is; } is.get(); - if (ch == '(') break; + if (ch == '(') + break; if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; @@ -495,14 +526,17 @@ inline std::istream &operator>>(std::istream &is, std::vector &t) { while (true) { ch = is.peek(); if (isspace(ch)) { - is.get(); continue; + is.get(); + continue; } if (ch == ')') { - is.get(); break; + is.get(); + break; } break; } - if (ch == ')') break; + if (ch == ')') + break; } else if (ch == ')') { break; } else { diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index c0d62ce5e..19e54e79f 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -107,7 +107,7 @@ class SketchMaker: public BaseMaker { } /*! \brief calculate gain of the solution */ inline double CalcGain(const TrainParam ¶m) const { - return param.CalcGain(pos_grad - neg_grad, sum_hess); + return xgboost::tree::CalcGain(param, pos_grad - neg_grad, sum_hess); } /*! \brief set current value to a - b */ inline void SetSubstract(const SKStats &a, const SKStats &b) { @@ -117,7 +117,7 @@ class SketchMaker: public BaseMaker { } // calculate leaf weight inline double CalcWeight(const TrainParam ¶m) const { - return param.CalcWeight(pos_grad - neg_grad, sum_hess); + return xgboost::tree::CalcWeight(param, pos_grad - neg_grad, sum_hess); } /*! \brief add statistics to the data */ inline void Add(const SKStats &b) {