Unify CPU hist sketching (#5880)

This commit is contained in:
Jiaming Yuan 2020-08-12 01:33:06 +08:00 committed by GitHub
parent bd6b7f4aa7
commit ee70a2380b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 648 additions and 677 deletions

View File

@ -70,6 +70,7 @@
#include "../src/common/common.cc" #include "../src/common/common.cc"
#include "../src/common/charconv.cc" #include "../src/common/charconv.cc"
#include "../src/common/timer.cc" #include "../src/common/timer.cc"
#include "../src/common/quantile.cc"
#include "../src/common/host_device_vector.cc" #include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc" #include "../src/common/hist_util.cc"
#include "../src/common/json.cc" #include "../src/common/json.cc"

View File

@ -239,6 +239,21 @@ struct BatchParam {
} }
}; };
struct HostSparsePageView {
using Inst = common::Span<Entry const>;
common::Span<bst_row_t const> offset;
common::Span<Entry const> data;
Inst operator[](size_t i) const {
auto size = *(offset.data() + i + 1) - *(offset.data() + i);
return {data.data() + *(offset.data() + i),
static_cast<Inst::index_type>(size)};
}
size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
};
/*! /*!
* \brief In-memory storage unit of sparse batch, stored in CSR format. * \brief In-memory storage unit of sparse batch, stored in CSR format.
*/ */
@ -270,6 +285,11 @@ class SparsePage {
static_cast<Inst::index_type>(size)}; static_cast<Inst::index_type>(size)};
} }
HostSparsePageView GetView() const {
return {offset.ConstHostSpan(), data.ConstHostSpan()};
}
/*! \brief constructor */ /*! \brief constructor */
SparsePage() { SparsePage() {
this->Clear(); this->Clear();

View File

@ -113,346 +113,12 @@ void GHistIndexMatrix::ResizeIndex(const size_t rbegin, const SparsePage& batch,
} }
HistogramCuts::HistogramCuts() { HistogramCuts::HistogramCuts() {
monitor_.Init(__FUNCTION__);
cut_ptrs_.HostVector().emplace_back(0); cut_ptrs_.HostVector().emplace_back(0);
} }
// Dispatch to specific builder.
void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
auto const& info = dmat->Info();
size_t const total = info.num_row_ * info.num_col_;
size_t const nnz = info.num_nonzero_;
float const sparsity = static_cast<float>(nnz) / static_cast<float>(total);
// Use a small number to avoid calling `dmat->GetColumnBatches'.
float constexpr kSparsityThreshold = 0.0005;
// FIXME(trivialfis): Distributed environment is not supported.
if (sparsity < kSparsityThreshold && (!rabit::IsDistributed())) {
LOG(INFO) << "Building quantile cut on a sparse dataset.";
SparseCuts cuts(this);
cuts.Build(dmat, max_num_bins);
} else {
LOG(INFO) << "Building quantile cut on a dense dataset or distributed environment.";
DenseCuts cuts(this);
cuts.Build(dmat, max_num_bins);
}
LOG(INFO) << "Total number of hist bins: " << cut_ptrs_.HostVector().back();
}
bool CutsBuilder::UseGroup(DMatrix* dmat) {
auto& info = dmat->Info();
return CutsBuilder::UseGroup(info);
}
bool CutsBuilder::UseGroup(MetaInfo const& info) {
size_t const num_groups = info.group_ptr_.size() == 0 ?
0 : info.group_ptr_.size() - 1;
// Use group index for weights?
bool const use_group_ind = num_groups != 0 &&
(info.weights_.Size() != info.num_row_);
return use_group_ind;
}
void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t max_num_bins,
bool const use_group_ind,
uint32_t beg_col, uint32_t end_col,
uint32_t thread_id) {
CHECK_GE(end_col, beg_col);
// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
auto &local_min_vals = p_cuts_->min_vals_.HostVector();
auto &local_cuts = p_cuts_->cut_values_.HostVector();
auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector();
local_min_vals.resize(end_col - beg_col, 0);
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
// Using a local variable makes things easier, but at the cost of memory trashing.
WQSketch sketch;
common::Span<xgboost::Entry const> const column = page[col_id];
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
max_num_bins);
if (n_bins == 0) {
// cut_ptrs_ is initialized with a zero, so there's always an element at the back
CHECK_GE(local_ptrs.size(), 1);
local_ptrs.emplace_back(local_ptrs.back());
continue;
}
sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor));
for (auto const& entry : column) {
uint32_t weight_ind = 0;
if (use_group_ind) {
auto row_idx = entry.index;
uint32_t group_ind =
this->SearchGroupIndFromRow(group_ptr, page.base_rowid + row_idx);
weight_ind = group_ind;
} else {
weight_ind = entry.index;
}
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
}
WQSketch::SummaryContainer out_summary;
sketch.GetSummary(&out_summary);
WQSketch::SummaryContainer summary;
summary.Reserve(n_bins + 1);
summary.SetPrune(out_summary, n_bins + 1);
// Can be use data[1] as the min values so that we don't need to
// store another array?
float mval = summary.data[0].value;
local_min_vals[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
this->AddCutPoint(summary, max_num_bins);
bst_float cpt = (summary.size > 0) ?
summary.data[summary.size - 1].value :
local_min_vals[col_id - beg_col];
cpt += fabs(cpt) + 1e-5;
local_cuts.emplace_back(cpt);
local_ptrs.emplace_back(local_cuts.size());
}
}
std::vector<size_t> SparseCuts::LoadBalance(SparsePage const& page,
size_t const nthreads) {
/* Some sparse datasets have their mass concentrating on small
* number of features. To avoid wating for a few threads running
* forever, we here distirbute different number of columns to
* different threads according to number of entries. */
size_t const total_entries = page.data.Size();
size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads);
std::vector<size_t> cols_ptr(nthreads+1, 0);
size_t count {0};
size_t current_thread {1};
for (size_t col_id = 0; col_id < page.Size(); ++col_id) {
auto const column = page[col_id];
cols_ptr[current_thread]++; // add one column to thread
count += column.size();
if (count > entries_per_thread + 1) {
current_thread++;
count = 0;
cols_ptr[current_thread] = cols_ptr[current_thread-1];
}
}
// Idle threads.
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
cols_ptr[current_thread+1] = cols_ptr[current_thread];
}
return cols_ptr;
}
void SparseCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
monitor_.Start(__FUNCTION__);
// Use group index for weights?
auto use_group = UseGroup(dmat);
uint32_t nthreads = omp_get_max_threads();
CHECK_GT(nthreads, 0);
std::vector<HistogramCuts> cuts_containers(nthreads);
std::vector<std::unique_ptr<SparseCuts>> sparse_cuts(nthreads);
for (size_t i = 0; i < nthreads; ++i) {
sparse_cuts[i].reset(new SparseCuts(&cuts_containers[i]));
}
for (auto const& page : dmat->GetBatches<CSCPage>()) {
CHECK_LE(page.Size(), dmat->Info().num_col_);
monitor_.Start("Load balance");
std::vector<size_t> col_ptr = LoadBalance(page, nthreads);
monitor_.Stop("Load balance");
// We here decouples the logic between build and parallelization
// to simplify things a bit.
#pragma omp parallel for num_threads(nthreads) schedule(static)
for (omp_ulong i = 0; i < nthreads; ++i) {
common::Monitor t_monitor;
t_monitor.Init("SingleThreadBuild: " + std::to_string(i));
t_monitor.Start(std::to_string(i));
sparse_cuts[i]->SingleThreadBuild(page, dmat->Info(), max_num_bins, use_group,
col_ptr[i], col_ptr[i+1], i);
t_monitor.Stop(std::to_string(i));
}
this->Concat(sparse_cuts, dmat->Info().num_col_);
}
monitor_.Stop(__FUNCTION__);
}
void SparseCuts::Concat(
std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols) {
monitor_.Start(__FUNCTION__);
uint32_t nthreads = omp_get_max_threads();
auto &local_min_vals = p_cuts_->min_vals_.HostVector();
auto &local_cuts = p_cuts_->cut_values_.HostVector();
auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector();
local_min_vals.resize(n_cols, std::numeric_limits<float>::max());
size_t min_vals_tail = 0;
for (uint32_t t = 0; t < nthreads; ++t) {
auto& thread_min_vals = cuts[t]->p_cuts_->min_vals_.HostVector();
auto& thread_cuts = cuts[t]->p_cuts_->cut_values_.HostVector();
auto& thread_ptrs = cuts[t]->p_cuts_->cut_ptrs_.HostVector();
// concat csc pointers.
size_t const old_ptr_size = local_ptrs.size();
local_ptrs.resize(
thread_ptrs.size() + local_ptrs.size() - 1);
size_t const new_icp_size = local_ptrs.size();
auto tail = local_ptrs[old_ptr_size-1];
for (size_t j = old_ptr_size; j < new_icp_size; ++j) {
local_ptrs[j] = tail + thread_ptrs[j-old_ptr_size+1];
}
// concat csc values
size_t const old_iv_size = local_cuts.size();
local_cuts.resize(
thread_cuts.size() + local_cuts.size());
size_t const new_iv_size = local_cuts.size();
for (size_t j = old_iv_size; j < new_iv_size; ++j) {
local_cuts[j] = thread_cuts[j-old_iv_size];
}
// merge min values
for (size_t j = 0; j < thread_min_vals.size(); ++j) {
local_min_vals.at(min_vals_tail + j) =
std::min(local_min_vals.at(min_vals_tail + j), thread_min_vals.at(j));
}
min_vals_tail += thread_min_vals.size();
}
monitor_.Stop(__FUNCTION__);
}
void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
monitor_.Start(__FUNCTION__);
const MetaInfo& info = p_fmat->Info();
// safe factor for better accuracy
std::vector<WQSketch> sketchs;
const int nthread = omp_get_max_threads();
unsigned const nstep =
static_cast<unsigned>((info.num_col_ + nthread - 1) / nthread);
unsigned const ncol = static_cast<unsigned>(info.num_col_);
sketchs.resize(info.num_col_);
for (auto& s : sketchs) {
s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor));
}
// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1;
// Use group index for weights?
bool const use_group = UseGroup(p_fmat);
const bool isDense = p_fmat->IsDense();
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
size_t group_ind = 0;
if (use_group) {
group_ind = this->SearchGroupIndFromRow(group_ptr, batch.base_rowid);
}
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group)
{
CHECK_EQ(nthread, omp_get_num_threads());
auto tid = static_cast<unsigned>(omp_get_thread_num());
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);
// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t const ridx = batch.base_rowid + i;
SparsePage::Inst const inst = batch[i];
if (use_group &&
group_ptr[group_ind] == ridx &&
// maximum equals to weights.size() - 1
group_ind < num_groups - 1) {
// move to next group
group_ind++;
}
size_t w_idx = use_group ? group_ind : ridx;
auto w = info.GetWeight(w_idx);
if (isDense) {
auto data = inst.data();
for (size_t ii = begin; ii < end; ii++) {
sketchs[ii].Push(data[ii].fvalue, w);
}
} else {
for (auto const& entry : inst) {
if (entry.index >= begin && entry.index < end) {
sketchs[entry.index].Push(entry.fvalue, w);
}
}
}
}
}
}
}
Init(&sketchs, max_num_bins, info.num_row_);
monitor_.Stop(__FUNCTION__);
}
/**
* \param [in,out] in_sketchs
* \param max_num_bins The maximum number bins.
* \param max_rows Number of rows in this DMatrix.
*/
void DenseCuts::Init
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins, size_t max_rows) {
monitor_.Start(__func__);
std::vector<WQSketch>& sketchs = *in_sketchs;
// Compute how many cuts samples we need at each node
// Do not require more than the number of total rows in training data
// This allows efficient training on wide data
size_t global_max_rows = max_rows;
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
size_t intermediate_num_cuts =
std::min(global_max_rows, static_cast<size_t>(max_num_bins * WQSketch::kFactor));
// gather the histogram data
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
std::vector<WQSketch::SummaryContainer> summary_array;
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
WQSketch::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].Reserve(intermediate_num_cuts);
summary_array[i].SetPrune(out, intermediate_num_cuts);
}
CHECK_EQ(summary_array.size(), in_sketchs->size());
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
// we need to move this allreduce before loadcheckpoint call in future
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
p_cuts_->min_vals_.HostVector().resize(sketchs.size());
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WQSketch::SummaryContainer a;
a.Reserve(max_num_bins + 1);
a.SetPrune(summary_array[fid], max_num_bins + 1);
const bst_float mval = a.data[0].value;
p_cuts_->min_vals_.HostVector()[fid] = mval - (fabs(mval) + 1e-5);
AddCutPoint(a, max_num_bins);
// push a value that is greater than anything
const bst_float cpt
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
p_cuts_->cut_values_.HostVector().push_back(last);
// Ensure that every feature gets at least one quantile point
CHECK_LE(p_cuts_->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts_->cut_values_.HostVector().size());
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.HostVector().back());
p_cuts_->cut_ptrs_.HostVector().push_back(cut_size);
}
monitor_.Stop(__func__);
}
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) {
cut.Build(p_fmat, max_bins); cut = SketchOnDMatrix(p_fmat, max_bins);
max_num_bins = max_bins; max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads(); const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.Ptrs().back(); const uint32_t nbins = cut.Ptrs().back();
@ -1049,11 +715,10 @@ void BuildHistKernel(const std::vector<GradientPair>& gpair,
} }
template <typename GradientSumT> template <typename GradientSumT>
void GHistBuilder<GradientSumT>::BuildHist(const std::vector<GradientPair>& gpair, void GHistBuilder<GradientSumT>::BuildHist(
const RowSetCollection::Elem row_indices, const std::vector<GradientPair> &gpair,
const GHistIndexMatrix& gmat, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRowT hist, GHistRowT hist, bool isDense) {
bool isDense) {
const size_t nrows = row_indices.Size(); const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);

View File

@ -313,7 +313,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
device, num_cuts_per_feature, has_weights); device, num_cuts_per_feature, has_weights);
HistogramCuts cuts; HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(max_bins, dmat->Info().num_col_, SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
dmat->Info().num_row_, device); dmat->Info().num_row_, device);
@ -324,7 +323,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) { if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat); bool is_ranking = HostSketchContainer::UseGroup(dmat->Info());
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(), dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend()); info.group_ptr_.cend());
ProcessWeightedBatch( ProcessWeightedBatch(

View File

@ -306,7 +306,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info, ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature, num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container); sketch_container);
} }
} else { } else {

View File

@ -17,6 +17,7 @@
#include <map> #include <map>
#include "row_set.h" #include "row_set.h"
#include "common.h"
#include "threading_utils.h" #include "threading_utils.h"
#include "../tree/param.h" #include "../tree/param.h"
#include "./quantile.h" #include "./quantile.h"
@ -34,15 +35,8 @@ using GHistIndexRow = Span<uint32_t const>;
// A CSC matrix representing histogram cuts, used in CPU quantile hist. // A CSC matrix representing histogram cuts, used in CPU quantile hist.
// The cut values represent upper bounds of bins containing approximately equal numbers of elements // The cut values represent upper bounds of bins containing approximately equal numbers of elements
class HistogramCuts { class HistogramCuts {
// Using friends to avoid creating a virtual class, since HistogramCuts is used as value
// object in many places.
friend class SparseCuts;
friend class DenseCuts;
friend class CutsBuilder;
protected: protected:
using BinIdx = uint32_t; using BinIdx = uint32_t;
common::Monitor monitor_;
public: public:
HostDeviceVector<bst_float> cut_values_; // NOLINT HostDeviceVector<bst_float> cut_values_; // NOLINT
@ -75,16 +69,12 @@ class HistogramCuts {
} }
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) { HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
monitor_ = std::move(that.monitor_);
cut_ptrs_ = std::move(that.cut_ptrs_); cut_ptrs_ = std::move(that.cut_ptrs_);
cut_values_ = std::move(that.cut_values_); cut_values_ = std::move(that.cut_values_);
min_vals_ = std::move(that.min_vals_); min_vals_ = std::move(that.min_vals_);
return *this; return *this;
} }
/* \brief Build histogram cuts. */
void Build(DMatrix* dmat, uint32_t const max_num_bins);
/* \brief How many bins a feature has. */
uint32_t FeatureBins(uint32_t feature) const { uint32_t FeatureBins(uint32_t feature) const {
return cut_ptrs_.ConstHostVector().at(feature + 1) - return cut_ptrs_.ConstHostVector().at(feature + 1) -
cut_ptrs_.ConstHostVector()[feature]; cut_ptrs_.ConstHostVector()[feature];
@ -118,86 +108,42 @@ class HistogramCuts {
} }
}; };
/* \brief An interface for building quantile cuts. inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) {
* HistogramCuts out;
* `DenseCuts' always assumes there are `max_bins` for each feature, which makes it not auto const& info = m->Info();
* suitable for sparse dataset. On the other hand `SparseCuts' uses `GetColumnBatches', const auto threads = omp_get_max_threads();
* which doubles the memory usage, hence can not be applied to dense dataset. std::vector<std::vector<bst_row_t>> column_sizes(threads);
*/ for (auto& column : column_sizes) {
class CutsBuilder { column.resize(info.num_col_, 0);
public:
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);
static bool UseGroup(MetaInfo const& info);
protected:
HistogramCuts* p_cuts_;
public:
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
virtual ~CutsBuilder() = default;
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
size_t const base_rowid) {
CHECK_LT(base_rowid, group_ptr.back())
<< "Row: " << base_rowid << " is not found in any group.";
auto it =
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
bst_group_t group_ind = it - group_ptr.cbegin() - 1;
return group_ind;
} }
for (auto const& page : m->GetBatches<SparsePage>()) {
void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) { page.data.HostVector();
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin)); page.offset.HostVector();
for (size_t i = 1; i < required_cuts; ++i) { ParallelFor(page.Size(), threads, [&](size_t i) {
bst_float cpt = summary.data[i].value; auto &local_column_sizes = column_sizes.at(omp_get_thread_num());
if (i == 1 || cpt > p_cuts_->cut_values_.ConstHostVector().back()) { auto row = page[i];
p_cuts_->cut_values_.HostVector().push_back(cpt); auto const *p_row = row.data();
for (size_t j = 0; j < row.size(); ++j) {
local_column_sizes.at(p_row[j].index)++;
} }
});
} }
std::vector<bst_row_t> reduced(info.num_col_, 0);
ParallelFor(info.num_col_, threads, [&](size_t i) {
for (auto const &thread : column_sizes) {
reduced[i] += thread[i];
} }
});
/* \brief Build histogram indices. */ HostSketchContainer container(reduced, max_bins,
virtual void Build(DMatrix* dmat, uint32_t const max_num_bins) = 0; HostSketchContainer::UseGroup(info));
}; for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info);
/*! \brief Cut configuration for sparse dataset. */
class SparseCuts : public CutsBuilder {
/* \brief Distribute columns to each thread according to number of entries. */
static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads);
Monitor monitor_;
public:
explicit SparseCuts(HistogramCuts* container) :
CutsBuilder(container) {
monitor_.Init(__FUNCTION__);
} }
container.MakeCuts(&out);
/* \brief Concatonate the built cuts in each thread. */ return out;
void Concat(std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols);
/* \brief Build histogram indices in single thread. */
void SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t max_num_bins,
bool const use_group_ind,
uint32_t beg, uint32_t end, uint32_t thread_id);
void Build(DMatrix* dmat, uint32_t const max_num_bins) override;
};
/*! \brief Cut configuration for dense dataset. */
class DenseCuts : public CutsBuilder {
protected:
Monitor monitor_;
public:
explicit DenseCuts(HistogramCuts* container) :
CutsBuilder(container) {
monitor_.Init(__FUNCTION__);
} }
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins, size_t max_rows);
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
};
enum BinTypeSize { enum BinTypeSize {
kUint8BinsTypeSize = 1, kUint8BinsTypeSize = 1,

193
src/common/quantile.cc Normal file
View File

@ -0,0 +1,193 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <limits>
#include <utility>
#include "quantile.h"
#include "hist_util.h"
namespace xgboost {
namespace common {
HostSketchContainer::HostSketchContainer(std::vector<bst_row_t> columns_size,
int32_t max_bins, bool use_group)
: columns_size_{std::move(columns_size)}, max_bins_{max_bins},
use_group_ind_{use_group} {
monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0);
sketches_.resize(columns_size_.size());
for (size_t i = 0; i < sketches_.size(); ++i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
sketches_[i].Init(columns_size_[i], eps);
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
}
}
std::vector<bst_feature_t> LoadBalance(SparsePage const &page,
std::vector<size_t> columns_size,
size_t const nthreads) {
/* Some sparse datasets have their mass concentrating on small
* number of features. To avoid wating for a few threads running
* forever, we here distirbute different number of columns to
* different threads according to number of entries. */
size_t const total_entries = page.data.Size();
size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads);
std::vector<bst_feature_t> cols_ptr(nthreads+1, 0);
size_t count {0};
size_t current_thread {1};
for (auto col : columns_size) {
cols_ptr[current_thread]++; // add one column to thread
count += col;
if (count > entries_per_thread + 1) {
current_thread++;
count = 0;
cols_ptr[current_thread] = cols_ptr[current_thread-1];
}
}
// Idle threads.
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
cols_ptr[current_thread+1] = cols_ptr[current_thread];
}
return cols_ptr;
}
void HostSketchContainer::PushRowPage(SparsePage const &page,
MetaInfo const &info) {
monitor_.Start(__func__);
int nthread = omp_get_max_threads();
CHECK_EQ(sketches_.size(), info.num_col_);
// Data groups, used in ranking.
std::vector<bst_uint> const &group_ptr = info.group_ptr_;
// Use group index for weights?
auto batch = page.GetView();
dmlc::OMPException exec;
// Parallel over columns. Asumming the data is dense, each thread owns a set of
// consecutive columns.
auto const ncol = static_cast<uint32_t>(info.num_col_);
auto const is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
auto thread_columns_ptr = LoadBalance(page, columns_size_, nthread);
#pragma omp parallel num_threads(nthread)
{
exec.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];
size_t group_ind = 0;
// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) {
size_t const ridx = page.base_rowid + i;
SparsePage::Inst const inst = batch[i];
if (use_group_ind_) {
group_ind = this->SearchGroupIndFromRow(group_ptr, i + page.base_rowid);
}
size_t w_idx = use_group_ind_ ? group_ind : ridx;
auto w = info.GetWeight(w_idx);
auto p_inst = inst.data();
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
sketches_[ii].Push(p_inst[ii].fvalue, w);
}
} else {
for (size_t i = 0; i < inst.size(); ++i) {
auto const& entry = p_inst[i];
if (entry.index >= begin && entry.index < end) {
sketches_[entry.index].Push(entry.fvalue, w);
}
}
}
}
}
});
}
exec.Rethrow();
monitor_.Stop(__func__);
}
void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary,
int max_bin, HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
auto& cut_values = cuts->cut_values_.HostVector();
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) {
cut_values.push_back(cpt);
}
}
}
void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
monitor_.Start(__func__);
rabit::Allreduce<rabit::op::Sum>(columns_size_.data(), columns_size_.size());
std::vector<WQSketch::SummaryContainer> reduced(sketches_.size());
std::vector<int32_t> num_cuts;
size_t nbytes = 0;
for (size_t i = 0; i < sketches_.size(); ++i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(std::min(
columns_size_[i], static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
if (columns_size_[i] != 0) {
WQSketch::SummaryContainer out;
sketches_[i].GetSummary(&out);
reduced[i].Reserve(intermediate_num_cuts);
CHECK(reduced[i].data);
reduced[i].SetPrune(out, intermediate_num_cuts);
}
num_cuts.push_back(intermediate_num_cuts);
nbytes = std::max(
WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts), nbytes);
}
if (rabit::IsDistributed()) {
// FIXME(trivialfis): This call will allocate nbytes * num_columns on rabit, which
// may generate oom error when data is sparse. To fix it, we need to:
// - gather the column offsets over all workers.
// - run rabit::allgather on sketch data to collect all data.
// - merge all gathered sketches based on worker offsets and column offsets of data
// from each worker.
// See GPU implementation for details.
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
sreducer.Allreduce(dmlc::BeginPtr(reduced), nbytes, reduced.size());
}
cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
for (size_t fid = 0; fid < reduced.size(); ++fid) {
WQSketch::SummaryContainer a;
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
a.Reserve(max_num_bins + 1);
CHECK(a.data);
if (columns_size_[fid] != 0) {
a.SetPrune(reduced[fid], max_num_bins + 1);
CHECK(a.data && reduced[fid].data);
const bst_float mval = a.data[0].value;
cuts->min_vals_.HostVector()[fid] = mval - fabs(mval) - 1e-5f;
} else {
// Empty column.
const float mval = 1e-5f;
cuts->min_vals_.HostVector()[fid] = mval;
}
AddCutPoint(a, max_num_bins, cuts);
// push a value that is greater than anything
const bst_float cpt
= (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
cuts->cut_values_.HostVector().push_back(last);
// Ensure that every feature gets at least one quantile point
CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(cuts->cut_values_.HostVector().size());
CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
cuts->cut_ptrs_.HostVector().push_back(cut_size);
}
monitor_.Stop(__func__);
}
} // namespace common
} // namespace xgboost

View File

@ -20,7 +20,7 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
using WQSketch = DenseCuts::WQSketch; using WQSketch = HostSketchContainer::WQSketch;
using SketchEntry = WQSketch::Entry; using SketchEntry = WQSketch::Entry;
// Algorithm 4 in XGBoost's paper, using binary search to find i. // Algorithm 4 in XGBoost's paper, using binary search to find i.

View File

@ -9,12 +9,15 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/data.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include "timer.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! /*!
@ -682,6 +685,57 @@ template<typename DType, typename RType = unsigned>
class WXQuantileSketch : class WXQuantileSketch :
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > { public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
}; };
class HistogramCuts;
/*!
* A sketch matrix storing sketches for each feature.
*/
class HostSketchContainer {
public:
using WQSketch = WQuantileSketch<float, float>;
private:
std::vector<WQSketch> sketches_;
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
Monitor monitor_;
public:
/* \brief Initialize necessary info.
*
* \param columns_size Size of each column.
* \param max_bins maximum number of bins for each feature.
* \param use_group whether is assigned to group to data instance.
*/
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
bool use_group);
static bool UseGroup(MetaInfo const &info) {
size_t const num_groups =
info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1;
// Use group index for weights?
bool const use_group_ind =
num_groups != 0 && (info.weights_.Size() != info.num_row_);
return use_group_ind;
}
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
size_t const base_rowid) {
CHECK_LT(base_rowid, group_ptr.back())
<< "Row: " << base_rowid << " is not found in any group.";
bst_group_t group_ind =
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
group_ptr.cbegin() - 1;
return group_ind;
}
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const& page, MetaInfo const& info);
void MakeCuts(HistogramCuts* cuts);
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_H_ #endif // XGBOOST_COMMON_QUANTILE_H_

View File

@ -6,9 +6,9 @@
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_ #ifndef XGBOOST_COMMON_THREADING_UTILS_H_
#define XGBOOST_COMMON_THREADING_UTILS_H_ #define XGBOOST_COMMON_THREADING_UTILS_H_
#include <dmlc/common.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "xgboost/logging.h" #include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
@ -115,17 +115,32 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
nthreads = std::min(nthreads, omp_get_max_threads()); nthreads = std::min(nthreads, omp_get_max_threads());
nthreads = std::max(nthreads, 1); nthreads = std::max(nthreads, 1);
dmlc::OMPException omp_exc;
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
omp_exc.Run([&]() {
size_t tid = omp_get_thread_num(); size_t tid = omp_get_thread_num();
size_t chunck_size = num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads); size_t chunck_size =
num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads);
size_t begin = chunck_size * tid; size_t begin = chunck_size * tid;
size_t end = std::min(begin + chunck_size, num_blocks_in_space); size_t end = std::min(begin + chunck_size, num_blocks_in_space);
for (auto i = begin; i < end; i++) { for (auto i = begin; i < end; i++) {
func(space.GetFirstDimension(i), space.GetRange(i)); func(space.GetFirstDimension(i), space.GetRange(i));
} }
});
} }
omp_exc.Rethrow();
}
template <typename Func>
void ParallelFor(size_t size, size_t nthreads, Func fn) {
dmlc::OMPException omp_exc;
#pragma omp parallel for num_threads(nthreads)
for (omp_ulong i = 0; i < size; ++i) {
omp_exc.Run(fn, i);
}
omp_exc.Rethrow();
} }
} // namespace common } // namespace common

View File

@ -44,18 +44,16 @@ bst_float PredValue(const SparsePage::Inst &inst,
template <size_t kUnrollLen = 8> template <size_t kUnrollLen = 8>
struct SparsePageView { struct SparsePageView {
SparsePage const* page;
bst_row_t base_rowid; bst_row_t base_rowid;
HostSparsePageView view;
static size_t constexpr kUnroll = kUnrollLen; static size_t constexpr kUnroll = kUnrollLen;
explicit SparsePageView(SparsePage const *p) explicit SparsePageView(SparsePage const *p)
: page{p}, base_rowid{page->base_rowid} { : base_rowid{p->base_rowid} {
// Pull to host before entering omp block, as this is not thread safe. view = p->GetView();
page->data.HostVector();
page->offset.HostVector();
} }
SparsePage::Inst operator[](size_t i) { return (*page)[i]; } SparsePage::Inst operator[](size_t i) { return view[i]; }
size_t Size() const { return page->Size(); } size_t Size() const { return view.Size(); }
}; };
template <typename Adapter, size_t kUnrollLen = 8> template <typename Adapter, size_t kUnrollLen = 8>

View File

@ -158,86 +158,20 @@ TEST(CutsBuilder, SearchGroupInd) {
HistogramCuts hmat; HistogramCuts hmat;
size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0); size_t group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0);
ASSERT_EQ(group_ind, 0); ASSERT_EQ(group_ind, 0);
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
ASSERT_EQ(group_ind, 2); ASSERT_EQ(group_ind, 2);
EXPECT_ANY_THROW(HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
p_mat->Info().Validate(-1); p_mat->Info().Validate(-1);
EXPECT_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17), EXPECT_THROW(HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17),
dmlc::Error); dmlc::Error);
std::vector<bst_uint> group_ptr {0, 1, 2}; std::vector<bst_uint> group_ptr {0, 1, 2};
CHECK_EQ(CutsBuilder::SearchGroupIndFromRow(group_ptr, 1), 1); CHECK_EQ(HostSketchContainer::SearchGroupIndFromRow(group_ptr, 1), 1);
}
TEST(SparseCuts, SingleThreadedBuild) {
size_t constexpr kRows = 267;
size_t constexpr kCols = 31;
size_t constexpr kBins = 256;
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat.get(), kBins);
HistogramCuts cuts;
SparseCuts indices(&cuts);
auto const& page = *(p_fmat->GetBatches<xgboost::CSCPage>().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size());
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
}
TEST(SparseCuts, MultiThreadedBuild) {
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;
size_t constexpr kBins = 255;
omp_ulong ori_nthreads = omp_get_max_threads();
omp_set_num_threads(16);
auto Compare =
#if defined(_MSC_VER) // msvc fails to capture
[kBins](DMatrix* p_fmat) {
#else
[](DMatrix* p_fmat) {
#endif
HistogramCuts threaded_container;
SparseCuts threaded_indices(&threaded_container);
threaded_indices.Build(p_fmat, kBins);
HistogramCuts container;
SparseCuts indices(&container);
auto const& page = *(p_fmat->GetBatches<xgboost::CSCPage>().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size());
ASSERT_EQ(container.Values().size(), threaded_container.Values().size());
for (uint32_t i = 0; i < container.Ptrs().size(); ++i) {
ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]);
}
for (uint32_t i = 0; i < container.Values().size(); ++i) {
ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]);
}
};
{
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
Compare(p_fmat.get());
}
{
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.0001).GenerateDMatrix();
Compare(p_fmat.get());
}
omp_set_num_threads(ori_nthreads);
} }
TEST(HistUtil, DenseCutsCategorical) { TEST(HistUtil, DenseCutsCategorical) {
@ -250,9 +184,7 @@ TEST(HistUtil, DenseCutsCategorical) {
std::vector<float> x_sorted(x); std::vector<float> x_sorted(x);
std::sort(x_sorted.begin(), x_sorted.end()); std::sort(x_sorted.begin(), x_sorted.end());
auto dmat = GetDMatrixFromData(x, n, 1); auto dmat = GetDMatrixFromData(x, n, 1);
HistogramCuts cuts; HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
auto cuts_from_sketch = cuts.Values(); auto cuts_from_sketch = cuts.Values();
EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
@ -264,15 +196,14 @@ TEST(HistUtil, DenseCutsCategorical) {
TEST(HistUtil, DenseCutsAccuracyTest) { TEST(HistUtil, DenseCutsAccuracyTest) {
int bin_sizes[] = {2, 16, 256, 512}; int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500}; int sizes[] = {100};
// omp_set_num_threads(1);
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts; HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -288,9 +219,7 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
auto w = GenerateRandomWeights(num_rows); auto w = GenerateRandomWeights(num_rows);
dmat->Info().weights_.HostVector() = w; dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts; HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -306,65 +235,7 @@ TEST(HistUtil, DenseCutsExternalMemory) {
auto dmat = auto dmat =
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts; HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins);
DenseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(HistUtil, SparseCutsAccuracyTest) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) {
HistogramCuts cuts;
SparseCuts sparse(&cuts);
sparse.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
}
TEST(HistUtil, SparseCutsCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
int sizes[] = {25, 100, 1000};
for (auto n : sizes) {
for (auto num_categories : categorical_sizes) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
std::vector<float> x_sorted(x);
std::sort(x_sorted.begin(), x_sorted.end());
auto dmat = GetDMatrixFromData(x, n, 1);
HistogramCuts cuts;
SparseCuts sparse(&cuts);
sparse.Build(dmat.get(), num_bins);
auto cuts_from_sketch = cuts.Values();
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
EXPECT_GE(cuts_from_sketch.back(), x_sorted.back());
EXPECT_EQ(cuts_from_sketch.size(), num_categories);
}
}
}
TEST(HistUtil, SparseCutsExternalMemory) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
dmlc::TemporaryDirectory tmpdir;
auto dmat =
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
for (auto num_bins : bin_sizes) {
HistogramCuts cuts;
SparseCuts dense(&cuts);
dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -391,25 +262,6 @@ TEST(HistUtil, IndexBinBound) {
} }
} }
TEST(HistUtil, SparseIndexBinBound) {
uint64_t bin_sizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
BinTypeSize expected_bin_type_sizes[] = { kUint32BinsTypeSize,
kUint32BinsTypeSize,
kUint32BinsTypeSize };
size_t constexpr kRows = 100;
size_t constexpr kCols = 10;
size_t bin_id = 0;
for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat.get(), max_bin);
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
}
}
template <typename T> template <typename T>
void CheckIndexData(T* data_ptr, uint32_t* offsets, void CheckIndexData(T* data_ptr, uint32_t* offsets,
const common::GHistIndexMatrix& hmat, size_t n_cols) { const common::GHistIndexMatrix& hmat, size_t n_cols) {
@ -448,25 +300,61 @@ TEST(HistUtil, IndexBinData) {
} }
} }
TEST(HistUtil, SparseIndexBinData) { void TestSketchFromWeights(bool with_group) {
uint64_t bin_sizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1, size_t constexpr kRows = 300, kCols = 20, kBins = 256;
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1, size_t constexpr kGroups = 10;
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 }; auto m =
size_t constexpr kRows = 100; RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
size_t constexpr kCols = 10; common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins);
for (auto max_bin : bin_sizes) { MetaInfo info;
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix(); auto& h_weights = info.weights_.HostVector();
common::GHistIndexMatrix hmat; if (with_group) {
hmat.Init(p_fmat.get(), max_bin); h_weights.resize(kGroups);
EXPECT_EQ(hmat.index.Offset(), nullptr); } else {
h_weights.resize(kRows);
}
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
uint32_t* data_ptr = hmat.index.data<uint32_t>(); std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < hmat.index.Size(); ++i) { if (with_group) {
EXPECT_EQ(data_ptr[i], hmat.index[i]); for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups;
}
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
}
info.num_row_ = kRows;
info.num_col_ = kCols;
// Assign weights.
if (with_group) {
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
}
m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
m->Info().num_col_ = kCols;
m->Info().num_row_ = kRows;
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
ValidateCuts(cuts, m.get(), kBins);
if (with_group) {
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins);
for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
}
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
}
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
} }
} }
} }
TEST(HistUtil, SketchFromWeights) {
TestSketchFromWeights(true);
TestSketchFromWeights(false);
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -24,10 +24,8 @@ namespace common {
template <typename AdapterT> template <typename AdapterT>
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
HistogramCuts cuts;
DenseCuts builder(&cuts);
data::SimpleDMatrix dmat(adapter, missing, 1); data::SimpleDMatrix dmat(adapter, missing, 1);
builder.Build(&dmat, num_bins); HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins);
return cuts; return cuts;
} }
@ -39,9 +37,7 @@ TEST(HistUtil, DeviceSketch) {
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
HistogramCuts host_cuts; HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins);
DenseCuts builder(&host_cuts);
builder.Build(dmat.get(), num_bins);
EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
@ -460,7 +456,11 @@ void TestAdapterSketchFromWeights(bool with_group) {
&storage); &storage);
MetaInfo info; MetaInfo info;
auto& h_weights = info.weights_.HostVector(); auto& h_weights = info.weights_.HostVector();
if (with_group) {
h_weights.resize(kGroups);
} else {
h_weights.resize(kRows); h_weights.resize(kRows);
}
std::fill(h_weights.begin(), h_weights.end(), 1.0f); std::fill(h_weights.begin(), h_weights.end(), 1.0f);
std::vector<bst_group_t> groups(kGroups); std::vector<bst_group_t> groups(kGroups);

View File

@ -0,0 +1,77 @@
#include <gtest/gtest.h>
#include "test_quantile.h"
#include "../../../src/common/quantile.h"
#include "../../../src/common/hist_util.h"
namespace xgboost {
namespace common {
TEST(Quantile, SameOnAllWorkers) {
std::string msg{"Skipping Quantile AllreduceBasic test"};
size_t constexpr kWorkers = 4;
InitRabitContext(msg, kWorkers);
auto world = rabit::GetWorldSize();
if (world != 1) {
CHECK_EQ(world, kWorkers);
} else {
return;
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) {
auto rank = rabit::GetRank();
HostDeviceVector<float> storage;
auto m = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateDMatrix();
auto cuts = SketchOnDMatrix(m.get(), n_bins);
std::vector<float> cut_values(cuts.Values().size() * world, 0);
std::vector<
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
cut_ptrs(cuts.Ptrs().size() * world, 0);
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
size_t value_size = cuts.Values().size();
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
size_t ptr_size = cuts.Ptrs().size();
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
CHECK_EQ(ptr_size, kCols + 1);
size_t min_value_size = cuts.MinValues().size();
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
CHECK_EQ(min_value_size, kCols);
size_t value_offset = value_size * rank;
std::copy(cuts.Values().begin(), cuts.Values().end(),
cut_values.begin() + value_offset);
size_t ptr_offset = ptr_size * rank;
std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(),
cut_ptrs.begin() + ptr_offset);
size_t min_values_offset = min_value_size * rank;
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
cut_min_values.begin() + min_values_offset);
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
for (int32_t i = 0; i < world; i++) {
for (size_t j = 0; j < value_size; ++j) {
size_t idx = i * value_size + j;
ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps);
}
for (size_t j = 0; j < ptr_size; ++j) {
size_t idx = i * ptr_size + j;
ASSERT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx));
}
for (size_t j = 0; j < min_value_size; ++j) {
size_t idx = i * min_value_size + j;
ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx));
}
}
});
}
} // namespace common
} // namespace xgboost

View File

@ -1,4 +1,5 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "test_quantile.h"
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh" #include "../../../src/common/quantile.cuh"
@ -16,32 +17,6 @@ TEST(GPUQuantile, Basic) {
ASSERT_EQ(sketch.Data().size(), 0); ASSERT_EQ(sketch.Data().size(), 0);
} }
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(4);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(3, 1000);
std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); });
std::vector<size_t> bins(8);
for (size_t i = 0; i < bins.size() - 1; ++i) {
bins[i] = i * 35 + 2;
}
bins.back() = rows + 80; // provide a bin number greater than rows.
std::vector<MetaInfo> infos(2);
auto& h_weights = infos.front().weights_.HostVector();
h_weights.resize(rows);
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
for (auto seed : seeds) {
for (auto n_bin : bins) {
for (auto const& info : infos) {
fn(seed, n_bin, info);
}
}
}
}
void TestSketchUnique(float sparsity) { void TestSketchUnique(float sparsity) {
constexpr size_t kRows = 1000, kCols = 100; constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) {
@ -297,31 +272,12 @@ TEST(GPUQuantile, MergeDuplicated) {
} }
} }
void InitRabitContext(std::string msg) {
auto n_gpus = AllVisibleGPUs();
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
port_str = port;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up.";
return;
}
std::vector<std::string> envs{
"DMLC_TRACKER_PORT=" + port_str,
"DMLC_TRACKER_URI=127.0.0.1",
"DMLC_NUM_WORKER=" + std::to_string(n_gpus)};
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
rabit::Init(3, c_envs);
}
TEST(GPUQuantile, AllReduceBasic) { TEST(GPUQuantile, AllReduceBasic) {
// This test is supposed to run by a python test that setups the environment. // This test is supposed to run by a python test that setups the environment.
std::string msg {"Skipping AllReduce test"}; std::string msg {"Skipping AllReduce test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL) #if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto n_gpus = AllVisibleGPUs(); auto n_gpus = AllVisibleGPUs();
InitRabitContext(msg, n_gpus);
auto world = rabit::GetWorldSize(); auto world = rabit::GetWorldSize();
if (world != 1) { if (world != 1) {
ASSERT_EQ(world, n_gpus); ASSERT_EQ(world, n_gpus);
@ -407,9 +363,9 @@ TEST(GPUQuantile, AllReduceBasic) {
TEST(GPUQuantile, SameOnAllWorkers) { TEST(GPUQuantile, SameOnAllWorkers) {
std::string msg {"Skipping SameOnAllWorkers test"}; std::string msg {"Skipping SameOnAllWorkers test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL) #if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto world = rabit::GetWorldSize();
auto n_gpus = AllVisibleGPUs(); auto n_gpus = AllVisibleGPUs();
InitRabitContext(msg, n_gpus);
auto world = rabit::GetWorldSize();
if (world != 1) { if (world != 1) {
ASSERT_EQ(world, n_gpus); ASSERT_EQ(world, n_gpus);
} else { } else {

View File

@ -0,0 +1,54 @@
#include <rabit/rabit.h>
#include <algorithm>
#include <string>
#include <vector>
#include "../helpers.h"
namespace xgboost {
namespace common {
inline void InitRabitContext(std::string msg, size_t n_workers) {
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
port_str = port;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up.";
return;
}
std::vector<std::string> envs{
"DMLC_TRACKER_PORT=" + port_str,
"DMLC_TRACKER_URI=127.0.0.1",
"DMLC_NUM_WORKER=" + std::to_string(n_workers)};
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
rabit::Init(3, c_envs);
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(4);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(3, 1000);
std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); });
std::vector<size_t> bins(8);
for (size_t i = 0; i < bins.size() - 1; ++i) {
bins[i] = i * 35 + 2;
}
bins.back() = rows + 80; // provide a bin number greater than rows.
std::vector<MetaInfo> infos(2);
auto& h_weights = infos.front().weights_.HostVector();
h_weights.resize(rows);
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
for (auto seed : seeds) {
for (auto n_bin : bins) {
for (auto const& info : infos) {
fn(seed, n_bin, info);
}
}
}
}
} // namespace common
} // namespace xgboost

View File

@ -233,12 +233,14 @@ class TestDistributedGPU(unittest.TestCase):
assert ret.returncode == 0, msg assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_basic(self): def test_quantile_basic(self):
self.run_quantile('AllReduceBasic') self.run_quantile('AllReduceBasic')
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_same_on_all_workers(self): def test_quantile_same_on_all_workers(self):

View File

@ -1,11 +1,16 @@
import testing as tm import testing as tm
import pytest import pytest
import unittest
import xgboost as xgb import xgboost as xgb
import sys import sys
import numpy as np import numpy as np
import json import json
import asyncio import asyncio
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
import os
import subprocess
from hypothesis import given, strategies, settings, note
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -14,12 +19,16 @@ pytestmark = pytest.mark.skipif(**tm.no_dask())
try: try:
from distributed import LocalCluster, Client from distributed import LocalCluster, Client
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd import dask.dataframe as dd
import dask.array as da import dask.array as da
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
except ImportError: except ImportError:
LocalCluster = None LocalCluster = None
Client = None Client = None
client = None
loop = None
cluster_fixture = None
dd = None dd = None
da = None da = None
DaskDMatrix = None DaskDMatrix = None
@ -461,3 +470,97 @@ def test_with_asyncio():
asyncio.run(run_dask_regressor_asyncio(address)) asyncio.run(run_dask_regressor_asyncio(address))
asyncio.run(run_dask_classifier_asyncio(address)) asyncio.run(run_dask_classifier_asyncio(address))
class TestWithDask:
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None
m = xgb.dask.DaskDMatrix(
client, data=X, label=y, weight=w)
history = xgb.dask.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
@given(params=hist_parameter_strategy,
num_rounds=strategies.integers(10, 20),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_hist(self, params, num_rounds, dataset, client):
self.run_updater_test(client, params, num_rounds, dataset, 'hist')
@given(params=exact_parameter_strategy,
num_rounds=strategies.integers(10, 20),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_approx(self, client, params, num_rounds, dataset):
self.run_updater_test(client, params, num_rounds, dataset, 'approx')
def run_quantile(self, name):
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost',
'../cpu-build/testxgboost',
'../gpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
if exe is None:
return
test = "--gtest_filter=Quantile." + name
def runit(worker_addr, rabit_args):
port = None
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port = arg.decode('utf-8')
port = port.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = list(xgb.dask._get_client_workers(client).keys())
rabit_args = client.sync(
xgb.dask._get_rabit_args, workers, client)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from Quantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_basic(self):
self.run_quantile('SameOnAllWorkers')