Sketching from adapters (#5365)

* Sketching from adapters

* Add weights test
This commit is contained in:
Rory Mitchell 2020-03-07 21:07:58 +13:00 committed by GitHub
parent 0dd97c206b
commit a38e7bd19c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 780 additions and 624 deletions

View File

@ -9,85 +9,29 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <utility>
#include <vector>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "hist_util.h"
#include "xgboost/host_device_vector.h"
#include "../data/adapter.h"
#include "../data/device_adapter.cuh"
#include "device_helpers.cuh"
#include "hist_util.h"
#include "math.h" // NOLINT
#include "quantile.h"
#include "../tree/param.h"
#include "xgboost/host_device_vector.h"
namespace xgboost {
namespace common {
using WQSketch = DenseCuts::WQSketch;
__global__ void FindCutsK(WQSketch::Entry* __restrict__ cuts,
const bst_float* __restrict__ data,
const float* __restrict__ cum_weights,
int nsamples,
int ncuts) {
// ncuts < nsamples
int icut = threadIdx.x + blockIdx.x * blockDim.x;
if (icut >= ncuts) {
return;
}
int isample = 0;
if (icut == 0) {
isample = 0;
} else if (icut == ncuts - 1) {
isample = nsamples - 1;
} else {
bst_float rank = cum_weights[nsamples - 1] / static_cast<float>(ncuts - 1)
* static_cast<float>(icut);
// -1 is used because cum_weights is an inclusive sum
isample = dh::UpperBound(cum_weights, nsamples, rank);
isample = max(0, min(isample, nsamples - 1));
}
// repeated values will be filtered out on the CPU
bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0;
bst_float rmax = cum_weights[isample];
cuts[icut] = WQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
}
// predictate for thrust filtering that returns true if the element is not a NaN
struct IsNotNaN {
__device__ bool operator()(float a) const { return !isnan(a); }
};
__global__ void UnpackFeaturesK(float* __restrict__ fvalues,
float* __restrict__ feature_weights,
const size_t* __restrict__ row_ptrs,
const float* __restrict__ weights,
Entry* entries,
size_t nrows_array,
size_t row_begin_ptr,
size_t nrows) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
if (irow >= nrows) {
return;
}
size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow];
int icol = threadIdx.y + blockIdx.y * blockDim.y;
if (icol >= row_length) {
return;
}
Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol];
size_t ind = entry.index * nrows_array + irow;
// if weights are present, ensure that a non-NaN value is written to weights
// if and only if it is also written to features
if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) {
fvalues[ind] = entry.fvalue;
if (feature_weights != nullptr && weights != nullptr) {
feature_weights[ind] = weights[irow];
}
}
}
using SketchEntry = WQSketch::Entry;
/*!
* \brief A container that holds the device sketches across all
@ -98,379 +42,410 @@ __global__ void UnpackFeaturesK(float* __restrict__ fvalues,
*/
struct SketchContainer {
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
std::vector<std::mutex> col_locks_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
SketchContainer(int max_bin, DMatrix* dmat) : col_locks_(dmat->Info().num_col_) {
const MetaInfo& info = dmat->Info();
SketchContainer(int max_bin, size_t num_columns, size_t num_rows) {
// Initialize Sketches for this dmatrix
sketches_.resize(info.num_col_);
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin));
sketches_.resize(num_columns);
#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_columns; ++icol) { // NOLINT
sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin));
}
}
// Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved
SketchContainer(const SketchContainer &) = delete;
SketchContainer(const SketchContainer &&) = delete;
SketchContainer &operator=(const SketchContainer &) = delete;
SketchContainer &operator=(const SketchContainer &&) = delete;
/**
* \brief Pushes cuts to the sketches.
*
* \param entries_per_column The entries per column.
* \param entries Vector of cuts from all columns, length
* entries_per_column * num_columns. \param column_scan Exclusive scan
* of column sizes. Used to detect cases where there are fewer entries than we
* have storage for.
*/
void Push(size_t entries_per_column,
const thrust::host_vector<SketchEntry>& entries,
const thrust::host_vector<size_t>& column_scan) {
#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < sketches_.size(); ++icol) {
size_t column_size = column_scan[icol + 1] - column_scan[icol];
if (column_size == 0) continue;
WQuantileSketch<bst_float, bst_float>::SummaryContainer summary;
size_t num_available_cuts =
std::min(size_t(entries_per_column), column_size);
summary.Reserve(num_available_cuts);
summary.MakeFromSorted(&entries[entries_per_column * icol],
num_available_cuts);
sketches_[icol].PushSummary(summary);
}
}
// Prevent copying/assigning/moving this as its internals can't be
// assigned/copied/moved
SketchContainer(const SketchContainer&) = delete;
SketchContainer(const SketchContainer&&) = delete;
SketchContainer& operator=(const SketchContainer&) = delete;
SketchContainer& operator=(const SketchContainer&&) = delete;
};
// finds quantiles on the GPU
class GPUSketcher {
public:
GPUSketcher(int device, int max_bin, int gpu_nrows)
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
~GPUSketcher() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_));
struct EntryCompareOp {
__device__ bool operator()(const Entry& a, const Entry& b) {
if (a.index == b.index) {
return a.fvalue < b.fvalue;
}
return a.index < b.index;
}
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
n_rows_ = batch.Size();
Init(batch, info, gpu_batch_nrows_);
Sketch(batch, info);
ComputeRowStride();
}
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
* for the entire dataset */
size_t Sketch(DMatrix *dmat, DenseCuts *hmat) {
const MetaInfo& info = dmat->Info();
row_stride_ = 0;
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
this->SketchBatch(batch, info);
}
hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_);
return row_stride_;
}
// This needs to be public because of the __device__ lambda.
void ComputeRowStride() {
// Find the row stride for this batch
auto row_iter = row_ptrs_.begin();
// Functor for finding the maximum row size for this batch
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size), decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
size_t batch_row_stride =
thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum<size_t>());
row_stride_ = std::max(row_stride_, batch_row_stride);
}
// This needs to be public because of the __device__ lambda.
void FindColumnCuts(size_t batch_nrows, size_t icol) {
size_t tmp_size = tmp_storage_.size();
// filter out NaNs in feature values
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
cub::DeviceSelect::If(tmp_storage_.data().get(),
tmp_size,
fvalues_begin,
fvalues_cur_.data(),
num_elements_.begin(),
batch_nrows,
IsNotNaN());
size_t nfvalues_cur = 0;
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
// compute cumulative weights using a prefix scan
if (has_weights_) {
// filter out NaNs in weights;
// since cub::DeviceSelect::If performs stable filtering,
// the weights are stored in the correct positions
auto feature_weights_begin = feature_weights_.data() + icol * gpu_batch_nrows_;
cub::DeviceSelect::If(tmp_storage_.data().get(),
tmp_size,
feature_weights_begin,
weights_.data().get(),
num_elements_.begin(),
batch_nrows,
IsNotNaN());
// sort the values and weights
cub::DeviceRadixSort::SortPairs(tmp_storage_.data().get(),
tmp_size,
fvalues_cur_.data().get(),
fvalues_begin.get(),
weights_.data().get(),
weights2_.data().get(),
nfvalues_cur);
// sum the weights to get cumulative weight values
cub::DeviceScan::InclusiveSum(tmp_storage_.data().get(),
tmp_size,
weights2_.begin(),
weights_.begin(),
nfvalues_cur);
} else {
// sort the batch values
cub::DeviceRadixSort::SortKeys(tmp_storage_.data().get(),
tmp_size,
fvalues_cur_.data().get(),
fvalues_begin.get(),
nfvalues_cur);
// fill in cumulative weights with counting iterator
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, weights_.begin());
}
// remove repeated items and sum the weights across them;
// non-negative weights are assumed
cub::DeviceReduce::ReduceByKey(tmp_storage_.data().get(),
tmp_size,
fvalues_begin,
fvalues_cur_.begin(),
weights_.begin(),
weights2_.begin(),
num_elements_.begin(),
thrust::maximum<bst_float>(),
nfvalues_cur);
size_t n_unique = 0;
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
// extract cuts
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
// if less elements than cuts: copy all elements with their weights
if (n_cuts_ > n_unique) {
float* weights2_ptr = weights2_.data().get();
float* fvalues_ptr = fvalues_cur_.data().get();
WQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
bst_float rmax = weights2_ptr[i];
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
cuts_ptr[i] = WQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
});
} else if (n_cuts_cur_[icol] > 0) {
// if more elements than cuts: use binary search on cumulative weights
uint32_t constexpr kBlockThreads = 256;
uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads);
dh::LaunchKernel {kGrids, kBlockThreads} (
FindCutsK,
cuts_d_.data().get() + icol * n_cuts_,
fvalues_cur_.data().get(),
weights2_.data().get(),
n_unique,
n_cuts_cur_[icol]);
dh::safe_cuda(cudaGetLastError()); // NOLINT
}
}
private:
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
num_cols_ = info.num_col_;
has_weights_ = info.weights_.Size() > 0;
// find the batch size
if (gpu_batch_nrows == 0) {
// By default, use no more than 1/16th of GPU memory
gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry));
} else if (gpu_batch_nrows == -1) {
gpu_batch_nrows_ = n_rows_;
} else {
gpu_batch_nrows_ = gpu_batch_nrows;
}
if (gpu_batch_nrows_ > n_rows_) {
gpu_batch_nrows_ = n_rows_;
}
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * max_bin_);
size_t dummy_nlevel;
WQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
// allocate necessary GPU buffers
dh::safe_cuda(cudaSetDevice(device_));
entries_.resize(gpu_batch_nrows_ * num_cols_);
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
fvalues_cur_.resize(gpu_batch_nrows_);
cuts_d_.resize(n_cuts_ * num_cols_);
cuts_h_.resize(n_cuts_ * num_cols_);
weights_.resize(gpu_batch_nrows_);
weights2_.resize(gpu_batch_nrows_);
num_elements_.resize(1);
if (has_weights_) {
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
}
n_cuts_cur_.resize(num_cols_);
// allocate storage for CUB algorithms; the size is the maximum of the sizes
// required for various algorithm
size_t tmp_size = 0, cur_tmp_size = 0;
// size for sorting
if (has_weights_) {
cub::DeviceRadixSort::SortPairs(nullptr,
cur_tmp_size,
fvalues_cur_.data().get(),
fvalues_.data().get(),
weights_.data().get(),
weights2_.data().get(),
gpu_batch_nrows_);
} else {
cub::DeviceRadixSort::SortKeys(nullptr,
cur_tmp_size,
fvalues_cur_.data().get(),
fvalues_.data().get(),
gpu_batch_nrows_);
}
tmp_size = std::max(tmp_size, cur_tmp_size);
// size for inclusive scan
if (has_weights_) {
cub::DeviceScan::InclusiveSum(nullptr,
cur_tmp_size,
weights2_.begin(),
weights_.begin(),
gpu_batch_nrows_);
tmp_size = std::max(tmp_size, cur_tmp_size);
}
// size for reduction by key
cub::DeviceReduce::ReduceByKey(nullptr,
cur_tmp_size,
fvalues_.begin(),
fvalues_cur_.begin(),
weights_.begin(),
weights2_.begin(),
num_elements_.begin(),
thrust::maximum<bst_float>(),
gpu_batch_nrows_);
tmp_size = std::max(tmp_size, cur_tmp_size);
// size for filtering
cub::DeviceSelect::If(nullptr,
cur_tmp_size,
fvalues_.begin(),
fvalues_cur_.begin(),
num_elements_.begin(),
gpu_batch_nrows_,
IsNotNaN());
tmp_size = std::max(tmp_size, cur_tmp_size);
tmp_storage_.resize(tmp_size);
}
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
// copy rows to the device
dh::safe_cuda(cudaSetDevice(device_));
const auto& offset_vec = row_batch.offset.HostVector();
row_ptrs_.resize(n_rows_ + 1);
thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
SketchBatch(row_batch, info, gpu_batch);
}
}
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, size_t gpu_batch) {
// compute start and end indices
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
static_cast<size_t>(n_rows_));
size_t batch_nrows = batch_row_end - batch_row_begin;
const auto& offset_vec = row_batch.offset.HostVector();
const auto& data_vec = row_batch.data.HostVector();
size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
// copy the batch to the GPU
dh::safe_cuda(cudaMemcpyAsync(entries_.data().get(),
data_vec.data() + offset_vec[batch_row_begin],
n_entries * sizeof(Entry),
cudaMemcpyDefault));
// copy the weights if necessary
if (has_weights_) {
const auto& weights_vec = info.weights_.HostVector();
dh::safe_cuda(cudaMemcpyAsync(weights_.data().get(),
weights_vec.data() + batch_row_begin,
batch_nrows * sizeof(bst_float),
cudaMemcpyDefault));
}
// unpack the features; also unpack weights if present
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
if (has_weights_) {
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
}
dim3 block3(16, 64, 1);
// NOTE: This will typically support ~ 4M features - 64K*64
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(num_cols_, block3.y), 1);
dh::LaunchKernel {grid3, block3} (
UnpackFeaturesK,
fvalues_.data().get(),
has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin,
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_,
offset_vec[batch_row_begin],
batch_nrows);
for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol);
}
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
#pragma omp parallel for default(none) schedule(static) \
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
WQSketch::SummaryContainer summary;
summary.Reserve(n_cuts_);
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
std::lock_guard<std::mutex> lock(sketch_container_->col_locks_[icol]);
sketch_container_->sketches_[icol].PushSummary(summary);
}
}
const int device_;
const int max_bin_;
int gpu_batch_nrows_;
size_t row_stride_;
std::unique_ptr<SketchContainer> sketch_container_;
bst_uint n_rows_{};
int num_cols_{0};
size_t n_cuts_{0};
bool has_weights_{false};
dh::device_vector<size_t> row_ptrs_{};
dh::device_vector<Entry> entries_{};
dh::device_vector<bst_float> fvalues_{};
dh::device_vector<bst_float> feature_weights_{};
dh::device_vector<bst_float> fvalues_cur_{};
dh::device_vector<WQSketch::Entry> cuts_d_{};
thrust::host_vector<WQSketch::Entry> cuts_h_{};
dh::device_vector<bst_float> weights_{};
dh::device_vector<bst_float> weights2_{};
std::vector<size_t> n_cuts_cur_{};
dh::device_vector<size_t> num_elements_{};
dh::device_vector<char> tmp_storage_{};
};
size_t DeviceSketch(int device,
int max_bin,
int gpu_batch_nrows,
DMatrix* dmat,
HistogramCuts* hmat) {
GPUSketcher sketcher(device, max_bin, gpu_batch_nrows);
// We only need to return the result in HistogramCuts container, so it is safe to
// use a pointer of local HistogramCutsDense
DenseCuts dense_cuts(hmat);
auto res = sketcher.Sketch(dmat, &dense_cuts);
return res;
// Count the entries in each column and exclusive scan
void GetColumnSizesScan(int device,
dh::caching_device_vector<size_t>* column_sizes_scan,
Span<const Entry> entries, size_t num_columns) {
column_sizes_scan->resize(num_columns + 1, 0);
auto d_column_sizes_scan = column_sizes_scan->data().get();
auto d_entries = entries.data();
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
auto& e = d_entries[idx];
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.index]),
static_cast<unsigned long long>(1)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
/**
* \brief Extracts the cuts from sorted data.
*
* \param device The device.
* \param cuts Output cuts
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns
* \param column_sizes_scan Describes the boundaries of column segments in
* sorted data
*/
void ExtractCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
Span<size_t> column_sizes_scan) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) / num_available_cuts;
auto value = column_entries[rank].fvalue;
cuts[idx] = SketchEntry(rank, rank + 1, 1, value);
});
}
/**
* \brief Extracts the cuts from sorted data, considering weights.
*
* \param device The device.
* \param cuts Output cuts.
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns.
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
Span<float> column_weights =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights.back();
size_t sample_idx = 0;
if (cut_idx == 0) {
// First cut
sample_idx = 0;
} else if (cut_idx == num_available_cuts - 1) {
// Last cut
sample_idx = column_entries.size() - 1;
} else if (num_available_cuts == column_size) {
// There are less samples available than our buffer
// Take every available sample
sample_idx = cut_idx;
} else {
bst_float rank = (total_column_weight * cut_idx) /
static_cast<float>(num_available_cuts);
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(),
column_weights.end(), rank) -
column_weights.begin() - 1;
sample_idx =
max(size_t(0), min(sample_idx, column_entries.size() - 1));
}
// repeated values will be filtered out on the CPU
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0;
bst_float rmax = column_weights[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
column_entries[sample_idx].fvalue);
});
}
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());
dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}
void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
// Binary search to assign weights to each element
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid;
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid];
});
// Sort
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
sorted_entries.begin(), sorted_entries.end(),
temp_weights.begin(), temp_weights.begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractWeightedCuts(
device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{temp_weights.data().get(), temp_weights.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) {
HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
dmat->Info().num_row_);
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts);
num_cuts = std::min(num_cuts, dmat->Info().num_row_);
if (sketch_batch_num_elements == 0) {
sketch_batch_num_elements = dmat->Info().num_nonzero_;
}
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
for (auto begin = 0ull; begin < batch_nnz;
begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (dmat->Info().weights_.Size() > 0) {
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container, num_cuts, dmat->Info().num_col_);
} else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts,
dmat->Info().num_col_);
}
}
}
dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_);
return cuts;
}
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}
float missing;
__device__ bool operator()(const data::COOTuple& e) const {
if (common::CheckNAN(e.value) || e.value == missing) {
return false;
}
return true;
}
__device__ bool operator()(const Entry& e) const {
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
return false;
}
return true;
}
};
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
template <typename AdapterT>
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc;
adapter->BeforeFirst();
adapter->Next();
auto &batch = adapter->Value();
// Enforce single batch
CHECK(!adapter->Next());
auto batch_iter = MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto entry_iter = MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
});
// Work out how many valid entries we have in each column
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
0);
auto d_column_sizes_scan = column_sizes_scan.data().get();
IsValidFunctor is_valid(missing);
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
}
});
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(),
column_sizes_scan.end(), column_sizes_scan.begin());
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
size_t num_valid = host_column_sizes_scan.back();
// Copy current subset of valid elements into temporary storage and sort
thrust::device_vector<Entry> sorted_entries(num_valid);
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
entry_iter + end, sorted_entries.begin(), is_valid);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());
// Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
// Push cuts into sketches stored in host memory
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}
template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing,
size_t sketch_batch_num_elements) {
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
// Enforce single batch
CHECK(!adapter->Next());
HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
adapter->NumRows());
constexpr int kFactor = 8;
double eps = 1.0 / (kFactor * num_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts);
num_cuts = std::min(num_cuts, adapter->NumRows());
if (sketch_batch_num_elements == 0) {
sketch_batch_num_elements = batch.Size();
}
for (auto begin = 0ull; begin < batch.Size();
begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts);
}
dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows());
return cuts;
}
template HistogramCuts AdapterDeviceSketch(data::CudfAdapter* adapter,
int num_bins, float missing,
size_t sketch_batch_size);
template HistogramCuts AdapterDeviceSketch(data::CupyAdapter* adapter,
int num_bins, float missing,
size_t sketch_batch_size);
} // namespace common
} // namespace xgboost

View File

@ -179,16 +179,14 @@ class DenseCuts : public CutsBuilder {
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
};
// FIXME(trivialfis): Merge this into generic cut builder.
/*! \brief Builds the cut matrix on the GPU.
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch(int device,
int max_bin,
int gpu_batch_nrows,
DMatrix* dmat,
HistogramCuts* hmat);
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 10000000);
template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing,
size_t sketch_batch_num_elements = 10000000);
/*!
* \brief preprocessed global index matrix, in CSR format

View File

@ -71,6 +71,7 @@ namespace data {
constexpr size_t kAdapterUnknownSize = std::numeric_limits<size_t >::max();
struct COOTuple {
COOTuple() = default;
XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value)
: row_idx(row_idx), column_idx(column_idx), value(value) {}

View File

@ -78,6 +78,20 @@ EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) {
monitor_.StopCuda("InitCompressedData");
}
size_t GetRowStride(DMatrix* dmat) {
if (dmat->IsDense()) return dmat->Info().num_col_;
size_t row_stride = 0;
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
const auto& row_offset = batch.offset.ConstHostVector();
for (auto i = 1ull; i < row_offset.size(); i++) {
row_stride = std::max(
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
}
}
return row_stride;
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page");
@ -87,13 +101,13 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
size_t row_stride =
common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
size_t row_stride = GetRowStride(dmat);
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
param.gpu_batch_nrows);
monitor_.StopCuda("Quantiles");
monitor_.StartCuda("InitEllpackInfo");
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat);
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, cuts);
monitor_.StopCuda("InitEllpackInfo");
monitor_.StartCuda("InitCompressedData");

View File

@ -70,6 +70,20 @@ const EllpackPage& EllpackPageSource::Value() const {
return impl_->Value();
}
size_t GetRowStride(DMatrix* dmat) {
if (dmat->IsDense()) return dmat->Info().num_col_;
size_t row_stride = 0;
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
const auto& row_offset = batch.offset.ConstHostVector();
for (auto i = 1ull; i < row_offset.size(); i++) {
row_stride = std::max(
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
}
}
return row_stride;
}
// Build the quantile sketch across the whole input data, then use the histogram cuts to compress
// each CSR page, and write the accumulated ELLPACK pages to disk.
EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
@ -85,13 +99,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
dh::safe_cuda(cudaSetDevice(device_));
monitor_.StartCuda("Quantiles");
common::HistogramCuts hmat;
size_t row_stride =
common::DeviceSketch(device_, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
size_t row_stride = GetRowStride(dmat);
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
param.gpu_batch_nrows);
monitor_.StopCuda("Quantiles");
monitor_.StartCuda("CreateEllpackInfo");
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_);
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, cuts, &ba_);
monitor_.StopCuda("CreateEllpackInfo");
monitor_.StartCuda("WriteEllpackPages");

View File

@ -1,105 +0,0 @@
#include <dmlc/filesystem.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <cmath>
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include "xgboost/c_api.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h"
#include "../helpers.h"
namespace xgboost {
namespace common {
void TestDeviceSketch(bool use_external_memory) {
// create the data
int nrows = 10001;
std::shared_ptr<xgboost::DMatrix> *dmat = nullptr;
size_t num_cols = 1;
dmlc::TemporaryDirectory tmpdir;
std::string file = tmpdir.path + "/big.libsvm";
if (use_external_memory) {
auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL, file); // 3 entries/row
dmat = new std::shared_ptr<xgboost::DMatrix>(std::move(sp_dmat));
num_cols = 5;
} else {
std::vector<float> test_data(nrows);
auto count_iter = thrust::make_counting_iterator(0);
// fill in reverse order
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
// create the DMatrix
DMatrixHandle dmat_handle;
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
&dmat_handle);
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
}
int device{0};
int max_bin{20};
int gpu_batch_nrows{0};
// find quantiles on the CPU
HistogramCuts hmat_cpu;
hmat_cpu.Build((*dmat).get(), max_bin);
// find the cuts on the GPU
HistogramCuts hmat_gpu;
size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu);
// compare the row stride with the one obtained from the dmatrix
bst_row_t expected_row_stride = 0;
for (const auto &batch : dmat->get()->GetBatches<xgboost::SparsePage>()) {
const auto &offset_vec = batch.offset.ConstHostVector();
for (int i = 1; i <= offset_vec.size() -1; ++i) {
expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]);
}
}
ASSERT_EQ(expected_row_stride, row_stride);
// compare the cuts
double eps = 1e-2;
ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols);
ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size());
ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.Values().size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
}
// Determinstic
size_t constexpr kRounds { 100 };
for (size_t r = 0; r < kRounds; ++r) {
HistogramCuts new_sketch;
DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &new_sketch);
ASSERT_EQ(hmat_gpu.Values().size(), new_sketch.Values().size());
for (size_t i = 0; i < hmat_gpu.Values().size(); ++i) {
ASSERT_EQ(hmat_gpu.Values()[i], new_sketch.Values()[i]);
}
for (size_t i = 0; i < hmat_gpu.MinValues().size(); ++i) {
ASSERT_EQ(hmat_gpu.MinValues()[i], new_sketch.MinValues()[i]);
}
}
delete dmat;
}
TEST(gpu_hist_util, DeviceSketch) {
TestDeviceSketch(false);
}
TEST(gpu_hist_util, DeviceSketch_ExternalMemory) {
TestDeviceSketch(true);
}
} // namespace common
} // namespace xgboost

View File

@ -261,7 +261,25 @@ TEST(hist_util, DenseCutsAccuracyTest) {
HistogramCuts cuts;
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, DenseCutsAccuracyTestWeights) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto w = GenerateRandomWeights(num_rows);
dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) {
HistogramCuts cuts;
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
@ -279,7 +297,7 @@ TEST(hist_util, DenseCutsExternalMemory) {
HistogramCuts cuts;
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
@ -295,7 +313,7 @@ TEST(hist_util, SparseCutsAccuracyTest) {
HistogramCuts cuts;
SparseCuts sparse(&cuts);
sparse.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
@ -335,7 +353,7 @@ TEST(hist_util, SparseCutsExternalMemory) {
HistogramCuts cuts;
SparseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}

View File

@ -0,0 +1,212 @@
#include <dmlc/filesystem.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <cmath>
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include "xgboost/c_api.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h"
#include "../helpers.h"
#include <xgboost/data.h>
#include "../../../src/data/device_adapter.cuh"
#include "../data/test_array_interface.h"
#include "../../../src/common/math.h"
#include "../../../src/data/simple_dmatrix.h"
#include "test_hist_util.h"
namespace xgboost {
namespace common {
template <typename AdapterT>
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
HistogramCuts cuts;
DenseCuts builder(&cuts);
data::SimpleDMatrix dmat(adapter, missing, 1);
builder.Build(&dmat, num_bins);
return cuts;
}
TEST(hist_util, DeviceSketch) {
int num_rows = 5;
int num_columns = 1;
int num_bins = 4;
std::vector<float> x = {1.0, 2.0, 3.0, 4.0, 5.0};
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
HistogramCuts host_cuts;
DenseCuts builder(&host_cuts);
builder.Build(dmat.get(), num_bins);
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
}
TEST(hist_util, DeviceSketchDeterminism) {
int num_rows = 500;
int num_columns = 5;
int num_bins = 256;
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins);
size_t constexpr kRounds{ 100 };
for (size_t r = 0; r < kRounds; ++r) {
auto new_sketch = DeviceSketch(0, dmat.get(), num_bins);
ASSERT_EQ(reference_sketch.Values(), new_sketch.Values());
ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues());
}
}
TEST(hist_util, DeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
int sizes[] = {25, 100, 1000};
for (auto n : sizes) {
for (auto num_categories : categorical_sizes) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
auto dmat = GetDMatrixFromData(x, n, 1);
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, DeviceSketchMultipleColumns) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, DeviceSketchMultipleColumnsWeights) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, DeviceSketchBatches) {
int num_bins = 256;
int num_rows = 5000;
int batch_sizes[] = {0, 100, 1500, 6000};
int num_columns = 5;
for (auto batch_size : batch_sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
TEST(hist_util, DeviceSketchMultipleColumnsExternal) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns =5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
dmlc::TemporaryDirectory temp;
auto dmat =
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 100, temp);
for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, AdapterDeviceSketch)
{
int rows = 5;
int cols = 1;
int num_bins = 4;
float missing = - 1.0;
thrust::device_vector< float> data(rows*cols);
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
data = std::vector<float >{ 1.0,2.0,3.0,4.0,5.0 };
std::stringstream ss;
Json::Dump(json_array_interface, &ss);
std::string str = ss.str();
data::CupyAdapter adapter(str);
auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing);
auto host_cuts = GetHostCuts(&adapter, num_bins, missing);
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
}
TEST(hist_util, AdapterDeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
int sizes[] = {25, 100, 1000};
for (auto n : sizes) {
for (auto num_categories : categorical_sizes) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
auto dmat = GetDMatrixFromData(x, n, 1);
auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, n, 1);
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
std::numeric_limits<float>::quiet_NaN());
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, AdapterDeviceSketchMultipleColumns) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
for (auto num_bins : bin_sizes) {
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
std::numeric_limits<float>::quiet_NaN());
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(hist_util, AdapterDeviceSketchBatches) {
int num_bins = 256;
int num_rows = 5000;
int batch_sizes[] = {0, 100, 1500, 6000};
int num_columns = 5;
for (auto batch_size : batch_sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
std::numeric_limits<float>::quiet_NaN(),
batch_size);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
} // namespace common
} // namespace xgboost

View File

@ -28,6 +28,34 @@ inline std::vector<float> GenerateRandom(int num_rows, int num_columns) {
return x;
}
inline std::vector<float> GenerateRandomWeights(int num_rows) {
std::vector<float> w(num_rows);
std::mt19937 rng(1);
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::generate(w.begin(), w.end(), [&]() { return dist(rng); });
return w;
}
#ifdef __CUDACC__
inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
int num_rows, int num_columns) {
Json array_interface{Object()};
std::vector<Json> shape = {Json(static_cast<Integer::Int>(num_rows)),
Json(static_cast<Integer::Int>(num_columns))};
array_interface["shape"] = Array(shape);
std::vector<Json> j_data{
Json(Integer(reinterpret_cast<Integer::Int>(x.data().get()))),
Json(Boolean(false))};
array_interface["data"] = j_data;
array_interface["version"] = Integer(static_cast<Integer::Int>(1));
array_interface["typestr"] = String("<f4");
std::stringstream ss;
Json::Dump(array_interface, &ss);
std::string str = ss.str();
return data::CupyAdapter(str);
}
#endif
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
int num_categories) {
std::vector<float> x(n);
@ -69,21 +97,22 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
// Test that elements are approximately equally distributed among bins
inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
const std::vector<float>& column,
const std::vector<float>& sorted_column,const std::vector<float >&sorted_weights,
int num_bins) {
std::map<int, int> counts;
for (auto& v : column) {
counts[cuts.SearchBin(v, column_idx)]++;
std::map<int, int> bin_weights;
for (auto i = 0ull; i < sorted_column.size(); i++) {
bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i];
}
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
int expected_num_elements = column.size() / local_num_bins;
// Allow about 30% deviation. This test is not very strict, it only ensures
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
int expected_bin_weight = total_weight / local_num_bins;
// Allow up to 30% deviation. This test is not very strict, it only ensures
// roughly equal distribution
int allowable_error = std::max(2, int(expected_num_elements * 0.3));
int allowable_error = std::max(2, int(expected_bin_weight * 0.3));
// First and last bin can have smaller
for (auto& kv : counts) {
EXPECT_LE(std::abs(counts[kv.first] - expected_num_elements),
for (auto& kv : bin_weights) {
EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight),
allowable_error );
}
}
@ -91,26 +120,29 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
// Test sketch quantiles against the real quantiles
// Not a very strict test
inline void TestRank(const std::vector<float>& cuts,
const std::vector<float>& sorted_x) {
float eps = 0.05;
const std::vector<float>& sorted_x,
const std::vector<float>& sorted_weights) {
double eps = 0.05;
auto total_weight =
std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0);
// Ignore the last cut, its special
double sum_weight = 0.0;
size_t j = 0;
for (auto i = 0; i < cuts.size() - 1; i++) {
int expected_rank = ((i+1) * sorted_x.size()) / cuts.size();
while (cuts[i] > sorted_x[j]) {
sum_weight += sorted_weights[j];
j++;
}
int actual_rank = j;
int acceptable_error = std::max(2, int(sorted_x.size() * eps));
ASSERT_LE(std::abs(expected_rank - actual_rank), acceptable_error);
double expected_rank = ((i + 1) * total_weight) / cuts.size();
double acceptable_error = std::max(2.0, total_weight * eps);
ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error);
}
}
inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
const std::vector<float>& column,
const std::vector<float>& sorted_column,
const std::vector<float>& sorted_weights,
int num_bins) {
std::vector<float> sorted_column(column);
std::sort(sorted_column.begin(), sorted_column.end());
// Check the endpoints are correct
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
@ -126,40 +158,60 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
EXPECT_EQ(std::set<float>(cuts_begin, cuts_end).size(),
cuts_end - cuts_begin);
if (sorted_column.size() <= num_bins) {
auto unique = std::set<float>(sorted_column.begin(), sorted_column.end());
if (unique.size() <= num_bins) {
// Less unique values than number of bins
// Each value should get its own bin
// First check the inputs are unique
int num_unique =
std::set<float>(sorted_column.begin(), sorted_column.end()).size();
EXPECT_EQ(num_unique, sorted_column.size());
for (auto i = 0ull; i < sorted_column.size(); i++) {
ASSERT_EQ(cuts.SearchBin(sorted_column[i], column_idx),
cuts.Ptrs()[column_idx] + i);
int i = 0;
for (auto v : unique) {
ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i);
i++;
}
}
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
std::vector<float> column_cuts(num_cuts_column);
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
cuts.Values().begin() + cuts.Ptrs()[column_idx + 1],
column_cuts.begin());
TestBinDistribution(cuts, column_idx, sorted_column, num_bins);
TestRank(column_cuts, sorted_column);
else {
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
std::vector<float> column_cuts(num_cuts_column);
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
cuts.Values().begin() + cuts.Ptrs()[column_idx + 1],
column_cuts.begin());
TestBinDistribution(cuts, column_idx, sorted_column,sorted_weights, num_bins);
TestRank(column_cuts, sorted_column,sorted_weights);
}
}
// x is dense and row major
inline void ValidateCuts(const HistogramCuts& cuts, std::vector<float>& x,
int num_rows, int num_columns,
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
int num_bins) {
for (auto i = 0; i < num_columns; i++) {
// Extract the column
std::vector<float> column(num_rows);
for (auto j = 0; j < num_rows; j++) {
column[j] = x[j*num_columns + i];
}
ValidateColumn(cuts,i, column, num_bins);
}
// Collect data into columns
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
for (auto& batch : dmat->GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
for (auto e : batch[i]) {
columns[e.index].push_back(e.fvalue);
}
}
}
// Sort
for (auto i = 0ull; i < columns.size(); i++) {
auto& col = columns.at(i);
const auto& w = dmat->Info().weights_.HostVector();
std::vector<size_t > index(col.size());
std::iota(index.begin(), index.end(), 0);
std::sort(index.begin(), index.end(),[=](size_t a,size_t b)
{
return col[a] < col[b];
});
std::vector<float> sorted_column(col.size());
std::vector<float> sorted_weights(col.size(), 1.0);
for (auto j = 0ull; j < col.size(); j++) {
sorted_column[j] = col[index[j]];
if (w.size() == col.size()) {
sorted_weights[j] = w[index[j]];
}
}
ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins);
}
}
} // namespace common

View File

@ -158,28 +158,6 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
current_row += impl_ext->matrix.n_rows;
}
current_row = 0;
thrust::device_vector<bst_float> row_d(kCols);
thrust::device_vector<bst_float> row_ext_d(kCols);
std::vector<bst_float> row(kCols);
std::vector<bst_float> row_ext(kCols);
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
auto impl_ext = page.Impl();
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
for (size_t i = 0; i < impl_ext->Size(); i++) {
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get()));
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());
EXPECT_EQ(row, row_ext) << "for row " << current_row;
current_row++;
}
}
}
} // namespace xgboost

View File

@ -284,7 +284,6 @@ void TestHistogramIndexImpl() {
ASSERT_EQ(maker->page->matrix.info.n_bins, maker_ext->page->matrix.info.n_bins);
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}
TEST(GpuHist, TestHistogramIndex) {