- Use std::uint64_t instead of size_t to avoid implementation-defined type. - Rename to bst_idx_t, to account for other types of indexing. - Small cleanup to the base header.
1021 lines
34 KiB
C++
1021 lines
34 KiB
C++
/**
|
|
* Copyright 2014-2024, XGBoost Contributors
|
|
* \file quantile.h
|
|
* \brief util to compute quantiles
|
|
* \author Tianqi Chen
|
|
*/
|
|
#ifndef XGBOOST_COMMON_QUANTILE_H_
|
|
#define XGBOOST_COMMON_QUANTILE_H_
|
|
|
|
#include <xgboost/data.h>
|
|
#include <xgboost/logging.h>
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <set>
|
|
#include <vector>
|
|
|
|
#include "categorical.h"
|
|
#include "common.h"
|
|
#include "error_msg.h" // GroupWeight
|
|
#include "optional_weight.h" // OptionalWeights
|
|
#include "threading_utils.h"
|
|
#include "timer.h"
|
|
|
|
namespace xgboost::common {
|
|
/*!
|
|
* \brief experimental wsummary
|
|
* \tparam DType type of data content
|
|
* \tparam RType type of rank
|
|
*/
|
|
template<typename DType, typename RType>
|
|
struct WQSummary {
|
|
/*! \brief an entry in the sketch summary */
|
|
struct Entry {
|
|
/*! \brief minimum rank */
|
|
RType rmin{};
|
|
/*! \brief maximum rank */
|
|
RType rmax{};
|
|
/*! \brief maximum weight */
|
|
RType wmin{};
|
|
/*! \brief the value of data */
|
|
DType value{};
|
|
// constructor
|
|
XGBOOST_DEVICE Entry() {} // NOLINT
|
|
// constructor
|
|
XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
|
|
: rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
|
|
/*!
|
|
* \brief debug function, check Valid
|
|
* \param eps the tolerate level for violating the relation
|
|
*/
|
|
inline void CheckValid(RType eps = 0) const {
|
|
CHECK(rmin >= 0 && rmax >= 0 && wmin >= 0) << "nonneg constraint";
|
|
CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
|
|
}
|
|
/*! \return rmin estimation for v strictly bigger than value */
|
|
XGBOOST_DEVICE inline RType RMinNext() const {
|
|
return rmin + wmin;
|
|
}
|
|
/*! \return rmax estimation for v strictly smaller than value */
|
|
XGBOOST_DEVICE inline RType RMaxPrev() const {
|
|
return rmax - wmin;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, Entry const& e) {
|
|
os << "rmin: " << e.rmin << ", "
|
|
<< "rmax: " << e.rmax << ", "
|
|
<< "wmin: " << e.wmin << ", "
|
|
<< "value: " << e.value;
|
|
return os;
|
|
}
|
|
};
|
|
/*! \brief input data queue before entering the summary */
|
|
struct Queue {
|
|
// entry in the queue
|
|
struct QEntry {
|
|
// value of the instance
|
|
DType value;
|
|
// weight of instance
|
|
RType weight;
|
|
// default constructor
|
|
QEntry() = default;
|
|
// constructor
|
|
QEntry(DType value, RType weight)
|
|
: value(value), weight(weight) {}
|
|
// comparator on value
|
|
inline bool operator<(const QEntry &b) const {
|
|
return value < b.value;
|
|
}
|
|
};
|
|
// the input queue
|
|
std::vector<QEntry> queue;
|
|
// end of the queue
|
|
size_t qtail;
|
|
// push data to the queue
|
|
inline void Push(DType x, RType w) {
|
|
if (qtail == 0 || queue[qtail - 1].value != x) {
|
|
queue[qtail++] = QEntry(x, w);
|
|
} else {
|
|
queue[qtail - 1].weight += w;
|
|
}
|
|
}
|
|
inline void MakeSummary(WQSummary *out) {
|
|
std::sort(queue.begin(), queue.begin() + qtail);
|
|
out->size = 0;
|
|
// start update sketch
|
|
RType wsum = 0;
|
|
// construct data with unique weights
|
|
for (size_t i = 0; i < qtail;) {
|
|
size_t j = i + 1;
|
|
RType w = queue[i].weight;
|
|
while (j < qtail && queue[j].value == queue[i].value) {
|
|
w += queue[j].weight; ++j;
|
|
}
|
|
out->data[out->size++] = Entry(wsum, wsum + w, w, queue[i].value);
|
|
wsum += w; i = j;
|
|
}
|
|
}
|
|
};
|
|
/*! \brief data field */
|
|
Entry *data;
|
|
/*! \brief number of elements in the summary */
|
|
size_t size;
|
|
// constructor
|
|
WQSummary(Entry *data, size_t size)
|
|
: data(data), size(size) {}
|
|
/*!
|
|
* \return the maximum error of the Summary
|
|
*/
|
|
inline RType MaxError() const {
|
|
RType res = data[0].rmax - data[0].rmin - data[0].wmin;
|
|
for (size_t i = 1; i < size; ++i) {
|
|
res = std::max(data[i].RMaxPrev() - data[i - 1].RMinNext(), res);
|
|
res = std::max(data[i].rmax - data[i].rmin - data[i].wmin, res);
|
|
}
|
|
return res;
|
|
}
|
|
/*!
|
|
* \brief query qvalue, start from istart
|
|
* \param qvalue the value we query for
|
|
* \param istart starting position
|
|
*/
|
|
inline Entry Query(DType qvalue, size_t &istart) const { // NOLINT(*)
|
|
while (istart < size && qvalue > data[istart].value) {
|
|
++istart;
|
|
}
|
|
if (istart == size) {
|
|
RType rmax = data[size - 1].rmax;
|
|
return Entry(rmax, rmax, 0.0f, qvalue);
|
|
}
|
|
if (qvalue == data[istart].value) {
|
|
return data[istart];
|
|
} else {
|
|
if (istart == 0) {
|
|
return Entry(0.0f, 0.0f, 0.0f, qvalue);
|
|
} else {
|
|
return Entry(data[istart - 1].RMinNext(),
|
|
data[istart].RMaxPrev(),
|
|
0.0f, qvalue);
|
|
}
|
|
}
|
|
}
|
|
/*! \return maximum rank in the summary */
|
|
inline RType MaxRank() const {
|
|
return data[size - 1].rmax;
|
|
}
|
|
/*!
|
|
* \brief copy content from src
|
|
* \param src source sketch
|
|
*/
|
|
inline void CopyFrom(const WQSummary &src) {
|
|
if (!src.data) {
|
|
CHECK_EQ(src.size, 0);
|
|
size = 0;
|
|
return;
|
|
}
|
|
if (!data) {
|
|
CHECK_EQ(this->size, 0);
|
|
CHECK_EQ(src.size, 0);
|
|
return;
|
|
}
|
|
size = src.size;
|
|
std::memcpy(data, src.data, sizeof(Entry) * size);
|
|
}
|
|
inline void MakeFromSorted(const Entry* entries, size_t n) {
|
|
size = 0;
|
|
for (size_t i = 0; i < n;) {
|
|
size_t j = i + 1;
|
|
// ignore repeated values
|
|
for (; j < n && entries[j].value == entries[i].value; ++j) {}
|
|
data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
|
|
entries[i].value);
|
|
i = j;
|
|
}
|
|
}
|
|
/*!
|
|
* \brief debug function, validate whether the summary
|
|
* run consistency check to check if it is a valid summary
|
|
* \param eps the tolerate error level, used when RType is floating point and
|
|
* some inconsistency could occur due to rounding error
|
|
*/
|
|
inline void CheckValid(RType eps) const {
|
|
for (size_t i = 0; i < size; ++i) {
|
|
data[i].CheckValid(eps);
|
|
if (i != 0) {
|
|
CHECK(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin) << "rmin range constraint";
|
|
CHECK(data[i].rmax >= data[i - 1].rmax + data[i].wmin) << "rmax range constraint";
|
|
}
|
|
}
|
|
}
|
|
|
|
/*!
|
|
* \brief set current summary to be pruned summary of src
|
|
* assume data field is already allocated to be at least maxsize
|
|
* \param src source summary
|
|
* \param maxsize size we can afford in the pruned sketch
|
|
*/
|
|
void SetPrune(const WQSummary &src, size_t maxsize) {
|
|
if (src.size <= maxsize) {
|
|
this->CopyFrom(src); return;
|
|
}
|
|
const RType begin = src.data[0].rmax;
|
|
const RType range = src.data[src.size - 1].rmin - src.data[0].rmax;
|
|
const size_t n = maxsize - 1;
|
|
data[0] = src.data[0];
|
|
this->size = 1;
|
|
// lastidx is used to avoid duplicated records
|
|
size_t i = 1, lastidx = 0;
|
|
for (size_t k = 1; k < n; ++k) {
|
|
RType dx2 = 2 * ((k * range) / n + begin);
|
|
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
|
while (i < src.size - 1
|
|
&& dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
|
if (i == src.size - 1) break;
|
|
if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
|
|
if (i != lastidx) {
|
|
data[size++] = src.data[i]; lastidx = i;
|
|
}
|
|
} else {
|
|
if (i + 1 != lastidx) {
|
|
data[size++] = src.data[i + 1]; lastidx = i + 1;
|
|
}
|
|
}
|
|
}
|
|
if (lastidx != src.size - 1) {
|
|
data[size++] = src.data[src.size - 1];
|
|
}
|
|
}
|
|
/*!
|
|
* \brief set current summary to be merged summary of sa and sb
|
|
* \param sa first input summary to be merged
|
|
* \param sb second input summary to be merged
|
|
*/
|
|
inline void SetCombine(const WQSummary &sa,
|
|
const WQSummary &sb) {
|
|
if (sa.size == 0) {
|
|
this->CopyFrom(sb); return;
|
|
}
|
|
if (sb.size == 0) {
|
|
this->CopyFrom(sa); return;
|
|
}
|
|
CHECK(sa.size > 0 && sb.size > 0);
|
|
const Entry *a = sa.data, *a_end = sa.data + sa.size;
|
|
const Entry *b = sb.data, *b_end = sb.data + sb.size;
|
|
// extended rmin value
|
|
RType aprev_rmin = 0, bprev_rmin = 0;
|
|
Entry *dst = this->data;
|
|
while (a != a_end && b != b_end) {
|
|
// duplicated value entry
|
|
if (a->value == b->value) {
|
|
*dst = Entry(a->rmin + b->rmin,
|
|
a->rmax + b->rmax,
|
|
a->wmin + b->wmin, a->value);
|
|
aprev_rmin = a->RMinNext();
|
|
bprev_rmin = b->RMinNext();
|
|
++dst; ++a; ++b;
|
|
} else if (a->value < b->value) {
|
|
*dst = Entry(a->rmin + bprev_rmin,
|
|
a->rmax + b->RMaxPrev(),
|
|
a->wmin, a->value);
|
|
aprev_rmin = a->RMinNext();
|
|
++dst; ++a;
|
|
} else {
|
|
*dst = Entry(b->rmin + aprev_rmin,
|
|
b->rmax + a->RMaxPrev(),
|
|
b->wmin, b->value);
|
|
bprev_rmin = b->RMinNext();
|
|
++dst; ++b;
|
|
}
|
|
}
|
|
if (a != a_end) {
|
|
RType brmax = (b_end - 1)->rmax;
|
|
do {
|
|
*dst = Entry(a->rmin + bprev_rmin, a->rmax + brmax, a->wmin, a->value);
|
|
++dst; ++a;
|
|
} while (a != a_end);
|
|
}
|
|
if (b != b_end) {
|
|
RType armax = (a_end - 1)->rmax;
|
|
do {
|
|
*dst = Entry(b->rmin + aprev_rmin, b->rmax + armax, b->wmin, b->value);
|
|
++dst; ++b;
|
|
} while (b != b_end);
|
|
}
|
|
this->size = dst - data;
|
|
const RType tol = 10;
|
|
RType err_mingap, err_maxgap, err_wgap;
|
|
this->FixError(&err_mingap, &err_maxgap, &err_wgap);
|
|
if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) {
|
|
LOG(INFO) << "mingap=" << err_mingap
|
|
<< ", maxgap=" << err_maxgap
|
|
<< ", wgap=" << err_wgap;
|
|
}
|
|
CHECK(size <= sa.size + sb.size) << "bug in combine";
|
|
}
|
|
// helper function to print the current content of sketch
|
|
inline void Print() const {
|
|
for (size_t i = 0; i < this->size; ++i) {
|
|
LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin
|
|
<< ", rmax=" << data[i].rmax
|
|
<< ", wmin=" << data[i].wmin
|
|
<< ", v=" << data[i].value;
|
|
}
|
|
}
|
|
// try to fix rounding error
|
|
// and re-establish invariance
|
|
inline void FixError(RType *err_mingap,
|
|
RType *err_maxgap,
|
|
RType *err_wgap) const {
|
|
*err_mingap = 0;
|
|
*err_maxgap = 0;
|
|
*err_wgap = 0;
|
|
RType prev_rmin = 0, prev_rmax = 0;
|
|
for (size_t i = 0; i < this->size; ++i) {
|
|
if (data[i].rmin < prev_rmin) {
|
|
data[i].rmin = prev_rmin;
|
|
*err_mingap = std::max(*err_mingap, prev_rmin - data[i].rmin);
|
|
} else {
|
|
prev_rmin = data[i].rmin;
|
|
}
|
|
if (data[i].rmax < prev_rmax) {
|
|
data[i].rmax = prev_rmax;
|
|
*err_maxgap = std::max(*err_maxgap, prev_rmax - data[i].rmax);
|
|
}
|
|
RType rmin_next = data[i].RMinNext();
|
|
if (data[i].rmax < rmin_next) {
|
|
data[i].rmax = rmin_next;
|
|
*err_wgap = std::max(*err_wgap, data[i].rmax - rmin_next);
|
|
}
|
|
prev_rmax = data[i].rmax;
|
|
}
|
|
}
|
|
};
|
|
|
|
/*! \brief try to do efficient pruning */
|
|
template<typename DType, typename RType>
|
|
struct WXQSummary : public WQSummary<DType, RType> {
|
|
// redefine entry type
|
|
using Entry = typename WQSummary<DType, RType>::Entry;
|
|
// constructor
|
|
WXQSummary(Entry *data, size_t size)
|
|
: WQSummary<DType, RType>(data, size) {}
|
|
// check if the block is large chunk
|
|
inline static bool CheckLarge(const Entry &e, RType chunk) {
|
|
return e.RMinNext() > e.RMaxPrev() + chunk;
|
|
}
|
|
// set prune
|
|
inline void SetPrune(const WQSummary<DType, RType> &src, size_t maxsize) {
|
|
if (src.size <= maxsize) {
|
|
this->CopyFrom(src); return;
|
|
}
|
|
RType begin = src.data[0].rmax;
|
|
// n is number of points exclude the min/max points
|
|
size_t n = maxsize - 2, nbig = 0;
|
|
// these is the range of data exclude the min/max point
|
|
RType range = src.data[src.size - 1].rmin - begin;
|
|
// prune off zero weights
|
|
if (range == 0.0f || maxsize <= 2) {
|
|
// special case, contain only two effective data pts
|
|
this->data[0] = src.data[0];
|
|
this->data[1] = src.data[src.size - 1];
|
|
this->size = 2;
|
|
return;
|
|
} else {
|
|
range = std::max(range, static_cast<RType>(1e-3f));
|
|
}
|
|
// Get a big enough chunk size, bigger than range / n
|
|
// (multiply by 2 is a safe factor)
|
|
const RType chunk = 2 * range / n;
|
|
// minimized range
|
|
RType mrange = 0;
|
|
{
|
|
// first scan, grab all the big chunk
|
|
// moving block index, exclude the two ends.
|
|
size_t bid = 0;
|
|
for (size_t i = 1; i < src.size - 1; ++i) {
|
|
// detect big chunk data point in the middle
|
|
// always save these data points.
|
|
if (CheckLarge(src.data[i], chunk)) {
|
|
if (bid != i - 1) {
|
|
// accumulate the range of the rest points
|
|
mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext();
|
|
}
|
|
bid = i; ++nbig;
|
|
}
|
|
}
|
|
if (bid != src.size - 2) {
|
|
mrange += src.data[src.size-1].RMaxPrev() - src.data[bid].RMinNext();
|
|
}
|
|
}
|
|
// assert: there cannot be more than n big data points
|
|
if (nbig >= n) {
|
|
// see what was the case
|
|
LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n;
|
|
LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize
|
|
<< ", range=" << range << ", chunk=" << chunk;
|
|
src.Print();
|
|
CHECK(nbig < n) << "quantile: too many large chunk";
|
|
}
|
|
this->data[0] = src.data[0];
|
|
this->size = 1;
|
|
// The counter on the rest of points, to be selected equally from small chunks.
|
|
n = n - nbig;
|
|
// find the rest of point
|
|
size_t bid = 0, k = 1, lastidx = 0;
|
|
for (size_t end = 1; end < src.size; ++end) {
|
|
if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) {
|
|
if (bid != end - 1) {
|
|
size_t i = bid;
|
|
RType maxdx2 = src.data[end].RMaxPrev() * 2;
|
|
for (; k < n; ++k) {
|
|
RType dx2 = 2 * ((k * mrange) / n + begin);
|
|
if (dx2 >= maxdx2) break;
|
|
while (i < end &&
|
|
dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
|
if (i == end) break;
|
|
if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
|
|
if (i != lastidx) {
|
|
this->data[this->size++] = src.data[i]; lastidx = i;
|
|
}
|
|
} else {
|
|
if (i + 1 != lastidx) {
|
|
this->data[this->size++] = src.data[i + 1]; lastidx = i + 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (lastidx != end) {
|
|
this->data[this->size++] = src.data[end];
|
|
lastidx = end;
|
|
}
|
|
bid = end;
|
|
// shift base by the gap
|
|
begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev();
|
|
}
|
|
}
|
|
}
|
|
};
|
|
/*!
|
|
* \brief template for all quantile sketch algorithm
|
|
* that uses merge/prune scheme
|
|
* \tparam DType type of data content
|
|
* \tparam RType type of rank
|
|
* \tparam TSummary actual summary data structure it uses
|
|
*/
|
|
template<typename DType, typename RType, class TSummary>
|
|
class QuantileSketchTemplate {
|
|
public:
|
|
static float constexpr kFactor = 8.0;
|
|
|
|
public:
|
|
/*! \brief type of summary type */
|
|
using Summary = TSummary;
|
|
/*! \brief the entry type */
|
|
using Entry = typename Summary::Entry;
|
|
/*! \brief same as summary, but use STL to backup the space */
|
|
struct SummaryContainer : public Summary {
|
|
std::vector<Entry> space;
|
|
SummaryContainer(const SummaryContainer &src) : Summary(nullptr, src.size) {
|
|
this->space = src.space;
|
|
this->data = dmlc::BeginPtr(this->space);
|
|
}
|
|
SummaryContainer() : Summary(nullptr, 0) {
|
|
}
|
|
/*! \brief reserve space for summary */
|
|
inline void Reserve(size_t size) {
|
|
if (size > space.size()) {
|
|
space.resize(size);
|
|
this->data = dmlc::BeginPtr(space);
|
|
}
|
|
}
|
|
/*!
|
|
* \brief do elementwise combination of summary array
|
|
* this[i] = combine(this[i], src[i]) for each i
|
|
* \param src the source summary
|
|
* \param max_nbyte maximum number of byte allowed in here
|
|
*/
|
|
inline void Reduce(const Summary &src, size_t max_nbyte) {
|
|
this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry));
|
|
SummaryContainer temp;
|
|
temp.Reserve(this->size + src.size);
|
|
temp.SetCombine(*this, src);
|
|
this->SetPrune(temp, space.size());
|
|
}
|
|
/*! \brief return the number of bytes this data structure cost in serialization */
|
|
inline static size_t CalcMemCost(size_t nentry) {
|
|
return sizeof(size_t) + sizeof(Entry) * nentry;
|
|
}
|
|
/*! \brief save the data structure into stream */
|
|
template<typename TStream>
|
|
inline void Save(TStream &fo) const { // NOLINT(*)
|
|
fo.Write(&(this->size), sizeof(this->size));
|
|
if (this->size != 0) {
|
|
fo.Write(this->data, this->size * sizeof(Entry));
|
|
}
|
|
}
|
|
/*! \brief load data structure from input stream */
|
|
template<typename TStream>
|
|
inline void Load(TStream &fi) { // NOLINT(*)
|
|
CHECK_EQ(fi.Read(&this->size, sizeof(this->size)), sizeof(this->size));
|
|
this->Reserve(this->size);
|
|
if (this->size != 0) {
|
|
CHECK_EQ(fi.Read(this->data, this->size * sizeof(Entry)),
|
|
this->size * sizeof(Entry));
|
|
}
|
|
}
|
|
};
|
|
/*!
|
|
* \brief initialize the quantile sketch, given the performance specification
|
|
* \param maxn maximum number of data points can be feed into sketch
|
|
* \param eps accuracy level of summary
|
|
*/
|
|
inline void Init(size_t maxn, double eps) {
|
|
LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
|
|
// lazy reserve the space, if there is only one value, no need to allocate space
|
|
inqueue.queue.resize(1);
|
|
inqueue.qtail = 0;
|
|
data.clear();
|
|
level.clear();
|
|
}
|
|
|
|
inline static void LimitSizeLevel
|
|
(size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) {
|
|
size_t& nlevel = *out_nlevel;
|
|
size_t& limit_size = *out_limit_size;
|
|
nlevel = 1;
|
|
while (true) {
|
|
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
|
limit_size = std::min(maxn, limit_size);
|
|
size_t n = (1ULL << nlevel);
|
|
if (n * limit_size >= maxn) break;
|
|
++nlevel;
|
|
}
|
|
// check invariant
|
|
size_t n = (1ULL << nlevel);
|
|
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
|
CHECK(nlevel <= std::max(static_cast<size_t>(1), static_cast<size_t>(limit_size * eps)))
|
|
<< "invalid init parameter";
|
|
}
|
|
|
|
/*!
|
|
* \brief add an element to a sketch
|
|
* \param x The element added to the sketch
|
|
* \param w The weight of the element.
|
|
*/
|
|
inline void Push(DType x, RType w = 1) {
|
|
if (w == static_cast<RType>(0)) return;
|
|
if (inqueue.qtail == inqueue.queue.size() && inqueue.queue[inqueue.qtail - 1].value != x) {
|
|
// jump from lazy one value to limit_size * 2
|
|
if (inqueue.queue.size() == 1) {
|
|
inqueue.queue.resize(limit_size * 2);
|
|
} else {
|
|
temp.Reserve(limit_size * 2);
|
|
inqueue.MakeSummary(&temp);
|
|
// cleanup queue
|
|
inqueue.qtail = 0;
|
|
this->PushTemp();
|
|
}
|
|
}
|
|
inqueue.Push(x, w);
|
|
}
|
|
|
|
inline void PushSummary(const Summary& summary) {
|
|
temp.Reserve(limit_size * 2);
|
|
temp.SetPrune(summary, limit_size * 2);
|
|
PushTemp();
|
|
}
|
|
|
|
/*! \brief push up temp */
|
|
inline void PushTemp() {
|
|
temp.Reserve(limit_size * 2);
|
|
for (size_t l = 1; true; ++l) {
|
|
this->InitLevel(l + 1);
|
|
// check if level l is empty
|
|
if (level[l].size == 0) {
|
|
level[l].SetPrune(temp, limit_size);
|
|
break;
|
|
} else {
|
|
// level 0 is actually temp space
|
|
level[0].SetPrune(temp, limit_size);
|
|
temp.SetCombine(level[0], level[l]);
|
|
if (temp.size > limit_size) {
|
|
// try next level
|
|
level[l].size = 0;
|
|
} else {
|
|
// if merged record is still smaller, no need to send to next level
|
|
level[l].CopyFrom(temp); break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/*! \brief get the summary after finalize */
|
|
inline void GetSummary(SummaryContainer *out) {
|
|
if (level.size() != 0) {
|
|
out->Reserve(limit_size * 2);
|
|
} else {
|
|
out->Reserve(inqueue.queue.size());
|
|
}
|
|
inqueue.MakeSummary(out);
|
|
if (level.size() != 0) {
|
|
level[0].SetPrune(*out, limit_size);
|
|
for (size_t l = 1; l < level.size(); ++l) {
|
|
if (level[l].size == 0) continue;
|
|
if (level[0].size == 0) {
|
|
level[0].CopyFrom(level[l]);
|
|
} else {
|
|
out->SetCombine(level[0], level[l]);
|
|
level[0].SetPrune(*out, limit_size);
|
|
}
|
|
}
|
|
out->CopyFrom(level[0]);
|
|
} else {
|
|
if (out->size > limit_size) {
|
|
temp.Reserve(limit_size);
|
|
temp.SetPrune(*out, limit_size);
|
|
out->CopyFrom(temp);
|
|
}
|
|
}
|
|
}
|
|
// used for debug, check if the sketch is valid
|
|
inline void CheckValid(RType eps) const {
|
|
for (size_t l = 1; l < level.size(); ++l) {
|
|
level[l].CheckValid(eps);
|
|
}
|
|
}
|
|
// initialize level space to at least nlevel
|
|
inline void InitLevel(size_t nlevel) {
|
|
if (level.size() >= nlevel) return;
|
|
data.resize(limit_size * nlevel);
|
|
level.resize(nlevel, Summary(nullptr, 0));
|
|
for (size_t l = 0; l < level.size(); ++l) {
|
|
level[l].data = dmlc::BeginPtr(data) + l * limit_size;
|
|
}
|
|
}
|
|
// input data queue
|
|
typename Summary::Queue inqueue;
|
|
// number of levels
|
|
size_t nlevel;
|
|
// size of summary in each level
|
|
size_t limit_size;
|
|
// the level of each summaries
|
|
std::vector<Summary> level;
|
|
// content of the summary
|
|
std::vector<Entry> data;
|
|
// temporal summary, used for temp-merge
|
|
SummaryContainer temp;
|
|
};
|
|
|
|
/*!
|
|
* \brief Quantile sketch use WQSummary
|
|
* \tparam DType type of data content
|
|
* \tparam RType type of rank
|
|
*/
|
|
template<typename DType, typename RType = unsigned>
|
|
class WQuantileSketch :
|
|
public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> > {
|
|
};
|
|
|
|
/*!
|
|
* \brief Quantile sketch use WXQSummary
|
|
* \tparam DType type of data content
|
|
* \tparam RType type of rank
|
|
*/
|
|
template<typename DType, typename RType = unsigned>
|
|
class WXQuantileSketch :
|
|
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
|
|
};
|
|
|
|
namespace detail {
|
|
inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
|
std::vector<float> const &group_weights = info.weights_.HostVector();
|
|
if (group_weights.empty()) {
|
|
return group_weights;
|
|
}
|
|
|
|
auto const &group_ptr = info.group_ptr_;
|
|
CHECK_GE(group_ptr.size(), 2);
|
|
|
|
auto n_groups = group_ptr.size() - 1;
|
|
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
|
|
|
|
bst_idx_t n_samples = info.num_row_;
|
|
std::vector<float> results(n_samples);
|
|
CHECK_EQ(group_ptr.back(), n_samples)
|
|
<< error::GroupSize() << " the number of rows from the data.";
|
|
size_t cur_group = 0;
|
|
for (bst_idx_t i = 0; i < n_samples; ++i) {
|
|
results[i] = group_weights[cur_group];
|
|
if (i == group_ptr[cur_group + 1]) {
|
|
cur_group++;
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
} // namespace detail
|
|
|
|
class HistogramCuts;
|
|
|
|
template <typename Batch, typename IsValid>
|
|
std::vector<bst_idx_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
|
|
size_t const n_threads, IsValid &&is_valid) {
|
|
std::vector<std::vector<bst_idx_t>> column_sizes_tloc(n_threads);
|
|
for (auto &column : column_sizes_tloc) {
|
|
column.resize(n_columns, 0);
|
|
}
|
|
|
|
ParallelFor(batch.Size(), n_threads, [&](omp_ulong i) {
|
|
auto &local_column_sizes = column_sizes_tloc.at(omp_get_thread_num());
|
|
auto const &line = batch.GetLine(i);
|
|
for (size_t j = 0; j < line.Size(); ++j) {
|
|
auto elem = line.GetElement(j);
|
|
if (is_valid(elem)) {
|
|
local_column_sizes[elem.column_idx]++;
|
|
}
|
|
}
|
|
});
|
|
// reduce to first thread
|
|
auto &entries_per_columns = column_sizes_tloc.front();
|
|
CHECK_EQ(entries_per_columns.size(), static_cast<size_t>(n_columns));
|
|
for (size_t i = 1; i < n_threads; ++i) {
|
|
CHECK_EQ(column_sizes_tloc[i].size(), static_cast<size_t>(n_columns));
|
|
for (size_t j = 0; j < n_columns; ++j) {
|
|
entries_per_columns[j] += column_sizes_tloc[i][j];
|
|
}
|
|
}
|
|
return entries_per_columns;
|
|
}
|
|
|
|
template <typename Batch, typename IsValid>
|
|
std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
|
|
size_t const nthreads, IsValid&& is_valid) {
|
|
/* Some sparse datasets have their mass concentrating on small number of features. To
|
|
* avoid waiting for a few threads running forever, we here distribute different number
|
|
* of columns to different threads according to number of entries.
|
|
*/
|
|
size_t const total_entries = nnz;
|
|
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
|
|
|
|
// Need to calculate the size for each batch.
|
|
std::vector<bst_idx_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
|
|
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
|
|
size_t count{0};
|
|
size_t current_thread{1};
|
|
|
|
for (auto col : entries_per_columns) {
|
|
cols_ptr.at(current_thread)++; // add one column to thread
|
|
count += col;
|
|
CHECK_LE(count, total_entries);
|
|
if (count > entries_per_thread) {
|
|
current_thread++;
|
|
count = 0;
|
|
cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
|
|
}
|
|
}
|
|
// Idle threads.
|
|
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
|
|
cols_ptr[current_thread + 1] = cols_ptr[current_thread];
|
|
}
|
|
return cols_ptr;
|
|
}
|
|
|
|
/*!
|
|
* A sketch matrix storing sketches for each feature.
|
|
*/
|
|
template <typename WQSketch>
|
|
class SketchContainerImpl {
|
|
protected:
|
|
std::vector<WQSketch> sketches_;
|
|
std::vector<std::set<float>> categories_;
|
|
std::vector<FeatureType> const feature_types_;
|
|
|
|
std::vector<bst_idx_t> columns_size_;
|
|
bst_bin_t max_bins_;
|
|
bool use_group_ind_{false};
|
|
int32_t n_threads_;
|
|
bool has_categorical_{false};
|
|
Monitor monitor_;
|
|
|
|
public:
|
|
/* \brief Initialize necessary info.
|
|
*
|
|
* \param columns_size Size of each column.
|
|
* \param max_bins maximum number of bins for each feature.
|
|
* \param use_group whether is assigned to group to data instance.
|
|
*/
|
|
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins,
|
|
common::Span<FeatureType const> feature_types, bool use_group);
|
|
|
|
static bool UseGroup(MetaInfo const &info) {
|
|
size_t const num_groups =
|
|
info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1;
|
|
// Use group index for weights?
|
|
bool const use_group_ind =
|
|
num_groups != 0 && (info.weights_.Size() != info.num_row_);
|
|
return use_group_ind;
|
|
}
|
|
|
|
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
|
size_t const base_rowid) {
|
|
CHECK_LT(base_rowid, group_ptr.back())
|
|
<< "Row: " << base_rowid << " is not found in any group.";
|
|
bst_group_t group_ind =
|
|
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
|
|
group_ptr.cbegin() - 1;
|
|
return group_ind;
|
|
}
|
|
// Gather sketches from all workers.
|
|
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
|
|
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
|
std::vector<bst_idx_t> *p_worker_segments,
|
|
std::vector<bst_idx_t> *p_sketches_scan,
|
|
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
|
// Merge sketches from all workers.
|
|
void AllReduce(Context const *ctx, MetaInfo const &info,
|
|
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
|
std::vector<int32_t> *p_num_cuts);
|
|
|
|
template <typename Batch, typename IsValid>
|
|
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
|
|
size_t n_features, bool is_dense, IsValid is_valid) {
|
|
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
|
|
|
|
dmlc::OMPException exc;
|
|
#pragma omp parallel num_threads(n_threads_)
|
|
{
|
|
exc.Run([&]() {
|
|
auto tid = static_cast<uint32_t>(omp_get_thread_num());
|
|
auto const begin = thread_columns_ptr[tid];
|
|
auto const end = thread_columns_ptr[tid + 1];
|
|
|
|
// do not iterate if no columns are assigned to the thread
|
|
if (begin < end && end <= n_features) {
|
|
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
|
|
auto const &line = batch.GetLine(ridx);
|
|
auto w = weights[ridx + base_rowid];
|
|
if (is_dense) {
|
|
for (size_t ii = begin; ii < end; ii++) {
|
|
auto elem = line.GetElement(ii);
|
|
if (is_valid(elem)) {
|
|
if (IsCat(feature_types_, ii)) {
|
|
categories_[ii].emplace(elem.value);
|
|
} else {
|
|
sketches_[ii].Push(elem.value, w);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for (size_t i = 0; i < line.Size(); ++i) {
|
|
auto const &elem = line.GetElement(i);
|
|
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
|
|
if (IsCat(feature_types_, elem.column_idx)) {
|
|
categories_[elem.column_idx].emplace(elem.value);
|
|
} else {
|
|
sketches_[elem.column_idx].Push(elem.value, w);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
exc.Rethrow();
|
|
}
|
|
|
|
/* \brief Push a CSR matrix. */
|
|
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
|
|
|
void MakeCuts(Context const *ctx, MetaInfo const &info, HistogramCuts *cuts);
|
|
|
|
private:
|
|
// Merge all categories from other workers.
|
|
void AllreduceCategories(Context const* ctx, MetaInfo const& info);
|
|
};
|
|
|
|
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
|
public:
|
|
using WQSketch = WQuantileSketch<float, float>;
|
|
|
|
public:
|
|
HostSketchContainer(Context const *ctx, bst_bin_t max_bins, common::Span<FeatureType const> ft,
|
|
std::vector<bst_idx_t> columns_size, bool use_group);
|
|
|
|
template <typename Batch>
|
|
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
|
};
|
|
|
|
/**
|
|
* \brief Quantile structure accepts sorted data, extracted from histmaker.
|
|
*/
|
|
struct SortedQuantile {
|
|
/*! \brief total sum of amount to be met */
|
|
double sum_total{0.0};
|
|
/*! \brief statistics used in the sketch */
|
|
double rmin, wmin;
|
|
/*! \brief last seen feature value */
|
|
bst_float last_fvalue;
|
|
/*! \brief current size of sketch */
|
|
double next_goal;
|
|
// pointer to the sketch to put things in
|
|
common::WXQuantileSketch<bst_float, bst_float>* sketch;
|
|
// initialize the space
|
|
inline void Init(unsigned max_size) {
|
|
next_goal = -1.0f;
|
|
rmin = wmin = 0.0f;
|
|
sketch->temp.Reserve(max_size + 1);
|
|
sketch->temp.size = 0;
|
|
}
|
|
/*!
|
|
* \brief push a new element to sketch
|
|
* \param fvalue feature value, comes in sorted ascending order
|
|
* \param w weight
|
|
* \param max_size
|
|
*/
|
|
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
|
if (next_goal == -1.0f) {
|
|
next_goal = 0.0f;
|
|
last_fvalue = fvalue;
|
|
wmin = w;
|
|
return;
|
|
}
|
|
if (last_fvalue != fvalue) {
|
|
double rmax = rmin + wmin;
|
|
if (rmax >= next_goal && sketch->temp.size != max_size) {
|
|
if (sketch->temp.size == 0 ||
|
|
last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
|
// push to sketch
|
|
sketch->temp.data[sketch->temp.size] =
|
|
common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
|
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
|
|
static_cast<bst_float>(wmin), last_fvalue);
|
|
CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
|
|
<< ", stemp.size" << sketch->temp.size;
|
|
++sketch->temp.size;
|
|
}
|
|
if (sketch->temp.size == max_size) {
|
|
next_goal = sum_total * 2.0f + 1e-5f;
|
|
} else {
|
|
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
|
|
}
|
|
} else {
|
|
if (rmax >= next_goal) {
|
|
LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
|
|
<< ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
|
|
}
|
|
}
|
|
rmin = rmax;
|
|
wmin = w;
|
|
last_fvalue = fvalue;
|
|
} else {
|
|
wmin += w;
|
|
}
|
|
}
|
|
|
|
/*! \brief push final unfinished value to the sketch */
|
|
inline void Finalize(unsigned max_size) {
|
|
double rmax = rmin + wmin;
|
|
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
|
CHECK_LE(sketch->temp.size, max_size)
|
|
<< "Finalize: invalid maximum size, max_size=" << max_size
|
|
<< ", stemp.size=" << sketch->temp.size;
|
|
// push to sketch
|
|
sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
|
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
|
|
last_fvalue);
|
|
++sketch->temp.size;
|
|
}
|
|
sketch->PushTemp();
|
|
}
|
|
};
|
|
|
|
class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
|
|
std::vector<SortedQuantile> sketches_;
|
|
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
|
|
|
public:
|
|
explicit SortedSketchContainer(Context const *ctx, int32_t max_bins,
|
|
common::Span<FeatureType const> ft,
|
|
std::vector<bst_idx_t> columns_size, bool use_group)
|
|
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
|
monitor_.Init(__func__);
|
|
sketches_.resize(columns_size.size());
|
|
size_t i = 0;
|
|
for (auto &sketch : sketches_) {
|
|
sketch.sketch = &Super::sketches_[i];
|
|
sketch.Init(max_bins_);
|
|
auto eps = 2.0 / max_bins;
|
|
sketch.sketch->Init(columns_size_[i], eps);
|
|
++i;
|
|
}
|
|
}
|
|
/**
|
|
* \brief Push a sorted CSC page.
|
|
*/
|
|
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
|
|
};
|
|
} // namespace xgboost::common
|
|
#endif // XGBOOST_COMMON_QUANTILE_H_
|