Sketching from adapters (#5365)
* Sketching from adapters * Add weights test
This commit is contained in:
parent
0dd97c206b
commit
a38e7bd19c
@ -9,85 +9,29 @@
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "hist_util.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "device_helpers.cuh"
|
||||
#include "hist_util.h"
|
||||
#include "math.h" // NOLINT
|
||||
#include "quantile.h"
|
||||
#include "../tree/param.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
using WQSketch = DenseCuts::WQSketch;
|
||||
|
||||
__global__ void FindCutsK(WQSketch::Entry* __restrict__ cuts,
|
||||
const bst_float* __restrict__ data,
|
||||
const float* __restrict__ cum_weights,
|
||||
int nsamples,
|
||||
int ncuts) {
|
||||
// ncuts < nsamples
|
||||
int icut = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (icut >= ncuts) {
|
||||
return;
|
||||
}
|
||||
int isample = 0;
|
||||
if (icut == 0) {
|
||||
isample = 0;
|
||||
} else if (icut == ncuts - 1) {
|
||||
isample = nsamples - 1;
|
||||
} else {
|
||||
bst_float rank = cum_weights[nsamples - 1] / static_cast<float>(ncuts - 1)
|
||||
* static_cast<float>(icut);
|
||||
// -1 is used because cum_weights is an inclusive sum
|
||||
isample = dh::UpperBound(cum_weights, nsamples, rank);
|
||||
isample = max(0, min(isample, nsamples - 1));
|
||||
}
|
||||
// repeated values will be filtered out on the CPU
|
||||
bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0;
|
||||
bst_float rmax = cum_weights[isample];
|
||||
cuts[icut] = WQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]);
|
||||
}
|
||||
|
||||
// predictate for thrust filtering that returns true if the element is not a NaN
|
||||
struct IsNotNaN {
|
||||
__device__ bool operator()(float a) const { return !isnan(a); }
|
||||
};
|
||||
|
||||
__global__ void UnpackFeaturesK(float* __restrict__ fvalues,
|
||||
float* __restrict__ feature_weights,
|
||||
const size_t* __restrict__ row_ptrs,
|
||||
const float* __restrict__ weights,
|
||||
Entry* entries,
|
||||
size_t nrows_array,
|
||||
size_t row_begin_ptr,
|
||||
size_t nrows) {
|
||||
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
|
||||
if (irow >= nrows) {
|
||||
return;
|
||||
}
|
||||
size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow];
|
||||
int icol = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if (icol >= row_length) {
|
||||
return;
|
||||
}
|
||||
Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol];
|
||||
size_t ind = entry.index * nrows_array + irow;
|
||||
// if weights are present, ensure that a non-NaN value is written to weights
|
||||
// if and only if it is also written to features
|
||||
if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) {
|
||||
fvalues[ind] = entry.fvalue;
|
||||
if (feature_weights != nullptr && weights != nullptr) {
|
||||
feature_weights[ind] = weights[irow];
|
||||
}
|
||||
}
|
||||
}
|
||||
using SketchEntry = WQSketch::Entry;
|
||||
|
||||
/*!
|
||||
* \brief A container that holds the device sketches across all
|
||||
@ -98,379 +42,410 @@ __global__ void UnpackFeaturesK(float* __restrict__ fvalues,
|
||||
*/
|
||||
struct SketchContainer {
|
||||
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
|
||||
std::vector<std::mutex> col_locks_; // NOLINT
|
||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||
|
||||
SketchContainer(int max_bin, DMatrix* dmat) : col_locks_(dmat->Info().num_col_) {
|
||||
const MetaInfo& info = dmat->Info();
|
||||
SketchContainer(int max_bin, size_t num_columns, size_t num_rows) {
|
||||
// Initialize Sketches for this dmatrix
|
||||
sketches_.resize(info.num_col_);
|
||||
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
|
||||
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
|
||||
sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin));
|
||||
sketches_.resize(num_columns);
|
||||
#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_columns; ++icol) { // NOLINT
|
||||
sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin));
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved
|
||||
SketchContainer(const SketchContainer &) = delete;
|
||||
SketchContainer(const SketchContainer &&) = delete;
|
||||
SketchContainer &operator=(const SketchContainer &) = delete;
|
||||
SketchContainer &operator=(const SketchContainer &&) = delete;
|
||||
/**
|
||||
* \brief Pushes cuts to the sketches.
|
||||
*
|
||||
* \param entries_per_column The entries per column.
|
||||
* \param entries Vector of cuts from all columns, length
|
||||
* entries_per_column * num_columns. \param column_scan Exclusive scan
|
||||
* of column sizes. Used to detect cases where there are fewer entries than we
|
||||
* have storage for.
|
||||
*/
|
||||
void Push(size_t entries_per_column,
|
||||
const thrust::host_vector<SketchEntry>& entries,
|
||||
const thrust::host_vector<size_t>& column_scan) {
|
||||
#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < sketches_.size(); ++icol) {
|
||||
size_t column_size = column_scan[icol + 1] - column_scan[icol];
|
||||
if (column_size == 0) continue;
|
||||
WQuantileSketch<bst_float, bst_float>::SummaryContainer summary;
|
||||
size_t num_available_cuts =
|
||||
std::min(size_t(entries_per_column), column_size);
|
||||
summary.Reserve(num_available_cuts);
|
||||
summary.MakeFromSorted(&entries[entries_per_column * icol],
|
||||
num_available_cuts);
|
||||
|
||||
sketches_[icol].PushSummary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent copying/assigning/moving this as its internals can't be
|
||||
// assigned/copied/moved
|
||||
SketchContainer(const SketchContainer&) = delete;
|
||||
SketchContainer(const SketchContainer&&) = delete;
|
||||
SketchContainer& operator=(const SketchContainer&) = delete;
|
||||
SketchContainer& operator=(const SketchContainer&&) = delete;
|
||||
};
|
||||
|
||||
// finds quantiles on the GPU
|
||||
class GPUSketcher {
|
||||
public:
|
||||
GPUSketcher(int device, int max_bin, int gpu_nrows)
|
||||
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
|
||||
|
||||
~GPUSketcher() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
struct EntryCompareOp {
|
||||
__device__ bool operator()(const Entry& a, const Entry& b) {
|
||||
if (a.index == b.index) {
|
||||
return a.fvalue < b.fvalue;
|
||||
}
|
||||
return a.index < b.index;
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
||||
n_rows_ = batch.Size();
|
||||
|
||||
Init(batch, info, gpu_batch_nrows_);
|
||||
Sketch(batch, info);
|
||||
ComputeRowStride();
|
||||
}
|
||||
|
||||
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
||||
* for the entire dataset */
|
||||
size_t Sketch(DMatrix *dmat, DenseCuts *hmat) {
|
||||
const MetaInfo& info = dmat->Info();
|
||||
|
||||
row_stride_ = 0;
|
||||
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->SketchBatch(batch, info);
|
||||
}
|
||||
|
||||
hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_);
|
||||
return row_stride_;
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void ComputeRowStride() {
|
||||
// Find the row stride for this batch
|
||||
auto row_iter = row_ptrs_.begin();
|
||||
// Functor for finding the maximum row size for this batch
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size), decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
size_t batch_row_stride =
|
||||
thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum<size_t>());
|
||||
row_stride_ = std::max(row_stride_, batch_row_stride);
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void FindColumnCuts(size_t batch_nrows, size_t icol) {
|
||||
size_t tmp_size = tmp_storage_.size();
|
||||
// filter out NaNs in feature values
|
||||
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_begin,
|
||||
fvalues_cur_.data(),
|
||||
num_elements_.begin(),
|
||||
batch_nrows,
|
||||
IsNotNaN());
|
||||
size_t nfvalues_cur = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
|
||||
|
||||
// compute cumulative weights using a prefix scan
|
||||
if (has_weights_) {
|
||||
// filter out NaNs in weights;
|
||||
// since cub::DeviceSelect::If performs stable filtering,
|
||||
// the weights are stored in the correct positions
|
||||
auto feature_weights_begin = feature_weights_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
feature_weights_begin,
|
||||
weights_.data().get(),
|
||||
num_elements_.begin(),
|
||||
batch_nrows,
|
||||
IsNotNaN());
|
||||
|
||||
// sort the values and weights
|
||||
cub::DeviceRadixSort::SortPairs(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(),
|
||||
weights_.data().get(),
|
||||
weights2_.data().get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// sum the weights to get cumulative weight values
|
||||
cub::DeviceScan::InclusiveSum(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
weights2_.begin(),
|
||||
weights_.begin(),
|
||||
nfvalues_cur);
|
||||
} else {
|
||||
// sort the batch values
|
||||
cub::DeviceRadixSort::SortKeys(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// fill in cumulative weights with counting iterator
|
||||
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, weights_.begin());
|
||||
}
|
||||
|
||||
// remove repeated items and sum the weights across them;
|
||||
// non-negative weights are assumed
|
||||
cub::DeviceReduce::ReduceByKey(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_begin,
|
||||
fvalues_cur_.begin(),
|
||||
weights_.begin(),
|
||||
weights2_.begin(),
|
||||
num_elements_.begin(),
|
||||
thrust::maximum<bst_float>(),
|
||||
nfvalues_cur);
|
||||
size_t n_unique = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
|
||||
|
||||
// extract cuts
|
||||
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
|
||||
// if less elements than cuts: copy all elements with their weights
|
||||
if (n_cuts_ > n_unique) {
|
||||
float* weights2_ptr = weights2_.data().get();
|
||||
float* fvalues_ptr = fvalues_cur_.data().get();
|
||||
WQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
|
||||
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
|
||||
bst_float rmax = weights2_ptr[i];
|
||||
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
|
||||
cuts_ptr[i] = WQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
|
||||
});
|
||||
} else if (n_cuts_cur_[icol] > 0) {
|
||||
// if more elements than cuts: use binary search on cumulative weights
|
||||
uint32_t constexpr kBlockThreads = 256;
|
||||
uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads);
|
||||
dh::LaunchKernel {kGrids, kBlockThreads} (
|
||||
FindCutsK,
|
||||
cuts_d_.data().get() + icol * n_cuts_,
|
||||
fvalues_cur_.data().get(),
|
||||
weights2_.data().get(),
|
||||
n_unique,
|
||||
n_cuts_cur_[icol]);
|
||||
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
|
||||
num_cols_ = info.num_col_;
|
||||
has_weights_ = info.weights_.Size() > 0;
|
||||
|
||||
// find the batch size
|
||||
if (gpu_batch_nrows == 0) {
|
||||
// By default, use no more than 1/16th of GPU memory
|
||||
gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry));
|
||||
} else if (gpu_batch_nrows == -1) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
} else {
|
||||
gpu_batch_nrows_ = gpu_batch_nrows;
|
||||
}
|
||||
if (gpu_batch_nrows_ > n_rows_) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
}
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bin_);
|
||||
size_t dummy_nlevel;
|
||||
WQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
|
||||
// allocate necessary GPU buffers
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
entries_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_cur_.resize(gpu_batch_nrows_);
|
||||
cuts_d_.resize(n_cuts_ * num_cols_);
|
||||
cuts_h_.resize(n_cuts_ * num_cols_);
|
||||
weights_.resize(gpu_batch_nrows_);
|
||||
weights2_.resize(gpu_batch_nrows_);
|
||||
num_elements_.resize(1);
|
||||
|
||||
if (has_weights_) {
|
||||
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
}
|
||||
n_cuts_cur_.resize(num_cols_);
|
||||
|
||||
// allocate storage for CUB algorithms; the size is the maximum of the sizes
|
||||
// required for various algorithm
|
||||
size_t tmp_size = 0, cur_tmp_size = 0;
|
||||
// size for sorting
|
||||
if (has_weights_) {
|
||||
cub::DeviceRadixSort::SortPairs(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(),
|
||||
weights_.data().get(),
|
||||
weights2_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortKeys(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
}
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for inclusive scan
|
||||
if (has_weights_) {
|
||||
cub::DeviceScan::InclusiveSum(nullptr,
|
||||
cur_tmp_size,
|
||||
weights2_.begin(),
|
||||
weights_.begin(),
|
||||
gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
}
|
||||
// size for reduction by key
|
||||
cub::DeviceReduce::ReduceByKey(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_.begin(),
|
||||
fvalues_cur_.begin(),
|
||||
weights_.begin(),
|
||||
weights2_.begin(),
|
||||
num_elements_.begin(),
|
||||
thrust::maximum<bst_float>(),
|
||||
gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for filtering
|
||||
cub::DeviceSelect::If(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_.begin(),
|
||||
fvalues_cur_.begin(),
|
||||
num_elements_.begin(),
|
||||
gpu_batch_nrows_,
|
||||
IsNotNaN());
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
|
||||
tmp_storage_.resize(tmp_size);
|
||||
}
|
||||
|
||||
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
|
||||
// copy rows to the device
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs_.resize(n_rows_ + 1);
|
||||
thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
|
||||
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
SketchBatch(row_batch, info, gpu_batch);
|
||||
}
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, size_t gpu_batch) {
|
||||
// compute start and end indices
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
|
||||
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
|
||||
static_cast<size_t>(n_rows_));
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
const auto& data_vec = row_batch.data.HostVector();
|
||||
|
||||
size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
|
||||
// copy the batch to the GPU
|
||||
dh::safe_cuda(cudaMemcpyAsync(entries_.data().get(),
|
||||
data_vec.data() + offset_vec[batch_row_begin],
|
||||
n_entries * sizeof(Entry),
|
||||
cudaMemcpyDefault));
|
||||
// copy the weights if necessary
|
||||
if (has_weights_) {
|
||||
const auto& weights_vec = info.weights_.HostVector();
|
||||
dh::safe_cuda(cudaMemcpyAsync(weights_.data().get(),
|
||||
weights_vec.data() + batch_row_begin,
|
||||
batch_nrows * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// unpack the features; also unpack weights if present
|
||||
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
|
||||
if (has_weights_) {
|
||||
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
|
||||
}
|
||||
|
||||
dim3 block3(16, 64, 1);
|
||||
// NOTE: This will typically support ~ 4M features - 64K*64
|
||||
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(num_cols_, block3.y), 1);
|
||||
dh::LaunchKernel {grid3, block3} (
|
||||
UnpackFeaturesK,
|
||||
fvalues_.data().get(),
|
||||
has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||
row_ptrs_.data().get() + batch_row_begin,
|
||||
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
||||
gpu_batch_nrows_,
|
||||
offset_vec[batch_row_begin],
|
||||
batch_nrows);
|
||||
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
FindColumnCuts(batch_nrows, icol);
|
||||
}
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||
#pragma omp parallel for default(none) schedule(static) \
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
WQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_cuts_);
|
||||
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
|
||||
|
||||
std::lock_guard<std::mutex> lock(sketch_container_->col_locks_[icol]);
|
||||
sketch_container_->sketches_[icol].PushSummary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
const int device_;
|
||||
const int max_bin_;
|
||||
int gpu_batch_nrows_;
|
||||
size_t row_stride_;
|
||||
std::unique_ptr<SketchContainer> sketch_container_;
|
||||
|
||||
bst_uint n_rows_{};
|
||||
int num_cols_{0};
|
||||
size_t n_cuts_{0};
|
||||
bool has_weights_{false};
|
||||
|
||||
dh::device_vector<size_t> row_ptrs_{};
|
||||
dh::device_vector<Entry> entries_{};
|
||||
dh::device_vector<bst_float> fvalues_{};
|
||||
dh::device_vector<bst_float> feature_weights_{};
|
||||
dh::device_vector<bst_float> fvalues_cur_{};
|
||||
dh::device_vector<WQSketch::Entry> cuts_d_{};
|
||||
thrust::host_vector<WQSketch::Entry> cuts_h_{};
|
||||
dh::device_vector<bst_float> weights_{};
|
||||
dh::device_vector<bst_float> weights2_{};
|
||||
std::vector<size_t> n_cuts_cur_{};
|
||||
dh::device_vector<size_t> num_elements_{};
|
||||
dh::device_vector<char> tmp_storage_{};
|
||||
};
|
||||
|
||||
size_t DeviceSketch(int device,
|
||||
int max_bin,
|
||||
int gpu_batch_nrows,
|
||||
DMatrix* dmat,
|
||||
HistogramCuts* hmat) {
|
||||
GPUSketcher sketcher(device, max_bin, gpu_batch_nrows);
|
||||
// We only need to return the result in HistogramCuts container, so it is safe to
|
||||
// use a pointer of local HistogramCutsDense
|
||||
DenseCuts dense_cuts(hmat);
|
||||
auto res = sketcher.Sketch(dmat, &dense_cuts);
|
||||
return res;
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
Span<const Entry> entries, size_t num_columns) {
|
||||
column_sizes_scan->resize(num_columns + 1, 0);
|
||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
||||
auto d_entries = entries.data();
|
||||
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
|
||||
auto& e = d_entries[idx];
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.index]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Extracts the cuts from sorted data.
|
||||
*
|
||||
* \param device The device.
|
||||
* \param cuts Output cuts
|
||||
* \param num_cuts_per_feature Number of cuts per feature.
|
||||
* \param sorted_data Sorted entries in segments of columns
|
||||
* \param column_sizes_scan Describes the boundaries of column segments in
|
||||
* sorted data
|
||||
*/
|
||||
void ExtractCuts(int device, Span<SketchEntry> cuts,
|
||||
size_t num_cuts_per_feature, Span<Entry> sorted_data,
|
||||
Span<size_t> column_sizes_scan) {
|
||||
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = idx / num_cuts_per_feature;
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts =
|
||||
min(size_t(num_cuts_per_feature), column_size);
|
||||
size_t cut_idx = idx % num_cuts_per_feature;
|
||||
if (cut_idx >= num_available_cuts) return;
|
||||
|
||||
Span<Entry> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
|
||||
size_t rank = (column_entries.size() * cut_idx) / num_available_cuts;
|
||||
auto value = column_entries[rank].fvalue;
|
||||
cuts[idx] = SketchEntry(rank, rank + 1, 1, value);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Extracts the cuts from sorted data, considering weights.
|
||||
*
|
||||
* \param device The device.
|
||||
* \param cuts Output cuts.
|
||||
* \param num_cuts_per_feature Number of cuts per feature.
|
||||
* \param sorted_data Sorted entries in segments of columns.
|
||||
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
|
||||
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
|
||||
*/
|
||||
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
|
||||
size_t num_cuts_per_feature, Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan) {
|
||||
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = idx / num_cuts_per_feature;
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts =
|
||||
min(size_t(num_cuts_per_feature), column_size);
|
||||
size_t cut_idx = idx % num_cuts_per_feature;
|
||||
if (cut_idx >= num_available_cuts) return;
|
||||
|
||||
Span<Entry> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
Span<float> column_weights =
|
||||
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
|
||||
|
||||
float total_column_weight = column_weights.back();
|
||||
size_t sample_idx = 0;
|
||||
if (cut_idx == 0) {
|
||||
// First cut
|
||||
sample_idx = 0;
|
||||
} else if (cut_idx == num_available_cuts - 1) {
|
||||
// Last cut
|
||||
sample_idx = column_entries.size() - 1;
|
||||
} else if (num_available_cuts == column_size) {
|
||||
// There are less samples available than our buffer
|
||||
// Take every available sample
|
||||
sample_idx = cut_idx;
|
||||
} else {
|
||||
bst_float rank = (total_column_weight * cut_idx) /
|
||||
static_cast<float>(num_available_cuts);
|
||||
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(),
|
||||
column_weights.end(), rank) -
|
||||
column_weights.begin() - 1;
|
||||
sample_idx =
|
||||
max(size_t(0), min(sample_idx, column_entries.size() - 1));
|
||||
}
|
||||
// repeated values will be filtered out on the CPU
|
||||
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0;
|
||||
bst_float rmax = column_weights[sample_idx];
|
||||
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
|
||||
column_entries[sample_idx].fvalue);
|
||||
});
|
||||
}
|
||||
|
||||
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
num_columns);
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
|
||||
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
|
||||
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
{column_sizes_scan.data().get(), column_sizes_scan.size()});
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
size_t base_rowid = page.base_rowid;
|
||||
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
|
||||
size_t element_idx = idx + begin;
|
||||
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
|
||||
row_ptrs.end(), element_idx) -
|
||||
row_ptrs.begin() - 1;
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
|
||||
// Sort
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), temp_weights.begin(),
|
||||
EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
temp_weights.begin(), temp_weights.begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
num_columns);
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
|
||||
// Extract cuts
|
||||
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
|
||||
ExtractWeightedCuts(
|
||||
device, {cuts.data().get(), cuts.size()}, num_cuts,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
{temp_weights.data().get(), temp_weights.size()},
|
||||
{column_sizes_scan.data().get(), column_sizes_scan.size()});
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_);
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = dmat->Info().num_nonzero_;
|
||||
}
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
for (auto begin = 0ull; begin < batch_nnz;
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (dmat->Info().weights_.Size() > 0) {
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container, num_cuts, dmat->Info().num_col_);
|
||||
} else {
|
||||
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts,
|
||||
dmat->Info().num_col_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_);
|
||||
return cuts;
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
adapter->BeforeFirst();
|
||||
adapter->Next();
|
||||
auto &batch = adapter->Value();
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
auto batch_iter = MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
auto entry_iter = MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
batch.GetElement(idx).value);
|
||||
});
|
||||
// Work out how many valid entries we have in each column
|
||||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
|
||||
0);
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(),
|
||||
column_sizes_scan.end(), column_sizes_scan.begin());
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
size_t num_valid = host_column_sizes_scan.back();
|
||||
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
thrust::device_vector<Entry> sorted_entries(num_valid);
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
|
||||
entry_iter + end, sorted_entries.begin(), is_valid);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
|
||||
// Extract the cuts from all columns concurrently
|
||||
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
|
||||
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
{column_sizes_scan.data().get(), column_sizes_scan.size()});
|
||||
|
||||
// Push cuts into sketches stored in host memory
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements) {
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
adapter->BeforeFirst();
|
||||
adapter->Next();
|
||||
auto& batch = adapter->Value();
|
||||
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
|
||||
adapter->NumRows());
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * num_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = batch.Size();
|
||||
}
|
||||
for (auto begin = 0ull; begin < batch.Size();
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts);
|
||||
}
|
||||
|
||||
dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows());
|
||||
return cuts;
|
||||
}
|
||||
|
||||
template HistogramCuts AdapterDeviceSketch(data::CudfAdapter* adapter,
|
||||
int num_bins, float missing,
|
||||
size_t sketch_batch_size);
|
||||
template HistogramCuts AdapterDeviceSketch(data::CupyAdapter* adapter,
|
||||
int num_bins, float missing,
|
||||
size_t sketch_batch_size);
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -179,16 +179,14 @@ class DenseCuts : public CutsBuilder {
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
// FIXME(trivialfis): Merge this into generic cut builder.
|
||||
/*! \brief Builds the cut matrix on the GPU.
|
||||
*
|
||||
* \return The row stride across the entire dataset.
|
||||
*/
|
||||
size_t DeviceSketch(int device,
|
||||
int max_bin,
|
||||
int gpu_batch_nrows,
|
||||
DMatrix* dmat,
|
||||
HistogramCuts* hmat);
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
|
||||
@ -71,6 +71,7 @@ namespace data {
|
||||
constexpr size_t kAdapterUnknownSize = std::numeric_limits<size_t >::max();
|
||||
|
||||
struct COOTuple {
|
||||
COOTuple() = default;
|
||||
XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value)
|
||||
: row_idx(row_idx), column_idx(column_idx), value(value) {}
|
||||
|
||||
|
||||
@ -78,6 +78,20 @@ EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) {
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
}
|
||||
|
||||
size_t GetRowStride(DMatrix* dmat) {
|
||||
if (dmat->IsDense()) return dmat->Info().num_col_;
|
||||
|
||||
size_t row_stride = 0;
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(
|
||||
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix in memory.
|
||||
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
|
||||
monitor_.Init("ellpack_page");
|
||||
@ -87,13 +101,13 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
common::HistogramCuts hmat;
|
||||
size_t row_stride =
|
||||
common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
|
||||
size_t row_stride = GetRowStride(dmat);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("InitEllpackInfo");
|
||||
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat);
|
||||
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, cuts);
|
||||
monitor_.StopCuda("InitEllpackInfo");
|
||||
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
|
||||
@ -70,6 +70,20 @@ const EllpackPage& EllpackPageSource::Value() const {
|
||||
return impl_->Value();
|
||||
}
|
||||
|
||||
size_t GetRowStride(DMatrix* dmat) {
|
||||
if (dmat->IsDense()) return dmat->Info().num_col_;
|
||||
|
||||
size_t row_stride = 0;
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(
|
||||
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
// Build the quantile sketch across the whole input data, then use the histogram cuts to compress
|
||||
// each CSR page, and write the accumulated ELLPACK pages to disk.
|
||||
EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
||||
@ -85,13 +99,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
common::HistogramCuts hmat;
|
||||
size_t row_stride =
|
||||
common::DeviceSketch(device_, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
|
||||
size_t row_stride = GetRowStride(dmat);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("CreateEllpackInfo");
|
||||
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_);
|
||||
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, cuts, &ba_);
|
||||
monitor_.StopCuda("CreateEllpackInfo");
|
||||
|
||||
monitor_.StartCuda("WriteEllpackPages");
|
||||
|
||||
@ -1,105 +0,0 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void TestDeviceSketch(bool use_external_memory) {
|
||||
// create the data
|
||||
int nrows = 10001;
|
||||
std::shared_ptr<xgboost::DMatrix> *dmat = nullptr;
|
||||
|
||||
size_t num_cols = 1;
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string file = tmpdir.path + "/big.libsvm";
|
||||
if (use_external_memory) {
|
||||
auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL, file); // 3 entries/row
|
||||
dmat = new std::shared_ptr<xgboost::DMatrix>(std::move(sp_dmat));
|
||||
num_cols = 5;
|
||||
} else {
|
||||
std::vector<float> test_data(nrows);
|
||||
auto count_iter = thrust::make_counting_iterator(0);
|
||||
// fill in reverse order
|
||||
std::copy(count_iter, count_iter + nrows, test_data.rbegin());
|
||||
|
||||
// create the DMatrix
|
||||
DMatrixHandle dmat_handle;
|
||||
XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1,
|
||||
&dmat_handle);
|
||||
dmat = static_cast<std::shared_ptr<xgboost::DMatrix> *>(dmat_handle);
|
||||
}
|
||||
|
||||
int device{0};
|
||||
int max_bin{20};
|
||||
int gpu_batch_nrows{0};
|
||||
|
||||
// find quantiles on the CPU
|
||||
HistogramCuts hmat_cpu;
|
||||
hmat_cpu.Build((*dmat).get(), max_bin);
|
||||
|
||||
// find the cuts on the GPU
|
||||
HistogramCuts hmat_gpu;
|
||||
size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu);
|
||||
|
||||
// compare the row stride with the one obtained from the dmatrix
|
||||
bst_row_t expected_row_stride = 0;
|
||||
for (const auto &batch : dmat->get()->GetBatches<xgboost::SparsePage>()) {
|
||||
const auto &offset_vec = batch.offset.ConstHostVector();
|
||||
for (int i = 1; i <= offset_vec.size() -1; ++i) {
|
||||
expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(expected_row_stride, row_stride);
|
||||
|
||||
// compare the cuts
|
||||
double eps = 1e-2;
|
||||
ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols);
|
||||
ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1);
|
||||
ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size());
|
||||
ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows);
|
||||
for (int i = 0; i < hmat_gpu.Values().size(); ++i) {
|
||||
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
|
||||
}
|
||||
|
||||
// Determinstic
|
||||
size_t constexpr kRounds { 100 };
|
||||
for (size_t r = 0; r < kRounds; ++r) {
|
||||
HistogramCuts new_sketch;
|
||||
DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &new_sketch);
|
||||
ASSERT_EQ(hmat_gpu.Values().size(), new_sketch.Values().size());
|
||||
for (size_t i = 0; i < hmat_gpu.Values().size(); ++i) {
|
||||
ASSERT_EQ(hmat_gpu.Values()[i], new_sketch.Values()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < hmat_gpu.MinValues().size(); ++i) {
|
||||
ASSERT_EQ(hmat_gpu.MinValues()[i], new_sketch.MinValues()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
TEST(gpu_hist_util, DeviceSketch) {
|
||||
TestDeviceSketch(false);
|
||||
}
|
||||
|
||||
TEST(gpu_hist_util, DeviceSketch_ExternalMemory) {
|
||||
TestDeviceSketch(true);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -261,7 +261,25 @@ TEST(hist_util, DenseCutsAccuracyTest) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DenseCutsAccuracyTestWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto w = GenerateRandomWeights(num_rows);
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -279,7 +297,7 @@ TEST(hist_util, DenseCutsExternalMemory) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -295,7 +313,7 @@ TEST(hist_util, SparseCutsAccuracyTest) {
|
||||
HistogramCuts cuts;
|
||||
SparseCuts sparse(&cuts);
|
||||
sparse.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -335,7 +353,7 @@ TEST(hist_util, SparseCutsExternalMemory) {
|
||||
HistogramCuts cuts;
|
||||
SparseCuts dense(&cuts);
|
||||
dense.Build(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
212
tests/cpp/common/test_hist_util.cu
Normal file
212
tests/cpp/common/test_hist_util.cu
Normal file
@ -0,0 +1,212 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../data/test_array_interface.h"
|
||||
#include "../../../src/common/math.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "test_hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts builder(&cuts);
|
||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||
builder.Build(&dmat, num_bins);
|
||||
return cuts;
|
||||
}
|
||||
TEST(hist_util, DeviceSketch) {
|
||||
int num_rows = 5;
|
||||
int num_columns = 1;
|
||||
int num_bins = 4;
|
||||
std::vector<float> x = {1.0, 2.0, 3.0, 4.0, 5.0};
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
HistogramCuts host_cuts;
|
||||
DenseCuts builder(&host_cuts);
|
||||
builder.Build(dmat.get(), num_bins);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchDeterminism) {
|
||||
int num_rows = 500;
|
||||
int num_columns = 5;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins);
|
||||
size_t constexpr kRounds{ 100 };
|
||||
for (size_t r = 0; r < kRounds; ++r) {
|
||||
auto new_sketch = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ASSERT_EQ(reference_sketch.Values(), new_sketch.Values());
|
||||
ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchMultipleColumns) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchMultipleColumnsWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchBatches) {
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
int batch_sizes[] = {0, 100, 1500, 6000};
|
||||
int num_columns = 5;
|
||||
for (auto batch_size : batch_sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchMultipleColumnsExternal) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns =5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
dmlc::TemporaryDirectory temp;
|
||||
auto dmat =
|
||||
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 100, temp);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, AdapterDeviceSketch)
|
||||
{
|
||||
int rows = 5;
|
||||
int cols = 1;
|
||||
int num_bins = 4;
|
||||
float missing = - 1.0;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
data = std::vector<float >{ 1.0,2.0,3.0,4.0,5.0 };
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
data::CupyAdapter adapter(str);
|
||||
|
||||
auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing);
|
||||
auto host_cuts = GetHostCuts(&adapter, num_bins, missing);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||
}
|
||||
|
||||
TEST(hist_util, AdapterDeviceSketchCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, n, 1);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, AdapterDeviceSketchMultipleColumns) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST(hist_util, AdapterDeviceSketchBatches) {
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
int batch_sizes[] = {0, 100, 1500, 6000};
|
||||
int num_columns = 5;
|
||||
for (auto batch_size : batch_sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -28,6 +28,34 @@ inline std::vector<float> GenerateRandom(int num_rows, int num_columns) {
|
||||
return x;
|
||||
}
|
||||
|
||||
inline std::vector<float> GenerateRandomWeights(int num_rows) {
|
||||
std::vector<float> w(num_rows);
|
||||
std::mt19937 rng(1);
|
||||
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||
std::generate(w.begin(), w.end(), [&]() { return dist(rng); });
|
||||
return w;
|
||||
}
|
||||
|
||||
#ifdef __CUDACC__
|
||||
inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
|
||||
int num_rows, int num_columns) {
|
||||
Json array_interface{Object()};
|
||||
std::vector<Json> shape = {Json(static_cast<Integer::Int>(num_rows)),
|
||||
Json(static_cast<Integer::Int>(num_columns))};
|
||||
array_interface["shape"] = Array(shape);
|
||||
std::vector<Json> j_data{
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(x.data().get()))),
|
||||
Json(Boolean(false))};
|
||||
array_interface["data"] = j_data;
|
||||
array_interface["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
array_interface["typestr"] = String("<f4");
|
||||
std::stringstream ss;
|
||||
Json::Dump(array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
return data::CupyAdapter(str);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
|
||||
int num_categories) {
|
||||
std::vector<float> x(n);
|
||||
@ -69,21 +97,22 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
|
||||
|
||||
// Test that elements are approximately equally distributed among bins
|
||||
inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
|
||||
const std::vector<float>& column,
|
||||
const std::vector<float>& sorted_column,const std::vector<float >&sorted_weights,
|
||||
int num_bins) {
|
||||
std::map<int, int> counts;
|
||||
for (auto& v : column) {
|
||||
counts[cuts.SearchBin(v, column_idx)]++;
|
||||
std::map<int, int> bin_weights;
|
||||
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
||||
bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i];
|
||||
}
|
||||
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
int expected_num_elements = column.size() / local_num_bins;
|
||||
// Allow about 30% deviation. This test is not very strict, it only ensures
|
||||
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
|
||||
int expected_bin_weight = total_weight / local_num_bins;
|
||||
// Allow up to 30% deviation. This test is not very strict, it only ensures
|
||||
// roughly equal distribution
|
||||
int allowable_error = std::max(2, int(expected_num_elements * 0.3));
|
||||
int allowable_error = std::max(2, int(expected_bin_weight * 0.3));
|
||||
|
||||
// First and last bin can have smaller
|
||||
for (auto& kv : counts) {
|
||||
EXPECT_LE(std::abs(counts[kv.first] - expected_num_elements),
|
||||
for (auto& kv : bin_weights) {
|
||||
EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight),
|
||||
allowable_error );
|
||||
}
|
||||
}
|
||||
@ -91,26 +120,29 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
|
||||
// Test sketch quantiles against the real quantiles
|
||||
// Not a very strict test
|
||||
inline void TestRank(const std::vector<float>& cuts,
|
||||
const std::vector<float>& sorted_x) {
|
||||
float eps = 0.05;
|
||||
const std::vector<float>& sorted_x,
|
||||
const std::vector<float>& sorted_weights) {
|
||||
double eps = 0.05;
|
||||
auto total_weight =
|
||||
std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0);
|
||||
// Ignore the last cut, its special
|
||||
double sum_weight = 0.0;
|
||||
size_t j = 0;
|
||||
for (auto i = 0; i < cuts.size() - 1; i++) {
|
||||
int expected_rank = ((i+1) * sorted_x.size()) / cuts.size();
|
||||
while (cuts[i] > sorted_x[j]) {
|
||||
sum_weight += sorted_weights[j];
|
||||
j++;
|
||||
}
|
||||
int actual_rank = j;
|
||||
int acceptable_error = std::max(2, int(sorted_x.size() * eps));
|
||||
ASSERT_LE(std::abs(expected_rank - actual_rank), acceptable_error);
|
||||
double expected_rank = ((i + 1) * total_weight) / cuts.size();
|
||||
double acceptable_error = std::max(2.0, total_weight * eps);
|
||||
ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
const std::vector<float>& column,
|
||||
const std::vector<float>& sorted_column,
|
||||
const std::vector<float>& sorted_weights,
|
||||
int num_bins) {
|
||||
std::vector<float> sorted_column(column);
|
||||
std::sort(sorted_column.begin(), sorted_column.end());
|
||||
|
||||
// Check the endpoints are correct
|
||||
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
|
||||
@ -126,40 +158,60 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
EXPECT_EQ(std::set<float>(cuts_begin, cuts_end).size(),
|
||||
cuts_end - cuts_begin);
|
||||
|
||||
if (sorted_column.size() <= num_bins) {
|
||||
auto unique = std::set<float>(sorted_column.begin(), sorted_column.end());
|
||||
if (unique.size() <= num_bins) {
|
||||
// Less unique values than number of bins
|
||||
// Each value should get its own bin
|
||||
|
||||
// First check the inputs are unique
|
||||
int num_unique =
|
||||
std::set<float>(sorted_column.begin(), sorted_column.end()).size();
|
||||
EXPECT_EQ(num_unique, sorted_column.size());
|
||||
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
||||
ASSERT_EQ(cuts.SearchBin(sorted_column[i], column_idx),
|
||||
cuts.Ptrs()[column_idx] + i);
|
||||
int i = 0;
|
||||
for (auto v : unique) {
|
||||
ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
std::vector<float> column_cuts(num_cuts_column);
|
||||
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
|
||||
cuts.Values().begin() + cuts.Ptrs()[column_idx + 1],
|
||||
column_cuts.begin());
|
||||
TestBinDistribution(cuts, column_idx, sorted_column, num_bins);
|
||||
TestRank(column_cuts, sorted_column);
|
||||
else {
|
||||
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
std::vector<float> column_cuts(num_cuts_column);
|
||||
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
|
||||
cuts.Values().begin() + cuts.Ptrs()[column_idx + 1],
|
||||
column_cuts.begin());
|
||||
TestBinDistribution(cuts, column_idx, sorted_column,sorted_weights, num_bins);
|
||||
TestRank(column_cuts, sorted_column,sorted_weights);
|
||||
}
|
||||
}
|
||||
|
||||
// x is dense and row major
|
||||
inline void ValidateCuts(const HistogramCuts& cuts, std::vector<float>& x,
|
||||
int num_rows, int num_columns,
|
||||
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
int num_bins) {
|
||||
for (auto i = 0; i < num_columns; i++) {
|
||||
// Extract the column
|
||||
std::vector<float> column(num_rows);
|
||||
for (auto j = 0; j < num_rows; j++) {
|
||||
column[j] = x[j*num_columns + i];
|
||||
}
|
||||
ValidateColumn(cuts,i, column, num_bins);
|
||||
}
|
||||
// Collect data into columns
|
||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
for (auto e : batch[i]) {
|
||||
columns[e.index].push_back(e.fvalue);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort
|
||||
for (auto i = 0ull; i < columns.size(); i++) {
|
||||
auto& col = columns.at(i);
|
||||
const auto& w = dmat->Info().weights_.HostVector();
|
||||
std::vector<size_t > index(col.size());
|
||||
std::iota(index.begin(), index.end(), 0);
|
||||
std::sort(index.begin(), index.end(),[=](size_t a,size_t b)
|
||||
{
|
||||
return col[a] < col[b];
|
||||
});
|
||||
|
||||
std::vector<float> sorted_column(col.size());
|
||||
std::vector<float> sorted_weights(col.size(), 1.0);
|
||||
for (auto j = 0ull; j < col.size(); j++) {
|
||||
sorted_column[j] = col[index[j]];
|
||||
if (w.size() == col.size()) {
|
||||
sorted_weights[j] = w[index[j]];
|
||||
}
|
||||
}
|
||||
|
||||
ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
@ -158,28 +158,6 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
|
||||
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
|
||||
current_row += impl_ext->matrix.n_rows;
|
||||
}
|
||||
|
||||
current_row = 0;
|
||||
thrust::device_vector<bst_float> row_d(kCols);
|
||||
thrust::device_vector<bst_float> row_ext_d(kCols);
|
||||
std::vector<bst_float> row(kCols);
|
||||
std::vector<bst_float> row_ext(kCols);
|
||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
|
||||
auto impl_ext = page.Impl();
|
||||
EXPECT_EQ(impl_ext->matrix.base_rowid, current_row);
|
||||
|
||||
for (size_t i = 0; i < impl_ext->Size(); i++) {
|
||||
dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get()));
|
||||
thrust::copy(row_d.begin(), row_d.end(), row.begin());
|
||||
|
||||
dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get()));
|
||||
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());
|
||||
|
||||
EXPECT_EQ(row, row_ext) << "for row " << current_row;
|
||||
|
||||
current_row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -284,7 +284,6 @@ void TestHistogramIndexImpl() {
|
||||
ASSERT_EQ(maker->page->matrix.info.n_bins, maker_ext->page->matrix.info.n_bins);
|
||||
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
|
||||
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
}
|
||||
|
||||
TEST(GpuHist, TestHistogramIndex) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user