Assert matching length of evaluation inputs. (#5540)

This commit is contained in:
Jiaming Yuan 2020-04-18 06:52:55 +08:00 committed by GitHub
parent c69a19e2b1
commit 93df871c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 0 deletions

View File

@ -514,6 +514,8 @@ class XGBModel(XGBModelBase):
raise TypeError('Unexpected input type for `eval_set`')
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
else:
assert len(eval_set) == len(sample_weight_eval_set)
evals = list(
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
@ -792,6 +794,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
else:
assert len(sample_weight_eval_set) == len(eval_set)
evals = list(
DMatrix(eval_set[i][0],
label=self._le.transform(eval_set[i][1]),

View File

@ -596,6 +596,17 @@ def test_validation_weights_xgbmodel():
assert all((logloss_with_weights[i] != logloss_without_weights[i]
for i in [0, 1]))
with pytest.raises(AssertionError):
# length of eval set and sample weight doesn't match.
clf.fit(X_train, y_train, sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train])
with pytest.raises(AssertionError):
cls = xgb.XGBClassifier()
cls.fit(X_train, y_train, sample_weight=weights_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train])
def test_validation_weights_xgbclassifier():
from sklearn.datasets import make_hastie_10_2