Fix gamma deviance (#6761)
This commit is contained in:
parent
c2b6b80600
commit
23b4165a6b
@ -274,18 +274,27 @@ struct EvalPoissonNegLogLik {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Gamma deviance
|
||||
*
|
||||
* Expected input:
|
||||
* label >= 0
|
||||
* predt >= 0
|
||||
*/
|
||||
struct EvalGammaDeviance {
|
||||
const char *Name() const {
|
||||
return "gamma-deviance";
|
||||
const char *Name() const { return "gamma-deviance"; }
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
|
||||
predt += kRtEps;
|
||||
label += kRtEps;
|
||||
return std::log(predt / label) + label / predt - 1;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float epsilon = 1.0e-9;
|
||||
bst_float tmp = label / (pred + epsilon);
|
||||
return tmp - std::log(tmp) - 1;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return 2 * esum;
|
||||
if (wsum <= 0) {
|
||||
wsum = kRtEps;
|
||||
}
|
||||
return 2 * esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -103,3 +103,23 @@ class TestEvalMetrics:
|
||||
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
|
||||
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
|
||||
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_gamma_deviance(self):
|
||||
from sklearn.metrics import mean_gamma_deviance
|
||||
rng = np.random.RandomState(1994)
|
||||
n_samples = 100
|
||||
n_features = 30
|
||||
|
||||
X = rng.randn(n_samples, n_features)
|
||||
y = rng.randn(n_samples)
|
||||
y = y - y.min() * 100
|
||||
|
||||
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10)
|
||||
reg.fit(X, y, eval_metric="gamma-deviance")
|
||||
|
||||
booster = reg.get_booster()
|
||||
score = reg.predict(X)
|
||||
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
|
||||
skl_gamma_dev = mean_gamma_deviance(y, score)
|
||||
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user