diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index ab68a6ec6..8e63e83ec 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -83,7 +83,15 @@ struct EvalLogLoss : public EvalEWiseBase { return "logloss"; } inline static float EvalRow(float y, float py) { - return - y * std::log(py) - (1.0f - y) * std::log(1 - py); + const float eps = 1e-16f; + const float pneg = 1.0f - py; + if (py < eps) { + return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps); + } else if (pneg < eps) { + return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps); + } else { + return -y * std::log(py) - (1.0f - y) * std::log(pneg); + } } }; diff --git a/src/learner/helper_utils.h b/src/learner/helper_utils.h index ac1ec745e..d318cf8bd 100644 --- a/src/learner/helper_utils.h +++ b/src/learner/helper_utils.h @@ -43,6 +43,26 @@ inline static int FindMaxIndex(const std::vector& rec) { return FindMaxIndex(BeginPtr(rec), rec.size()); } +// perform numerical safe logsum +inline float LogSum(float x, float y) { + if (x < y) { + return y + std::log(std::exp(x - y) + 1.0f); + } else { + return x + std::log(std::exp(y - x) + 1.0f); + } +} +// numerical safe logsum +inline float LogSum(const float *rec, size_t size) { + float mx = rec[0]; + for (size_t i = 1; i < size; ++i) { + mx = std::max(mx, rec[i]); + } + float sum = 0.0f; + for (size_t i = 0; i < size; ++i) { + sum += std::exp(rec[i] - mx); + } + return mx + std::log(sum); +} inline static bool CmpFirst(const std::pair &a, const std::pair &b) { diff --git a/src/utils/quantile.h b/src/utils/quantile.h index fe8589ad8..47ada7210 100644 --- a/src/utils/quantile.h +++ b/src/utils/quantile.h @@ -297,6 +297,14 @@ struct WXQSummary : public WQSummary { RType begin = src.data[0].rmax; size_t n = maxsize - 1, nbig = 0; const RType range = src.data[src.size - 1].rmin - begin; + // prune off zero weights + if (range == 0) { + // special case, contain only two effective data pts + this->data[0] = src.data[0]; + this->data[1] = src.data[src.size - 1]; + this->size = 2; + return; + } const RType chunk = 2 * range / n; // minimized range RType mrange = 0; @@ -323,9 +331,9 @@ struct WXQSummary : public WQSummary { src.size, maxsize, static_cast(range), static_cast(chunk)); for (size_t i = 0; i < src.size; ++i) { - printf("[%lu] rmin=%g, rmax=%g, wmin=%g, isbig=%d\n", i, - src.data[i].rmin, src.data[i].rmax, src.data[i].wmin, - CheckLarge(src.data[i], chunk)); + printf("[%lu] rmin=%g, rmax=%g, wmin=%g, v=%g, isbig=%d\n", i, + src.data[i].rmin, src.data[i].rmax, src.data[i].wmin, + src.data[i].value, CheckLarge(src.data[i], chunk)); } utils::Assert(nbig < n - 1, "quantile: too many large chunk"); } @@ -631,6 +639,7 @@ class QuantileSketchTemplate { * \param x the elemented added to the sketch */ inline void Push(DType x, RType w = 1) { + if (w == static_cast(0)) return; if (inqueue.qtail == inqueue.queue.size()) { // jump from lazy one value to limit_size * 2 if (inqueue.queue.size() == 1) {