Added finding quantiles on GPU. (#3393)
* Added finding quantiles on GPU.
- this includes datasets where weights are assigned to data rows
- as the quantiles found by the new algorithm are not the same
as those found by the old one, test thresholds in
tests/python-gpu/test_gpu_updaters.py have been adjusted.
* Adjustments and improved testing for finding quantiles on the GPU.
- added C++ tests for the DeviceSketch() function
- reduced one of the thresholds in test_gpu_updaters.py
- adjusted the cuts found by the find_cuts_k kernel
This commit is contained in:
parent
e2f09db77a
commit
cc6a5a3666
@ -111,7 +111,7 @@ class CompressedBufferWriter {
|
||||
symbol <<= 7 - ibit_end % 8;
|
||||
for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) {
|
||||
dh::AtomicOrByte(reinterpret_cast<unsigned int*>(buffer + detail::kPadding),
|
||||
ibyte, symbol & 0xff);
|
||||
ibyte, symbol & 0xff);
|
||||
symbol >>= 8;
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,11 +163,41 @@ inline void CheckComputeCapability() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
|
||||
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Find the strict upper bound for an element in a sorted array
|
||||
* using binary search.
|
||||
* \param cuts pointer to the first element of the sorted array
|
||||
* \param n length of the sorted array
|
||||
* \param v value for which to find the upper bound
|
||||
* \return the smallest index i such that v < cuts[i], or n if v is greater or equal
|
||||
* than all elements of the array
|
||||
*/
|
||||
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
|
||||
if (n == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (cuts[n - 1] <= v) {
|
||||
return n;
|
||||
}
|
||||
if (cuts[0] > v) {
|
||||
return 0;
|
||||
}
|
||||
int left = 0, right = n - 1;
|
||||
while (right - left > 1) {
|
||||
int middle = left + (right - left) / 2;
|
||||
if (cuts[middle] > v) {
|
||||
right = middle;
|
||||
} else {
|
||||
left = middle;
|
||||
}
|
||||
}
|
||||
return right;
|
||||
}
|
||||
|
||||
/*
|
||||
* Range iterator
|
||||
@ -252,6 +282,18 @@ T1 DivRoundUp(const T1 a, const T2 b) {
|
||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||
}
|
||||
|
||||
inline void RowSegments(size_t n_rows, size_t n_devices, std::vector<size_t>* segments) {
|
||||
segments->push_back(0);
|
||||
size_t row_begin = 0;
|
||||
size_t shard_size = DivRoundUp(n_rows, n_devices);
|
||||
for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
size_t row_end = std::min(row_begin + shard_size, n_rows);
|
||||
segments->push_back(row_end);
|
||||
row_begin = row_end;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename L>
|
||||
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
||||
for (auto i : GridStrideRange(begin, end)) {
|
||||
|
||||
@ -43,18 +43,28 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
unsigned begin = std::min(nstep * tid, ncol);
|
||||
unsigned end = std::min(nstep * (tid + 1), ncol);
|
||||
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
||||
size_t ridx = batch.base_rowid + i;
|
||||
SparsePage::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
if (inst[j].index >= begin && inst[j].index < end) {
|
||||
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
|
||||
// do not iterate if no columns are assigned to the thread
|
||||
if (begin < end && end <= ncol) {
|
||||
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
||||
size_t ridx = batch.base_rowid + i;
|
||||
SparsePage::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
if (inst[j].index >= begin && inst[j].index < end) {
|
||||
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Init(&sketchs, max_num_bins);
|
||||
}
|
||||
|
||||
void HistCutMatrix::Init
|
||||
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
||||
constexpr int kFactor = 8;
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
@ -68,7 +78,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
|
||||
this->min_val.resize(info.num_col_);
|
||||
this->min_val.resize(sketchs.size());
|
||||
row_ptr.push_back(0);
|
||||
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
||||
WXQSketch::SummaryContainer a;
|
||||
|
||||
398
src/common/hist_util.cu
Normal file
398
src/common/hist_util.cu
Normal file
@ -0,0 +1,398 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include "./hist_util.h"
|
||||
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../tree/param.h"
|
||||
#include "./host_device_vector.h"
|
||||
#include "./device_helpers.cuh"
|
||||
#include "./quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
using WXQSketch = HistCutMatrix::WXQSketch;
|
||||
|
||||
__global__ void find_cuts_k
|
||||
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
|
||||
const float* __restrict__ cum_weights, int nsamples, int ncuts) {
|
||||
// ncuts < nsamples
|
||||
int icut = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (icut >= ncuts)
|
||||
return;
|
||||
WXQSketch::Entry v;
|
||||
int isample = 0;
|
||||
if (icut == 0) {
|
||||
isample = 0;
|
||||
} else if (icut == ncuts - 1) {
|
||||
isample = nsamples - 1;
|
||||
} else {
|
||||
bst_float rank = cum_weights[nsamples - 1] / static_cast<float>(ncuts - 1)
|
||||
* static_cast<float>(icut);
|
||||
// -1 is used because cum_weights is an inclusive sum
|
||||
isample = dh::UpperBound(cum_weights, nsamples, rank);
|
||||
isample = max(0, min(isample, nsamples - 1));
|
||||
}
|
||||
// repeated values will be filtered out on the CPU
|
||||
bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0;
|
||||
bst_float rmax = cum_weights[isample];
|
||||
cuts[icut] = WXQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
|
||||
}
|
||||
|
||||
// predictate for thrust filtering that returns true if the element is not a NaN
|
||||
struct IsNotNaN {
|
||||
__device__ bool operator()(float a) const { return !isnan(a); }
|
||||
};
|
||||
|
||||
__global__ void unpack_features_k
|
||||
(float* __restrict__ fvalues, float* __restrict__ feature_weights,
|
||||
const size_t* __restrict__ row_ptrs, const float* __restrict__ weights,
|
||||
Entry* entries, size_t nrows_array, int ncols, size_t row_begin_ptr,
|
||||
size_t nrows) {
|
||||
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
|
||||
if (irow >= nrows) {
|
||||
return;
|
||||
}
|
||||
size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow];
|
||||
int icol = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if (icol >= row_length) {
|
||||
return;
|
||||
}
|
||||
Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol];
|
||||
size_t ind = entry.index * nrows_array + irow;
|
||||
// if weights are present, ensure that a non-NaN value is written to weights
|
||||
// if and only if it is also written to features
|
||||
if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) {
|
||||
fvalues[ind] = entry.fvalue;
|
||||
if (feature_weights != nullptr) {
|
||||
feature_weights[ind] = weights[irow];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finds quantiles on the GPU
|
||||
struct GPUSketcher {
|
||||
// manage memory for a single GPU
|
||||
struct DeviceShard {
|
||||
int device_;
|
||||
bst_uint row_begin_; // The row offset for this shard
|
||||
bst_uint row_end_;
|
||||
bst_uint n_rows_;
|
||||
int num_cols_{0};
|
||||
size_t n_cuts_{0};
|
||||
size_t gpu_batch_nrows_{0};
|
||||
bool has_weights_{false};
|
||||
|
||||
tree::TrainParam param_;
|
||||
std::vector<WXQSketch> sketches_;
|
||||
thrust::device_vector<size_t> row_ptrs_;
|
||||
std::vector<WXQSketch::SummaryContainer> summaries_;
|
||||
thrust::device_vector<Entry> entries_;
|
||||
thrust::device_vector<bst_float> fvalues_;
|
||||
thrust::device_vector<bst_float> feature_weights_;
|
||||
thrust::device_vector<bst_float> fvalues_cur_;
|
||||
thrust::device_vector<WXQSketch::Entry> cuts_d_;
|
||||
thrust::host_vector<WXQSketch::Entry> cuts_h_;
|
||||
thrust::device_vector<bst_float> weights_;
|
||||
thrust::device_vector<bst_float> weights2_;
|
||||
std::vector<size_t> n_cuts_cur_;
|
||||
thrust::device_vector<size_t> num_elements_;
|
||||
thrust::device_vector<char> tmp_storage_;
|
||||
|
||||
DeviceShard(int device, bst_uint row_begin, bst_uint row_end,
|
||||
tree::TrainParam param) :
|
||||
device_(device), row_begin_(row_begin), row_end_(row_end),
|
||||
n_rows_(row_end - row_begin), param_(std::move(param)) {
|
||||
}
|
||||
|
||||
void Init(const SparsePage& row_batch, const MetaInfo& info) {
|
||||
num_cols_ = info.num_col_;
|
||||
has_weights_ = info.weights_.size() > 0;
|
||||
|
||||
// find the batch size
|
||||
if (param_.gpu_batch_nrows == 0) {
|
||||
// By default, use no more than 1/16th of GPU memory
|
||||
gpu_batch_nrows_ = dh::TotalMemory(device_) /
|
||||
(16 * num_cols_ * sizeof(Entry));
|
||||
} else if (param_.gpu_batch_nrows == -1) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
} else {
|
||||
gpu_batch_nrows_ = param_.gpu_batch_nrows;
|
||||
}
|
||||
if (gpu_batch_nrows_ > n_rows_) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
}
|
||||
|
||||
// initialize sketches
|
||||
sketches_.resize(num_cols_);
|
||||
summaries_.resize(num_cols_);
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * param_.max_bin);
|
||||
size_t dummy_nlevel;
|
||||
WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_);
|
||||
// double ncuts to be the same as the number of values
|
||||
// in the temporary buffers of the sketches
|
||||
n_cuts_ *= 2;
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
sketches_[icol].Init(row_batch.Size(), eps);
|
||||
summaries_[icol].Reserve(n_cuts_);
|
||||
}
|
||||
|
||||
// allocate necessary GPU buffers
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
entries_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_cur_.resize(gpu_batch_nrows_);
|
||||
cuts_d_.resize(n_cuts_ * num_cols_);
|
||||
cuts_h_.resize(n_cuts_ * num_cols_);
|
||||
weights_.resize(gpu_batch_nrows_);
|
||||
weights2_.resize(gpu_batch_nrows_);
|
||||
num_elements_.resize(1);
|
||||
|
||||
if (has_weights_) {
|
||||
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
}
|
||||
n_cuts_cur_.resize(num_cols_);
|
||||
|
||||
// allocate storage for CUB algorithms; the size is the maximum of the sizes
|
||||
// required for various algorithm
|
||||
size_t tmp_size = 0, cur_tmp_size = 0;
|
||||
// size for sorting
|
||||
if (has_weights_) {
|
||||
cub::DeviceRadixSort::SortPairs
|
||||
(nullptr, cur_tmp_size, fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(), weights_.data().get(), weights2_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortKeys
|
||||
(nullptr, cur_tmp_size, fvalues_cur_.data().get(), fvalues_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
}
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for inclusive scan
|
||||
if (has_weights_) {
|
||||
cub::DeviceScan::InclusiveSum
|
||||
(nullptr, cur_tmp_size, weights2_.begin(), weights_.begin(), gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
}
|
||||
// size for reduction by key
|
||||
cub::DeviceReduce::ReduceByKey
|
||||
(nullptr, cur_tmp_size, fvalues_.begin(),
|
||||
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
|
||||
num_elements_.begin(), thrust::maximum<bst_float>(), gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for filtering
|
||||
cub::DeviceSelect::If
|
||||
(nullptr, cur_tmp_size, fvalues_.begin(), fvalues_cur_.begin(),
|
||||
num_elements_.begin(), gpu_batch_nrows_, IsNotNaN());
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
|
||||
tmp_storage_.resize(tmp_size);
|
||||
}
|
||||
|
||||
void FindColumnCuts(size_t batch_nrows, size_t icol) {
|
||||
size_t tmp_size = tmp_storage_.size();
|
||||
// filter out NaNs in feature values
|
||||
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
|
||||
fvalues_cur_.data(), num_elements_.begin(), batch_nrows, IsNotNaN());
|
||||
size_t nfvalues_cur = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
|
||||
|
||||
// compute cumulative weights using a prefix scan
|
||||
if (has_weights_) {
|
||||
// filter out NaNs in weights;
|
||||
// since cub::DeviceSelect::If performs stable filtering,
|
||||
// the weights are stored in the correct positions
|
||||
auto feature_weights_begin = feature_weights_.data() +
|
||||
icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If
|
||||
(tmp_storage_.data().get(), tmp_size, feature_weights_begin,
|
||||
weights_.data().get(), num_elements_.begin(), batch_nrows, IsNotNaN());
|
||||
|
||||
// sort the values and weights
|
||||
cub::DeviceRadixSort::SortPairs
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(), weights_.data().get(), weights2_.data().get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// sum the weights to get cumulative weight values
|
||||
cub::DeviceScan::InclusiveSum
|
||||
(tmp_storage_.data().get(), tmp_size, weights2_.begin(),
|
||||
weights_.begin(), nfvalues_cur);
|
||||
} else {
|
||||
// sort the batch values
|
||||
cub::DeviceRadixSort::SortKeys
|
||||
(tmp_storage_.data().get(), tmp_size,
|
||||
fvalues_cur_.data().get(), fvalues_begin.get(), nfvalues_cur);
|
||||
|
||||
// fill in cumulative weights with counting iterator
|
||||
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur,
|
||||
weights_.begin());
|
||||
}
|
||||
|
||||
// remove repeated items and sum the weights across them;
|
||||
// non-negative weights are assumed
|
||||
cub::DeviceReduce::ReduceByKey
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
|
||||
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
|
||||
num_elements_.begin(), thrust::maximum<bst_float>(), nfvalues_cur);
|
||||
size_t n_unique = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
|
||||
|
||||
// extract cuts
|
||||
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
|
||||
// if less elements than cuts: copy all elements with their weights
|
||||
if (n_cuts_ > n_unique) {
|
||||
auto weights2_iter = weights2_.begin();
|
||||
auto fvalues_iter = fvalues_cur_.begin();
|
||||
auto cuts_iter = cuts_d_.begin() + icol * n_cuts_;
|
||||
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
|
||||
bst_float rmax = weights2_iter[i];
|
||||
bst_float rmin = i > 0 ? weights2_iter[i - 1] : 0;
|
||||
cuts_iter[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_iter[i]);
|
||||
});
|
||||
} else if (n_cuts_cur_[icol] > 0) {
|
||||
// if more elements than cuts: use binary search on cumulative weights
|
||||
int block = 256;
|
||||
find_cuts_k<<<dh::DivRoundUp(n_cuts_cur_[icol], block), block>>>
|
||||
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
|
||||
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info,
|
||||
size_t gpu_batch) {
|
||||
// compute start and end indices
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
|
||||
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
|
||||
static_cast<size_t>(n_rows_));
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
size_t n_entries =
|
||||
row_batch.offset[row_begin_ + batch_row_end] -
|
||||
row_batch.offset[row_begin_ + batch_row_begin];
|
||||
// copy the batch to the GPU
|
||||
dh::safe_cuda
|
||||
(cudaMemcpy(entries_.data().get(),
|
||||
&row_batch.data[row_batch.offset[row_begin_ + batch_row_begin]],
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
// copy the weights if necessary
|
||||
if (has_weights_) {
|
||||
dh::safe_cuda
|
||||
(cudaMemcpy(weights_.data().get(),
|
||||
info.weights_.data() + row_begin_ + batch_row_begin,
|
||||
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// unpack the features; also unpack weights if present
|
||||
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
|
||||
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
|
||||
|
||||
dim3 block3(64, 4, 1);
|
||||
dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
|
||||
dh::DivRoundUp(num_cols_, block3.y), 1);
|
||||
unpack_features_k<<<grid3, block3>>>
|
||||
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||
row_ptrs_.data().get() + batch_row_begin,
|
||||
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
||||
gpu_batch_nrows_, num_cols_,
|
||||
row_batch.offset[row_begin_ + batch_row_begin], batch_nrows);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
FindColumnCuts(batch_nrows, icol);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
|
||||
sketches_[icol].PushSummary(summaries_[icol]);
|
||||
}
|
||||
}
|
||||
|
||||
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
|
||||
// copy rows to the device
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
row_ptrs_.resize(n_rows_ + 1);
|
||||
thrust::copy(row_batch.offset.data() + row_begin_,
|
||||
row_batch.offset.data() + row_end_ + 1,
|
||||
row_ptrs_.begin());
|
||||
|
||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
||||
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
SketchBatch(row_batch, info, gpu_batch);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
|
||||
// partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
|
||||
|
||||
// create device shards
|
||||
shards_.resize(devices_.Size());
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
|
||||
});
|
||||
|
||||
// compute sketches for each shard
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->Init(batch, info);
|
||||
shard->Sketch(batch, info);
|
||||
});
|
||||
|
||||
// merge the sketches from all shards
|
||||
// TODO(canonizer): do it in a tree-like reduction
|
||||
int num_cols = info.num_col_;
|
||||
std::vector<WXQSketch> sketches(num_cols);
|
||||
WXQSketch::SummaryContainer summary;
|
||||
for (int icol = 0; icol < num_cols; ++icol) {
|
||||
sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin));
|
||||
for (int shard = 0; shard < shards_.size(); ++shard) {
|
||||
shards_[shard]->sketches_[icol].GetSummary(&summary);
|
||||
sketches[icol].PushSummary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
hmat->Init(&sketches, param_.max_bin);
|
||||
}
|
||||
|
||||
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
|
||||
devices_ = GPUSet::Range(param_.gpu_id, dh::NDevices(param_.n_gpus, n_rows));
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard>> shards_;
|
||||
tree::TrainParam param_;
|
||||
GPUSet devices_;
|
||||
};
|
||||
|
||||
void DeviceSketch
|
||||
(const SparsePage& batch, const MetaInfo& info,
|
||||
const tree::TrainParam& param, HistCutMatrix* hmat) {
|
||||
GPUSketcher sketcher(param, info.num_row_);
|
||||
sketcher.Sketch(batch, info, hmat);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -12,8 +12,11 @@
|
||||
#include <vector>
|
||||
#include "row_set.h"
|
||||
#include "../tree/fast_hist_param.h"
|
||||
#include "../tree/param.h"
|
||||
#include "./quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
namespace common {
|
||||
|
||||
using tree::FastHistParam;
|
||||
@ -77,11 +80,20 @@ struct HistCutMatrix {
|
||||
return {dmlc::BeginPtr(cut) + row_ptr[fid],
|
||||
row_ptr[fid + 1] - row_ptr[fid]};
|
||||
}
|
||||
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
|
||||
// create histogram cut matrix given statistics from data
|
||||
// using approximate quantile sketch approach
|
||||
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
|
||||
|
||||
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
||||
};
|
||||
|
||||
/*! \brief Builds the cut matrix on the GPU */
|
||||
void DeviceSketch
|
||||
(const SparsePage& batch, const MetaInfo& info,
|
||||
const tree::TrainParam& param, HistCutMatrix* hmat);
|
||||
|
||||
/*!
|
||||
* \brief A single row in global histogram index.
|
||||
|
||||
@ -35,9 +35,9 @@ struct WQSummary {
|
||||
/*! \brief the value of data */
|
||||
DType value;
|
||||
// constructor
|
||||
Entry() = default;
|
||||
XGBOOST_DEVICE Entry() {} // NOLINT
|
||||
// constructor
|
||||
Entry(RType rmin, RType rmax, RType wmin, DType value)
|
||||
XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
|
||||
: rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
|
||||
/*!
|
||||
* \brief debug function, check Valid
|
||||
@ -48,11 +48,11 @@ struct WQSummary {
|
||||
CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
|
||||
}
|
||||
/*! \return rmin estimation for v strictly bigger than value */
|
||||
inline RType RMinNext() const {
|
||||
XGBOOST_DEVICE inline RType RMinNext() const {
|
||||
return rmin + wmin;
|
||||
}
|
||||
/*! \return rmax estimation for v strictly smaller than value */
|
||||
inline RType RMaxPrev() const {
|
||||
XGBOOST_DEVICE inline RType RMaxPrev() const {
|
||||
return rmax - wmin;
|
||||
}
|
||||
};
|
||||
@ -158,6 +158,17 @@ struct WQSummary {
|
||||
size = src.size;
|
||||
std::memcpy(data, src.data, sizeof(Entry) * size);
|
||||
}
|
||||
inline void MakeFromSorted(const Entry* entries, size_t n) {
|
||||
size = 0;
|
||||
for (size_t i = 0; i < n;) {
|
||||
size_t j = i + 1;
|
||||
// ignore repeated values
|
||||
for (; j < n && entries[j].value == entries[i].value; ++j) {}
|
||||
data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
|
||||
entries[i].value);
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief debug function, validate whether the summary
|
||||
* run consistency check to check if it is a valid summary
|
||||
@ -676,6 +687,18 @@ class QuantileSketchTemplate {
|
||||
* \param eps accuracy level of summary
|
||||
*/
|
||||
inline void Init(size_t maxn, double eps) {
|
||||
LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
|
||||
// lazy reserve the space, if there is only one value, no need to allocate space
|
||||
inqueue.queue.resize(1);
|
||||
inqueue.qtail = 0;
|
||||
data.clear();
|
||||
level.clear();
|
||||
}
|
||||
|
||||
inline static void LimitSizeLevel
|
||||
(size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) {
|
||||
size_t& nlevel = *out_nlevel;
|
||||
size_t& limit_size = *out_limit_size;
|
||||
nlevel = 1;
|
||||
while (true) {
|
||||
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
||||
@ -687,12 +710,8 @@ class QuantileSketchTemplate {
|
||||
size_t n = (1ULL << nlevel);
|
||||
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
||||
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
|
||||
// lazy reserve the space, if there is only one value, no need to allocate space
|
||||
inqueue.queue.resize(1);
|
||||
inqueue.qtail = 0;
|
||||
data.clear();
|
||||
level.clear();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief add an element to a sketch
|
||||
* \param x The element added to the sketch
|
||||
@ -714,6 +733,13 @@ class QuantileSketchTemplate {
|
||||
}
|
||||
inqueue.Push(x, w);
|
||||
}
|
||||
|
||||
inline void PushSummary(const Summary& summary) {
|
||||
temp.Reserve(limit_size * 2);
|
||||
temp.SetPrune(summary, limit_size * 2);
|
||||
PushTemp();
|
||||
}
|
||||
|
||||
/*! \brief push up temp */
|
||||
inline void PushTemp() {
|
||||
temp.Reserve(limit_size * 2);
|
||||
|
||||
@ -77,6 +77,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
int gpu_id;
|
||||
// number of GPUs to use
|
||||
int n_gpus;
|
||||
// number of rows in a single GPU batch
|
||||
int gpu_batch_nrows;
|
||||
// the criteria to use for ranking splits
|
||||
std::string split_evaluator;
|
||||
// declare the parameters
|
||||
@ -186,6 +188,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.set_lower_bound(-1)
|
||||
.set_default(1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");
|
||||
DMLC_DECLARE_FIELD(gpu_batch_nrows)
|
||||
.set_lower_bound(-1)
|
||||
.set_default(0)
|
||||
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
|
||||
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
|
||||
DMLC_DECLARE_FIELD(split_evaluator)
|
||||
.set_default("elastic_net,monotonic")
|
||||
.describe("The criteria to use for ranking splits");
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
@ -9,6 +9,7 @@
|
||||
#include <thrust/sequence.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
@ -227,26 +228,6 @@ struct CalcWeightTrainParam {
|
||||
learning_rate(p.learning_rate) {}
|
||||
};
|
||||
|
||||
// index of the first element in cuts greater than v, or n if none;
|
||||
// cuts are ordered, and binary search is used
|
||||
__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
|
||||
if (n == 0)
|
||||
return 0;
|
||||
if (cuts[n - 1] <= v)
|
||||
return n;
|
||||
if (cuts[0] > v)
|
||||
return 0;
|
||||
int left = 0, right = n - 1;
|
||||
while (right - left > 1) {
|
||||
int middle = left + (right - left) / 2;
|
||||
if (cuts[middle] > v)
|
||||
right = middle;
|
||||
else
|
||||
left = middle;
|
||||
}
|
||||
return right;
|
||||
}
|
||||
|
||||
__global__ void compress_bin_ellpack_k
|
||||
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
|
||||
const size_t* __restrict__ row_ptrs,
|
||||
@ -266,7 +247,7 @@ __global__ void compress_bin_ellpack_k
|
||||
float fvalue = entry.fvalue;
|
||||
const float *feature_cuts = &cuts[cut_rows[feature]];
|
||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
||||
bin = upper_bound(feature_cuts, ncuts, fvalue);
|
||||
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
|
||||
if (bin >= ncuts)
|
||||
bin = ncuts - 1;
|
||||
bin += cut_rows[feature];
|
||||
@ -330,6 +311,7 @@ struct DeviceShard {
|
||||
dh::DVec<bst_float> prediction_cache;
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
common::CompressedIterator<uint32_t> gidx;
|
||||
size_t row_stride;
|
||||
bst_uint row_begin_idx; // The row offset for this shard
|
||||
@ -348,41 +330,51 @@ struct DeviceShard {
|
||||
|
||||
dh::CubMemory temp_memory;
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int device_idx, int normalised_device_idx,
|
||||
bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param)
|
||||
bst_uint row_begin, bst_uint row_end, TrainParam param)
|
||||
: device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(n_bins),
|
||||
null_gidx_value(n_bins),
|
||||
n_bins(0),
|
||||
null_gidx_value(0),
|
||||
param(param),
|
||||
prediction_cache_initialised(false),
|
||||
can_use_smem_atomics(false) {}
|
||||
|
||||
void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||
// copy cuts to the GPU
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
thrust::device_vector<float> cuts_d(hmat.cut);
|
||||
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
|
||||
|
||||
// find the maximum row size
|
||||
thrust::device_vector<size_t> row_ptr_d(
|
||||
row_batch.offset.data() + row_begin_idx, row_batch.offset.data() + row_end_idx + 1);
|
||||
|
||||
auto row_iter = row_ptr_d.begin();
|
||||
row_ptrs.resize(n_rows + 1);
|
||||
thrust::copy(row_batch.offset.data() + row_begin_idx,
|
||||
row_batch.offset.data() + row_end_idx + 1,
|
||||
row_ptrs.begin());
|
||||
auto row_iter = row_ptrs.begin();
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size),
|
||||
decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
thrust::maximum<size_t>());
|
||||
int num_symbols =
|
||||
n_bins + 1;
|
||||
thrust::maximum<size_t>());
|
||||
}
|
||||
|
||||
void InitCompressedData(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||
n_bins = hmat.row_ptr.back();
|
||||
null_gidx_value = hmat.row_ptr.back();
|
||||
|
||||
// copy cuts to the GPU
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
thrust::device_vector<float> cuts_d(hmat.cut);
|
||||
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
@ -391,17 +383,17 @@ struct DeviceShard {
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);
|
||||
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
int nbits = common::detail::SymbolBits(num_symbols);
|
||||
|
||||
// bin and compress entries in batches of rows
|
||||
// use no more than 1/16th of GPU memory per batch
|
||||
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
|
||||
(16 * row_stride * sizeof(Entry));
|
||||
if (gpu_batch_nrows > n_rows) {
|
||||
gpu_batch_nrows = n_rows;
|
||||
}
|
||||
size_t gpu_batch_nrows = std::min
|
||||
(dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(n_rows));
|
||||
|
||||
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
|
||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
||||
@ -423,7 +415,7 @@ struct DeviceShard {
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
compress_bin_ellpack_k<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(),
|
||||
row_ptr_d.data().get() + batch_row_begin,
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_batch.offset[row_begin_idx + batch_row_begin],
|
||||
@ -434,8 +426,8 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
// free the memory that is no longer needed
|
||||
row_ptr_d.resize(0);
|
||||
row_ptr_d.shrink_to_fit();
|
||||
row_ptrs.resize(0);
|
||||
row_ptrs.shrink_to_fit();
|
||||
entries_d.resize(0);
|
||||
entries_d.shrink_to_fit();
|
||||
|
||||
@ -741,17 +733,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
info_ = &dmat->Info();
|
||||
monitor_.Start("Quantiles", device_list_);
|
||||
hmat_.Init(dmat, param_.max_bin);
|
||||
monitor_.Stop("Quantiles", device_list_);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
|
||||
int n_devices = dh::NDevices(param_.n_gpus, info_->num_row_);
|
||||
|
||||
bst_uint row_begin = 0;
|
||||
bst_uint shard_size =
|
||||
std::ceil(static_cast<double>(info_->num_row_) / n_devices);
|
||||
|
||||
device_list_.resize(n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
int device_idx = (param_.gpu_id + d_idx) % dh::NVisibleDevices();
|
||||
@ -762,32 +746,34 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
// Partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
dh::RowSegments(info_->num_row_, n_devices, &row_segments);
|
||||
|
||||
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next()) << "Empty batches are not supported";
|
||||
const SparsePage& batch = iter->Value();
|
||||
// Create device shards
|
||||
shards_.resize(n_devices);
|
||||
row_segments.push_back(0);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
bst_uint row_end =
|
||||
std::min(static_cast<size_t>(row_begin + shard_size), info_->num_row_);
|
||||
row_segments.push_back(row_end);
|
||||
row_begin = row_end;
|
||||
}
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(device_list_[i], i,
|
||||
row_segments[i], row_segments[i + 1], param_));
|
||||
shard->InitRowPtrs(batch);
|
||||
});
|
||||
|
||||
monitor_.Start("Quantiles", device_list_);
|
||||
common::DeviceSketch(batch, *info_, param_, &hmat_);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
monitor_.Stop("Quantiles", device_list_);
|
||||
|
||||
monitor_.Start("BinningCompression", device_list_);
|
||||
{
|
||||
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next()) << "Empty batches are not supported";
|
||||
const SparsePage& batch = iter->Value();
|
||||
// Create device shards
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(device_list_[i], i,
|
||||
row_segments[i], row_segments[i + 1], n_bins_, param_));
|
||||
shard->Init(hmat_, batch);
|
||||
});
|
||||
CHECK(!iter->Next()) << "External memory not supported";
|
||||
}
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
});
|
||||
monitor_.Stop("BinningCompression", device_list_);
|
||||
|
||||
CHECK(!iter->Next()) << "External memory not supported";
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
}
|
||||
@ -1017,9 +1003,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
// Temporarily store number of threads so we can change it back later
|
||||
int nthread = omp_get_max_threads();
|
||||
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor_.Start("InitData", device_list_);
|
||||
|
||||
60
tests/cpp/common/test_gpu_hist_util.cu
Normal file
60
tests/cpp/common/test_gpu_hist_util.cu
Normal file
@ -0,0 +1,60 @@
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(gpu_hist_util, TestDeviceSketch) {
|
||||
// create the data
|
||||
int nrows = 10001;
|
||||
std::vector<float> test_data(nrows);
|
||||
auto count_iter = thrust::make_counting_iterator(0);
|
||||
// fill in reverse order
|
||||
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
|
||||
|
||||
// create the DMatrix
|
||||
DMatrixHandle dmat_handle;
|
||||
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
|
||||
&dmat_handle);
|
||||
auto dmat = *static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
|
||||
|
||||
// parameters for finding quantiles
|
||||
tree::TrainParam p;
|
||||
p.max_bin = 20;
|
||||
p.gpu_id = 0;
|
||||
p.n_gpus = 1;
|
||||
// ensure that the exact quantiles are found
|
||||
p.gpu_batch_nrows = nrows * 10;
|
||||
|
||||
// find quantiles on the CPU
|
||||
HistCutMatrix hmat_cpu;
|
||||
hmat_cpu.Init(dmat.get(), p.max_bin);
|
||||
|
||||
// find the cuts on the GPU
|
||||
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
const SparsePage& batch = iter->Value();
|
||||
HistCutMatrix hmat_gpu;
|
||||
DeviceSketch(batch, dmat->Info(), p, &hmat_gpu);
|
||||
CHECK(!iter->Next());
|
||||
|
||||
// compare the cuts
|
||||
double eps = 1e-2;
|
||||
ASSERT_EQ(hmat_gpu.min_val.size(), 1);
|
||||
ASSERT_EQ(hmat_gpu.row_ptr.size(), 2);
|
||||
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
|
||||
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
|
||||
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
|
||||
ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -30,8 +30,9 @@ TEST(gpu_hist_experimental, TestSparseShard) {
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
const SparsePage& batch = iter->Value();
|
||||
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
|
||||
shard.Init(hmat, batch);
|
||||
DeviceShard shard(0, 0, 0, rows, p);
|
||||
shard.InitRowPtrs(batch);
|
||||
shard.InitCompressedData(hmat, batch);
|
||||
CHECK(!iter->Next());
|
||||
|
||||
ASSERT_LT(shard.row_stride, columns);
|
||||
@ -72,8 +73,9 @@ TEST(gpu_hist_experimental, TestDenseShard) {
|
||||
CHECK(iter->Next());
|
||||
const SparsePage& batch = iter->Value();
|
||||
|
||||
DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p);
|
||||
shard.Init(hmat, batch);
|
||||
DeviceShard shard(0, 0, 0, rows, p);
|
||||
shard.InitRowPtrs(batch);
|
||||
shard.InitCompressedData(hmat, batch);
|
||||
CHECK(!iter->Next());
|
||||
|
||||
ASSERT_EQ(shard.row_stride, columns);
|
||||
|
||||
@ -7,12 +7,26 @@ import unittest
|
||||
|
||||
|
||||
class TestGPULinear(unittest.TestCase):
|
||||
|
||||
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
|
||||
"Boston External Memory"]
|
||||
|
||||
def test_gpu_coordinate(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
|
||||
'top_k': [10], 'tolerance': [1e-5], 'nthread': [2], 'alpha': [.005, .1], 'lambda': [0.005],
|
||||
'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1]}
|
||||
variable_param = {
|
||||
'booster': ['gblinear'],
|
||||
'updater': ['coord_descent'],
|
||||
'eta': [0.5],
|
||||
'top_k': [10],
|
||||
'tolerance': [1e-5],
|
||||
'nthread': [2],
|
||||
'alpha': [.005, .1],
|
||||
'lambda': [0.005],
|
||||
'coordinate_selection': ['cyclic', 'random', 'greedy'],
|
||||
'n_gpus': [-1]
|
||||
}
|
||||
for param in test_linear.parameter_combinations(variable_param):
|
||||
results = test_linear.run_suite(param, 200, None, scale_features=True)
|
||||
results = test_linear.run_suite(
|
||||
param, 200, self.datasets, scale_features=True)
|
||||
test_linear.assert_regression_result(results, 1e-2)
|
||||
test_linear.assert_classification_result(results)
|
||||
|
||||
@ -11,11 +11,10 @@ from regression_test_utilities import run_suite, parameter_combinations, \
|
||||
def assert_gpu_results(cpu_results, gpu_results):
|
||||
for cpu_res, gpu_res in zip(cpu_results, gpu_results):
|
||||
# Check final eval result roughly equivalent
|
||||
assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-3, 1e-2)
|
||||
|
||||
|
||||
datasets = ["Boston", "Cancer", "Digits", "Sparse regression"]
|
||||
assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-2, 1e-2)
|
||||
|
||||
datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
|
||||
"Sparse regression with weights"]
|
||||
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_gpu_exact(self):
|
||||
|
||||
@ -15,11 +15,16 @@ except ImportError:
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, name, get_dataset, objective, metric, use_external_memory=False):
|
||||
def __init__(self, name, get_dataset, objective, metric,
|
||||
has_weights=False, use_external_memory=False):
|
||||
self.name = name
|
||||
self.objective = objective
|
||||
self.metric = metric
|
||||
self.X, self.y = get_dataset()
|
||||
if has_weights:
|
||||
self.X, self.y, self.w = get_dataset()
|
||||
else:
|
||||
self.X, self.y = get_dataset()
|
||||
self.w = None
|
||||
self.use_external_memory = use_external_memory
|
||||
|
||||
|
||||
@ -49,6 +54,16 @@ def get_sparse():
|
||||
return X, y
|
||||
|
||||
|
||||
def get_sparse_weights():
|
||||
rng = np.random.RandomState(199)
|
||||
n = 10000
|
||||
sparsity = 0.25
|
||||
X, y = datasets.make_regression(n, random_state=rng)
|
||||
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
|
||||
w = np.array([rng.uniform(1, 10) for i in range(n)])
|
||||
return X, y, w
|
||||
|
||||
|
||||
def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
|
||||
param = param_in.copy()
|
||||
param["objective"] = dataset.objective
|
||||
@ -64,9 +79,10 @@ def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
|
||||
if dataset.use_external_memory:
|
||||
np.savetxt('tmptmp_1234.csv', np.hstack((dataset.y.reshape(len(dataset.y), 1), X)),
|
||||
delimiter=',')
|
||||
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
|
||||
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_',
|
||||
weight=dataset.w)
|
||||
else:
|
||||
dtrain = xgb.DMatrix(X, dataset.y)
|
||||
dtrain = xgb.DMatrix(X, dataset.y, weight=dataset.w)
|
||||
|
||||
print("Training on dataset: " + dataset.name, file=sys.stderr)
|
||||
print("Using parameters: " + str(param), file=sys.stderr)
|
||||
@ -112,6 +128,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
|
||||
Dataset("Digits", get_digits, "multi:softmax", "merror"),
|
||||
Dataset("Cancer", get_cancer, "binary:logistic", "error"),
|
||||
Dataset("Sparse regression", get_sparse, "reg:linear", "rmse"),
|
||||
Dataset("Sparse regression with weights", get_sparse_weights,
|
||||
"reg:linear", "rmse", has_weights=True),
|
||||
Dataset("Boston External Memory", get_boston, "reg:linear", "rmse",
|
||||
use_external_memory=True)
|
||||
]
|
||||
|
||||
@ -52,6 +52,10 @@ def assert_classification_result(results):
|
||||
|
||||
|
||||
class TestLinear(unittest.TestCase):
|
||||
|
||||
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
|
||||
"Boston External Memory"]
|
||||
|
||||
def test_coordinate(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
|
||||
@ -60,7 +64,7 @@ class TestLinear(unittest.TestCase):
|
||||
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']
|
||||
}
|
||||
for param in parameter_combinations(variable_param):
|
||||
results = run_suite(param, 200, None, scale_features=True)
|
||||
results = run_suite(param, 200, self.datasets, scale_features=True)
|
||||
assert_regression_result(results, 1e-2)
|
||||
assert_classification_result(results)
|
||||
|
||||
@ -72,6 +76,6 @@ class TestLinear(unittest.TestCase):
|
||||
'feature_selector': ['cyclic', 'shuffle']
|
||||
}
|
||||
for param in parameter_combinations(variable_param):
|
||||
results = run_suite(param, 200, None, True)
|
||||
results = run_suite(param, 200, self.datasets, True)
|
||||
assert_regression_result(results, 1e-2)
|
||||
assert_classification_result(results)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user