From cc6a5a3666bced8c5090d9f9a25c1325b7d2cf57 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Fri, 27 Jul 2018 04:03:16 +0200 Subject: [PATCH] Added finding quantiles on GPU. (#3393) * Added finding quantiles on GPU. - this includes datasets where weights are assigned to data rows - as the quantiles found by the new algorithm are not the same as those found by the old one, test thresholds in tests/python-gpu/test_gpu_updaters.py have been adjusted. * Adjustments and improved testing for finding quantiles on the GPU. - added C++ tests for the DeviceSketch() function - reduced one of the thresholds in test_gpu_updaters.py - adjusted the cuts found by the find_cuts_k kernel --- src/common/compressed_iterator.h | 2 +- src/common/device_helpers.cuh | 44 ++- src/common/hist_util.cc | 24 +- src/common/hist_util.cu | 398 ++++++++++++++++++++++ src/common/hist_util.h | 12 + src/common/quantile.h | 44 ++- src/tree/param.h | 7 + src/tree/updater_gpu_hist.cu | 143 ++++---- tests/cpp/common/test_gpu_hist_util.cu | 60 ++++ tests/cpp/tree/test_gpu_hist.cu | 10 +- tests/python-gpu/test_gpu_linear.py | 22 +- tests/python-gpu/test_gpu_updaters.py | 7 +- tests/python/regression_test_utilities.py | 26 +- tests/python/test_linear.py | 8 +- 14 files changed, 691 insertions(+), 116 deletions(-) create mode 100644 src/common/hist_util.cu create mode 100644 tests/cpp/common/test_gpu_hist_util.cu diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 4b2ee45b6..2d834aae6 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -111,7 +111,7 @@ class CompressedBufferWriter { symbol <<= 7 - ibit_end % 8; for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) { dh::AtomicOrByte(reinterpret_cast(buffer + detail::kPadding), - ibyte, symbol & 0xff); + ibyte, symbol & 0xff); symbol >>= 8; } } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index aa109d9e4..de5155364 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -163,11 +163,41 @@ inline void CheckComputeCapability() { } } - + DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) { atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8)); } +/*! + * \brief Find the strict upper bound for an element in a sorted array + * using binary search. + * \param cuts pointer to the first element of the sorted array + * \param n length of the sorted array + * \param v value for which to find the upper bound + * \return the smallest index i such that v < cuts[i], or n if v is greater or equal + * than all elements of the array +*/ +DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) { + if (n == 0) { + return 0; + } + if (cuts[n - 1] <= v) { + return n; + } + if (cuts[0] > v) { + return 0; + } + int left = 0, right = n - 1; + while (right - left > 1) { + int middle = left + (right - left) / 2; + if (cuts[middle] > v) { + right = middle; + } else { + left = middle; + } + } + return right; +} /* * Range iterator @@ -252,6 +282,18 @@ T1 DivRoundUp(const T1 a, const T2 b) { return static_cast(ceil(static_cast(a) / b)); } +inline void RowSegments(size_t n_rows, size_t n_devices, std::vector* segments) { + segments->push_back(0); + size_t row_begin = 0; + size_t shard_size = DivRoundUp(n_rows, n_devices); + for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) { + size_t row_end = std::min(row_begin + shard_size, n_rows); + segments->push_back(row_end); + row_begin = row_end; + } +} + + template __global__ void LaunchNKernel(size_t begin, size_t end, L lambda) { for (auto i : GridStrideRange(begin, end)) { diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 7c733caee..afd06eca6 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -43,18 +43,28 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { auto tid = static_cast(omp_get_thread_num()); unsigned begin = std::min(nstep * tid, ncol); unsigned end = std::min(nstep * (tid + 1), ncol); - for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*) - size_t ridx = batch.base_rowid + i; - SparsePage::Inst inst = batch[i]; - for (bst_uint j = 0; j < inst.length; ++j) { - if (inst[j].index >= begin && inst[j].index < end) { - sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx)); + // do not iterate if no columns are assigned to the thread + if (begin < end && end <= ncol) { + for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*) + size_t ridx = batch.base_rowid + i; + SparsePage::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + if (inst[j].index >= begin && inst[j].index < end) { + sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx)); + } } } } } } + Init(&sketchs, max_num_bins); +} + +void HistCutMatrix::Init +(std::vector* in_sketchs, uint32_t max_num_bins) { + std::vector& sketchs = *in_sketchs; + constexpr int kFactor = 8; // gather the histogram data rabit::SerializeReducer sreducer; std::vector summary_array; @@ -68,7 +78,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); - this->min_val.resize(info.num_col_); + this->min_val.resize(sketchs.size()); row_ptr.push_back(0); for (size_t fid = 0; fid < summary_array.size(); ++fid) { WXQSketch::SummaryContainer a; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu new file mode 100644 index 000000000..bb4453f57 --- /dev/null +++ b/src/common/hist_util.cu @@ -0,0 +1,398 @@ +/*! + * Copyright 2018 XGBoost contributors + */ + +#include "./hist_util.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../tree/param.h" +#include "./host_device_vector.h" +#include "./device_helpers.cuh" +#include "./quantile.h" + +namespace xgboost { +namespace common { + +using WXQSketch = HistCutMatrix::WXQSketch; + +__global__ void find_cuts_k +(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data, + const float* __restrict__ cum_weights, int nsamples, int ncuts) { + // ncuts < nsamples + int icut = threadIdx.x + blockIdx.x * blockDim.x; + if (icut >= ncuts) + return; + WXQSketch::Entry v; + int isample = 0; + if (icut == 0) { + isample = 0; + } else if (icut == ncuts - 1) { + isample = nsamples - 1; + } else { + bst_float rank = cum_weights[nsamples - 1] / static_cast(ncuts - 1) + * static_cast(icut); + // -1 is used because cum_weights is an inclusive sum + isample = dh::UpperBound(cum_weights, nsamples, rank); + isample = max(0, min(isample, nsamples - 1)); + } + // repeated values will be filtered out on the CPU + bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0; + bst_float rmax = cum_weights[isample]; + cuts[icut] = WXQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]); +} + +// predictate for thrust filtering that returns true if the element is not a NaN +struct IsNotNaN { + __device__ bool operator()(float a) const { return !isnan(a); } +}; + +__global__ void unpack_features_k +(float* __restrict__ fvalues, float* __restrict__ feature_weights, + const size_t* __restrict__ row_ptrs, const float* __restrict__ weights, + Entry* entries, size_t nrows_array, int ncols, size_t row_begin_ptr, + size_t nrows) { + size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x; + if (irow >= nrows) { + return; + } + size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow]; + int icol = threadIdx.y + blockIdx.y * blockDim.y; + if (icol >= row_length) { + return; + } + Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol]; + size_t ind = entry.index * nrows_array + irow; + // if weights are present, ensure that a non-NaN value is written to weights + // if and only if it is also written to features + if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) { + fvalues[ind] = entry.fvalue; + if (feature_weights != nullptr) { + feature_weights[ind] = weights[irow]; + } + } +} + +// finds quantiles on the GPU +struct GPUSketcher { + // manage memory for a single GPU + struct DeviceShard { + int device_; + bst_uint row_begin_; // The row offset for this shard + bst_uint row_end_; + bst_uint n_rows_; + int num_cols_{0}; + size_t n_cuts_{0}; + size_t gpu_batch_nrows_{0}; + bool has_weights_{false}; + + tree::TrainParam param_; + std::vector sketches_; + thrust::device_vector row_ptrs_; + std::vector summaries_; + thrust::device_vector entries_; + thrust::device_vector fvalues_; + thrust::device_vector feature_weights_; + thrust::device_vector fvalues_cur_; + thrust::device_vector cuts_d_; + thrust::host_vector cuts_h_; + thrust::device_vector weights_; + thrust::device_vector weights2_; + std::vector n_cuts_cur_; + thrust::device_vector num_elements_; + thrust::device_vector tmp_storage_; + + DeviceShard(int device, bst_uint row_begin, bst_uint row_end, + tree::TrainParam param) : + device_(device), row_begin_(row_begin), row_end_(row_end), + n_rows_(row_end - row_begin), param_(std::move(param)) { + } + + void Init(const SparsePage& row_batch, const MetaInfo& info) { + num_cols_ = info.num_col_; + has_weights_ = info.weights_.size() > 0; + + // find the batch size + if (param_.gpu_batch_nrows == 0) { + // By default, use no more than 1/16th of GPU memory + gpu_batch_nrows_ = dh::TotalMemory(device_) / + (16 * num_cols_ * sizeof(Entry)); + } else if (param_.gpu_batch_nrows == -1) { + gpu_batch_nrows_ = n_rows_; + } else { + gpu_batch_nrows_ = param_.gpu_batch_nrows; + } + if (gpu_batch_nrows_ > n_rows_) { + gpu_batch_nrows_ = n_rows_; + } + + // initialize sketches + sketches_.resize(num_cols_); + summaries_.resize(num_cols_); + constexpr int kFactor = 8; + double eps = 1.0 / (kFactor * param_.max_bin); + size_t dummy_nlevel; + WXQSketch::LimitSizeLevel(row_batch.Size(), eps, &dummy_nlevel, &n_cuts_); + // double ncuts to be the same as the number of values + // in the temporary buffers of the sketches + n_cuts_ *= 2; + for (int icol = 0; icol < num_cols_; ++icol) { + sketches_[icol].Init(row_batch.Size(), eps); + summaries_[icol].Reserve(n_cuts_); + } + + // allocate necessary GPU buffers + dh::safe_cuda(cudaSetDevice(device_)); + + entries_.resize(gpu_batch_nrows_ * num_cols_); + fvalues_.resize(gpu_batch_nrows_ * num_cols_); + fvalues_cur_.resize(gpu_batch_nrows_); + cuts_d_.resize(n_cuts_ * num_cols_); + cuts_h_.resize(n_cuts_ * num_cols_); + weights_.resize(gpu_batch_nrows_); + weights2_.resize(gpu_batch_nrows_); + num_elements_.resize(1); + + if (has_weights_) { + feature_weights_.resize(gpu_batch_nrows_ * num_cols_); + } + n_cuts_cur_.resize(num_cols_); + + // allocate storage for CUB algorithms; the size is the maximum of the sizes + // required for various algorithm + size_t tmp_size = 0, cur_tmp_size = 0; + // size for sorting + if (has_weights_) { + cub::DeviceRadixSort::SortPairs + (nullptr, cur_tmp_size, fvalues_cur_.data().get(), + fvalues_.data().get(), weights_.data().get(), weights2_.data().get(), + gpu_batch_nrows_); + } else { + cub::DeviceRadixSort::SortKeys + (nullptr, cur_tmp_size, fvalues_cur_.data().get(), fvalues_.data().get(), + gpu_batch_nrows_); + } + tmp_size = std::max(tmp_size, cur_tmp_size); + // size for inclusive scan + if (has_weights_) { + cub::DeviceScan::InclusiveSum + (nullptr, cur_tmp_size, weights2_.begin(), weights_.begin(), gpu_batch_nrows_); + tmp_size = std::max(tmp_size, cur_tmp_size); + } + // size for reduction by key + cub::DeviceReduce::ReduceByKey + (nullptr, cur_tmp_size, fvalues_.begin(), + fvalues_cur_.begin(), weights_.begin(), weights2_.begin(), + num_elements_.begin(), thrust::maximum(), gpu_batch_nrows_); + tmp_size = std::max(tmp_size, cur_tmp_size); + // size for filtering + cub::DeviceSelect::If + (nullptr, cur_tmp_size, fvalues_.begin(), fvalues_cur_.begin(), + num_elements_.begin(), gpu_batch_nrows_, IsNotNaN()); + tmp_size = std::max(tmp_size, cur_tmp_size); + + tmp_storage_.resize(tmp_size); + } + + void FindColumnCuts(size_t batch_nrows, size_t icol) { + size_t tmp_size = tmp_storage_.size(); + // filter out NaNs in feature values + auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_; + cub::DeviceSelect::If + (tmp_storage_.data().get(), tmp_size, fvalues_begin, + fvalues_cur_.data(), num_elements_.begin(), batch_nrows, IsNotNaN()); + size_t nfvalues_cur = 0; + thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur); + + // compute cumulative weights using a prefix scan + if (has_weights_) { + // filter out NaNs in weights; + // since cub::DeviceSelect::If performs stable filtering, + // the weights are stored in the correct positions + auto feature_weights_begin = feature_weights_.data() + + icol * gpu_batch_nrows_; + cub::DeviceSelect::If + (tmp_storage_.data().get(), tmp_size, feature_weights_begin, + weights_.data().get(), num_elements_.begin(), batch_nrows, IsNotNaN()); + + // sort the values and weights + cub::DeviceRadixSort::SortPairs + (tmp_storage_.data().get(), tmp_size, fvalues_cur_.data().get(), + fvalues_begin.get(), weights_.data().get(), weights2_.data().get(), + nfvalues_cur); + + // sum the weights to get cumulative weight values + cub::DeviceScan::InclusiveSum + (tmp_storage_.data().get(), tmp_size, weights2_.begin(), + weights_.begin(), nfvalues_cur); + } else { + // sort the batch values + cub::DeviceRadixSort::SortKeys + (tmp_storage_.data().get(), tmp_size, + fvalues_cur_.data().get(), fvalues_begin.get(), nfvalues_cur); + + // fill in cumulative weights with counting iterator + thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, + weights_.begin()); + } + + // remove repeated items and sum the weights across them; + // non-negative weights are assumed + cub::DeviceReduce::ReduceByKey + (tmp_storage_.data().get(), tmp_size, fvalues_begin, + fvalues_cur_.begin(), weights_.begin(), weights2_.begin(), + num_elements_.begin(), thrust::maximum(), nfvalues_cur); + size_t n_unique = 0; + thrust::copy_n(num_elements_.begin(), 1, &n_unique); + + // extract cuts + n_cuts_cur_[icol] = std::min(n_cuts_, n_unique); + // if less elements than cuts: copy all elements with their weights + if (n_cuts_ > n_unique) { + auto weights2_iter = weights2_.begin(); + auto fvalues_iter = fvalues_cur_.begin(); + auto cuts_iter = cuts_d_.begin() + icol * n_cuts_; + dh::LaunchN(device_, n_unique, [=]__device__(size_t i) { + bst_float rmax = weights2_iter[i]; + bst_float rmin = i > 0 ? weights2_iter[i - 1] : 0; + cuts_iter[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_iter[i]); + }); + } else if (n_cuts_cur_[icol] > 0) { + // if more elements than cuts: use binary search on cumulative weights + int block = 256; + find_cuts_k<<>> + (cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(), + weights2_.data().get(), n_unique, n_cuts_cur_[icol]); + dh::safe_cuda(cudaGetLastError()); + } + } + + void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, + size_t gpu_batch) { + // compute start and end indices + size_t batch_row_begin = gpu_batch * gpu_batch_nrows_; + size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_, + static_cast(n_rows_)); + size_t batch_nrows = batch_row_end - batch_row_begin; + size_t n_entries = + row_batch.offset[row_begin_ + batch_row_end] - + row_batch.offset[row_begin_ + batch_row_begin]; + // copy the batch to the GPU + dh::safe_cuda + (cudaMemcpy(entries_.data().get(), + &row_batch.data[row_batch.offset[row_begin_ + batch_row_begin]], + n_entries * sizeof(Entry), cudaMemcpyDefault)); + // copy the weights if necessary + if (has_weights_) { + dh::safe_cuda + (cudaMemcpy(weights_.data().get(), + info.weights_.data() + row_begin_ + batch_row_begin, + batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); + } + + // unpack the features; also unpack weights if present + thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); + thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); + + dim3 block3(64, 4, 1); + dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x), + dh::DivRoundUp(num_cols_, block3.y), 1); + unpack_features_k<<>> + (fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr, + row_ptrs_.data().get() + batch_row_begin, + has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), + gpu_batch_nrows_, num_cols_, + row_batch.offset[row_begin_ + batch_row_begin], batch_nrows); + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); + + for (int icol = 0; icol < num_cols_; ++icol) { + FindColumnCuts(batch_nrows, icol); + } + + dh::safe_cuda(cudaDeviceSynchronize()); + + // add cuts into sketches + thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); + for (int icol = 0; icol < num_cols_; ++icol) { + summaries_[icol].MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); + sketches_[icol].PushSummary(summaries_[icol]); + } + } + + void Sketch(const SparsePage& row_batch, const MetaInfo& info) { + // copy rows to the device + dh::safe_cuda(cudaSetDevice(device_)); + row_ptrs_.resize(n_rows_ + 1); + thrust::copy(row_batch.offset.data() + row_begin_, + row_batch.offset.data() + row_end_ + 1, + row_ptrs_.begin()); + + size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_); + + for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { + SketchBatch(row_batch, info, gpu_batch); + } + } + }; + + void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) { + // partition input matrix into row segments + std::vector row_segments; + dh::RowSegments(info.num_row_, devices_.Size(), &row_segments); + + // create device shards + shards_.resize(devices_.Size()); + dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { + shard = std::unique_ptr + (new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_)); + }); + + // compute sketches for each shard + dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + shard->Init(batch, info); + shard->Sketch(batch, info); + }); + + // merge the sketches from all shards + // TODO(canonizer): do it in a tree-like reduction + int num_cols = info.num_col_; + std::vector sketches(num_cols); + WXQSketch::SummaryContainer summary; + for (int icol = 0; icol < num_cols; ++icol) { + sketches[icol].Init(batch.Size(), 1.0 / (8 * param_.max_bin)); + for (int shard = 0; shard < shards_.size(); ++shard) { + shards_[shard]->sketches_[icol].GetSummary(&summary); + sketches[icol].PushSummary(summary); + } + } + + hmat->Init(&sketches, param_.max_bin); + } + + GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) { + devices_ = GPUSet::Range(param_.gpu_id, dh::NDevices(param_.n_gpus, n_rows)); + } + + std::vector> shards_; + tree::TrainParam param_; + GPUSet devices_; +}; + +void DeviceSketch + (const SparsePage& batch, const MetaInfo& info, + const tree::TrainParam& param, HistCutMatrix* hmat) { + GPUSketcher sketcher(param, info.num_row_); + sketcher.Sketch(batch, info, hmat); +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index a416d87fa..bc5eaeb58 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -12,8 +12,11 @@ #include #include "row_set.h" #include "../tree/fast_hist_param.h" +#include "../tree/param.h" +#include "./quantile.h" namespace xgboost { + namespace common { using tree::FastHistParam; @@ -77,11 +80,20 @@ struct HistCutMatrix { return {dmlc::BeginPtr(cut) + row_ptr[fid], row_ptr[fid + 1] - row_ptr[fid]}; } + + using WXQSketch = common::WXQuantileSketch; + // create histogram cut matrix given statistics from data // using approximate quantile sketch approach void Init(DMatrix* p_fmat, uint32_t max_num_bins); + + void Init(std::vector* sketchs, uint32_t max_num_bins); }; +/*! \brief Builds the cut matrix on the GPU */ +void DeviceSketch + (const SparsePage& batch, const MetaInfo& info, + const tree::TrainParam& param, HistCutMatrix* hmat); /*! * \brief A single row in global histogram index. diff --git a/src/common/quantile.h b/src/common/quantile.h index 9372581a9..9ad8aa253 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -35,9 +35,9 @@ struct WQSummary { /*! \brief the value of data */ DType value; // constructor - Entry() = default; + XGBOOST_DEVICE Entry() {} // NOLINT // constructor - Entry(RType rmin, RType rmax, RType wmin, DType value) + XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value) : rmin(rmin), rmax(rmax), wmin(wmin), value(value) {} /*! * \brief debug function, check Valid @@ -48,11 +48,11 @@ struct WQSummary { CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max"; } /*! \return rmin estimation for v strictly bigger than value */ - inline RType RMinNext() const { + XGBOOST_DEVICE inline RType RMinNext() const { return rmin + wmin; } /*! \return rmax estimation for v strictly smaller than value */ - inline RType RMaxPrev() const { + XGBOOST_DEVICE inline RType RMaxPrev() const { return rmax - wmin; } }; @@ -158,6 +158,17 @@ struct WQSummary { size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } + inline void MakeFromSorted(const Entry* entries, size_t n) { + size = 0; + for (size_t i = 0; i < n;) { + size_t j = i + 1; + // ignore repeated values + for (; j < n && entries[j].value == entries[i].value; ++j) {} + data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin, + entries[i].value); + i = j; + } + } /*! * \brief debug function, validate whether the summary * run consistency check to check if it is a valid summary @@ -676,6 +687,18 @@ class QuantileSketchTemplate { * \param eps accuracy level of summary */ inline void Init(size_t maxn, double eps) { + LimitSizeLevel(maxn, eps, &nlevel, &limit_size); + // lazy reserve the space, if there is only one value, no need to allocate space + inqueue.queue.resize(1); + inqueue.qtail = 0; + data.clear(); + level.clear(); + } + + inline static void LimitSizeLevel + (size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) { + size_t& nlevel = *out_nlevel; + size_t& limit_size = *out_limit_size; nlevel = 1; while (true) { limit_size = static_cast(ceil(nlevel / eps)) + 1; @@ -687,12 +710,8 @@ class QuantileSketchTemplate { size_t n = (1ULL << nlevel); CHECK(n * limit_size >= maxn) << "invalid init parameter"; CHECK(nlevel <= limit_size * eps) << "invalid init parameter"; - // lazy reserve the space, if there is only one value, no need to allocate space - inqueue.queue.resize(1); - inqueue.qtail = 0; - data.clear(); - level.clear(); } + /*! * \brief add an element to a sketch * \param x The element added to the sketch @@ -714,6 +733,13 @@ class QuantileSketchTemplate { } inqueue.Push(x, w); } + + inline void PushSummary(const Summary& summary) { + temp.Reserve(limit_size * 2); + temp.SetPrune(summary, limit_size * 2); + PushTemp(); + } + /*! \brief push up temp */ inline void PushTemp() { temp.Reserve(limit_size * 2); diff --git a/src/tree/param.h b/src/tree/param.h index 5621c3e8d..43d653e9b 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -77,6 +77,8 @@ struct TrainParam : public dmlc::Parameter { int gpu_id; // number of GPUs to use int n_gpus; + // number of rows in a single GPU batch + int gpu_batch_nrows; // the criteria to use for ranking splits std::string split_evaluator; // declare the parameters @@ -186,6 +188,11 @@ struct TrainParam : public dmlc::Parameter { .set_lower_bound(-1) .set_default(1) .describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs"); + DMLC_DECLARE_FIELD(gpu_batch_nrows) + .set_lower_bound(-1) + .set_default(0) + .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; " + "-1 to use all rows assignted to a GPU, and 0 to auto-deduce"); DMLC_DECLARE_FIELD(split_evaluator) .set_default("elastic_net,monotonic") .describe("The criteria to use for ranking splits"); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e096a07c8..17b775937 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,7 +1,7 @@ /*! * Copyright 2017 XGBoost contributors */ -#include +#include #include #include #include @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -227,26 +228,6 @@ struct CalcWeightTrainParam { learning_rate(p.learning_rate) {} }; -// index of the first element in cuts greater than v, or n if none; -// cuts are ordered, and binary search is used -__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) { - if (n == 0) - return 0; - if (cuts[n - 1] <= v) - return n; - if (cuts[0] > v) - return 0; - int left = 0, right = n - 1; - while (right - left > 1) { - int middle = left + (right - left) / 2; - if (cuts[middle] > v) - right = middle; - else - left = middle; - } - return right; -} - __global__ void compress_bin_ellpack_k (common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer, const size_t* __restrict__ row_ptrs, @@ -266,7 +247,7 @@ __global__ void compress_bin_ellpack_k float fvalue = entry.fvalue; const float *feature_cuts = &cuts[cut_rows[feature]]; int ncuts = cut_rows[feature + 1] - cut_rows[feature]; - bin = upper_bound(feature_cuts, ncuts, fvalue); + bin = dh::UpperBound(feature_cuts, ncuts, fvalue); if (bin >= ncuts) bin = ncuts - 1; bin += cut_rows[feature]; @@ -330,6 +311,7 @@ struct DeviceShard { dh::DVec prediction_cache; std::vector node_sum_gradients; dh::DVec node_sum_gradients_d; + thrust::device_vector row_ptrs; common::CompressedIterator gidx; size_t row_stride; bst_uint row_begin_idx; // The row offset for this shard @@ -348,41 +330,51 @@ struct DeviceShard { dh::CubMemory temp_memory; + // TODO(canonizer): do add support multi-batch DMatrix here DeviceShard(int device_idx, int normalised_device_idx, - bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param) + bst_uint row_begin, bst_uint row_end, TrainParam param) : device_idx(device_idx), normalised_device_idx(normalised_device_idx), row_begin_idx(row_begin), row_end_idx(row_end), + row_stride(0), n_rows(row_end - row_begin), - n_bins(n_bins), - null_gidx_value(n_bins), + n_bins(0), + null_gidx_value(0), param(param), prediction_cache_initialised(false), can_use_smem_atomics(false) {} - void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) { - // copy cuts to the GPU + void InitRowPtrs(const SparsePage& row_batch) { dh::safe_cuda(cudaSetDevice(device_idx)); - thrust::device_vector cuts_d(hmat.cut); - thrust::device_vector cut_row_ptrs_d(hmat.row_ptr); - - // find the maximum row size - thrust::device_vector row_ptr_d( - row_batch.offset.data() + row_begin_idx, row_batch.offset.data() + row_end_idx + 1); - - auto row_iter = row_ptr_d.begin(); + row_ptrs.resize(n_rows + 1); + thrust::copy(row_batch.offset.data() + row_begin_idx, + row_batch.offset.data() + row_end_idx + 1, + row_ptrs.begin()); + auto row_iter = row_ptrs.begin(); auto get_size = [=] __device__(size_t row) { return row_iter[row + 1] - row_iter[row]; }; // NOLINT + auto counting = thrust::make_counting_iterator(size_t(0)); using TransformT = thrust::transform_iterator; TransformT row_size_iter = TransformT(counting, get_size); row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0, - thrust::maximum()); - int num_symbols = - n_bins + 1; + thrust::maximum()); + } + + void InitCompressedData(const common::HistCutMatrix& hmat, const SparsePage& row_batch) { + n_bins = hmat.row_ptr.back(); + null_gidx_value = hmat.row_ptr.back(); + + // copy cuts to the GPU + dh::safe_cuda(cudaSetDevice(device_idx)); + thrust::device_vector cuts_d(hmat.cut); + thrust::device_vector cut_row_ptrs_d(hmat.row_ptr); + + // allocate compressed bin data + int num_symbols = n_bins + 1; size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols); @@ -391,17 +383,17 @@ struct DeviceShard { << "Max leaves and max depth cannot both be unconstrained for " "gpu_hist."; ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes); - gidx_buffer.Fill(0); + int nbits = common::detail::SymbolBits(num_symbols); + // bin and compress entries in batches of rows - // use no more than 1/16th of GPU memory per batch - size_t gpu_batch_nrows = dh::TotalMemory(device_idx) / - (16 * row_stride * sizeof(Entry)); - if (gpu_batch_nrows > n_rows) { - gpu_batch_nrows = n_rows; - } + size_t gpu_batch_nrows = std::min + (dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)), + static_cast(n_rows)); + thrust::device_vector entries_d(gpu_batch_nrows * row_stride); + size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows); for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { size_t batch_row_begin = gpu_batch * gpu_batch_nrows; @@ -423,7 +415,7 @@ struct DeviceShard { dh::DivRoundUp(row_stride, block3.y), 1); compress_bin_ellpack_k<<>> (common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(), - row_ptr_d.data().get() + batch_row_begin, + row_ptrs.data().get() + batch_row_begin, entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(), batch_row_begin, batch_nrows, row_batch.offset[row_begin_idx + batch_row_begin], @@ -434,8 +426,8 @@ struct DeviceShard { } // free the memory that is no longer needed - row_ptr_d.resize(0); - row_ptr_d.shrink_to_fit(); + row_ptrs.resize(0); + row_ptrs.shrink_to_fit(); entries_d.resize(0); entries_d.shrink_to_fit(); @@ -741,17 +733,9 @@ class GPUHistMaker : public TreeUpdater { void InitDataOnce(DMatrix* dmat) { info_ = &dmat->Info(); - monitor_.Start("Quantiles", device_list_); - hmat_.Init(dmat, param_.max_bin); - monitor_.Stop("Quantiles", device_list_); - n_bins_ = hmat_.row_ptr.back(); int n_devices = dh::NDevices(param_.n_gpus, info_->num_row_); - bst_uint row_begin = 0; - bst_uint shard_size = - std::ceil(static_cast(info_->num_row_) / n_devices); - device_list_.resize(n_devices); for (int d_idx = 0; d_idx < n_devices; ++d_idx) { int device_idx = (param_.gpu_id + d_idx) % dh::NVisibleDevices(); @@ -762,32 +746,34 @@ class GPUHistMaker : public TreeUpdater { // Partition input matrix into row segments std::vector row_segments; + dh::RowSegments(info_->num_row_, n_devices, &row_segments); + + dmlc::DataIter* iter = dmat->RowIterator(); + iter->BeforeFirst(); + CHECK(iter->Next()) << "Empty batches are not supported"; + const SparsePage& batch = iter->Value(); + // Create device shards shards_.resize(n_devices); - row_segments.push_back(0); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - bst_uint row_end = - std::min(static_cast(row_begin + shard_size), info_->num_row_); - row_segments.push_back(row_end); - row_begin = row_end; - } + dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { + shard = std::unique_ptr + (new DeviceShard(device_list_[i], i, + row_segments[i], row_segments[i + 1], param_)); + shard->InitRowPtrs(batch); + }); + + monitor_.Start("Quantiles", device_list_); + common::DeviceSketch(batch, *info_, param_, &hmat_); + n_bins_ = hmat_.row_ptr.back(); + monitor_.Stop("Quantiles", device_list_); monitor_.Start("BinningCompression", device_list_); - { - dmlc::DataIter* iter = dmat->RowIterator(); - iter->BeforeFirst(); - CHECK(iter->Next()) << "Empty batches are not supported"; - const SparsePage& batch = iter->Value(); - // Create device shards - dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - shard = std::unique_ptr - (new DeviceShard(device_list_[i], i, - row_segments[i], row_segments[i + 1], n_bins_, param_)); - shard->Init(hmat_, batch); - }); - CHECK(!iter->Next()) << "External memory not supported"; - } + dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + shard->InitCompressedData(hmat_, batch); + }); monitor_.Stop("BinningCompression", device_list_); + CHECK(!iter->Next()) << "External memory not supported"; + p_last_fmat_ = dmat; initialised_ = true; } @@ -1017,9 +1003,6 @@ class GPUHistMaker : public TreeUpdater { void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { - // Temporarily store number of threads so we can change it back later - int nthread = omp_get_max_threads(); - auto& tree = *p_tree; monitor_.Start("InitData", device_list_); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu new file mode 100644 index 000000000..8d0117e72 --- /dev/null +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -0,0 +1,60 @@ +#include "../../../src/common/device_helpers.cuh" +#include "../../../src/common/hist_util.h" +#include "gtest/gtest.h" +#include "xgboost/c_api.h" +#include +#include +#include +#include + +namespace xgboost { +namespace common { + +TEST(gpu_hist_util, TestDeviceSketch) { + // create the data + int nrows = 10001; + std::vector test_data(nrows); + auto count_iter = thrust::make_counting_iterator(0); + // fill in reverse order + std::copy(count_iter, count_iter + nrows, test_data.rbegin()); + + // create the DMatrix + DMatrixHandle dmat_handle; + XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, + &dmat_handle); + auto dmat = *static_cast *>(dmat_handle); + + // parameters for finding quantiles + tree::TrainParam p; + p.max_bin = 20; + p.gpu_id = 0; + p.n_gpus = 1; + // ensure that the exact quantiles are found + p.gpu_batch_nrows = nrows * 10; + + // find quantiles on the CPU + HistCutMatrix hmat_cpu; + hmat_cpu.Init(dmat.get(), p.max_bin); + + // find the cuts on the GPU + dmlc::DataIter* iter = dmat->RowIterator(); + iter->BeforeFirst(); + CHECK(iter->Next()); + const SparsePage& batch = iter->Value(); + HistCutMatrix hmat_gpu; + DeviceSketch(batch, dmat->Info(), p, &hmat_gpu); + CHECK(!iter->Next()); + + // compare the cuts + double eps = 1e-2; + ASSERT_EQ(hmat_gpu.min_val.size(), 1); + ASSERT_EQ(hmat_gpu.row_ptr.size(), 2); + ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size()); + ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows); + for (int i = 0; i < hmat_gpu.cut.size(); ++i) { + ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows); + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 4a33ad02b..2c2022d10 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -30,8 +30,9 @@ TEST(gpu_hist_experimental, TestSparseShard) { iter->BeforeFirst(); CHECK(iter->Next()); const SparsePage& batch = iter->Value(); - DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p); - shard.Init(hmat, batch); + DeviceShard shard(0, 0, 0, rows, p); + shard.InitRowPtrs(batch); + shard.InitCompressedData(hmat, batch); CHECK(!iter->Next()); ASSERT_LT(shard.row_stride, columns); @@ -72,8 +73,9 @@ TEST(gpu_hist_experimental, TestDenseShard) { CHECK(iter->Next()); const SparsePage& batch = iter->Value(); - DeviceShard shard(0, 0, 0, rows, hmat.row_ptr.back(), p); - shard.Init(hmat, batch); + DeviceShard shard(0, 0, 0, rows, p); + shard.InitRowPtrs(batch); + shard.InitCompressedData(hmat, batch); CHECK(!iter->Next()); ASSERT_EQ(shard.row_stride, columns); diff --git a/tests/python-gpu/test_gpu_linear.py b/tests/python-gpu/test_gpu_linear.py index a56230463..25b042a37 100644 --- a/tests/python-gpu/test_gpu_linear.py +++ b/tests/python-gpu/test_gpu_linear.py @@ -7,12 +7,26 @@ import unittest class TestGPULinear(unittest.TestCase): + + datasets = ["Boston", "Digits", "Cancer", "Sparse regression", + "Boston External Memory"] + def test_gpu_coordinate(self): tm._skip_if_no_sklearn() - variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5], - 'top_k': [10], 'tolerance': [1e-5], 'nthread': [2], 'alpha': [.005, .1], 'lambda': [0.005], - 'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1]} + variable_param = { + 'booster': ['gblinear'], + 'updater': ['coord_descent'], + 'eta': [0.5], + 'top_k': [10], + 'tolerance': [1e-5], + 'nthread': [2], + 'alpha': [.005, .1], + 'lambda': [0.005], + 'coordinate_selection': ['cyclic', 'random', 'greedy'], + 'n_gpus': [-1] + } for param in test_linear.parameter_combinations(variable_param): - results = test_linear.run_suite(param, 200, None, scale_features=True) + results = test_linear.run_suite( + param, 200, self.datasets, scale_features=True) test_linear.assert_regression_result(results, 1e-2) test_linear.assert_classification_result(results) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index df0fa9958..2fdbfaee4 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -11,11 +11,10 @@ from regression_test_utilities import run_suite, parameter_combinations, \ def assert_gpu_results(cpu_results, gpu_results): for cpu_res, gpu_res in zip(cpu_results, gpu_results): # Check final eval result roughly equivalent - assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-3, 1e-2) - - -datasets = ["Boston", "Cancer", "Digits", "Sparse regression"] + assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-2, 1e-2) +datasets = ["Boston", "Cancer", "Digits", "Sparse regression", + "Sparse regression with weights"] class TestGPU(unittest.TestCase): def test_gpu_exact(self): diff --git a/tests/python/regression_test_utilities.py b/tests/python/regression_test_utilities.py index bd5192b55..6918ed6a7 100644 --- a/tests/python/regression_test_utilities.py +++ b/tests/python/regression_test_utilities.py @@ -15,11 +15,16 @@ except ImportError: class Dataset: - def __init__(self, name, get_dataset, objective, metric, use_external_memory=False): + def __init__(self, name, get_dataset, objective, metric, + has_weights=False, use_external_memory=False): self.name = name self.objective = objective self.metric = metric - self.X, self.y = get_dataset() + if has_weights: + self.X, self.y, self.w = get_dataset() + else: + self.X, self.y = get_dataset() + self.w = None self.use_external_memory = use_external_memory @@ -49,6 +54,16 @@ def get_sparse(): return X, y +def get_sparse_weights(): + rng = np.random.RandomState(199) + n = 10000 + sparsity = 0.25 + X, y = datasets.make_regression(n, random_state=rng) + X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X]) + w = np.array([rng.uniform(1, 10) for i in range(n)]) + return X, y, w + + def train_dataset(dataset, param_in, num_rounds=10, scale_features=False): param = param_in.copy() param["objective"] = dataset.objective @@ -64,9 +79,10 @@ def train_dataset(dataset, param_in, num_rounds=10, scale_features=False): if dataset.use_external_memory: np.savetxt('tmptmp_1234.csv', np.hstack((dataset.y.reshape(len(dataset.y), 1), X)), delimiter=',') - dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_') + dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_', + weight=dataset.w) else: - dtrain = xgb.DMatrix(X, dataset.y) + dtrain = xgb.DMatrix(X, dataset.y, weight=dataset.w) print("Training on dataset: " + dataset.name, file=sys.stderr) print("Using parameters: " + str(param), file=sys.stderr) @@ -112,6 +128,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False): Dataset("Digits", get_digits, "multi:softmax", "merror"), Dataset("Cancer", get_cancer, "binary:logistic", "error"), Dataset("Sparse regression", get_sparse, "reg:linear", "rmse"), + Dataset("Sparse regression with weights", get_sparse_weights, + "reg:linear", "rmse", has_weights=True), Dataset("Boston External Memory", get_boston, "reg:linear", "rmse", use_external_memory=True) ] diff --git a/tests/python/test_linear.py b/tests/python/test_linear.py index a20e32724..5cfd42687 100644 --- a/tests/python/test_linear.py +++ b/tests/python/test_linear.py @@ -52,6 +52,10 @@ def assert_classification_result(results): class TestLinear(unittest.TestCase): + + datasets = ["Boston", "Digits", "Cancer", "Sparse regression", + "Boston External Memory"] + def test_coordinate(self): tm._skip_if_no_sklearn() variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5], @@ -60,7 +64,7 @@ class TestLinear(unittest.TestCase): 'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty'] } for param in parameter_combinations(variable_param): - results = run_suite(param, 200, None, scale_features=True) + results = run_suite(param, 200, self.datasets, scale_features=True) assert_regression_result(results, 1e-2) assert_classification_result(results) @@ -72,6 +76,6 @@ class TestLinear(unittest.TestCase): 'feature_selector': ['cyclic', 'shuffle'] } for param in parameter_combinations(variable_param): - results = run_suite(param, 200, None, True) + results = run_suite(param, 200, self.datasets, True) assert_regression_result(results, 1e-2) assert_classification_result(results)