Fix gamma deviance (#6761)
This commit is contained in:
@@ -103,3 +103,23 @@ class TestEvalMetrics:
|
||||
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
|
||||
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
|
||||
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_gamma_deviance(self):
|
||||
from sklearn.metrics import mean_gamma_deviance
|
||||
rng = np.random.RandomState(1994)
|
||||
n_samples = 100
|
||||
n_features = 30
|
||||
|
||||
X = rng.randn(n_samples, n_features)
|
||||
y = rng.randn(n_samples)
|
||||
y = y - y.min() * 100
|
||||
|
||||
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10)
|
||||
reg.fit(X, y, eval_metric="gamma-deviance")
|
||||
|
||||
booster = reg.get_booster()
|
||||
score = reg.predict(X)
|
||||
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
|
||||
skl_gamma_dev = mean_gamma_deviance(y, score)
|
||||
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)
|
||||
|
||||
Reference in New Issue
Block a user