cap second order gradient
This commit is contained in:
parent
53c9a7b66b
commit
08fb205102
@ -82,11 +82,13 @@ struct LossType {
|
|||||||
* \return second order gradient
|
* \return second order gradient
|
||||||
*/
|
*/
|
||||||
inline float SecondOrderGradient(float predt, float label) const {
|
inline float SecondOrderGradient(float predt, float label) const {
|
||||||
|
// cap second order gradient to postive value
|
||||||
|
const float eps = 1e-16f;
|
||||||
switch (loss_type) {
|
switch (loss_type) {
|
||||||
case kLinearSquare: return 1.0f;
|
case kLinearSquare: return 1.0f;
|
||||||
case kLogisticRaw: predt = 1.0f / (1.0f + std::exp(-predt));
|
case kLogisticRaw: predt = 1.0f / (1.0f + std::exp(-predt));
|
||||||
case kLogisticClassify:
|
case kLogisticClassify:
|
||||||
case kLogisticNeglik: return predt * (1 - predt);
|
case kLogisticNeglik: return std::max(predt * (1.0f - predt), eps);
|
||||||
default: utils::Error("unknown loss_type"); return 0.0f;
|
default: utils::Error("unknown loss_type"); return 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -296,14 +296,16 @@ struct WXQSummary : public WQSummary<DType, RType> {
|
|||||||
}
|
}
|
||||||
RType begin = src.data[0].rmax;
|
RType begin = src.data[0].rmax;
|
||||||
size_t n = maxsize - 1, nbig = 0;
|
size_t n = maxsize - 1, nbig = 0;
|
||||||
const RType range = src.data[src.size - 1].rmin - begin;
|
RType range = src.data[src.size - 1].rmin - begin;
|
||||||
// prune off zero weights
|
// prune off zero weights
|
||||||
if (range == 0) {
|
if (range == 0.0f) {
|
||||||
// special case, contain only two effective data pts
|
// special case, contain only two effective data pts
|
||||||
this->data[0] = src.data[0];
|
this->data[0] = src.data[0];
|
||||||
this->data[1] = src.data[src.size - 1];
|
this->data[1] = src.data[src.size - 1];
|
||||||
this->size = 2;
|
this->size = 2;
|
||||||
return;
|
return;
|
||||||
|
} else {
|
||||||
|
range = std::max(range, static_cast<RType>(1e-3f));
|
||||||
}
|
}
|
||||||
const RType chunk = 2 * range / n;
|
const RType chunk = 2 * range / n;
|
||||||
// minimized range
|
// minimized range
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user