ok for now
This commit is contained in:
parent
9d101b47f9
commit
e7ea87b5fd
@ -174,9 +174,9 @@ struct WQSummary {
|
|||||||
for (size_t k = 1; k < n; ++k) {
|
for (size_t k = 1; k < n; ++k) {
|
||||||
RType dx2 = 2 * ((k * range) / n + begin);
|
RType dx2 = 2 * ((k * range) / n + begin);
|
||||||
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
// find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
|
||||||
while (i < src.size - 1 &&
|
while (i < src.size - 1
|
||||||
dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
&& dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
|
||||||
if (i == src.size - 1) break;
|
utils::Assert(i != src.size - 1, "this cannot happen");
|
||||||
if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) {
|
if (dx2 < src.data[i].rmin_next() + src.data[i + 1].rmax_prev()) {
|
||||||
if (i != lastidx) {
|
if (i != lastidx) {
|
||||||
data[size++] = src.data[i]; lastidx = i;
|
data[size++] = src.data[i]; lastidx = i;
|
||||||
|
|||||||
@ -1,40 +1,92 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utils/quantile.h>
|
#include <utils/quantile.h>
|
||||||
|
#include <ctime>
|
||||||
using namespace xgboost;
|
using namespace xgboost;
|
||||||
|
|
||||||
|
|
||||||
|
struct Entry {
|
||||||
|
double x, w, rmin;
|
||||||
|
inline bool operator<(const Entry &e) const {
|
||||||
|
return x < e.x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline void MakeQuantile(std::vector<Entry> &dat) {
|
||||||
|
std::sort(dat.begin(), dat.end());
|
||||||
|
size_t top = 0;
|
||||||
|
double wsum = 0.0;
|
||||||
|
for (size_t i = 0; i < dat.size();) {
|
||||||
|
size_t j = i + 1;
|
||||||
|
for (;j < dat.size() && dat[i].x == dat[j].x; ++j) {
|
||||||
|
dat[i].w += dat[j].w;
|
||||||
|
}
|
||||||
|
dat[top] = dat[i];
|
||||||
|
dat[top].rmin = wsum;
|
||||||
|
wsum += dat[top].w;
|
||||||
|
++top;
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
dat.resize(top);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Summary>
|
||||||
|
inline void verifyWQ(std::vector<Entry> &dat, Summary out) {
|
||||||
|
MakeQuantile(dat);
|
||||||
|
size_t j = 0;
|
||||||
|
double err = 0.0;
|
||||||
|
const double eps = 1e-4;
|
||||||
|
for (size_t i = 0; i < out.size; ++i) {
|
||||||
|
while (j < dat.size() && dat[j].x < out.data[i].value) ++j;
|
||||||
|
utils::Assert(j < dat.size() && fabs(dat[j].x - out.data[i].value) < eps, "bug");
|
||||||
|
err = std::min(dat[j].rmin - out.data[i].rmin, err);
|
||||||
|
err = std::min(out.data[i].rmax - dat[j].rmin + dat[j].w, err);
|
||||||
|
err = std::min(dat[j].w - out.data[i].wmin, err);
|
||||||
|
}
|
||||||
|
if (err < 0.0) err = -err;
|
||||||
|
printf("verify correctness, max-constraint-violation=%g (0 means perfect, coubld be nonzero due to floating point)\n", err);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename Sketch, typename RType>
|
template<typename Sketch, typename RType>
|
||||||
inline void test(void) {
|
inline typename Sketch::SummaryContainer test(std::vector<Entry> &dat) {
|
||||||
Sketch sketch;
|
Sketch sketch;
|
||||||
size_t n;
|
size_t n;
|
||||||
double wsum = 0.0;
|
double wsum = 0.0;
|
||||||
float eps, x, w;
|
float eps;
|
||||||
utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps");
|
utils::Check(scanf("%lu%f", &n, &eps) == 2, "needs to start with n eps");
|
||||||
sketch.Init(n, eps);
|
sketch.Init(n, eps);
|
||||||
while (scanf("%f%f", &x, &w) == 2) {
|
Entry e;
|
||||||
sketch.Push(x, static_cast<RType>(w));
|
while (scanf("%lf%lf", &e.x, &e.w) == 2) {
|
||||||
wsum += w;
|
dat.push_back(e);
|
||||||
|
wsum += e.w;
|
||||||
}
|
}
|
||||||
sketch.CheckValid(static_cast<RType>(0.1));
|
clock_t start = clock();
|
||||||
|
for (size_t i = 0; i < dat.size(); ++i) {
|
||||||
|
sketch.Push(dat[i].x, dat[i].w);
|
||||||
|
}
|
||||||
|
double tcost = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
|
||||||
typename Sketch::SummaryContainer out;
|
typename Sketch::SummaryContainer out;
|
||||||
sketch.GetSummary(&out);
|
sketch.GetSummary(&out);
|
||||||
double maxerr = static_cast<double>(out.MaxError());
|
double maxerr = static_cast<double>(out.MaxError());
|
||||||
out.Print();
|
out.Print();
|
||||||
|
printf("-------------------------\n");
|
||||||
|
printf("timecost=%g sec\n", tcost);
|
||||||
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
printf("MaxError=%g/%g = %g\n", maxerr, wsum, maxerr / wsum);
|
||||||
printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size);
|
printf("maxlevel = %lu, usedlevel=%lu, limit_size=%lu\n", sketch.nlevel, sketch.level.size(), sketch.limit_size);
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
const char *method = "wq";
|
const char *method = "wq";
|
||||||
if (argc > 1) method = argv[1];
|
if (argc > 1) method = argv[1];
|
||||||
|
std::vector<Entry> dat;
|
||||||
if (!strcmp(method, "wq")) {
|
if (!strcmp(method, "wq")) {
|
||||||
test<utils::WQuantileSketch<float, float>, float>();
|
verifyWQ(dat, test<utils::WQuantileSketch<float, float>, float>(dat));
|
||||||
}
|
}
|
||||||
if (!strcmp(method, "wx")) {
|
if (!strcmp(method, "wx")) {
|
||||||
test<utils::WXQuantileSketch<float, float>, float>();
|
verifyWQ(dat, test<utils::WXQuantileSketch<float, float>, float>(dat));
|
||||||
}
|
}
|
||||||
if (!strcmp(method, "gk")) {
|
if (!strcmp(method, "gk")) {
|
||||||
test<utils::GKQuantileSketch<float, unsigned>, unsigned>();
|
test<utils::GKQuantileSketch<float, unsigned>, unsigned>(dat);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user