[GPU-Plugin] Various fixes (#2579)

* Fix test large

* Add check for max_depth 0

* Update readme

* Add LBS specialisation for dense data

* Add bst_gpair_precise

* Temporarily disable accuracy tests on test_large.py

* Solve unused variable compiler warning

* Fix max_bin > 1024 error
This commit is contained in:
Rory Mitchell 2017-08-05 22:16:23 +12:00 committed by GitHub
parent 03e213c7cd
commit eda9e180f0
8 changed files with 147 additions and 98 deletions

View File

@ -82,49 +82,67 @@ typedef uint64_t bst_ulong; // NOLINT(*)
/*! \brief float type, used for storing statistics */
typedef float bst_float;
/*! \brief gradient statistics pair usually needed in gradient boosting */
struct bst_gpair {
/*! \brief Implementation of gradient statistics pair */
template <typename T>
struct bst_gpair_internal {
/*! \brief gradient statistics */
bst_float grad;
T grad;
/*! \brief second order gradient statistics */
bst_float hess;
T hess;
XGBOOST_DEVICE bst_gpair() : grad(0), hess(0) {}
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {}
XGBOOST_DEVICE bst_gpair(bst_float grad, bst_float hess)
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess)
: grad(grad), hess(hess) {}
XGBOOST_DEVICE bst_gpair &operator+=(const bst_gpair &rhs) {
template <typename T2>
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g)
: grad(g.grad), hess(g.hess) {}
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) {
grad += rhs.grad;
hess += rhs.hess;
return *this;
}
XGBOOST_DEVICE bst_gpair operator+(const bst_gpair &rhs) const {
bst_gpair g;
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g;
g.grad = grad + rhs.grad;
g.hess = hess + rhs.hess;
return g;
}
XGBOOST_DEVICE bst_gpair &operator-=(const bst_gpair &rhs) {
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) {
grad -= rhs.grad;
hess -= rhs.hess;
return *this;
}
XGBOOST_DEVICE bst_gpair operator-(const bst_gpair &rhs) const {
bst_gpair g;
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g;
g.grad = grad - rhs.grad;
g.hess = hess - rhs.hess;
return g;
}
XGBOOST_DEVICE bst_gpair(int value) {
*this = bst_gpair(static_cast<float>(value), static_cast<float>(value));
XGBOOST_DEVICE bst_gpair_internal(int value) {
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value));
}
friend std::ostream &operator<<(std::ostream &os,
const bst_gpair_internal<T> &g) {
os << g.grad << "/" << g.hess;
return os;
}
};
/*! \brief gradient statistics pair usually needed in gradient boosting */
typedef bst_gpair_internal<float> bst_gpair;
/*! \brief High precision gradient statistics pair */
typedef bst_gpair_internal<double> bst_gpair_precise;
/*! \brief small eps gap for minimum split decision. */
const bst_float rt_eps = 1e-6f;

View File

@ -27,7 +27,7 @@ This plugin currently works with the CLI version and python version.
Python example:
```python
param['gpu_id'] = 1
param['gpu_id'] = 0
param['max_bin'] = 16
param['tree_method'] = 'gpu_hist'
```

View File

@ -15,28 +15,30 @@
namespace xgboost {
namespace tree {
template <typename gpair_t>
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
const bst_gpair& scan,
const bst_gpair& missing,
const bst_gpair& parent_sum,
const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
bool missing_left) {
bst_gpair left = scan;
gpair_t left = scan;
if (missing_left) {
left += missing;
}
bst_gpair right = parent_sum - left;
gpair_t right = parent_sum - left;
float left_gain = CalcGain(param, left.grad, left.hess);
float right_gain = CalcGain(param, right.grad, right.hess);
return left_gain + right_gain - parent_gain;
}
__device__ float inline loss_chg_missing(const bst_gpair& scan,
const bst_gpair& missing,
const bst_gpair& parent_sum,
template <typename gpair_t>
__device__ float inline loss_chg_missing(const gpair_t& scan,
const gpair_t& missing,
const gpair_t& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
@ -177,11 +179,11 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
}
struct GpairCallbackOp {
// Running prefix
bst_gpair running_total;
bst_gpair_precise running_total;
// Constructor
__device__ GpairCallbackOp() : running_total(bst_gpair()) {}
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
bst_gpair old_prefix = running_total;
__device__ GpairCallbackOp() : running_total(bst_gpair_precise()) {}
__device__ bst_gpair_precise operator()(bst_gpair_precise block_aggregate) {
bst_gpair_precise old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}

View File

@ -729,32 +729,8 @@ __global__ void LbsKernel(coordinate_t *d_coordinates,
}
}
/**
* \fn template <typename func_t, typename segments_iter, typename offset_t>
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
* segments_iter segments, offset_t num_segments, func_t f)
*
* \brief Load balancing search function. Reads a CSR type matrix description
* and allows a function to be executed on each element. Search 'modern GPU load
* balancing search' for more information.
*
* \author Rory
* \date 7/9/2017
*
* \tparam func_t Type of the function t.
* \tparam segments_iter Type of the segments iterator.
* \tparam offset_t Type of the offset.
* \tparam segments_t Type of the segments t.
* \param device_idx Zero-based index of the device.
* \param [in,out] temp_memory Temporary memory allocator.
* \param count Number of elements.
* \param segments Device pointer to segments.
* \param num_segments Number of segments.
* \param f Lambda to be executed on matrix elements.
*/
template <typename func_t, typename segments_iter, typename offset_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
segments_iter segments, offset_t num_segments, func_t f) {
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
dh::safe_cuda(cudaSetDevice(device_idx));
@ -775,4 +751,47 @@ void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
num_segments);
}
template <typename func_t, typename offset_t>
void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments, func_t f) {
CHECK(count % num_segments == 0) << "Data is not dense.";
launch_n(device_idx, count, [=]__device__(offset_t idx)
{
offset_t segment = idx / (count / num_segments);
f(idx, segment);
});
}
/**
* \fn template <typename func_t, typename segments_iter, typename offset_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, segments_iter segments, offset_t num_segments, bool is_dense, func_t f)
*
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
* to be executed on each element. Search 'modern GPU load balancing search' for more
* information.
*
* \author Rory
* \date 7/9/2017
*
* \tparam func_t Type of the function t.
* \tparam segments_iter Type of the segments iterator.
* \tparam offset_t Type of the offset.
* \param device_idx Zero-based index of the device.
* \param [in,out] temp_memory Temporary memory allocator.
* \param count Number of elements.
* \param segments Device pointer to segments.
* \param num_segments Number of segments.
* \param is_dense True if this object is dense.
* \param f Lambda to be executed on matrix elements.
*/
template <typename func_t, typename segments_iter, typename offset_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
segments_iter segments, offset_t num_segments, bool is_dense, func_t f) {
if (is_dense) {
DenseTransformLbs(device_idx, count, num_segments, f);
}
else {
SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments, f);
}
}
} // namespace dh

View File

@ -49,10 +49,10 @@ void DeviceHist::Init(int n_bins_in) {
void DeviceHist::Reset(int device_idx) {
cudaSetDevice(device_idx);
data.fill(bst_gpair());
data.fill(bst_gpair_precise());
}
bst_gpair* DeviceHist::GetLevelPtr(int depth) {
bst_gpair_precise* DeviceHist::GetLevelPtr(int depth) {
return data.data() + n_nodes(depth - 1) * n_bins;
}
@ -62,10 +62,25 @@ HistBuilder DeviceHist::GetBuilder() {
return HistBuilder(data.data(), n_bins);
}
HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins)
HistBuilder::HistBuilder(bst_gpair_precise* ptr, int n_bins)
: d_hist(ptr), n_bins(n_bins) {}
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
// Define double precision atomic add for older architectures
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
#else
__device__ double atomicAdd(double* address, double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
__device__ void HistBuilder::Add(bst_gpair_precise gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx;
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
// line lead to about 3X
@ -75,7 +90,7 @@ __device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
}
__device__ bst_gpair HistBuilder::Get(int gidx, int nidx) const {
__device__ bst_gpair_precise HistBuilder::Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx];
}
@ -104,6 +119,7 @@ GPUHistBuilder::~GPUHistBuilder() {
void GPUHistBuilder::Init(const TrainParam& param) {
CHECK(param.max_depth < 16) << "Tree depth too large.";
CHECK(param.max_depth != 0) << "Tree depth cannot be 0.";
CHECK(param.grow_policy != TrainParam::kLossGuide)
<< "Loss guided growth policy not supported. Use CPU algorithm.";
this->param = param;
@ -358,7 +374,7 @@ void GPUHistBuilder::BuildHist(int depth) {
auto hist_builder = hist_vec[d_idx].GetBuilder();
dh::TransformLbs(
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) {
int nidx = d_position[local_ridx]; // OPTMARK: latency
if (!is_active(nidx, depth)) return;
@ -396,8 +412,8 @@ void GPUHistBuilder::BuildHist(int depth) {
dh::safe_nccl(ncclAllReduce(
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / sizeof(float),
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair_precise) / sizeof(double),
ncclDouble, ncclSum, comms[d_idx], *(streams[d_idx])));
}
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
@ -428,9 +444,9 @@ void GPUHistBuilder::BuildHist(int depth) {
}
int gidx = idx % hist_builder.n_bins;
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
bst_gpair_precise parent = hist_builder.Get(gidx, parent_nidx(nidx));
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
bst_gpair other = hist_builder.Get(gidx, other_nidx);
bst_gpair_precise other = hist_builder.Get(gidx, other_nidx);
hist_builder.Add(parent - other, gidx,
nidx); // OPTMARK: This is slow, could use shared
// memory or cache results intead of writing to
@ -443,16 +459,16 @@ void GPUHistBuilder::BuildHist(int depth) {
template <int BLOCK_THREADS>
__global__ void find_split_kernel(
const bst_gpair* d_level_hist, int* d_feature_segments, int depth,
const bst_gpair_precise* d_level_hist, int* d_feature_segments, int depth,
int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp,
Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map,
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) {
typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
typedef cub::BlockScan<bst_gpair_precise, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT;
typedef cub::BlockReduce<bst_gpair_precise, BLOCK_THREADS> SumReduceT;
union TempStorage {
typename BlockScanT::TempStorage scan;
@ -461,12 +477,12 @@ __global__ void find_split_kernel(
};
struct UninitializedSplit : cub::Uninitialized<Split> {};
struct UninitializedGpair : cub::Uninitialized<bst_gpair> {};
struct UninitializedGpair : cub::Uninitialized<bst_gpair_precise> {};
__shared__ UninitializedSplit uninitialized_split;
Split& split = uninitialized_split.Alias();
__shared__ UninitializedGpair uninitialized_sum;
bst_gpair& shared_sum = uninitialized_sum.Alias();
bst_gpair_precise& shared_sum = uninitialized_sum.Alias();
__shared__ ArgMaxT block_max;
__shared__ TempStorage temp_storage;
@ -486,15 +502,14 @@ __global__ void find_split_kernel(
int begin = d_feature_segments[level_node_idx * n_features + fidx];
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
int gidx = (begin - (level_node_idx * n_bins)) + threadIdx.x;
bool thread_active = threadIdx.x < end - begin;
bst_gpair feature_sum = bst_gpair();
bst_gpair_precise feature_sum = bst_gpair_precise();
for (int reduce_begin = begin; reduce_begin < end;
reduce_begin += BLOCK_THREADS) {
bool thread_active = reduce_begin + threadIdx.x < end;
// Scan histogram
bst_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
: bst_gpair();
bst_gpair_precise bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
: bst_gpair_precise();
feature_sum +=
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
@ -508,17 +523,18 @@ __global__ void find_split_kernel(
GpairCallbackOp prefix_op = GpairCallbackOp();
for (int scan_begin = begin; scan_begin < end;
scan_begin += BLOCK_THREADS) {
bst_gpair bin =
thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair();
bool thread_active = scan_begin + threadIdx.x < end;
bst_gpair_precise bin =
thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair_precise();
BlockScanT(temp_storage.scan)
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Calculate gain
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients;
bst_gpair_precise parent_sum = d_nodes[node_idx].sum_gradients;
float parent_gain = d_nodes[node_idx].root_gain;
bst_gpair missing = parent_sum - shared_sum;
bst_gpair_precise missing = parent_sum - shared_sum;
bool missing_left;
float gain = thread_active
@ -541,6 +557,7 @@ __global__ void find_split_kernel(
// Best thread updates split
if (threadIdx.x == block_max.key) {
float fvalue;
int gidx = (scan_begin - (level_node_idx * n_bins)) + threadIdx.x;
if (threadIdx.x == 0 &&
begin == scan_begin) { // check at start of first tile
fvalue = d_fidx_min_map[fidx];
@ -548,8 +565,8 @@ __global__ void find_split_kernel(
fvalue = d_gidx_fvalue_map[gidx - 1];
}
bst_gpair left = missing_left ? bin + missing : bin;
bst_gpair right = parent_sum - left;
bst_gpair_precise left = missing_left ? bin + missing : bin;
bst_gpair_precise right = parent_sum - left;
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
}
@ -662,7 +679,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
int nodes_offset_device = d_idx * num_nodes_device;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
nodes_child_temp[d_idx].data(), nodes_offset_device,
@ -759,7 +776,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
int nodes_offset_device = d_idx * num_nodes_device;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
nodes_offset_device, fidx_min_map[d_idx].data(),
@ -812,7 +829,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
int nodes_offset_device = 0;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
nodes_offset_device, fidx_min_map[d_idx].data(),
@ -958,7 +975,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
dh::TransformLbs(
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) {
int pos = d_position[local_ridx];
if (!is_active(pos, depth)) {
return;

View File

@ -25,16 +25,16 @@ struct DeviceGMat {
};
struct HistBuilder {
bst_gpair *d_hist;
bst_gpair_precise *d_hist;
int n_bins;
__host__ __device__ HistBuilder(bst_gpair *ptr, int n_bins);
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const;
__device__ bst_gpair Get(int gidx, int nidx) const;
__host__ __device__ HistBuilder(bst_gpair_precise *ptr, int n_bins);
__device__ void Add(bst_gpair_precise gpair, int gidx, int nidx) const;
__device__ bst_gpair_precise Get(int gidx, int nidx) const;
};
struct DeviceHist {
int n_bins;
dh::dvec<bst_gpair> data;
dh::dvec<bst_gpair_precise> data;
void Init(int max_depth);
@ -42,7 +42,7 @@ struct DeviceHist {
HistBuilder GetBuilder();
bst_gpair *GetLevelPtr(int depth);
bst_gpair_precise *GetLevelPtr(int depth);
int LevelSize(int depth);
};

View File

@ -82,22 +82,16 @@ class TestGPU(unittest.TestCase):
num_rounds = 1
eprint("normal updater")
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
eprint("hist updater")
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_resb)
eprint("gpu_hist updater 1 gpu")
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res2)
eprint("gpu_hist updater all gpus")
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res3)
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001

View File

@ -117,7 +117,6 @@ struct GBTreeModel {
}
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) {
size_t old_ntree = trees.size();
for (size_t i = 0; i < new_trees.size(); ++i) {
trees.push_back(std::move(new_trees[i]));
tree_info.push_back(bst_group);