-Add experimental GPU algorithm for lossguided mode (#2755)

-Improved GPU algorithm unit tests
-Removed some thrust code to improve compile times
This commit is contained in:
Rory Mitchell 2017-10-01 00:18:35 +13:00 committed by GitHub
parent 69c3b78a29
commit 4cb2f7598b
14 changed files with 1291 additions and 593 deletions

View File

@ -8,6 +8,7 @@
#include <dmlc/base.h>
#include <dmlc/omp.h>
#include <cmath>
/*!
* \brief string flag for R library, to leave hooks when needed.
@ -163,7 +164,7 @@ class bst_gpair_internal {
friend std::ostream &operator<<(std::ostream &os,
const bst_gpair_internal<T> &g) {
os << g.grad_ << "/" << g.hess_;
os << g.GetGrad() << "/" << g.GetHess();
return os;
}
};
@ -178,11 +179,11 @@ inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
grad_ = g * 1e5;
grad_ = std::round(g * 1e5);
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
hess_ = h * 1e5;
hess_ = std::round(h * 1e5);
}
} // namespace detail

View File

@ -2,9 +2,9 @@
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/system_error.h>
#include <xgboost/logging.h>
#include <algorithm>
@ -58,10 +58,20 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
return code;
}
template <typename T>
T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T>
const T *raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
inline int n_visible_devices() {
int n_visgpus = 0;
cudaGetDeviceCount(&n_visgpus);
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
return n_visgpus;
}
@ -127,29 +137,6 @@ inline int get_device_idx(int gpu_id) {
return (std::abs(gpu_id) + 0) % dh::n_visible_devices();
}
/*
* Timers
*/
struct Timer {
typedef std::chrono::high_resolution_clock ClockT;
typedef std::chrono::high_resolution_clock::time_point TimePointT;
TimePointT start;
Timer() { reset(); }
void reset() { start = ClockT::now(); }
int64_t elapsed() const { return (ClockT::now() - start).count(); }
double elapsedSeconds() const {
return elapsed() * ((double)ClockT::period::num / ClockT::period::den);
}
void printElapsed(std::string label) {
// synchronize_n_devices(n_devices, dList);
printf("%s:\t %fs\n", label.c_str(), elapsedSeconds());
reset();
}
};
/*
* Range iterator
*/
@ -224,6 +211,68 @@ __device__ void block_fill(IterT begin, size_t n, ValueT value) {
}
}
/*
* Kernel launcher
*/
template <typename T1, typename T2>
T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename L>
__global__ void launch_n_kernel(size_t begin, size_t end, L lambda) {
for (auto i : grid_stride_range(begin, end)) {
lambda(i);
}
}
template <typename L>
__global__ void launch_n_kernel(int device_idx, size_t begin, size_t end,
L lambda) {
for (auto i : grid_stride_range(begin, end)) {
lambda(i, device_idx);
}
}
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void launch_n(int device_idx, size_t n, L lambda) {
if (n == 0) {
return;
}
safe_cuda(cudaSetDevice(device_idx));
// TODO: Template on n so GRID_SIZE always fits into int.
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
lambda);
}
/*
* Timers
*/
struct Timer {
typedef std::chrono::high_resolution_clock ClockT;
typedef std::chrono::high_resolution_clock::time_point TimePointT;
typedef std::chrono::high_resolution_clock::duration DurationT;
typedef std::chrono::duration<double> SecondsT;
TimePointT start;
DurationT elapsed;
Timer() { Reset(); }
void Reset() {
elapsed = DurationT::zero();
Start();
}
void Start() { start = ClockT::now(); }
void Stop() { elapsed += ClockT::now() - start; }
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
void PrintElapsed(std::string label) {
printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count());
Reset();
}
};
/*
* Memory
*/
@ -273,8 +322,9 @@ class dvec {
}
void fill(T value) {
safe_cuda(cudaSetDevice(_device_idx));
thrust::fill_n(thrust::device_pointer_cast(_ptr), size(), value);
auto d_ptr = _ptr;
launch_n(_device_idx, size(),
[=] __device__(size_t idx) { d_ptr[idx] = value; });
}
void print() {
@ -304,7 +354,9 @@ class dvec {
}
safe_cuda(cudaSetDevice(this->device_idx()));
if (other.device_idx() == this->device_idx()) {
thrust::copy(other.tbegin(), other.tend(), this->tbegin());
dh::safe_cuda(cudaMemcpy(this->data(), other.data(),
other.size() * sizeof(T),
cudaMemcpyDeviceToDevice));
} else {
std::cout << "deviceother: " << other.device_idx()
<< " devicethis: " << this->device_idx() << std::endl;
@ -496,6 +548,12 @@ struct CubMemory {
~CubMemory() { Free(); }
template <typename T>
T* Pointer()
{
return static_cast<T*>(d_temp_storage);
}
void Free() {
if (this->IsAllocated()) {
safe_cuda(cudaFree(d_temp_storage));
@ -527,15 +585,6 @@ struct CubMemory {
* Utility functions
*/
template <typename T>
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
thrust::host_vector<T> h = v;
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
std::cout << " " << h[i];
}
std::cout << "\n";
}
template <typename T>
void print(const dvec<T> &v, size_t max_items = 10) {
std::vector<T> h = v.as_vector();
@ -545,91 +594,6 @@ void print(const dvec<T> &v, size_t max_items = 10) {
std::cout << "\n";
}
template <typename T>
void print(char *label, const thrust::device_vector<T> &v,
const char *format = "%d ", size_t max = 10) {
thrust::host_vector<T> h_v = v;
std::cout << label << ":\n";
for (size_t i = 0; i < std::min(static_cast<size_t>(h_v.size()), max); i++) {
printf(format, h_v[i]);
}
std::cout << "\n";
}
template <typename T1, typename T2>
T1 div_round_up(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
template <typename T>
thrust::device_ptr<T> dptr(T *d_ptr) {
return thrust::device_pointer_cast(d_ptr);
}
template <typename T>
T *raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T>
const T *raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
template <typename T>
size_t size_bytes(const thrust::device_vector<T> &v) {
return sizeof(T) * v.size();
}
/*
* Kernel launcher
*/
template <typename L>
__global__ void launch_n_kernel(size_t begin, size_t end, L lambda) {
for (auto i : grid_stride_range(begin, end)) {
lambda(i);
}
}
template <typename L>
__global__ void launch_n_kernel(int device_idx, size_t begin, size_t end,
L lambda) {
for (auto i : grid_stride_range(begin, end)) {
lambda(i, device_idx);
}
}
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void launch_n(int device_idx, size_t n, L lambda) {
safe_cuda(cudaSetDevice(device_idx));
// TODO: Template on n so GRID_SIZE always fits into int.
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
lambda);
#endif
}
// if n_devices=-1, then use all visible devices
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void multi_launch_n(size_t n, int n_devices, L lambda) {
n_devices = n_devices < 0 ? n_visible_devices() : n_devices;
CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested "
"needs to be less than equal to "
"number of visible devices.";
// TODO: Template on n so GRID_SIZE always fits into int.
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
#if defined(__CUDACC__)
n_devices = n_devices > n ? n : n_devices;
for (int device_idx = 0; device_idx < n_devices; device_idx++) {
safe_cuda(cudaSetDevice(device_idx));
size_t begin = (n / n_devices) * device_idx;
size_t end = std::min((n / n_devices) * (device_idx + 1), n);
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(device_idx, begin, end,
lambda);
}
#endif
}
/**
* @brief Helper macro to measure timing on GPU
* @param call the GPU call

View File

@ -110,6 +110,7 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
.add_enum("hist", 3)
.add_enum("gpu_exact", 4)
.add_enum("gpu_hist", 5)
.add_enum("gpu_hist_experimental", 6)
.describe("Choice of tree construction method.");
DMLC_DECLARE_FIELD(test_flag).set_default("").describe(
"Internal test flag");
@ -178,6 +179,13 @@ class LearnerImpl : public Learner {
if (cfg_.count("predictor") == 0) {
cfg_["predictor"] = "gpu_predictor";
}
} else if (tparam.tree_method == 6) {
if (cfg_.count("updater") == 0) {
cfg_["updater"] = "grow_gpu_hist_experimental,prune";
}
if (cfg_.count("predictor") == 0) {
cfg_["predictor"] = "gpu_predictor";
}
}
}

View File

@ -3,6 +3,7 @@
*/
#include <dmlc/parameter.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <xgboost/data.h>
#include <xgboost/predictor.h>
#include <xgboost/tree_model.h>

View File

@ -271,6 +271,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess
return -2.0 * (ret + p.reg_alpha * std::abs(w));
}
}
// calculate weight given the statistics
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
@ -292,6 +293,11 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
return dw;
}
template <typename TrainingParams, typename gpair_t>
XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, gpair_t sum_grad) {
return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess());
}
/*! \brief core statistics used for tree construction */
struct XGBOOST_ALIGNAS(16) GradStats {
/*! \brief sum gradient statistics */

View File

@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu);
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist_experimental);
#endif
} // namespace tree
} // namespace xgboost

View File

@ -23,14 +23,13 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu);
*/
static HOST_DEV_INLINE node_id_t abs2uniqKey(int tid, const node_id_t* abs,
const int* colIds, node_id_t nodeStart,
int nKeys) {
const int* colIds,
node_id_t nodeStart, int nKeys) {
int a = abs[tid];
if (a == UNUSED_NODE) return a;
return ((a - nodeStart) + (colIds[tid] * nKeys));
}
/**
* @struct Pair
* @brief Pair used for key basd scan operations on bst_gpair
@ -284,7 +283,7 @@ DEV_INLINE void atomicArgMax(ExactSplitCandidate* address,
DEV_INLINE void argMaxWithAtomics(
int id, ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const GPUTrainingParam& param) {
int nodeId = nodeAssigns[id];
// @todo: this is really a bad check! but will be fixed when we move
@ -296,7 +295,7 @@ DEV_INLINE void argMaxWithAtomics(
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
bst_gpair colSum = gradSums[sumId];
int uid = nodeId - nodeStart;
DeviceDenseNode n = nodes[nodeId];
DeviceNodeStats n = nodes[nodeId];
bst_gpair parentSum = n.sum_gradients;
float parentGain = n.root_gain;
bool tmp;
@ -313,7 +312,7 @@ DEV_INLINE void argMaxWithAtomics(
__global__ void atomicArgMaxByKeyGmem(
ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
@ -327,7 +326,7 @@ __global__ void atomicArgMaxByKeyGmem(
__global__ void atomicArgMaxByKeySmem(
ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const DeviceDenseNode* nodes, int nUniqKeys,
const node_id_t* nodeAssigns, const DeviceNodeStats* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param) {
extern __shared__ char sArr[];
ExactSplitCandidate* sNodeSplits =
@ -372,7 +371,7 @@ template <int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
void argMaxByKey(ExactSplitCandidate* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals,
const int* colIds, const node_id_t* nodeAssigns,
const DeviceDenseNode* nodes, int nUniqKeys,
const DeviceNodeStats* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param,
ArgMaxByKeyAlgo algo) {
dh::fillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
@ -406,7 +405,7 @@ __global__ void assignColIds(int* colIds, const int* colOffsets) {
}
__global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
const DeviceDenseNode* nodes, int nRows) {
const DeviceNodeStats* nodes, int nRows) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
if (id >= nRows) {
return;
@ -416,7 +415,7 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
if (nId == UNUSED_NODE) {
return;
}
const DeviceDenseNode n = nodes[nId];
const DeviceNodeStats n = nodes[nId];
node_id_t result;
if (n.IsLeaf() || n.IsUnused()) {
result = UNUSED_NODE;
@ -430,7 +429,7 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
__global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
const node_id_t* nodeIds, const int* instId,
const DeviceDenseNode* nodes,
const DeviceNodeStats* nodes,
const int* colOffsets, const float* vals,
int nVals, int nCols) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
@ -443,7 +442,7 @@ __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
int nId = nodeIds[id];
// if this element belongs to none of the currently active node-id's
if (nId != UNUSED_NODE) {
const DeviceDenseNode n = nodes[nId];
const DeviceNodeStats n = nodes[nId];
int colId = n.fidx;
// printf("nid=%d colId=%d id=%d\n", nId, colId, id);
int start = colOffsets[colId];
@ -457,7 +456,7 @@ __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
}
}
__global__ void markLeavesKernel(DeviceDenseNode* nodes, int len) {
__global__ void markLeavesKernel(DeviceNodeStats* nodes, int len) {
int id = (blockIdx.x * blockDim.x) + threadIdx.x;
if ((id < len) && !nodes[id].IsUnused()) {
int lid = (id << 1) + 1;
@ -486,7 +485,7 @@ class GPUMaker : public TreeUpdater {
dh::dvec<bst_gpair> gradsInst;
dh::dvec2<node_id_t> nodeAssigns;
dh::dvec2<int> nodeLocations;
dh::dvec<DeviceDenseNode> nodes;
dh::dvec<DeviceNodeStats> nodes;
dh::dvec<node_id_t> nodeAssignsPerInst;
dh::dvec<bst_gpair> gradSums;
dh::dvec<bst_gpair> gradScans;
@ -573,7 +572,7 @@ class GPUMaker : public TreeUpdater {
int nodeInstId =
abs2uniqKey(idx, d_nodeAssigns, d_colIds, nodeStart, nUniqKeys);
bool missingLeft = true;
const DeviceDenseNode& n = d_nodes[absNodeId];
const DeviceNodeStats& n = d_nodes[absNodeId];
bst_gpair gradScan = d_gradScans[idx];
bst_gpair gradSum = d_gradSums[nodeInstId];
float thresh = d_vals[idx];
@ -588,12 +587,13 @@ class GPUMaker : public TreeUpdater {
// Create children
d_nodes[left_child_nidx(absNodeId)] =
DeviceDenseNode(lGradSum, left_child_nidx(absNodeId), gpu_param);
DeviceNodeStats(lGradSum, left_child_nidx(absNodeId), gpu_param);
d_nodes[right_child_nidx(absNodeId)] =
DeviceDenseNode(rGradSum, right_child_nidx(absNodeId), gpu_param);
DeviceNodeStats(rGradSum, right_child_nidx(absNodeId), gpu_param);
// Set split for parent
d_nodes[absNodeId].SetSplit(thresh, colId,
missingLeft ? LeftDir : RightDir);
missingLeft ? LeftDir : RightDir, lGradSum,
rGradSum);
} else {
// cannot be split further, so this node is a leaf!
d_nodes[absNodeId].root_gain = -FLT_MAX;
@ -677,7 +677,7 @@ class GPUMaker : public TreeUpdater {
instIds.current_dvec() = fId;
colOffsets = offset;
dh::segmentedSort<float, int>(&tmp_mem, &vals, &instIds, nVals, nCols,
colOffsets);
colOffsets);
vals_cached = vals.current_dvec();
instIds_cached = instIds.current_dvec();
assignColIds<<<nCols, 512>>>(colIds.data(), colOffsets.data());
@ -695,7 +695,7 @@ class GPUMaker : public TreeUpdater {
void initNodeData(int level, node_id_t nodeStart, int nNodes) {
// all instances belong to root node at the beginning!
if (level == 0) {
nodes.fill(DeviceDenseNode());
nodes.fill(DeviceNodeStats());
nodeAssigns.current_dvec().fill(0);
nodeAssignsPerInst.fill(0);
// for root node, just update the gradient/score/weight/id info
@ -705,7 +705,7 @@ class GPUMaker : public TreeUpdater {
auto d_sums = gradSums.data();
auto gpu_params = GPUTrainingParam(param);
dh::launch_n(param.gpu_id, 1, [=] __device__(int idx) {
d_nodes[0] = DeviceDenseNode(d_sums[0], 0, gpu_params);
d_nodes[0] = DeviceNodeStats(d_sums[0], 0, gpu_params);
});
} else {
const int BlkDim = 256;
@ -722,7 +722,7 @@ class GPUMaker : public TreeUpdater {
colOffsets.data(), vals.current(), nVals, nCols);
// gather the node assignments across all other columns too
dh::gather(dh::get_device_idx(param.gpu_id), nodeAssigns.current(),
nodeAssignsPerInst.data(), instIds.current(), nVals);
nodeAssignsPerInst.data(), instIds.current(), nVals);
sortKeys(level);
}
}
@ -733,8 +733,8 @@ class GPUMaker : public TreeUpdater {
segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
colOffsets, 0, level + 1);
dh::gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
vals.buff().selector ^= 1;
instIds.buff().selector ^= 1;
}

View File

@ -4,13 +4,13 @@
#pragma once
#include <thrust/random.h>
#include <cstdio>
#include <cub/cub.cuh>
#include <stdexcept>
#include <string>
#include <vector>
#include "../common/device_helpers.cuh"
#include "../common/random.h"
#include "param.h"
#include <cub/cub.cuh>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace tree {
@ -52,7 +52,47 @@ enum DefaultDirection {
RightDir
};
struct DeviceDenseNode {
struct DeviceSplitCandidate {
float loss_chg;
DefaultDirection dir;
float fvalue;
int findex;
bst_gpair_integer left_sum;
bst_gpair_integer right_sum;
__host__ __device__ DeviceSplitCandidate()
: loss_chg(-FLT_MAX), dir(LeftDir), fvalue(0), findex(-1) {}
template <typename param_t>
__host__ __device__ void Update(const DeviceSplitCandidate &other,
const param_t& param) {
if (other.loss_chg > loss_chg &&
other.left_sum.GetHess() >= param.min_child_weight &&
other.right_sum.GetHess() >= param.min_child_weight) {
*this = other;
}
}
__device__ void Update(float loss_chg_in, DefaultDirection dir_in,
float fvalue_in, int findex_in,
bst_gpair_integer left_sum_in,
bst_gpair_integer right_sum_in,
const GPUTrainingParam& param) {
if (loss_chg_in > loss_chg &&
left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
dir = dir_in;
fvalue = fvalue_in;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}
__device__ bool IsValid() const { return loss_chg > 0.0f; }
};
struct DeviceNodeStats {
bst_gpair sum_gradients;
float root_gain;
float weight;
@ -61,35 +101,50 @@ struct DeviceDenseNode {
DefaultDirection dir;
/** threshold value for comparison */
float fvalue;
bst_gpair left_sum;
bst_gpair right_sum;
/** \brief The feature index. */
int fidx;
/** node id (used as key for reduce/scan) */
node_id_t idx;
HOST_DEV_INLINE DeviceDenseNode()
HOST_DEV_INLINE DeviceNodeStats()
: sum_gradients(),
root_gain(-FLT_MAX),
weight(-FLT_MAX),
dir(LeftDir),
fvalue(0.f),
left_sum(),
right_sum(),
fidx(UNUSED_NODE),
idx(UNUSED_NODE) {}
HOST_DEV_INLINE DeviceDenseNode(bst_gpair sum_gradients, node_id_t nidx,
const GPUTrainingParam& param)
template <typename param_t>
HOST_DEV_INLINE DeviceNodeStats(bst_gpair sum_gradients, node_id_t nidx,
const param_t& param)
: sum_gradients(sum_gradients),
dir(LeftDir),
fvalue(0.f),
fidx(UNUSED_NODE),
idx(nidx) {
this->root_gain = CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
this->weight = CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
this->root_gain =
CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
this->weight =
CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
}
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) {
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir,
bst_gpair left_sum, bst_gpair right_sum) {
this->fvalue = fvalue;
this->fidx = fidx;
this->dir = dir;
this->left_sum = left_sum;
this->right_sum = right_sum;
}
HOST_DEV_INLINE void SetSplit(const DeviceSplitCandidate& split) {
this->SetSplit(split.fvalue, split.findex, split.dir, split.left_sum,
split.right_sum);
}
/** Tells whether this node is part of the decision tree */
@ -101,18 +156,23 @@ struct DeviceDenseNode {
}
};
template <typename T>
struct SumCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ SumCallbackOp() : running_total(T()) {}
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <typename gpair_t>
__device__ inline float device_calc_loss_chg(
const GPUTrainingParam& param, const gpair_t& scan, const gpair_t& missing,
const gpair_t& parent_sum, const float& parent_gain, bool missing_left) {
gpair_t left = scan;
if (missing_left) {
left += missing;
}
const GPUTrainingParam& param, const gpair_t& left, const gpair_t& parent_sum, const float& parent_gain) {
gpair_t right = parent_sum - left;
float left_gain = CalcGain(param, left.GetGrad(), left.GetHess());
float right_gain = CalcGain(param, right.GetGrad(), right.GetHess());
return left_gain + right_gain - parent_gain;
@ -126,9 +186,9 @@ __device__ float inline loss_chg_missing(const gpair_t& scan,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
float missing_left_loss =
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
float missing_right_loss = device_calc_loss_chg(
param, scan, missing, parent_sum, parent_gain, false);
param, scan, parent_sum, parent_gain);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
@ -168,14 +228,14 @@ __host__ __device__ inline bool is_left_child(int nidx) {
// Copy gpu dense representation of tree to xgboost sparse representation
inline void dense2sparse_tree(RegTree* p_tree,
const dh::dvec<DeviceDenseNode>& nodes,
const dh::dvec<DeviceNodeStats>& nodes,
const TrainParam& param) {
RegTree& tree = *p_tree;
std::vector<DeviceDenseNode> h_nodes = nodes.as_vector();
std::vector<DeviceNodeStats> h_nodes = nodes.as_vector();
int nid = 0;
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceDenseNode& n = h_nodes[gpu_nid];
const DeviceNodeStats& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.AddChilds(nid);
tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir);

View File

@ -43,12 +43,14 @@ struct DeviceGMat {
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.data(), n_bins);
// row_ptr
thrust::copy(gmat.row_ptr.data() + row_begin,
gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
dh::safe_cuda(cudaMemcpy(row_ptr.data(), gmat.row_ptr.data() + row_begin,
row_ptr.size() * sizeof(size_t),
cudaMemcpyHostToDevice));
// normalise row_ptr
size_t start = gmat.row_ptr[row_begin];
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
[=] __device__(size_t val) { return val - start; });
auto d_row_ptr = row_ptr.data();
dh::launch_n(row_ptr.device_idx(), row_ptr.size(),
[=] __device__(size_t idx) { d_row_ptr[idx] -= start; });
}
};
@ -61,12 +63,15 @@ struct HistHelper {
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx;
auto dst_ptr = reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
auto dst_ptr =
reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<gpair_sum_t::value_t*>(&tmp);
atomicAdd(dst_ptr, static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1, static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
}
__device__ gpair_sum_t Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx];
@ -96,51 +101,10 @@ struct DeviceHist {
int LevelSize(int depth) { return n_bins * n_nodes_level(depth); }
};
struct SplitCandidate {
float loss_chg;
bool missing_left;
float fvalue;
int findex;
gpair_sum_t left_sum;
gpair_sum_t right_sum;
__host__ __device__ SplitCandidate()
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
__device__ void Update(float loss_chg_in, bool missing_left_in,
float fvalue_in, int findex_in,
gpair_sum_t left_sum_in, gpair_sum_t right_sum_in,
const GPUTrainingParam& param) {
if (loss_chg_in > loss_chg &&
left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
missing_left = missing_left_in;
fvalue = fvalue_in;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}
__device__ bool IsValid() const { return loss_chg > 0.0f; }
};
struct GpairCallbackOp {
// Running prefix
gpair_sum_t running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(gpair_sum_t()) {}
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
gpair_sum_t old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <int BLOCK_THREADS>
__global__ void find_split_kernel(
const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth,
int n_features, int n_bins, DeviceDenseNode* d_nodes,
int n_features, int n_bins, DeviceNodeStats* d_nodes,
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
bool colsample, int* d_feature_flags) {
@ -156,15 +120,15 @@ __global__ void find_split_kernel(
typename SumReduceT::TempStorage sum_reduce;
};
__shared__ cub::Uninitialized<SplitCandidate> uninitialized_split;
SplitCandidate& split = uninitialized_split.Alias();
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
DeviceSplitCandidate& split = uninitialized_split.Alias();
__shared__ cub::Uninitialized<gpair_sum_t> uninitialized_sum;
gpair_sum_t& shared_sum = uninitialized_sum.Alias();
__shared__ ArgMaxT block_max;
__shared__ TempStorage temp_storage;
if (threadIdx.x == 0) {
split = SplitCandidate();
split = DeviceSplitCandidate();
}
__syncthreads();
@ -197,7 +161,7 @@ __global__ void find_split_kernel(
}
// __syncthreads(); // no need to synch because below there is a Scan
GpairCallbackOp prefix_op = GpairCallbackOp();
auto prefix_op = SumCallbackOp<gpair_sum_t>();
for (int scan_begin = begin; scan_begin < end;
scan_begin += BLOCK_THREADS) {
bool thread_active = scan_begin + threadIdx.x < end;
@ -245,7 +209,8 @@ __global__ void find_split_kernel(
gpair_sum_t left = missing_left ? bin + missing : bin;
gpair_sum_t right = parent_sum - left;
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
split.Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx,
left, right, gpu_param);
}
__syncthreads();
} // end scan
@ -253,17 +218,16 @@ __global__ void find_split_kernel(
// Create node
if (threadIdx.x == 0 && split.IsValid()) {
d_nodes[node_idx].SetSplit(split.fvalue, split.findex,
split.missing_left ? LeftDir : RightDir);
d_nodes[node_idx].SetSplit(split);
DeviceDenseNode& left_child = d_nodes[left_child_nidx(node_idx)];
DeviceDenseNode& right_child = d_nodes[right_child_nidx(node_idx)];
DeviceNodeStats& left_child = d_nodes[left_child_nidx(node_idx)];
DeviceNodeStats& right_child = d_nodes[right_child_nidx(node_idx)];
bool& left_child_smallest = d_left_child_smallest_temp[node_idx];
left_child =
DeviceDenseNode(split.left_sum, left_child_nidx(node_idx), gpu_param);
DeviceNodeStats(split.left_sum, left_child_nidx(node_idx), gpu_param);
right_child =
DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param);
DeviceNodeStats(split.right_sum, right_child_nidx(node_idx), gpu_param);
// Record smallest node
if (split.left_sum.GetHess() <= split.right_sum.GetHess()) {
@ -336,7 +300,7 @@ class GPUHistMaker : public TreeUpdater {
// reset static timers used across iterations
cpu_init_time = 0;
gpu_init_time = 0;
cpu_time.reset();
cpu_time.Reset();
gpu_time = 0;
// set dList member
@ -399,31 +363,31 @@ class GPUHistMaker : public TreeUpdater {
is_dense = info->num_nonzero == info->num_col * info->num_row;
dh::Timer time0;
hmat_.Init(&fmat, param.max_bin);
cpu_init_time += time0.elapsedSeconds();
cpu_init_time += time0.ElapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init "
<< time0.elapsedSeconds() << " sec";
<< time0.ElapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
time0.Reset();
gmat_.cut = &hmat_;
cpu_init_time += time0.elapsedSeconds();
cpu_init_time += time0.ElapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut "
<< time0.elapsedSeconds() << " sec";
<< time0.ElapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
time0.Reset();
gmat_.Init(&fmat);
cpu_init_time += time0.elapsedSeconds();
cpu_init_time += time0.ElapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() "
<< time0.elapsedSeconds() << " sec";
<< time0.ElapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
time0.Reset();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE)
@ -563,9 +527,9 @@ class GPUHistMaker : public TreeUpdater {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
nodes[d_idx].fill(DeviceDenseNode());
nodes_temp[d_idx].fill(DeviceDenseNode());
nodes_child_temp[d_idx].fill(DeviceDenseNode());
nodes[d_idx].fill(DeviceNodeStats());
nodes_temp[d_idx].fill(DeviceNodeStats());
nodes_child_temp[d_idx].fill(DeviceNodeStats());
position[d_idx].fill(0);
@ -584,7 +548,7 @@ class GPUHistMaker : public TreeUpdater {
dh::synchronize_n_devices(n_devices, dList);
if (!initialised) {
gpu_init_time = time1.elapsedSeconds() - cpu_init_time;
gpu_init_time = time1.ElapsedSeconds() - cpu_init_time;
gpu_time = -cpu_init_time;
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First "
@ -701,12 +665,12 @@ class GPUHistMaker : public TreeUpdater {
dh::synchronize_n_devices(n_devices, dList);
}
}
#define MIN_BLOCK_THREADS 32
#define CHUNK_BLOCK_THREADS 32
#define MIN_BLOCK_THREADS 128
#define CHUNK_BLOCK_THREADS 128
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
// to CUDA capability 35 and above requirement
// for Maximum number of threads per block
#define MAX_BLOCK_THREADS 1024
#define MAX_BLOCK_THREADS 512
void FindSplit(int depth) {
// Specialised based on max_bins
@ -783,7 +747,7 @@ class GPUHistMaker : public TreeUpdater {
dh::launch_n(device_idx, 1, [=] __device__(int idx) {
bst_gpair sum_gradients = sum;
d_nodes[idx] = DeviceDenseNode(sum_gradients, 0, gpu_param);
d_nodes[idx] = DeviceNodeStats(sum_gradients, 0, gpu_param);
});
}
// synch all devices to host before moving on (No, can avoid because
@ -802,7 +766,7 @@ class GPUHistMaker : public TreeUpdater {
int device_idx = dList[d_idx];
auto d_position = position[d_idx].data();
DeviceDenseNode* d_nodes = nodes[d_idx].data();
DeviceNodeStats* d_nodes = nodes[d_idx].data();
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
auto d_gidx = device_matrix[d_idx].gidx;
int n_columns = info->num_col;
@ -814,7 +778,7 @@ class GPUHistMaker : public TreeUpdater {
if (!is_active(pos, depth)) {
return;
}
DeviceDenseNode node = d_nodes[pos];
DeviceNodeStats node = d_nodes[pos];
if (node.IsLeaf()) {
return;
@ -842,7 +806,7 @@ class GPUHistMaker : public TreeUpdater {
auto d_position = position[d_idx].data();
auto d_position_tmp = position_tmp[d_idx].data();
DeviceDenseNode* d_nodes = nodes[d_idx].data();
DeviceNodeStats* d_nodes = nodes[d_idx].data();
auto d_gidx_feature_map = gidx_feature_map[d_idx].data();
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
auto d_gidx = device_matrix[d_idx].gidx;
@ -862,7 +826,7 @@ class GPUHistMaker : public TreeUpdater {
return;
}
DeviceDenseNode node = d_nodes[pos];
DeviceNodeStats node = d_nodes[pos];
if (node.IsLeaf()) {
d_position_tmp[local_idx] = pos;
@ -887,7 +851,7 @@ class GPUHistMaker : public TreeUpdater {
return;
}
DeviceDenseNode node = d_nodes[pos];
DeviceNodeStats node = d_nodes[pos];
if (node.IsLeaf()) {
return;
@ -976,8 +940,10 @@ class GPUHistMaker : public TreeUpdater {
d_prediction_cache[local_idx] += d_nodes[pos].weight * eps;
});
thrust::copy(prediction_cache[d_idx].tbegin(),
prediction_cache[d_idx].tend(), &out_preds[row_begin]);
dh::safe_cuda(
cudaMemcpy(&out_preds[row_begin], prediction_cache[d_idx].data(),
prediction_cache[d_idx].size() * sizeof(bst_float),
cudaMemcpyDeviceToHost));
}
dh::synchronize_n_devices(n_devices, dList);
@ -1003,7 +969,7 @@ class GPUHistMaker : public TreeUpdater {
dh::safe_cuda(cudaSetDevice(master_device));
dense2sparse_tree(p_tree, nodes[0], param);
gpu_time += time0.elapsedSeconds();
gpu_time += time0.ElapsedSeconds();
if (param.debug_verbose) {
LOG(CONSOLE)
@ -1014,10 +980,10 @@ class GPUHistMaker : public TreeUpdater {
if (param.debug_verbose) {
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time "
<< cpu_time.elapsedSeconds() << " sec";
<< cpu_time.ElapsedSeconds() << " sec";
LOG(CONSOLE)
<< "[GPU Plug-in] Cumulative CPU Time excluding initial time "
<< (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time) << " sec";
<< (cpu_time.ElapsedSeconds() - cpu_init_time - gpu_time) << " sec";
fflush(stdout);
}
}
@ -1048,9 +1014,9 @@ class GPUHistMaker : public TreeUpdater {
std::vector<dh::CubMemory> temp_memory;
std::vector<DeviceHist> hist_vec;
std::vector<dh::dvec<DeviceDenseNode>> nodes;
std::vector<dh::dvec<DeviceDenseNode>> nodes_temp;
std::vector<dh::dvec<DeviceDenseNode>> nodes_child_temp;
std::vector<dh::dvec<DeviceNodeStats>> nodes;
std::vector<dh::dvec<DeviceNodeStats>> nodes_temp;
std::vector<dh::dvec<DeviceNodeStats>> nodes_child_temp;
std::vector<dh::dvec<bool>> left_child_smallest;
std::vector<dh::dvec<bool>> left_child_smallest_temp;
std::vector<dh::dvec<int>> feature_flags;

View File

@ -0,0 +1,833 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/count.h>
#include <thrust/sort.h>
#include <xgboost/tree_updater.h>
#include <algorithm>
#include <memory>
#include <queue>
#include <utility>
#include <vector>
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "param.h"
#include "updater_gpu_common.cuh"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist_experimental);
template <int BLOCK_THREADS, typename reduce_t, typename temp_storage_t>
__device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin,
const bst_gpair_integer* end,
temp_storage_t* temp_storage) {
__shared__ cub::Uninitialized<bst_gpair_integer> uninitialized_sum;
bst_gpair_integer& shared_sum = uninitialized_sum.Alias();
bst_gpair_integer local_sum = bst_gpair_integer();
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
bool thread_active = itr + threadIdx.x < end;
// Scan histogram
bst_gpair_integer bin =
thread_active ? *(itr + threadIdx.x) : bst_gpair_integer();
local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum());
}
if (threadIdx.x == 0) {
shared_sum = local_sum;
}
__syncthreads();
return shared_sum;
}
template <int BLOCK_THREADS, typename reduce_t, typename scan_t,
typename max_reduce_t, typename temp_storage_t>
__device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
const int* feature_segments, float min_fvalue,
const float* gidx_fvalue_map,
DeviceSplitCandidate* best_split,
const DeviceNodeStats& node,
const GPUTrainingParam& param,
temp_storage_t* temp_storage) {
int gidx_begin = feature_segments[fidx];
int gidx_end = feature_segments[fidx + 1];
bst_gpair_integer feature_sum = ReduceFeature<BLOCK_THREADS, reduce_t>(
hist + gidx_begin, hist + gidx_end, temp_storage);
auto prefix_op = SumCallbackOp<bst_gpair_integer>();
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = scan_begin + threadIdx.x < gidx_end;
bst_gpair_integer bin =
thread_active ? hist[scan_begin + threadIdx.x] : bst_gpair_integer();
scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Calculate gain
bst_gpair_integer parent_sum = bst_gpair_integer(node.sum_gradients);
bst_gpair_integer missing = parent_sum - feature_sum;
bool missing_left = true;
const float null_gain = -FLT_MAX;
float gain = null_gain;
if (thread_active) {
gain = loss_chg_missing(bin, missing, parent_sum, node.root_gain, param,
missing_left);
}
__syncthreads();
// Find thread with best gain
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
cub::KeyValuePair<int, float> best =
max_reduce_t(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
__shared__ cub::KeyValuePair<int, float> block_max;
if (threadIdx.x == 0) {
block_max = best;
}
__syncthreads();
// Best thread updates split
if (threadIdx.x == block_max.key) {
int gidx = scan_begin + threadIdx.x;
float fvalue =
gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1];
bst_gpair_integer left = missing_left ? bin + missing : bin;
bst_gpair_integer right = parent_sum - left;
best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx,
left, right, param);
}
__syncthreads();
}
}
template <int BLOCK_THREADS>
__global__ void evaluate_split_kernel(const bst_gpair_integer* d_hist, int nidx,
int n_features, DeviceNodeStats nodes,
const int* d_feature_segments,
const float* d_fidx_min_map,
const float* d_gidx_fvalue_map,
GPUTrainingParam gpu_param,
DeviceSplitCandidate* d_split) {
typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<bst_gpair_integer, BLOCK_THREADS,
cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<bst_gpair_integer, BLOCK_THREADS> SumReduceT;
union TempStorage {
typename BlockScanT::TempStorage scan;
typename MaxReduceT::TempStorage max_reduce;
typename SumReduceT::TempStorage sum_reduce;
};
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
DeviceSplitCandidate& best_split = uninitialized_split.Alias();
__shared__ TempStorage temp_storage;
if (threadIdx.x == 0) {
best_split = DeviceSplitCandidate();
}
__syncthreads();
auto fidx = blockIdx.x;
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
fidx, d_hist, d_feature_segments, d_fidx_min_map[fidx], d_gidx_fvalue_map,
&best_split, nodes, gpu_param, &temp_storage);
__syncthreads();
if (threadIdx.x == 0) {
// Record best loss
d_split[fidx] = best_split;
}
}
// Find a gidx value for a given feature otherwise return -1 if not found
template <typename gidx_iter_t>
__device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data,
int fidx_begin, int fidx_end) {
// for(auto i = begin; i < end; i++)
//{
// auto gidx = data[i];
// if (gidx >= fidx_begin&&gidx < fidx_end) return gidx;
//}
// return -1;
bst_uint previous_middle = UINT32_MAX;
while (end != begin) {
auto middle = begin + (end - begin) / 2;
if (middle == previous_middle) {
break;
}
previous_middle = middle;
auto gidx = data[middle];
if (gidx >= fidx_begin && gidx < fidx_end) {
return gidx;
} else if (gidx < fidx_begin) {
begin = middle;
} else {
end = middle;
}
}
// Value is missing
return -1;
}
template <int BLOCK_THREADS>
__global__ void RadixSortSmall(bst_uint* d_ridx, int* d_position, bst_uint n) {
typedef cub::BlockRadixSort<int, BLOCK_THREADS, 1, bst_uint> BlockRadixSort;
__shared__ typename BlockRadixSort::TempStorage temp_storage;
bool thread_active = threadIdx.x < n;
int thread_key[1];
bst_uint thread_value[1];
thread_key[0] = thread_active ? d_position[threadIdx.x] : INT_MAX;
thread_value[0] = thread_active ? d_ridx[threadIdx.x] : UINT_MAX;
BlockRadixSort(temp_storage).Sort(thread_key, thread_value);
if (thread_active) {
d_position[threadIdx.x] = thread_key[0];
d_ridx[threadIdx.x] = thread_value[0];
}
}
struct DeviceHistogram {
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
dh::dvec<bst_gpair_integer> data;
std::map<int, bst_gpair_integer*> node_map;
int n_bins;
void Init(int device_idx, int max_nodes, int n_bins, bool silent) {
this->n_bins = n_bins;
ba.allocate(device_idx, silent, &data, max_nodes * n_bins);
}
void Reset() {
data.fill(bst_gpair_integer());
node_map.clear();
}
void AddNode(int nidx) {
CHECK_EQ(node_map.count(nidx), 0)
<< nidx << " already exists in the histogram.";
node_map[nidx] = data.data() + n_bins * node_map.size();
}
};
// Manage memory for a single GPU
struct DeviceShard {
int device_idx;
int normalised_device_idx; // Device index counting from param.gpu_id
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
dh::dvec<common::compressed_byte_t> gidx_buffer;
dh::dvec<bst_gpair> gpair;
dh::dvec2<bst_uint> ridx;
dh::dvec2<int> position;
std::vector<std::pair<int64_t, int64_t>> ridx_segments;
dh::dvec<int> feature_segments;
dh::dvec<float> gidx_fvalue_map;
dh::dvec<float> min_fvalue;
std::vector<bst_gpair> node_sum_gradients;
common::CompressedIterator<uint32_t> gidx;
int row_stride;
bst_uint row_start_idx;
bst_uint row_end_idx;
bst_uint n_rows;
int n_bins;
int null_gidx_value;
DeviceHistogram hist;
std::vector<cudaStream_t> streams;
dh::CubMemory temp_memory;
DeviceShard(int device_idx, int normalised_device_idx,
const common::GHistIndexMatrix& gmat, bst_uint row_begin,
bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_start_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins) {
// Convert to ELLPACK matrix representation
int max_elements_row = 0;
for (int i = row_begin; i < row_end; i++) {
max_elements_row =
(std::max)(max_elements_row,
static_cast<int>(gmat.row_ptr[i + 1] - gmat.row_ptr[i]));
}
row_stride = max_elements_row;
std::vector<int> ellpack_matrix(row_stride * n_rows, null_gidx_value);
for (int i = row_begin; i < row_end; i++) {
int row_count = 0;
for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
ellpack_matrix[i * row_stride + row_count] = gmat.index[j];
row_count++;
}
}
// Allocate
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(
ellpack_matrix.size(), num_symbols);
int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : n_nodes(param.max_depth);
ba.allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
&gpair, n_rows, &ridx, n_rows, &position, n_rows,
&feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map,
gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size());
gidx_fvalue_map = gmat.cut->cut;
min_fvalue = gmat.cut->min_val;
feature_segments = gmat.cut->row_ptr;
node_sum_gradients.resize(max_nodes);
ridx_segments.resize(max_nodes);
// Compress gidx
common::CompressedBufferWriter cbw(num_symbols);
std::vector<common::compressed_byte_t> host_buffer(gidx_buffer.size());
cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end());
gidx_buffer = host_buffer;
gidx =
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols);
common::CompressedIterator<uint32_t> ci_host(host_buffer.data(),
num_symbols);
// Init histogram
hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent);
}
~DeviceShard() {
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
}
// Get vector of at least n initialised streams
std::vector<cudaStream_t>& GetStreams(int n) {
if (n > streams.size()) {
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
streams.clear();
streams.resize(n);
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
}
return streams;
}
// Reset values for each update iteration
void Reset(const std::vector<bst_gpair>& host_gpair) {
position.current_dvec().fill(0);
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
bst_gpair());
// TODO(rory): support subsampling
thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend(),
row_start_idx);
std::fill(ridx_segments.begin(), ridx_segments.end(), std::make_pair(0, 0));
ridx_segments.front() = std::make_pair(0, ridx.size());
this->gpair.copy(host_gpair.begin() + row_start_idx,
host_gpair.begin() + row_end_idx);
hist.Reset();
}
__device__ void IncrementHist(bst_gpair gpair, int gidx,
bst_gpair_integer* node_hist) const {
auto dst_ptr =
reinterpret_cast<unsigned long long int*>(&node_hist[gidx]); // NOLINT
bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<bst_gpair_integer::value_t*>(&tmp);
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
}
void BuildHist(int nidx) {
hist.AddNode(nidx);
auto d_node_hist = hist.node_map[nidx];
auto d_gidx = gidx;
auto d_ridx = ridx.current();
auto d_gpair = gpair.data();
auto row_stride = this->row_stride;
auto null_gidx_value = this->null_gidx_value;
auto segment = ridx_segments[nidx];
auto n_elements = (segment.second - segment.first) * row_stride;
dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) {
int relative_ridx = d_ridx[(idx / row_stride) + segment.first];
int gidx = d_gidx[relative_ridx * row_stride + idx % row_stride];
if (gidx != null_gidx_value) {
bst_gpair gpair = d_gpair[relative_ridx];
IncrementHist(gpair, gidx, d_node_hist);
}
});
}
void SortPosition(const std::pair<bst_uint, bst_uint>& segment, int left_nidx,
int right_nidx) {
auto n = segment.second - segment.first;
int min_bits = 0;
int max_bits = std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1));
// const int SINGLE_TILE_SIZE = 1024;
// if (n < SINGLE_TILE_SIZE) {
// RadixSortSmall<SINGLE_TILE_SIZE>
// <<<1, SINGLE_TILE_SIZE>>>(ridx.current() + segment.first,
// position.current() + segment.first, n);
//} else {
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes, position.current() + segment.first,
position.other() + segment.first, ridx.current() + segment.first,
ridx.other() + segment.first, n, min_bits, max_bits);
temp_memory.LazyAllocate(temp_storage_bytes);
cub::DeviceRadixSort::SortPairs(
temp_memory.d_temp_storage, temp_memory.temp_storage_bytes,
position.current() + segment.first, position.other() + segment.first,
ridx.current() + segment.first, ridx.other() + segment.first, n,
min_bits, max_bits);
dh::safe_cuda(cudaMemcpy(position.current() + segment.first,
position.other() + segment.first, n * sizeof(int),
cudaMemcpyDeviceToDevice));
dh::safe_cuda(cudaMemcpy(ridx.current() + segment.first,
ridx.other() + segment.first, n * sizeof(bst_uint),
cudaMemcpyDeviceToDevice));
//}
}
};
class GPUHistMakerExperimental : public TreeUpdater {
public:
struct ExpandEntry;
GPUHistMakerExperimental() : initialised(false) {}
~GPUHistMakerExperimental() {}
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
CHECK(param.n_gpus != 0) << "Must have at least one device";
CHECK(param.n_gpus <= 1 && param.n_gpus != -1)
<< "Only one GPU currently supported";
n_devices = param.n_gpus;
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand_.reset(new ExpandQueue(loss_guide));
} else {
qexpand_.reset(new ExpandQueue(depth_wise));
}
monitor.Init("updater_gpu_hist_experimental", param.debug_verbose);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
try {
for (size_t i = 0; i < trees.size(); ++i) {
this->UpdateTree(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
void InitDataOnce(DMatrix* dmat) {
info = &dmat->info();
hmat_.Init(dmat, param.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(dmat);
n_bins = hmat_.row_ptr.back();
shards.emplace_back(param.gpu_id, 0, gmat_, 0, info->num_row, n_bins,
param);
initialised = true;
}
void InitData(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const RegTree& tree) {
if (!initialised) {
this->InitDataOnce(dmat);
}
this->ColSampleTree();
// Copy gpair & reset memory
for (auto& shard : shards) {
shard.Reset(gpair);
}
}
void BuildHist(int nidx) {
for (auto& shard : shards) {
shard.BuildHist(nidx);
}
}
// Returns best loss
std::vector<DeviceSplitCandidate> EvaluateSplits(
const std::vector<int>& nidx_set, RegTree* p_tree) {
auto columns = info->num_col;
std::vector<DeviceSplitCandidate> best_splits(nidx_set.size());
std::vector<DeviceSplitCandidate> candidate_splits(nidx_set.size() *
columns);
// Use first device
auto& shard = shards.front();
dh::safe_cuda(cudaSetDevice(shard.device_idx));
shard.temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns *
nidx_set.size());
auto d_split = shard.temp_memory.Pointer<DeviceSplitCandidate>();
auto& streams = shard.GetStreams(nidx_set.size());
// Use streams to process nodes concurrently
for (auto i = 0; i < nidx_set.size(); i++) {
auto nidx = nidx_set[i];
DeviceNodeStats node(shard.node_sum_gradients[nidx], nidx, param);
const int BLOCK_THREADS = 256;
evaluate_split_kernel<BLOCK_THREADS>
<<<columns, BLOCK_THREADS, 0, streams[i]>>>(
shard.hist.node_map[nidx], nidx, info->num_col, node,
shard.feature_segments.data(), shard.min_fvalue.data(),
shard.gidx_fvalue_map.data(), GPUTrainingParam(param),
d_split + i * columns);
}
dh::safe_cuda(
cudaMemcpy(candidate_splits.data(), shard.temp_memory.d_temp_storage,
sizeof(DeviceSplitCandidate) * columns * nidx_set.size(),
cudaMemcpyDeviceToHost));
for (auto i = 0; i < nidx_set.size(); i++) {
DeviceSplitCandidate nidx_best;
for (auto fidx = 0; fidx < columns; fidx++) {
nidx_best.Update(candidate_splits[i * columns + fidx], param);
}
best_splits[i] = nidx_best;
}
return std::move(best_splits);
}
void InitRoot(const std::vector<bst_gpair>& gpair, RegTree* p_tree) {
int root_nidx = 0;
BuildHist(root_nidx);
// TODO(rory): support sub sampling
// TODO(rory): not asynchronous
bst_gpair sum_gradient;
for (auto& shard : shards) {
sum_gradient += thrust::reduce(shard.gpair.tbegin(), shard.gpair.tend());
}
// Remember root stats
p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess();
p_tree->stat(root_nidx).base_weight = CalcWeight(param, sum_gradient);
// Store sum gradients
for (auto& shard : shards) {
shard.node_sum_gradients[root_nidx] = sum_gradient;
}
auto splits = this->EvaluateSplits({root_nidx}, p_tree);
// Generate candidate
qexpand_->push(
ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0));
}
struct MatchingFunctor : public thrust::unary_function<int, int> {
int val;
__host__ __device__ MatchingFunctor(int val) : val(val) {}
__host__ __device__ int operator()(int x) const { return x == val; }
};
__device__ void CountLeft(bst_uint* d_count, int val, int left_nidx) {
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(d_count, __popc(ballot));
}
}
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
auto nidx = candidate.nid;
auto is_dense = info->num_nonzero == info->num_row * info->num_col;
auto left_nidx = (*p_tree)[nidx].cleft();
auto right_nidx = (*p_tree)[nidx].cright();
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
auto split_gidx = -1;
auto fidx = candidate.split.findex;
auto default_dir_left = candidate.split.dir == LeftDir;
auto fidx_begin = hmat_.row_ptr[fidx];
auto fidx_end = hmat_.row_ptr[fidx + 1];
for (auto i = fidx_begin; i < fidx_end; ++i) {
if (candidate.split.fvalue == hmat_.cut[i]) {
split_gidx = static_cast<int32_t>(i);
}
}
for (auto& shard : shards) {
monitor.Start("update position kernel");
shard.temp_memory.LazyAllocate(sizeof(bst_uint));
auto d_left_count = shard.temp_memory.Pointer<bst_uint>();
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(bst_uint)));
dh::safe_cuda(cudaSetDevice(shard.device_idx));
auto segment = shard.ridx_segments[nidx];
CHECK_GT(segment.second - segment.first, 0);
auto d_ridx = shard.ridx.current();
auto d_position = shard.position.current();
auto d_gidx = shard.gidx;
auto row_stride = shard.row_stride;
dh::launch_n<1, 512>(
shard.device_idx, segment.second - segment.first,
[=] __device__(bst_uint idx) {
idx += segment.first;
auto ridx = d_ridx[idx];
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = d_gidx[row_begin + fidx];
} else {
gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin,
fidx_end);
}
int position;
if (gidx >= 0) {
// Feature is found
position = gidx <= split_gidx ? left_nidx : right_nidx;
} else {
// Feature is missing
position = default_dir_left ? left_nidx : right_nidx;
}
CountLeft(d_left_count, position, left_nidx);
d_position[idx] = position;
});
bst_uint left_count;
dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(bst_uint),
cudaMemcpyDeviceToHost));
monitor.Stop("update position kernel");
monitor.Start("sort");
shard.SortPosition(segment, left_nidx, right_nidx);
monitor.Stop("sort");
shard.ridx_segments[left_nidx] =
std::make_pair(segment.first, segment.first + left_count);
shard.ridx_segments[right_nidx] =
std::make_pair(segment.first + left_count, segment.second);
}
}
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
// Add new leaves
RegTree& tree = *p_tree;
tree.AddChilds(candidate.nid);
auto& parent = tree[candidate.nid];
parent.set_split(candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == LeftDir);
tree.stat(candidate.nid).loss_chg = candidate.split.loss_chg;
// Configure left child
auto left_weight = CalcWeight(param, candidate.split.left_sum);
tree[parent.cleft()].set_leaf(left_weight * param.learning_rate, 0);
tree.stat(parent.cleft()).base_weight = left_weight;
tree.stat(parent.cleft()).sum_hess = candidate.split.left_sum.GetHess();
// Configure right child
auto right_weight = CalcWeight(param, candidate.split.right_sum);
tree[parent.cright()].set_leaf(right_weight * param.learning_rate, 0);
tree.stat(parent.cright()).base_weight = right_weight;
tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess();
// Store sum gradients
for (auto& shard : shards) {
shard.node_sum_gradients[parent.cleft()] = candidate.split.left_sum;
shard.node_sum_gradients[parent.cright()] = candidate.split.right_sum;
}
this->UpdatePosition(candidate, p_tree);
}
void ColSampleTree() {
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
feature_set_tree.resize(info->num_col);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
}
struct Monitor {
bool debug_verbose = false;
std::string label = "";
std::map<std::string, dh::Timer> timer_map;
~Monitor() {
if (!debug_verbose) return;
std::cout << "Monitor: " << label << "\n";
for (auto& kv : timer_map) {
kv.second.PrintElapsed(kv.first);
}
}
void Init(std::string label, bool debug_verbose) {
this->debug_verbose = debug_verbose;
this->label = label;
}
void Start(const std::string& name) { timer_map[name].Start(); }
void Stop(const std::string& name) { timer_map[name].Stop(); }
};
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree* p_tree) {
auto& tree = *p_tree;
monitor.Start("InitData");
this->InitData(gpair, p_fmat, *p_tree);
monitor.Stop("InitData");
monitor.Start("InitRoot");
this->InitRoot(gpair, p_tree);
monitor.Stop("InitRoot");
unsigned timestamp = qexpand_->size();
auto num_leaves = 1;
while (!qexpand_->empty()) {
auto candidate = qexpand_->top();
qexpand_->pop();
if (!candidate.IsValid(param, num_leaves)) continue;
// std::cout << candidate;
monitor.Start("ApplySplit");
this->ApplySplit(candidate, p_tree);
monitor.Stop("ApplySplit");
num_leaves++;
auto left_child_nidx = tree[candidate.nid].cleft();
auto right_child_nidx = tree[candidate.nid].cright();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("BuildHist");
this->BuildHist(left_child_nidx);
this->BuildHist(right_child_nidx);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
auto splits =
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
qexpand_->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits[0],
timestamp++));
qexpand_->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx), splits[1],
timestamp++));
monitor.Stop("EvaluateSplits");
}
}
}
struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
unsigned timestamp;
ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split,
unsigned timestamp)
: nid(nid), depth(depth), split(split), timestamp(timestamp) {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= rt_eps) return false;
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
return true;
}
static bool ChildIsValid(const TrainParam& param, int depth,
int num_leaves) {
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
return true;
}
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};
inline static bool depth_wise(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool loss_guide(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}
TrainParam param;
common::HistCutMatrix hmat_;
common::GHistIndexMatrix gmat_;
MetaInfo* info;
bool initialised;
int n_devices;
int n_bins;
std::vector<DeviceShard> shards;
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
typedef std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>
ExpandQueue;
std::unique_ptr<ExpandQueue> qexpand_;
Monitor monitor;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMakerExperimental,
"grow_gpu_hist_experimental")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMakerExperimental(); });
} // namespace tree
} // namespace xgboost

View File

@ -5,57 +5,45 @@ import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import time
import ast
rng = np.random.RandomState(1994)
def run_benchmark(args, gpu_algorithm, cpu_algorithm):
def run_benchmark(args):
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size))
tmp = time.time()
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
if args.sparsity < 1.0:
X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7)
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
print ("DMatrix Start")
# omp way
dtrain = xgb.DMatrix(X_train, y_train, nthread=-1)
dtest = xgb.DMatrix(X_test, y_test, nthread=-1)
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
param = {'objective': 'binary:logistic',
'max_depth': 6,
'silent': 0,
'n_gpus': 1,
'gpu_id': 0,
'eval_metric': 'error',
'debug_verbose': 0,
}
param = {'objective': 'binary:logistic'}
if args.params is not '':
param.update(ast.literal_eval(args.params))
param['tree_method'] = gpu_algorithm
param['tree_method'] = args.tree_method
print("Training with '%s'" % param['tree_method'])
tmp = time.time()
xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")])
print ("Train Time: %s seconds" % (str(time.time() - tmp)))
param['silent'] = 1
param['tree_method'] = cpu_algorithm
print("Training with '%s'" % param['tree_method'])
tmp = time.time()
xgb.train(param, dtrain, args.iterations, evals=[(dtest, "test")])
print ("Time: %s seconds" % (str(time.time() - tmp)))
parser = argparse.ArgumentParser()
parser.add_argument('--algorithm', choices=['all', 'gpu_exact', 'gpu_hist'], default='all')
parser.add_argument('--tree_method', default='gpu_hist')
parser.add_argument('--sparsity', type=float, default=0.0)
parser.add_argument('--rows', type=int, default=1000000)
parser.add_argument('--columns', type=int, default=50)
parser.add_argument('--iterations', type=int, default=500)
parser.add_argument('--test_size', type=float, default=0.25)
parser.add_argument('--params', default='', help='Provide additional parameters as a Python dict string, e.g. --params \"{\'max_depth\':2}\"')
args = parser.parse_args()
if 'gpu_hist' in args.algorithm:
run_benchmark(args, args.algorithm, 'hist')
elif 'gpu_exact' in args.algorithm:
run_benchmark(args, args.algorithm, 'exact')
elif 'all' in args.algorithm:
run_benchmark(args, 'gpu_exact', 'exact')
run_benchmark(args, 'gpu_hist', 'hist')
run_benchmark(args)

View File

@ -41,7 +41,7 @@ void SpeedTest() {
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
dh::safe_cuda(cudaDeviceSynchronize());
double time = t.elapsedSeconds();
double time = t.ElapsedSeconds();
const int mb_size = 1048576;
size_t size = (sizeof(int) * h_rows.size()) / mb_size;
printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time,

View File

@ -0,0 +1,72 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include "../helpers.h"
#include "gtest/gtest.h"
#include "../../../src/tree/updater_gpu_hist_experimental.cu"
#include "../../../src/gbm/gbtree_model.h"
namespace xgboost {
namespace tree {
TEST(gpu_hist_experimental, TestSparseShard) {
int rows = 100;
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0.9);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam());
ASSERT_LT(shard.row_stride, columns);
auto host_gidx_buffer = shard.gidx_buffer.as_vector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
for (int i = 0; i < rows; i++) {
int row_offset = 0;
for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
ASSERT_EQ(gidx[i * shard.row_stride + row_offset], gmat.index[j]);
row_offset++;
}
for (; row_offset < shard.row_stride; row_offset++) {
ASSERT_EQ(gidx[i * shard.row_stride + row_offset], shard.null_gidx_value);
}
}
}
TEST(gpu_hist_experimental, TestDenseShard) {
int rows = 100;
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam());
ASSERT_EQ(shard.row_stride, columns);
auto host_gidx_buffer = shard.gidx_buffer.as_vector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
for (int i = 0; i < gmat.index.size(); i++) {
ASSERT_EQ(gidx[i], gmat.index[i]);
}
}
} // namespace tree
} // namespace xgboost

View File

@ -7,316 +7,114 @@ import xgboost as xgb
import numpy as np
import unittest
from nose.plugins.attrib import attr
from sklearn.datasets import load_digits, load_boston, load_breast_cancer, make_regression
rng = np.random.RandomState(1994)
dpath = 'demo/data/'
def non_increasing(L, tolerance):
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
#Check result is always decreasing and final accuracy is within tolerance
def assert_accuracy(res, tree_method, comparison_tree_method, tolerance):
assert non_increasing(res[tree_method], tolerance)
assert np.allclose(res[tree_method][-1], res[comparison_tree_method][-1], 1e-3, 1e-2)
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
print(*args, file=sys.stdout, **kwargs)
def train_boston(param_in, comparison_tree_method):
data = load_boston()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['rmse']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['rmse']
return res
def train_digits(param_in, comparison_tree_method):
data = load_digits()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param['objective'] = 'multi:softmax'
param['num_class'] = 10
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['merror']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['merror']
return res
def train_cancer(param_in, comparison_tree_method):
data = load_breast_cancer()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param['objective'] = 'binary:logistic'
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['error']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['error']
return res
def train_sparse(param_in, comparison_tree_method):
n = 5000
sparsity = 0.75
X, y = make_regression(n, random_state=rng)
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
dtrain = xgb.DMatrix(X, label=y)
param = {}
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['rmse']
param["tree_method"] = comparison_tree_method
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['rmse']
return res
def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, tolerance):
param = {'tree_method': tree_method}
for k, set in variable_param.items():
for val in set:
param_tmp = param.copy()
param_tmp[k] = val
print(param_tmp, file=sys.stderr)
assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
@attr('gpu')
class TestGPU(unittest.TestCase):
def test_grow_gpu(self):
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
def test_gpu_hist(self):
variable_param = {'max_depth': [2, 6, 11], 'max_bin': [2, 16, 1024], 'n_gpus': [1, -1]}
assert_updater_accuracy('gpu_hist', 'hist', variable_param, 0.02)
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def test_gpu_exact(self):
variable_param = {'max_depth': [2, 6, 15]}
assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02)
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'nthread': 0,
'eta': 1,
'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'tree_method': 'gpu_exact',
'nthread': 0,
'eta': 1,
'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
num_rounds = 10
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_exact',
'max_depth': 3,
'debug_verbose': 0,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_exact',
'max_depth': 2,
'debug_verbose': 0,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=num_rounds, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=num_rounds, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def test_grow_gpu_hist(self):
n_gpus = -1
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
for max_depth in range(3, 10): # TODO: Doesn't work with 2 for some tests
# eprint("max_depth=%d" % (max_depth))
for max_bin_i in range(3, 11):
max_bin = np.power(2, max_bin_i)
# eprint("max_bin=%d" % (max_bin))
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
'nthread': 0,
'eta': 1,
'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth,
'nthread': 0,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 1,
'debug_verbose': 0,
'n_gpus': 1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'nthread': 0,
'tree_method': 'gpu_hist',
'eta': 1,
'silent': 1,
'debug_verbose': 0,
'n_gpus': n_gpus,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
ag_res3 = {}
num_rounds = 10
# eprint("normal updater");
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
# eprint("grow_gpu_hist updater 1 gpu");
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
# eprint("grow_gpu_hist updater %d gpus" % (n_gpus));
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res3)
# assert 1==0
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
assert ag_res['test']['auc'] == ag_res3['test']['auc']
######################################################################
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'tree_method': 'gpu_hist',
'nthread': 0,
'max_depth': max_depth,
'n_gpus': 1,
'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'}
res = {}
# eprint("digits: grow_gpu_hist updater 1 gpu");
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
# assert self.non_decreasing(res['test']['auc'])
param2 = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'}
res2 = {}
# eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
xgb.train(param2, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res2)
assert self.non_decreasing(res2['train']['auc'])
# assert self.non_decreasing(res2['test']['auc'])
assert res['train']['auc'] == res2['train']['auc']
# assert res['test']['auc'] == res2['test']['auc']
######################################################################
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in rng.choice(X2.shape[0], size=num_rounds, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, num_rounds, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=num_rounds, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, num_rounds, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
# fail-safe test for max_bin
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
'debug_verbose': 0,
'eval_metric': 'auc',
'max_bin': max_bin}
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
# subsampling
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': max_depth,
'n_gpus': n_gpus,
'eval_metric': 'auc',
'colsample_bytree': 0.5,
'colsample_bylevel': 0.5,
'subsample': 0.5,
'debug_verbose': 0,
'max_bin': max_bin}
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
######################################################################
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'nthread': 0,
'tree_method': 'gpu_hist',
'max_depth': 2,
'n_gpus': n_gpus,
'debug_verbose': 0,
'eval_metric': 'auc',
'max_bin': 2}
res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
if max_bin > 32:
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
def test_gpu_hist_experimental(self):
variable_param = {'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]}
assert_updater_accuracy('gpu_hist_experimental', 'hist', variable_param, 0.01)