Assert matching length of evaluation inputs. (#5540)
This commit is contained in:
parent
c69a19e2b1
commit
93df871c8c
@ -514,6 +514,8 @@ class XGBModel(XGBModelBase):
|
||||
raise TypeError('Unexpected input type for `eval_set`')
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
else:
|
||||
assert len(eval_set) == len(sample_weight_eval_set)
|
||||
evals = list(
|
||||
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
|
||||
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
|
||||
@ -792,6 +794,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
if eval_set is not None:
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
else:
|
||||
assert len(sample_weight_eval_set) == len(eval_set)
|
||||
evals = list(
|
||||
DMatrix(eval_set[i][0],
|
||||
label=self._le.transform(eval_set[i][1]),
|
||||
|
||||
@ -596,6 +596,17 @@ def test_validation_weights_xgbmodel():
|
||||
assert all((logloss_with_weights[i] != logloss_without_weights[i]
|
||||
for i in [0, 1]))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# length of eval set and sample weight doesn't match.
|
||||
clf.fit(X_train, y_train, sample_weight=weights_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
sample_weight_eval_set=[weights_train])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
cls = xgb.XGBClassifier()
|
||||
cls.fit(X_train, y_train, sample_weight=weights_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
sample_weight_eval_set=[weights_train])
|
||||
|
||||
def test_validation_weights_xgbclassifier():
|
||||
from sklearn.datasets import make_hastie_10_2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user