Fix gamma deviance (#6761)

This commit is contained in:
Jiaming Yuan
2021-03-20 01:56:17 +08:00
committed by GitHub
parent c2b6b80600
commit 23b4165a6b
2 changed files with 37 additions and 8 deletions

View File

@@ -103,3 +103,23 @@ class TestEvalMetrics:
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
@pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_deviance(self):
from sklearn.metrics import mean_gamma_deviance
rng = np.random.RandomState(1994)
n_samples = 100
n_features = 30
X = rng.randn(n_samples, n_features)
y = rng.randn(n_samples)
y = y - y.min() * 100
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10)
reg.fit(X, y, eval_metric="gamma-deviance")
booster = reg.get_booster()
score = reg.predict(X)
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
skl_gamma_dev = mean_gamma_deviance(y, score)
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)