From 9d101b47f902d36cb4a7ac37d5963246eddc16f2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 10 Nov 2014 21:18:37 -0800 Subject: [PATCH] optimize heavy hitter --- src/utils/quantile.h | 94 ++++++++++++++++++++++++++++++++++++++++-- test/mkquantest.py | 6 ++- test/test_quantile.cpp | 9 ++-- 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/src/utils/quantile.h b/src/utils/quantile.h index 5dd3d9059..5a002c4cb 100644 --- a/src/utils/quantile.h +++ b/src/utils/quantile.h @@ -1,7 +1,7 @@ #ifndef XGBOOST_UTILS_QUANTILE_H_ #define XGBOOST_UTILS_QUANTILE_H_ /*! - * \file quantile + * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen */ @@ -14,7 +14,6 @@ namespace xgboost { namespace utils { - /*! * \brief experimental wsummary * \tparam DType type of data content @@ -172,7 +171,7 @@ struct WQSummary { this->size = 1; // lastidx is used to avoid duplicated records size_t i = 1, lastidx = 0; - for (RType k = 1; k < n; ++k) { + for (size_t k = 1; k < n; ++k) { RType dx2 = 2 * ((k * range) / n + begin); // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2 while (i < src.size - 1 && @@ -246,7 +245,84 @@ struct WQSummary { utils::Assert(size <= sa.size + sb.size, "bug in combine"); } }; - +/*! \brief try to do efficient prunning */ +template +struct WXQSummary : public WQSummary { + // redefine entry type + typedef typename WQSummary::Entry Entry; + // constructor + WXQSummary(Entry *data, size_t size) + : WQSummary(data, size) {} + // check if the block is large chunk + inline static bool CheckLarge(const Entry &e, RType chunk) { + return e.rmin_next() > e.rmax_prev() + chunk; + } + // set prune + inline void SetPrune(const WXQSummary &src, RType maxsize) { + if (src.size <= maxsize) { + this->CopyFrom(src); return; + } + RType begin = src.data[0].rmax; + size_t n = maxsize - 1, nbig = 0; + const RType range = src.data[src.size - 1].rmin - begin; + const RType chunk = 2 * range / n; + // minimized range + RType mrange = 0; + { + // first scan, grab all the big chunk + // moviing block index + size_t bid = 0; + for (size_t i = 1; i < src.size; ++i) { + if (CheckLarge(src.data[i], chunk)) { + if (bid != i - 1) { + mrange += src.data[i].rmax_prev() - src.data[bid].rmin_next(); + } + bid = i; ++nbig; + } + } + if (bid != src.size - 2) { + mrange += src.data[src.size-1].rmax_prev() - src.data[bid].rmin_next(); + } + } + utils::Assert(nbig < n - 1, "too many large chunk"); + this->data[0] = src.data[0]; + this->size = 1; + // use smaller size + n = n - nbig; + // find the rest of point + size_t bid = 0, k = 1, lastidx = 0; + for (size_t end = 1; end < src.size; ++end) { + if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) { + if (bid != end - 1) { + size_t i = bid; + RType maxdx2 = src.data[end].rmax_prev() * 2; + for (; k < n; ++k) { + RType dx2 = 2 * ((k * mrange) / n + begin); + if (dx2 >= maxdx2) break; + while (i < end && + dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; + if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) { + if (i != lastidx) { + this->data[this->size++] = src.data[i]; lastidx = i; + } + } else { + if (i + 1 != lastidx) { + this->data[this->size++] = src.data[i + 1]; lastidx = i + 1; + } + } + } + } + if (lastidx != end) { + this->data[this->size++] = src.data[end]; + lastidx = end; + } + bid = end; + // shift base by the gap + begin += src.data[bid].rmin_next() - src.data[bid].rmax_prev(); + } + } + } +}; /*! * \brief traditional GK summary */ @@ -564,6 +640,16 @@ template class WQuantileSketch : public QuantileSketchTemplate >{ }; + +/*! + * \brief Quantile sketch use WXQSummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +class WXQuantileSketch : + public QuantileSketchTemplate >{ +}; /*! * \brief Quantile sketch use WQSummary * \tparam DType type of data content diff --git a/test/mkquantest.py b/test/mkquantest.py index f228dc1eb..48d837577 100755 --- a/test/mkquantest.py +++ b/test/mkquantest.py @@ -7,7 +7,11 @@ import subprocess funcs = { 'seq': 'lambda n: sorted([(x,1) for x in range(1,n+1)], key = lambda x:random.random())', 'seqlogw': 'lambda n: sorted([(x, math.log(x)) for x in range(1,n+1)], key = lambda x:random.random())', - 'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())' + 'lots0': 'lambda n: sorted([(max(x - n*3/4,0), 1) for x in range(1,n+1)], key = lambda x:random.random())', + 'lots9': 'lambda n: sorted([(9 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())', + 'lotsm': 'lambda n: sorted([(n/8 if x > n / 4 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())', + 'lotsmr': 'lambda n: sorted([( x * 4 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())', + 'lotsmr2': 'lambda n: sorted([( x * 10 / n + n / 20 if x > n / 10 else x, 1) for x in range(1,n+1)], key = lambda x:random.random())' } if len(sys.argv) < 3: diff --git a/test/test_quantile.cpp b/test/test_quantile.cpp index e6af5b1ec..0fed6bf49 100644 --- a/test/test_quantile.cpp +++ b/test/test_quantile.cpp @@ -2,7 +2,6 @@ #include using namespace xgboost; - template inline void test(void) { Sketch sketch; @@ -11,7 +10,6 @@ inline void test(void) { float eps, x, w; utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps"); sketch.Init(n, eps); - printf("nlevel = %lu, limit_size=%lu\n", sketch.nlevel, sketch.limit_size); while (scanf("%f%f", &x, &w) == 2) { sketch.Push(x, static_cast(w)); wsum += w; @@ -20,8 +18,10 @@ inline void test(void) { typename Sketch::SummaryContainer out; sketch.GetSummary(&out); double maxerr = static_cast(out.MaxError()); - printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum); out.Print(); + + printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum); + printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size); } int main(int argc, char *argv[]) { @@ -30,6 +30,9 @@ int main(int argc, char *argv[]) { if (!strcmp(method, "wq")) { test, float>(); } + if (!strcmp(method, "wx")) { + test, float>(); + } if (!strcmp(method, "gk")) { test, unsigned>(); }