optimize heavy hitter
This commit is contained in:
parent
b426eef527
commit
9d101b47f9
@ -1,7 +1,7 @@
|
||||
#ifndef XGBOOST_UTILS_QUANTILE_H_
|
||||
#define XGBOOST_UTILS_QUANTILE_H_
|
||||
/*!
|
||||
* \file quantile
|
||||
* \file quantile.h
|
||||
* \brief util to compute quantiles
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -14,7 +14,6 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace utils {
|
||||
|
||||
/*!
|
||||
* \brief experimental wsummary
|
||||
* \tparam DType type of data content
|
||||
@ -172,7 +171,7 @@ struct WQSummary {
|
||||
this->size = 1;
|
||||
// lastidx is used to avoid duplicated records
|
||||
size_t i = 1, lastidx = 0;
|
||||
for (RType k = 1; k < n; ++k) {
|
||||
for (size_t k = 1; k < n; ++k) {
|
||||
RType dx2 = 2 * ((k * range) / n + begin);
|
||||
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
||||
while (i < src.size - 1 &&
|
||||
@ -246,7 +245,84 @@ struct WQSummary {
|
||||
utils::Assert(size <= sa.size + sb.size, "bug in combine");
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief try to do efficient prunning */
|
||||
template<typename DType, typename RType>
|
||||
struct WXQSummary : public WQSummary<DType, RType> {
|
||||
// redefine entry type
|
||||
typedef typename WQSummary<DType, RType>::Entry Entry;
|
||||
// constructor
|
||||
WXQSummary(Entry *data, size_t size)
|
||||
: WQSummary<DType, RType>(data, size) {}
|
||||
// check if the block is large chunk
|
||||
inline static bool CheckLarge(const Entry &e, RType chunk) {
|
||||
return e.rmin_next() > e.rmax_prev() + chunk;
|
||||
}
|
||||
// set prune
|
||||
inline void SetPrune(const WXQSummary &src, RType maxsize) {
|
||||
if (src.size <= maxsize) {
|
||||
this->CopyFrom(src); return;
|
||||
}
|
||||
RType begin = src.data[0].rmax;
|
||||
size_t n = maxsize - 1, nbig = 0;
|
||||
const RType range = src.data[src.size - 1].rmin - begin;
|
||||
const RType chunk = 2 * range / n;
|
||||
// minimized range
|
||||
RType mrange = 0;
|
||||
{
|
||||
// first scan, grab all the big chunk
|
||||
// moviing block index
|
||||
size_t bid = 0;
|
||||
for (size_t i = 1; i < src.size; ++i) {
|
||||
if (CheckLarge(src.data[i], chunk)) {
|
||||
if (bid != i - 1) {
|
||||
mrange += src.data[i].rmax_prev() - src.data[bid].rmin_next();
|
||||
}
|
||||
bid = i; ++nbig;
|
||||
}
|
||||
}
|
||||
if (bid != src.size - 2) {
|
||||
mrange += src.data[src.size-1].rmax_prev() - src.data[bid].rmin_next();
|
||||
}
|
||||
}
|
||||
utils::Assert(nbig < n - 1, "too many large chunk");
|
||||
this->data[0] = src.data[0];
|
||||
this->size = 1;
|
||||
// use smaller size
|
||||
n = n - nbig;
|
||||
// find the rest of point
|
||||
size_t bid = 0, k = 1, lastidx = 0;
|
||||
for (size_t end = 1; end < src.size; ++end) {
|
||||
if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) {
|
||||
if (bid != end - 1) {
|
||||
size_t i = bid;
|
||||
RType maxdx2 = src.data[end].rmax_prev() * 2;
|
||||
for (; k < n; ++k) {
|
||||
RType dx2 = 2 * ((k * mrange) / n + begin);
|
||||
if (dx2 >= maxdx2) break;
|
||||
while (i < end &&
|
||||
dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
||||
if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) {
|
||||
if (i != lastidx) {
|
||||
this->data[this->size++] = src.data[i]; lastidx = i;
|
||||
}
|
||||
} else {
|
||||
if (i + 1 != lastidx) {
|
||||
this->data[this->size++] = src.data[i + 1]; lastidx = i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lastidx != end) {
|
||||
this->data[this->size++] = src.data[end];
|
||||
lastidx = end;
|
||||
}
|
||||
bid = end;
|
||||
// shift base by the gap
|
||||
begin += src.data[bid].rmin_next() - src.data[bid].rmax_prev();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief traditional GK summary
|
||||
*/
|
||||
@ -564,6 +640,16 @@ template<typename DType, typename RType=unsigned>
|
||||
class WQuantileSketch :
|
||||
public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> >{
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Quantile sketch use WXQSummary
|
||||
* \tparam DType type of data content
|
||||
* \tparam RType type of rank
|
||||
*/
|
||||
template<typename DType, typename RType=unsigned>
|
||||
class WXQuantileSketch :
|
||||
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> >{
|
||||
};
|
||||
/*!
|
||||
* \brief Quantile sketch use WQSummary
|
||||
* \tparam DType type of data content
|
||||
|
||||
@ -7,7 +7,11 @@ import subprocess
|
||||
funcs = {
|
||||
'seq': 'lambda n: sorted([(x,1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'seqlogw': 'lambda n: sorted([(x, math.log(x)) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())'
|
||||
'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'lots9': 'lambda n: sorted([(9 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'lotsm': 'lambda n: sorted([(n/8 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'lotsmr': 'lambda n: sorted([( x * 4 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||
'lotsmr2': 'lambda n: sorted([( x * 10 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())'
|
||||
}
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
#include <utils/quantile.h>
|
||||
using namespace xgboost;
|
||||
|
||||
|
||||
template<typename Sketch, typename RType>
|
||||
inline void test(void) {
|
||||
Sketch sketch;
|
||||
@ -11,7 +10,6 @@ inline void test(void) {
|
||||
float eps, x, w;
|
||||
utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps");
|
||||
sketch.Init(n, eps);
|
||||
printf("nlevel = %lu, limit_size=%lu\n", sketch.nlevel, sketch.limit_size);
|
||||
while (scanf("%f%f", &x, &w) == 2) {
|
||||
sketch.Push(x, static_cast<RType>(w));
|
||||
wsum += w;
|
||||
@ -20,8 +18,10 @@ inline void test(void) {
|
||||
typename Sketch::SummaryContainer out;
|
||||
sketch.GetSummary(&out);
|
||||
double maxerr = static_cast<double>(out.MaxError());
|
||||
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
||||
out.Print();
|
||||
|
||||
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
||||
printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
@ -30,6 +30,9 @@ int main(int argc, char *argv[]) {
|
||||
if (!strcmp(method, "wq")) {
|
||||
test<utils::WQuantileSketch<float, float>, float>();
|
||||
}
|
||||
if (!strcmp(method, "wx")) {
|
||||
test<utils::WXQuantileSketch<float, float>, float>();
|
||||
}
|
||||
if (!strcmp(method, "gk")) {
|
||||
test<utils::GKQuantileSketch<float, unsigned>, unsigned>();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user