Added finding quantiles on GPU. (#3393)

* Added finding quantiles on GPU.

- this includes datasets where weights are assigned to data rows
- as the quantiles found by the new algorithm are not the same
  as those found by the old one, test thresholds in
    tests/python-gpu/test_gpu_updaters.py have been adjusted.

* Adjustments and improved testing for finding quantiles on the GPU.

- added C++ tests for the DeviceSketch() function
- reduced one of the thresholds in test_gpu_updaters.py
- adjusted the cuts found by the find_cuts_k kernel
This commit is contained in:
Andy Adinets 2018-07-27 04:03:16 +02:00 committed by Rory Mitchell
parent e2f09db77a
commit cc6a5a3666
14 changed files with 691 additions and 116 deletions

View File

@ -111,7 +111,7 @@ class CompressedBufferWriter {
symbol <<= 7 - ibit_end % 8;
for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) {
dh::AtomicOrByte(reinterpret_cast<unsigned int*>(buffer + detail::kPadding),
ibyte, symbol & 0xff);
ibyte, symbol & 0xff);
symbol >>= 8;
}
}

View File

@ -163,11 +163,41 @@ inline void CheckComputeCapability() {
}
}
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
}
/*!
* \brief Find the strict upper bound for an element in a sorted array
* using binary search.
* \param cuts pointer to the first element of the sorted array
* \param n length of the sorted array
* \param v value for which to find the upper bound
* \return the smallest index i such that v < cuts[i], or n if v is greater or equal
* than all elements of the array
*/
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
if (n == 0) {
return 0;
}
if (cuts[n - 1] <= v) {
return n;
}
if (cuts[0] > v) {
return 0;
}
int left = 0, right = n - 1;
while (right - left > 1) {
int middle = left + (right - left) / 2;
if (cuts[middle] > v) {
right = middle;
} else {
left = middle;
}
}
return right;
}
/*
* Range iterator
@ -252,6 +282,18 @@ T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
inline void RowSegments(size_t n_rows, size_t n_devices, std::vector<size_t>* segments) {
segments->push_back(0);
size_t row_begin = 0;
size_t shard_size = DivRoundUp(n_rows, n_devices);
for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) {
size_t row_end = std::min(row_begin + shard_size, n_rows);
segments->push_back(row_end);
row_begin = row_end;
}
}
template <typename L>
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
for (auto i : GridStrideRange(begin, end)) {

View File

@ -43,18 +43,28 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
auto tid = static_cast<unsigned>(omp_get_thread_num());
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
SparsePage::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
SparsePage::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
}
}
}
}
}
}
Init(&sketchs, max_num_bins);
}
void HistCutMatrix::Init
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
std::vector<WXQSketch>& sketchs = *in_sketchs;
constexpr int kFactor = 8;
// gather the histogram data
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
std::vector<WXQSketch::SummaryContainer> summary_array;
@ -68,7 +78,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
this->min_val.resize(info.num_col_);
this->min_val.resize(sketchs.size());
row_ptr.push_back(0);
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WXQSketch::SummaryContainer a;

398
src/common/hist_util.cu Normal file
View File

@ -0,0 +1,398 @@
/*!
* Copyright 2018 XGBoost contributors
*/
#include "./hist_util.h"
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <utility>
#include <vector>
#include "../tree/param.h"
#include "./host_device_vector.h"
#include "./device_helpers.cuh"
#include "./quantile.h"
namespace xgboost {
namespace common {
using WXQSketch = HistCutMatrix::WXQSketch;
__global__ void find_cuts_k
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
const float* __restrict__ cum_weights, int nsamples, int ncuts) {
// ncuts < nsamples
int icut = threadIdx.x + blockIdx.x * blockDim.x;
if (icut >= ncuts)
return;
WXQSketch::Entry v;
int isample = 0;
if (icut == 0) {
isample = 0;
} else if (icut == ncuts - 1) {
isample = nsamples - 1;
} else {
bst_float rank = cum_weights[nsamples - 1] / static_cast<float>(ncuts - 1)
* static_cast<float>(icut);
// -1 is used because cum_weights is an inclusive sum
isample = dh::UpperBound(cum_weights, nsamples, rank);
isample = max(0, min(isample, nsamples - 1));
}
// repeated values will be filtered out on the CPU
bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0;
bst_float rmax = cum_weights[isample];
cuts[icut] = WXQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
}
// predictate for thrust filtering that returns true if the element is not a NaN
struct IsNotNaN {
__device__ bool operator()(float a) const { return !isnan(a); }
};
__global__ void unpack_features_k
(float* __restrict__ fvalues, float* __restrict__ feature_weights,
const size_t* __restrict__ row_ptrs, const float* __restrict__ weights,
Entry* entries, size_t nrows_array, int ncols, size_t row_begin_ptr,
size_t nrows) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
if (irow >= nrows) {
return;
}
size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow];
int icol = threadIdx.y + blockIdx.y * blockDim.y;
if (icol >= row_length) {
return;
}
Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol];
size_t ind = entry.index * nrows_array + irow;
// if weights are present, ensure that a non-NaN value is written to weights
// if and only if it is also written to features
if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) {
fvalues[ind] = entry.fvalue;
if (feature_weights != nullptr) {
feature_weights[ind] = weights[irow];
}
}
}
// finds quantiles on the GPU
struct GPUSketcher {
// manage memory for a single GPU
struct DeviceShard {
int device_;
bst_uint row_begin_; // The row offset for this shard
bst_uint row_end_;
bst_uint n_rows_;
int num_cols_{0};
size_t n_cuts_{0};
size_t gpu_batch_nrows_{0};
bool has_weights_{false};
tree::TrainParam param_;
std::vector<WXQSketch> sketches_;
thrust::device_vector<size_t> row_ptrs_;
std::vector<WXQSketch::SummaryContainer> summaries_;
thrust::device_vector<Entry> entries_;
thrust::device_vector<bst_float> fvalues_;
thrust::device_vector<bst_float> feature_weights_;
thrust::device_vector<bst_float> fvalues_cur_;
thrust::device_vector<WXQSketch::Entry> cuts_d_;
thrust::host_vector<WXQSketch::Entry> cuts_h_;
thrust::device_vector<bst_float> weights_;
thrust::device_vector<bst_float> weights2_;
std::vector<size_t> n_cuts_cur_;
thrust::device_vector<size_t> num_elements_;
thrust::device_vector<char> tmp_storage_;
DeviceShard(int device, bst_uint row_begin, bst_uint row_end,
tree::TrainParam param) :
device_(device), row_begin_(row_begin), row_end_(row_end),
n_rows_(row_end - row_begin), param_(std::move(param)) {
}
void Init(const SparsePage& row_batch, const MetaInfo& info) {
num_cols_ = info.num_col_;
has_weights_ = info.weights_.size() > 0;
// find the batch size
if (param_.gpu_batch_nrows == 0) {
// By default, use no more than 1/16th of GPU memory
gpu_batch_nrows_ = dh::TotalMemory(device_) /
(16 * num_cols_ * sizeof(Entry));
} else if (param_.gpu_batch_nrows == -1) {
gpu_batch_nrows_ = n_rows_;
} else {
gpu_batch_nrows_ = param_.gpu_batch_nrows;
}
if (gpu_batch_nrows_ > n_rows_) {
gpu_batch_nrows_ = n_rows_;
}
// initialize sketches
sketches_.resize(num_cols_);
summaries_.resize(num_cols_);
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * param_.max_bin);
size_t dummy_nlevel;
WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_);
// double ncuts to be the same as the number of values
// in the temporary buffers of the sketches
n_cuts_ *= 2;
for (int icol = 0; icol < num_cols_; ++icol) {
sketches_[icol].Init(row_batch.Size(), eps);
summaries_[icol].Reserve(n_cuts_);
}
// allocate necessary GPU buffers
dh::safe_cuda(cudaSetDevice(device_));
entries_.resize(gpu_batch_nrows_ * num_cols_);
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
fvalues_cur_.resize(gpu_batch_nrows_);
cuts_d_.resize(n_cuts_ * num_cols_);
cuts_h_.resize(n_cuts_ * num_cols_);
weights_.resize(gpu_batch_nrows_);
weights2_.resize(gpu_batch_nrows_);
num_elements_.resize(1);
if (has_weights_) {
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
}
n_cuts_cur_.resize(num_cols_);
// allocate storage for CUB algorithms; the size is the maximum of the sizes
// required for various algorithm
size_t tmp_size = 0, cur_tmp_size = 0;
// size for sorting
if (has_weights_) {
cub::DeviceRadixSort::SortPairs
(nullptr, cur_tmp_size, fvalues_cur_.data().get(),
fvalues_.data().get(), weights_.data().get(), weights2_.data().get(),
gpu_batch_nrows_);
} else {
cub::DeviceRadixSort::SortKeys
(nullptr, cur_tmp_size, fvalues_cur_.data().get(), fvalues_.data().get(),
gpu_batch_nrows_);
}
tmp_size = std::max(tmp_size, cur_tmp_size);
// size for inclusive scan
if (has_weights_) {
cub::DeviceScan::InclusiveSum
(nullptr, cur_tmp_size, weights2_.begin(), weights_.begin(), gpu_batch_nrows_);
tmp_size = std::max(tmp_size, cur_tmp_size);
}
// size for reduction by key
cub::DeviceReduce::ReduceByKey
(nullptr, cur_tmp_size, fvalues_.begin(),
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
num_elements_.begin(), thrust::maximum<bst_float>(), gpu_batch_nrows_);
tmp_size = std::max(tmp_size, cur_tmp_size);
// size for filtering
cub::DeviceSelect::If
(nullptr, cur_tmp_size, fvalues_.begin(), fvalues_cur_.begin(),
num_elements_.begin(), gpu_batch_nrows_, IsNotNaN());
tmp_size = std::max(tmp_size, cur_tmp_size);
tmp_storage_.resize(tmp_size);
}
void FindColumnCuts(size_t batch_nrows, size_t icol) {
size_t tmp_size = tmp_storage_.size();
// filter out NaNs in feature values
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
cub::DeviceSelect::If
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
fvalues_cur_.data(), num_elements_.begin(), batch_nrows, IsNotNaN());
size_t nfvalues_cur = 0;
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
// compute cumulative weights using a prefix scan
if (has_weights_) {
// filter out NaNs in weights;
// since cub::DeviceSelect::If performs stable filtering,
// the weights are stored in the correct positions
auto feature_weights_begin = feature_weights_.data() +
icol * gpu_batch_nrows_;
cub::DeviceSelect::If
(tmp_storage_.data().get(), tmp_size, feature_weights_begin,
weights_.data().get(), num_elements_.begin(), batch_nrows, IsNotNaN());
// sort the values and weights
cub::DeviceRadixSort::SortPairs
(tmp_storage_.data().get(), tmp_size, fvalues_cur_.data().get(),
fvalues_begin.get(), weights_.data().get(), weights2_.data().get(),
nfvalues_cur);
// sum the weights to get cumulative weight values
cub::DeviceScan::InclusiveSum
(tmp_storage_.data().get(), tmp_size, weights2_.begin(),
weights_.begin(), nfvalues_cur);
} else {
// sort the batch values
cub::DeviceRadixSort::SortKeys
(tmp_storage_.data().get(), tmp_size,
fvalues_cur_.data().get(), fvalues_begin.get(), nfvalues_cur);
// fill in cumulative weights with counting iterator
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur,
weights_.begin());
}
// remove repeated items and sum the weights across them;
// non-negative weights are assumed
cub::DeviceReduce::ReduceByKey
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
num_elements_.begin(), thrust::maximum<bst_float>(), nfvalues_cur);
size_t n_unique = 0;
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
// extract cuts
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
// if less elements than cuts: copy all elements with their weights
if (n_cuts_ > n_unique) {
auto weights2_iter = weights2_.begin();
auto fvalues_iter = fvalues_cur_.begin();
auto cuts_iter = cuts_d_.begin() + icol * n_cuts_;
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
bst_float rmax = weights2_iter[i];
bst_float rmin = i > 0 ? weights2_iter[i - 1] : 0;
cuts_iter[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_iter[i]);
});
} else if (n_cuts_cur_[icol] > 0) {
// if more elements than cuts: use binary search on cumulative weights
int block = 256;
find_cuts_k<<<dh::DivRoundUp(n_cuts_cur_[icol], block), block>>>
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
dh::safe_cuda(cudaGetLastError());
}
}
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info,
size_t gpu_batch) {
// compute start and end indices
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
static_cast<size_t>(n_rows_));
size_t batch_nrows = batch_row_end - batch_row_begin;
size_t n_entries =
row_batch.offset[row_begin_ + batch_row_end] -
row_batch.offset[row_begin_ + batch_row_begin];
// copy the batch to the GPU
dh::safe_cuda
(cudaMemcpy(entries_.data().get(),
&row_batch.data[row_batch.offset[row_begin_ + batch_row_begin]],
n_entries * sizeof(Entry), cudaMemcpyDefault));
// copy the weights if necessary
if (has_weights_) {
dh::safe_cuda
(cudaMemcpy(weights_.data().get(),
info.weights_.data() + row_begin_ + batch_row_begin,
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
}
// unpack the features; also unpack weights if present
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
dim3 block3(64, 4, 1);
dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(num_cols_, block3.y), 1);
unpack_features_k<<<grid3, block3>>>
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin,
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_,
row_batch.offset[row_begin_ + batch_row_begin], batch_nrows);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol);
}
dh::safe_cuda(cudaDeviceSynchronize());
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
for (int icol = 0; icol < num_cols_; ++icol) {
summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
sketches_[icol].PushSummary(summaries_[icol]);
}
}
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
// copy rows to the device
dh::safe_cuda(cudaSetDevice(device_));
row_ptrs_.resize(n_rows_ + 1);
thrust::copy(row_batch.offset.data() + row_begin_,
row_batch.offset.data() + row_end_ + 1,
row_ptrs_.begin());
size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
SketchBatch(row_batch, info, gpu_batch);
}
}
};
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
// partition input matrix into row segments
std::vector<size_t> row_segments;
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
// create device shards
shards_.resize(devices_.Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
});
// compute sketches for each shard
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info);
shard->Sketch(batch, info);
});
// merge the sketches from all shards
// TODO(canonizer): do it in a tree-like reduction
int num_cols = info.num_col_;
std::vector<WXQSketch> sketches(num_cols);
WXQSketch::SummaryContainer summary;
for (int icol = 0; icol < num_cols; ++icol) {
sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin));
for (int shard = 0; shard < shards_.size(); ++shard) {
shards_[shard]->sketches_[icol].GetSummary(&summary);
sketches[icol].PushSummary(summary);
}
}
hmat->Init(&sketches, param_.max_bin);
}
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
devices_ = GPUSet::Range(param_.gpu_id, dh::NDevices(param_.n_gpus, n_rows));
}
std::vector<std::unique_ptr<DeviceShard>> shards_;
tree::TrainParam param_;
GPUSet devices_;
};
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat) {
GPUSketcher sketcher(param, info.num_row_);
sketcher.Sketch(batch, info, hmat);
}
} // namespace common
} // namespace xgboost

View File

@ -12,8 +12,11 @@
#include <vector>
#include "row_set.h"
#include "../tree/fast_hist_param.h"
#include "../tree/param.h"
#include "./quantile.h"
namespace xgboost {
namespace common {
using tree::FastHistParam;
@ -77,11 +80,20 @@ struct HistCutMatrix {
return {dmlc::BeginPtr(cut) + row_ptr[fid],
row_ptr[fid + 1] - row_ptr[fid]};
}
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
// create histogram cut matrix given statistics from data
// using approximate quantile sketch approach
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
};
/*! \brief Builds the cut matrix on the GPU */
void DeviceSketch
(const SparsePage& batch, const MetaInfo& info,
const tree::TrainParam& param, HistCutMatrix* hmat);
/*!
* \brief A single row in global histogram index.

View File

@ -35,9 +35,9 @@ struct WQSummary {
/*! \brief the value of data */
DType value;
// constructor
Entry() = default;
XGBOOST_DEVICE Entry() {} // NOLINT
// constructor
Entry(RType rmin, RType rmax, RType wmin, DType value)
XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
: rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
/*!
* \brief debug function, check Valid
@ -48,11 +48,11 @@ struct WQSummary {
CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
}
/*! \return rmin estimation for v strictly bigger than value */
inline RType RMinNext() const {
XGBOOST_DEVICE inline RType RMinNext() const {
return rmin + wmin;
}
/*! \return rmax estimation for v strictly smaller than value */
inline RType RMaxPrev() const {
XGBOOST_DEVICE inline RType RMaxPrev() const {
return rmax - wmin;
}
};
@ -158,6 +158,17 @@ struct WQSummary {
size = src.size;
std::memcpy(data, src.data, sizeof(Entry) * size);
}
inline void MakeFromSorted(const Entry* entries, size_t n) {
size = 0;
for (size_t i = 0; i < n;) {
size_t j = i + 1;
// ignore repeated values
for (; j < n && entries[j].value == entries[i].value; ++j) {}
data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
entries[i].value);
i = j;
}
}
/*!
* \brief debug function, validate whether the summary
* run consistency check to check if it is a valid summary
@ -676,6 +687,18 @@ class QuantileSketchTemplate {
* \param eps accuracy level of summary
*/
inline void Init(size_t maxn, double eps) {
LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
// lazy reserve the space, if there is only one value, no need to allocate space
inqueue.queue.resize(1);
inqueue.qtail = 0;
data.clear();
level.clear();
}
inline static void LimitSizeLevel
(size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) {
size_t& nlevel = *out_nlevel;
size_t& limit_size = *out_limit_size;
nlevel = 1;
while (true) {
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
@ -687,12 +710,8 @@ class QuantileSketchTemplate {
size_t n = (1ULL << nlevel);
CHECK(n * limit_size >= maxn) << "invalid init parameter";
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
// lazy reserve the space, if there is only one value, no need to allocate space
inqueue.queue.resize(1);
inqueue.qtail = 0;
data.clear();
level.clear();
}
/*!
* \brief add an element to a sketch
* \param x The element added to the sketch
@ -714,6 +733,13 @@ class QuantileSketchTemplate {
}
inqueue.Push(x, w);
}
inline void PushSummary(const Summary& summary) {
temp.Reserve(limit_size * 2);
temp.SetPrune(summary, limit_size * 2);
PushTemp();
}
/*! \brief push up temp */
inline void PushTemp() {
temp.Reserve(limit_size * 2);

View File

@ -77,6 +77,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
int gpu_id;
// number of GPUs to use
int n_gpus;
// number of rows in a single GPU batch
int gpu_batch_nrows;
// the criteria to use for ranking splits
std::string split_evaluator;
// declare the parameters
@ -186,6 +188,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_lower_bound(-1)
.set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
DMLC_DECLARE_FIELD(gpu_batch_nrows)
.set_lower_bound(-1)
.set_default(0)
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic")
.describe("The criteria to use for ranking splits");

View File

@ -1,7 +1,7 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/execution_policy.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
@ -9,6 +9,7 @@
#include <thrust/sequence.h>
#include <xgboost/tree_updater.h>
#include <algorithm>
#include <cmath>
#include <memory>
#include <queue>
#include <utility>
@ -227,26 +228,6 @@ struct CalcWeightTrainParam {
learning_rate(p.learning_rate) {}
};
// index of the first element in cuts greater than v, or n if none;
// cuts are ordered, and binary search is used
__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
if (n == 0)
return 0;
if (cuts[n - 1] <= v)
return n;
if (cuts[0] > v)
return 0;
int left = 0, right = n - 1;
while (right - left > 1) {
int middle = left + (right - left) / 2;
if (cuts[middle] > v)
right = middle;
else
left = middle;
}
return right;
}
__global__ void compress_bin_ellpack_k
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
const size_t* __restrict__ row_ptrs,
@ -266,7 +247,7 @@ __global__ void compress_bin_ellpack_k
float fvalue = entry.fvalue;
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
bin = upper_bound(feature_cuts, ncuts, fvalue);
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts)
bin = ncuts - 1;
bin += cut_rows[feature];
@ -330,6 +311,7 @@ struct DeviceShard {
dh::DVec<bst_float> prediction_cache;
std::vector<GradientPair> node_sum_gradients;
dh::DVec<GradientPair> node_sum_gradients_d;
thrust::device_vector<size_t> row_ptrs;
common::CompressedIterator<uint32_t> gidx;
size_t row_stride;
bst_uint row_begin_idx; // The row offset for this shard
@ -348,41 +330,51 @@ struct DeviceShard {
dh::CubMemory temp_memory;
// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int device_idx, int normalised_device_idx,
bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param)
bst_uint row_begin, bst_uint row_end, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
row_stride(0),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins),
n_bins(0),
null_gidx_value(0),
param(param),
prediction_cache_initialised(false),
can_use_smem_atomics(false) {}
void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
// copy cuts to the GPU
void InitRowPtrs(const SparsePage& row_batch) {
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::device_vector<float> cuts_d(hmat.cut);
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
// find the maximum row size
thrust::device_vector<size_t> row_ptr_d(
row_batch.offset.data() + row_begin_idx, row_batch.offset.data() + row_end_idx + 1);
auto row_iter = row_ptr_d.begin();
row_ptrs.resize(n_rows + 1);
thrust::copy(row_batch.offset.data() + row_begin_idx,
row_batch.offset.data() + row_end_idx + 1,
row_ptrs.begin());
auto row_iter = row_ptrs.begin();
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size),
decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
thrust::maximum<size_t>());
int num_symbols =
n_bins + 1;
thrust::maximum<size_t>());
}
void InitCompressedData(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
n_bins = hmat.row_ptr.back();
null_gidx_value = hmat.row_ptr.back();
// copy cuts to the GPU
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::device_vector<float> cuts_d(hmat.cut);
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
// allocate compressed bin data
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
num_symbols);
@ -391,17 +383,17 @@ struct DeviceShard {
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);
gidx_buffer.Fill(0);
int nbits = common::detail::SymbolBits(num_symbols);
// bin and compress entries in batches of rows
// use no more than 1/16th of GPU memory per batch
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
(16 * row_stride * sizeof(Entry));
if (gpu_batch_nrows > n_rows) {
gpu_batch_nrows = n_rows;
}
size_t gpu_batch_nrows = std::min
(dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(n_rows));
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
@ -423,7 +415,7 @@ struct DeviceShard {
dh::DivRoundUp(row_stride, block3.y), 1);
compress_bin_ellpack_k<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(),
row_ptr_d.data().get() + batch_row_begin,
row_ptrs.data().get() + batch_row_begin,
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
batch_row_begin, batch_nrows,
row_batch.offset[row_begin_idx + batch_row_begin],
@ -434,8 +426,8 @@ struct DeviceShard {
}
// free the memory that is no longer needed
row_ptr_d.resize(0);
row_ptr_d.shrink_to_fit();
row_ptrs.resize(0);
row_ptrs.shrink_to_fit();
entries_d.resize(0);
entries_d.shrink_to_fit();
@ -741,17 +733,9 @@ class GPUHistMaker : public TreeUpdater {
void InitDataOnce(DMatrix* dmat) {
info_ = &dmat->Info();
monitor_.Start("Quantiles", device_list_);
hmat_.Init(dmat, param_.max_bin);
monitor_.Stop("Quantiles", device_list_);
n_bins_ = hmat_.row_ptr.back();
int n_devices = dh::NDevices(param_.n_gpus, info_->num_row_);
bst_uint row_begin = 0;
bst_uint shard_size =
std::ceil(static_cast<double>(info_->num_row_) / n_devices);
device_list_.resize(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
int device_idx = (param_.gpu_id + d_idx) % dh::NVisibleDevices();
@ -762,32 +746,34 @@ class GPUHistMaker : public TreeUpdater {
// Partition input matrix into row segments
std::vector<size_t> row_segments;
dh::RowSegments(info_->num_row_, n_devices, &row_segments);
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next()) << "Empty batches are not supported";
const SparsePage& batch = iter->Value();
// Create device shards
shards_.resize(n_devices);
row_segments.push_back(0);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
bst_uint row_end =
std::min(static_cast<size_t>(row_begin + shard_size), info_->num_row_);
row_segments.push_back(row_end);
row_begin = row_end;
}
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(device_list_[i], i,
row_segments[i], row_segments[i + 1], param_));
shard->InitRowPtrs(batch);
});
monitor_.Start("Quantiles", device_list_);
common::DeviceSketch(batch, *info_, param_, &hmat_);
n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", device_list_);
monitor_.Start("BinningCompression", device_list_);
{
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next()) << "Empty batches are not supported";
const SparsePage& batch = iter->Value();
// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(device_list_[i], i,
row_segments[i], row_segments[i + 1], n_bins_, param_));
shard->Init(hmat_, batch);
});
CHECK(!iter->Next()) << "External memory not supported";
}
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->InitCompressedData(hmat_, batch);
});
monitor_.Stop("BinningCompression", device_list_);
CHECK(!iter->Next()) << "External memory not supported";
p_last_fmat_ = dmat;
initialised_ = true;
}
@ -1017,9 +1003,6 @@ class GPUHistMaker : public TreeUpdater {
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree) {
// Temporarily store number of threads so we can change it back later
int nthread = omp_get_max_threads();
auto& tree = *p_tree;
monitor_.Start("InitData", device_list_);

View File

@ -0,0 +1,60 @@
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h"
#include "gtest/gtest.h"
#include "xgboost/c_api.h"
#include <algorithm>
#include <cmath>
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>
namespace xgboost {
namespace common {
TEST(gpu_hist_util, TestDeviceSketch) {
// create the data
int nrows = 10001;
std::vector<float> test_data(nrows);
auto count_iter = thrust::make_counting_iterator(0);
// fill in reverse order
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
// create the DMatrix
DMatrixHandle dmat_handle;
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
&dmat_handle);
auto dmat = *static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
// parameters for finding quantiles
tree::TrainParam p;
p.max_bin = 20;
p.gpu_id = 0;
p.n_gpus = 1;
// ensure that the exact quantiles are found
p.gpu_batch_nrows = nrows * 10;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
hmat_cpu.Init(dmat.get(), p.max_bin);
// find the cuts on the GPU
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const SparsePage& batch = iter->Value();
HistCutMatrix hmat_gpu;
DeviceSketch(batch, dmat->Info(), p, &hmat_gpu);
CHECK(!iter->Next());
// compare the cuts
double eps = 1e-2;
ASSERT_EQ(hmat_gpu.min_val.size(), 1);
ASSERT_EQ(hmat_gpu.row_ptr.size(), 2);
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows);
}
}
} // namespace common
} // namespace xgboost

View File

@ -30,8 +30,9 @@ TEST(gpu_hist_experimental, TestSparseShard) {
iter->BeforeFirst();
CHECK(iter->Next());
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
CHECK(!iter->Next());
ASSERT_LT(shard.row_stride, columns);
@ -72,8 +73,9 @@ TEST(gpu_hist_experimental, TestDenseShard) {
CHECK(iter->Next());
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
shard.Init(hmat, batch);
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
CHECK(!iter->Next());
ASSERT_EQ(shard.row_stride, columns);

View File

@ -7,12 +7,26 @@ import unittest
class TestGPULinear(unittest.TestCase):
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
"Boston External Memory"]
def test_gpu_coordinate(self):
tm._skip_if_no_sklearn()
variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
'top_k': [10], 'tolerance': [1e-5], 'nthread': [2], 'alpha': [.005, .1], 'lambda': [0.005],
'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1]}
variable_param = {
'booster': ['gblinear'],
'updater': ['coord_descent'],
'eta': [0.5],
'top_k': [10],
'tolerance': [1e-5],
'nthread': [2],
'alpha': [.005, .1],
'lambda': [0.005],
'coordinate_selection': ['cyclic', 'random', 'greedy'],
'n_gpus': [-1]
}
for param in test_linear.parameter_combinations(variable_param):
results = test_linear.run_suite(param, 200, None, scale_features=True)
results = test_linear.run_suite(
param, 200, self.datasets, scale_features=True)
test_linear.assert_regression_result(results, 1e-2)
test_linear.assert_classification_result(results)

View File

@ -11,11 +11,10 @@ from regression_test_utilities import run_suite, parameter_combinations, \
def assert_gpu_results(cpu_results, gpu_results):
for cpu_res, gpu_res in zip(cpu_results, gpu_results):
# Check final eval result roughly equivalent
assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-3, 1e-2)
datasets = ["Boston", "Cancer", "Digits", "Sparse regression"]
assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-2, 1e-2)
datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
"Sparse regression with weights"]
class TestGPU(unittest.TestCase):
def test_gpu_exact(self):

View File

@ -15,11 +15,16 @@ except ImportError:
class Dataset:
def __init__(self, name, get_dataset, objective, metric, use_external_memory=False):
def __init__(self, name, get_dataset, objective, metric,
has_weights=False, use_external_memory=False):
self.name = name
self.objective = objective
self.metric = metric
self.X, self.y = get_dataset()
if has_weights:
self.X, self.y, self.w = get_dataset()
else:
self.X, self.y = get_dataset()
self.w = None
self.use_external_memory = use_external_memory
@ -49,6 +54,16 @@ def get_sparse():
return X, y
def get_sparse_weights():
rng = np.random.RandomState(199)
n = 10000
sparsity = 0.25
X, y = datasets.make_regression(n, random_state=rng)
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
w = np.array([rng.uniform(1, 10) for i in range(n)])
return X, y, w
def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
param = param_in.copy()
param["objective"] = dataset.objective
@ -64,9 +79,10 @@ def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
if dataset.use_external_memory:
np.savetxt('tmptmp_1234.csv', np.hstack((dataset.y.reshape(len(dataset.y), 1), X)),
delimiter=',')
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_',
weight=dataset.w)
else:
dtrain = xgb.DMatrix(X, dataset.y)
dtrain = xgb.DMatrix(X, dataset.y, weight=dataset.w)
print("Training on dataset: " + dataset.name, file=sys.stderr)
print("Using parameters: " + str(param), file=sys.stderr)
@ -112,6 +128,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
Dataset("Digits", get_digits, "multi:softmax", "merror"),
Dataset("Cancer", get_cancer, "binary:logistic", "error"),
Dataset("Sparse regression", get_sparse, "reg:linear", "rmse"),
Dataset("Sparse regression with weights", get_sparse_weights,
"reg:linear", "rmse", has_weights=True),
Dataset("Boston External Memory", get_boston, "reg:linear", "rmse",
use_external_memory=True)
]

View File

@ -52,6 +52,10 @@ def assert_classification_result(results):
class TestLinear(unittest.TestCase):
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
"Boston External Memory"]
def test_coordinate(self):
tm._skip_if_no_sklearn()
variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
@ -60,7 +64,7 @@ class TestLinear(unittest.TestCase):
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']
}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, None, scale_features=True)
results = run_suite(param, 200, self.datasets, scale_features=True)
assert_regression_result(results, 1e-2)
assert_classification_result(results)
@ -72,6 +76,6 @@ class TestLinear(unittest.TestCase):
'feature_selector': ['cyclic', 'shuffle']
}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, None, True)
results = run_suite(param, 200, self.datasets, True)
assert_regression_result(results, 1e-2)
assert_classification_result(results)