Add sample_weight to eval_metric (#8706)
This commit is contained in:
parent
dd79ab846f
commit
213b5602d9
@ -119,7 +119,11 @@ def _metric_decorator(func: Callable) -> Metric:
|
||||
|
||||
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
|
||||
y_true = dmatrix.get_label()
|
||||
return func.__name__, func(y_true, y_score)
|
||||
weight = dmatrix.get_weight()
|
||||
if weight.size == 0:
|
||||
return func.__name__, func(y_true, y_score)
|
||||
else:
|
||||
return func.__name__, func(y_true, y_score, sample_weight=weight)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@ -1417,3 +1417,41 @@ def test_evaluation_metric():
|
||||
with pytest.raises(AssertionError):
|
||||
# shape check inside the `merror` function
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
|
||||
def test_weighted_evaluation_metric():
|
||||
from sklearn.datasets import make_hastie_10_2
|
||||
from sklearn.metrics import log_loss
|
||||
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
|
||||
labels, y = np.unique(y, return_inverse=True)
|
||||
X_train, X_test = X[:1600], X[1600:]
|
||||
y_train, y_test = y[:1600], y[1600:]
|
||||
weights_eval_set = np.random.choice([1, 2], len(X_test))
|
||||
|
||||
np.random.seed(0)
|
||||
weights_train = np.random.choice([1, 2], len(X_train))
|
||||
|
||||
clf = xgb.XGBClassifier(
|
||||
tree_method="hist",
|
||||
eval_metric=log_loss,
|
||||
n_estimators=16,
|
||||
objective="binary:logistic",
|
||||
)
|
||||
clf.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_test, y_test)],
|
||||
sample_weight_eval_set=[weights_eval_set])
|
||||
custom = clf.evals_result()
|
||||
|
||||
clf = xgb.XGBClassifier(
|
||||
tree_method="hist",
|
||||
eval_metric="logloss",
|
||||
n_estimators=16,
|
||||
objective="binary:logistic"
|
||||
)
|
||||
clf.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_test, y_test)],
|
||||
sample_weight_eval_set=[weights_eval_set])
|
||||
internal = clf.evals_result()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
custom["validation_0"]["log_loss"],
|
||||
internal["validation_0"]["logloss"],
|
||||
atol=1e-6
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user