optimize heavy hitter
This commit is contained in:
parent
b426eef527
commit
9d101b47f9
@ -1,7 +1,7 @@
|
|||||||
#ifndef XGBOOST_UTILS_QUANTILE_H_
|
#ifndef XGBOOST_UTILS_QUANTILE_H_
|
||||||
#define XGBOOST_UTILS_QUANTILE_H_
|
#define XGBOOST_UTILS_QUANTILE_H_
|
||||||
/*!
|
/*!
|
||||||
* \file quantile
|
* \file quantile.h
|
||||||
* \brief util to compute quantiles
|
* \brief util to compute quantiles
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief experimental wsummary
|
* \brief experimental wsummary
|
||||||
* \tparam DType type of data content
|
* \tparam DType type of data content
|
||||||
@ -172,7 +171,7 @@ struct WQSummary {
|
|||||||
this->size = 1;
|
this->size = 1;
|
||||||
// lastidx is used to avoid duplicated records
|
// lastidx is used to avoid duplicated records
|
||||||
size_t i = 1, lastidx = 0;
|
size_t i = 1, lastidx = 0;
|
||||||
for (RType k = 1; k < n; ++k) {
|
for (size_t k = 1; k < n; ++k) {
|
||||||
RType dx2 = 2 * ((k * range) / n + begin);
|
RType dx2 = 2 * ((k * range) / n + begin);
|
||||||
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
||||||
while (i < src.size - 1 &&
|
while (i < src.size - 1 &&
|
||||||
@ -246,7 +245,84 @@ struct WQSummary {
|
|||||||
utils::Assert(size <= sa.size + sb.size, "bug in combine");
|
utils::Assert(size <= sa.size + sb.size, "bug in combine");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/*! \brief try to do efficient prunning */
|
||||||
|
template<typename DType, typename RType>
|
||||||
|
struct WXQSummary : public WQSummary<DType, RType> {
|
||||||
|
// redefine entry type
|
||||||
|
typedef typename WQSummary<DType, RType>::Entry Entry;
|
||||||
|
// constructor
|
||||||
|
WXQSummary(Entry *data, size_t size)
|
||||||
|
: WQSummary<DType, RType>(data, size) {}
|
||||||
|
// check if the block is large chunk
|
||||||
|
inline static bool CheckLarge(const Entry &e, RType chunk) {
|
||||||
|
return e.rmin_next() > e.rmax_prev() + chunk;
|
||||||
|
}
|
||||||
|
// set prune
|
||||||
|
inline void SetPrune(const WXQSummary &src, RType maxsize) {
|
||||||
|
if (src.size <= maxsize) {
|
||||||
|
this->CopyFrom(src); return;
|
||||||
|
}
|
||||||
|
RType begin = src.data[0].rmax;
|
||||||
|
size_t n = maxsize - 1, nbig = 0;
|
||||||
|
const RType range = src.data[src.size - 1].rmin - begin;
|
||||||
|
const RType chunk = 2 * range / n;
|
||||||
|
// minimized range
|
||||||
|
RType mrange = 0;
|
||||||
|
{
|
||||||
|
// first scan, grab all the big chunk
|
||||||
|
// moviing block index
|
||||||
|
size_t bid = 0;
|
||||||
|
for (size_t i = 1; i < src.size; ++i) {
|
||||||
|
if (CheckLarge(src.data[i], chunk)) {
|
||||||
|
if (bid != i - 1) {
|
||||||
|
mrange += src.data[i].rmax_prev() - src.data[bid].rmin_next();
|
||||||
|
}
|
||||||
|
bid = i; ++nbig;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bid != src.size - 2) {
|
||||||
|
mrange += src.data[src.size-1].rmax_prev() - src.data[bid].rmin_next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils::Assert(nbig < n - 1, "too many large chunk");
|
||||||
|
this->data[0] = src.data[0];
|
||||||
|
this->size = 1;
|
||||||
|
// use smaller size
|
||||||
|
n = n - nbig;
|
||||||
|
// find the rest of point
|
||||||
|
size_t bid = 0, k = 1, lastidx = 0;
|
||||||
|
for (size_t end = 1; end < src.size; ++end) {
|
||||||
|
if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) {
|
||||||
|
if (bid != end - 1) {
|
||||||
|
size_t i = bid;
|
||||||
|
RType maxdx2 = src.data[end].rmax_prev() * 2;
|
||||||
|
for (; k < n; ++k) {
|
||||||
|
RType dx2 = 2 * ((k * mrange) / n + begin);
|
||||||
|
if (dx2 >= maxdx2) break;
|
||||||
|
while (i < end &&
|
||||||
|
dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
||||||
|
if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) {
|
||||||
|
if (i != lastidx) {
|
||||||
|
this->data[this->size++] = src.data[i]; lastidx = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (i + 1 != lastidx) {
|
||||||
|
this->data[this->size++] = src.data[i + 1]; lastidx = i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (lastidx != end) {
|
||||||
|
this->data[this->size++] = src.data[end];
|
||||||
|
lastidx = end;
|
||||||
|
}
|
||||||
|
bid = end;
|
||||||
|
// shift base by the gap
|
||||||
|
begin += src.data[bid].rmin_next() - src.data[bid].rmax_prev();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief traditional GK summary
|
* \brief traditional GK summary
|
||||||
*/
|
*/
|
||||||
@ -564,6 +640,16 @@ template<typename DType, typename RType=unsigned>
|
|||||||
class WQuantileSketch :
|
class WQuantileSketch :
|
||||||
public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> >{
|
public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> >{
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Quantile sketch use WXQSummary
|
||||||
|
* \tparam DType type of data content
|
||||||
|
* \tparam RType type of rank
|
||||||
|
*/
|
||||||
|
template<typename DType, typename RType=unsigned>
|
||||||
|
class WXQuantileSketch :
|
||||||
|
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> >{
|
||||||
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief Quantile sketch use WQSummary
|
* \brief Quantile sketch use WQSummary
|
||||||
* \tparam DType type of data content
|
* \tparam DType type of data content
|
||||||
|
|||||||
@ -7,7 +7,11 @@ import subprocess
|
|||||||
funcs = {
|
funcs = {
|
||||||
'seq': 'lambda n: sorted([(x,1) for x in range(1,n+1)], key = lambda x:random.random())',
|
'seq': 'lambda n: sorted([(x,1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
'seqlogw': 'lambda n: sorted([(x, math.log(x)) for x in range(1,n+1)], key = lambda x:random.random())',
|
'seqlogw': 'lambda n: sorted([(x, math.log(x)) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())'
|
'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
|
'lots9': 'lambda n: sorted([(9 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
|
'lotsm': 'lambda n: sorted([(n/8 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
|
'lotsmr': 'lambda n: sorted([( x * 4 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())',
|
||||||
|
'lotsmr2': 'lambda n: sorted([( x * 10 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())'
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(sys.argv) < 3:
|
if len(sys.argv) < 3:
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
#include <utils/quantile.h>
|
#include <utils/quantile.h>
|
||||||
using namespace xgboost;
|
using namespace xgboost;
|
||||||
|
|
||||||
|
|
||||||
template<typename Sketch, typename RType>
|
template<typename Sketch, typename RType>
|
||||||
inline void test(void) {
|
inline void test(void) {
|
||||||
Sketch sketch;
|
Sketch sketch;
|
||||||
@ -11,7 +10,6 @@ inline void test(void) {
|
|||||||
float eps, x, w;
|
float eps, x, w;
|
||||||
utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps");
|
utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps");
|
||||||
sketch.Init(n, eps);
|
sketch.Init(n, eps);
|
||||||
printf("nlevel = %lu, limit_size=%lu\n", sketch.nlevel, sketch.limit_size);
|
|
||||||
while (scanf("%f%f", &x, &w) == 2) {
|
while (scanf("%f%f", &x, &w) == 2) {
|
||||||
sketch.Push(x, static_cast<RType>(w));
|
sketch.Push(x, static_cast<RType>(w));
|
||||||
wsum += w;
|
wsum += w;
|
||||||
@ -20,8 +18,10 @@ inline void test(void) {
|
|||||||
typename Sketch::SummaryContainer out;
|
typename Sketch::SummaryContainer out;
|
||||||
sketch.GetSummary(&out);
|
sketch.GetSummary(&out);
|
||||||
double maxerr = static_cast<double>(out.MaxError());
|
double maxerr = static_cast<double>(out.MaxError());
|
||||||
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
|
||||||
out.Print();
|
out.Print();
|
||||||
|
|
||||||
|
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
||||||
|
printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
@ -30,6 +30,9 @@ int main(int argc, char *argv[]) {
|
|||||||
if (!strcmp(method, "wq")) {
|
if (!strcmp(method, "wq")) {
|
||||||
test<utils::WQuantileSketch<float, float>, float>();
|
test<utils::WQuantileSketch<float, float>, float>();
|
||||||
}
|
}
|
||||||
|
if (!strcmp(method, "wx")) {
|
||||||
|
test<utils::WXQuantileSketch<float, float>, float>();
|
||||||
|
}
|
||||||
if (!strcmp(method, "gk")) {
|
if (!strcmp(method, "gk")) {
|
||||||
test<utils::GKQuantileSketch<float, unsigned>, unsigned>();
|
test<utils::GKQuantileSketch<float, unsigned>, unsigned>();
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user