cap second order gradient

This commit is contained in:
tqchen 2015-03-25 12:08:53 -07:00
parent 53c9a7b66b
commit 08fb205102
2 changed files with 7 additions and 3 deletions

View File

@ -82,11 +82,13 @@ struct LossType {
* \return second order gradient * \return second order gradient
*/ */
inline float SecondOrderGradient(float predt, float label) const { inline float SecondOrderGradient(float predt, float label) const {
// cap second order gradient to postive value
const float eps = 1e-16f;
switch (loss_type) { switch (loss_type) {
case kLinearSquare: return 1.0f; case kLinearSquare: return 1.0f;
case kLogisticRaw: predt = 1.0f / (1.0f + std::exp(-predt)); case kLogisticRaw: predt = 1.0f / (1.0f + std::exp(-predt));
case kLogisticClassify: case kLogisticClassify:
case kLogisticNeglik: return predt * (1 - predt); case kLogisticNeglik: return std::max(predt * (1.0f - predt), eps);
default: utils::Error("unknown loss_type"); return 0.0f; default: utils::Error("unknown loss_type"); return 0.0f;
} }
} }

View File

@ -296,14 +296,16 @@ struct WXQSummary : public WQSummary<DType, RType> {
} }
RType begin = src.data[0].rmax; RType begin = src.data[0].rmax;
size_t n = maxsize - 1, nbig = 0; size_t n = maxsize - 1, nbig = 0;
const RType range = src.data[src.size - 1].rmin - begin; RType range = src.data[src.size - 1].rmin - begin;
// prune off zero weights // prune off zero weights
if (range == 0) { if (range == 0.0f) {
// special case, contain only two effective data pts // special case, contain only two effective data pts
this->data[0] = src.data[0]; this->data[0] = src.data[0];
this->data[1] = src.data[src.size - 1]; this->data[1] = src.data[src.size - 1];
this->size = 2; this->size = 2;
return; return;
} else {
range = std::max(range, static_cast<RType>(1e-3f));
} }
const RType chunk = 2 * range / n; const RType chunk = 2 * range / n;
// minimized range // minimized range