From e7ea87b5fd75b96b36111d0f6c2ad20fb549d3c1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 10 Nov 2014 22:03:42 -0800 Subject: [PATCH] ok for now --- src/utils/quantile.h | 6 ++-- test/test_quantile.cpp | 74 +++++++++++++++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/src/utils/quantile.h b/src/utils/quantile.h index 5a002c4cb..62dc36e6c 100644 --- a/src/utils/quantile.h +++ b/src/utils/quantile.h @@ -174,9 +174,9 @@ struct WQSummary { for (size_t k = 1; k < n; ++k) { RType dx2 = 2 * ((k * range) / n + begin); // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2 - while (i < src.size - 1 && - dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; - if (i == src.size - 1) break; + while (i < src.size - 1 + && dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; + utils::Assert(i != src.size - 1, "this cannot happen"); if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) { if (i != lastidx) { data[size++] = src.data[i]; lastidx = i; diff --git a/test/test_quantile.cpp b/test/test_quantile.cpp index 0fed6bf49..c1b85668d 100644 --- a/test/test_quantile.cpp +++ b/test/test_quantile.cpp @@ -1,40 +1,92 @@ #include #include +#include using namespace xgboost; + +struct Entry { + double x, w, rmin; + inline bool operator<(const Entry &e) const { + return x < e.x; + } +}; + +inline void MakeQuantile(std::vector &dat) { + std::sort(dat.begin(), dat.end()); + size_t top = 0; + double wsum = 0.0; + for (size_t i = 0; i < dat.size();) { + size_t j = i + 1; + for (;j < dat.size() && dat[i].x == dat[j].x; ++j) { + dat[i].w += dat[j].w; + } + dat[top] = dat[i]; + dat[top].rmin = wsum; + wsum += dat[top].w; + ++top; + i = j; + } + dat.resize(top); +} + +template +inline void verifyWQ(std::vector &dat, Summary out) { + MakeQuantile(dat); + size_t j = 0; + double err = 0.0; + const double eps = 1e-4; + for (size_t i = 0; i < out.size; ++i) { + while (j < dat.size() && dat[j].x < out.data[i].value) ++j; + utils::Assert(j < dat.size() && fabs(dat[j].x - out.data[i].value) < eps, "bug"); + err = std::min(dat[j].rmin - out.data[i].rmin, err); + err = std::min(out.data[i].rmax - dat[j].rmin + dat[j].w, err); + err = std::min(dat[j].w - out.data[i].wmin, err); + } + if (err < 0.0) err = -err; + printf("verify correctness, max-constraint-violation=%g (0 means perfect, coubld be nonzero due to floating point)\n", err); +} + template -inline void test(void) { +inline typename Sketch::SummaryContainer test(std::vector &dat) { Sketch sketch; size_t n; double wsum = 0.0; - float eps, x, w; + float eps; utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps"); sketch.Init(n, eps); - while (scanf("%f%f", &x, &w) == 2) { - sketch.Push(x, static_cast(w)); - wsum += w; + Entry e; + while (scanf("%lf%lf", &e.x, &e.w) == 2) { + dat.push_back(e); + wsum += e.w; } - sketch.CheckValid(static_cast(0.1)); + clock_t start = clock(); + for (size_t i = 0; i < dat.size(); ++i) { + sketch.Push(dat[i].x, dat[i].w); + } + double tcost = static_cast(clock() - start) / CLOCKS_PER_SEC; typename Sketch::SummaryContainer out; - sketch.GetSummary(&out); + sketch.GetSummary(&out); double maxerr = static_cast(out.MaxError()); out.Print(); - + printf("-------------------------\n"); + printf("timecost=%g sec\n", tcost); printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum); printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size); + return out; } int main(int argc, char *argv[]) { const char *method = "wq"; if (argc > 1) method = argv[1]; + std::vector dat; if (!strcmp(method, "wq")) { - test, float>(); + verifyWQ(dat, test, float>(dat)); } if (!strcmp(method, "wx")) { - test, float>(); + verifyWQ(dat, test, float>(dat)); } if (!strcmp(method, "gk")) { - test, unsigned>(); + test, unsigned>(dat); } return 0; }