Fix gamma deviance (#6761)

This commit is contained in:
Jiaming Yuan 2021-03-20 01:56:17 +08:00 committed by GitHub
parent c2b6b80600
commit 23b4165a6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 8 deletions

View File

@ -274,18 +274,27 @@ struct EvalPoissonNegLogLik {
}
};
/**
* Gamma deviance
*
* Expected input:
* label >= 0
* predt >= 0
*/
struct EvalGammaDeviance {
const char *Name() const {
return "gamma-deviance";
const char *Name() const { return "gamma-deviance"; }
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
predt += kRtEps;
label += kRtEps;
return std::log(predt / label) + label / predt - 1;
}
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float epsilon = 1.0e-9;
bst_float tmp = label / (pred + epsilon);
return tmp - std::log(tmp) - 1;
}
static bst_float GetFinal(bst_float esum, bst_float wsum) {
return 2 * esum;
if (wsum <= 0) {
wsum = kRtEps;
}
return 2 * esum / wsum;
}
};

View File

@ -103,3 +103,23 @@ class TestEvalMetrics:
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
@pytest.mark.skipif(**tm.no_sklearn())
def test_gamma_deviance(self):
from sklearn.metrics import mean_gamma_deviance
rng = np.random.RandomState(1994)
n_samples = 100
n_features = 30
X = rng.randn(n_samples, n_features)
y = rng.randn(n_samples)
y = y - y.min() * 100
reg = xgb.XGBRegressor(tree_method="hist", objective="reg:gamma", n_estimators=10)
reg.fit(X, y, eval_metric="gamma-deviance")
booster = reg.get_booster()
score = reg.predict(X)
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
skl_gamma_dev = mean_gamma_deviance(y, score)
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)