Assert matching length of evaluation inputs. (#5540)
This commit is contained in:
parent
c69a19e2b1
commit
93df871c8c
@ -514,6 +514,8 @@ class XGBModel(XGBModelBase):
|
|||||||
raise TypeError('Unexpected input type for `eval_set`')
|
raise TypeError('Unexpected input type for `eval_set`')
|
||||||
if sample_weight_eval_set is None:
|
if sample_weight_eval_set is None:
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
|
else:
|
||||||
|
assert len(eval_set) == len(sample_weight_eval_set)
|
||||||
evals = list(
|
evals = list(
|
||||||
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
|
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
|
||||||
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
|
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
|
||||||
@ -792,6 +794,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
if sample_weight_eval_set is None:
|
if sample_weight_eval_set is None:
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
|
else:
|
||||||
|
assert len(sample_weight_eval_set) == len(eval_set)
|
||||||
evals = list(
|
evals = list(
|
||||||
DMatrix(eval_set[i][0],
|
DMatrix(eval_set[i][0],
|
||||||
label=self._le.transform(eval_set[i][1]),
|
label=self._le.transform(eval_set[i][1]),
|
||||||
|
|||||||
@ -596,6 +596,17 @@ def test_validation_weights_xgbmodel():
|
|||||||
assert all((logloss_with_weights[i] != logloss_without_weights[i]
|
assert all((logloss_with_weights[i] != logloss_without_weights[i]
|
||||||
for i in [0, 1]))
|
for i in [0, 1]))
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# length of eval set and sample weight doesn't match.
|
||||||
|
clf.fit(X_train, y_train, sample_weight=weights_train,
|
||||||
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
|
sample_weight_eval_set=[weights_train])
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
cls.fit(X_train, y_train, sample_weight=weights_train,
|
||||||
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
|
sample_weight_eval_set=[weights_train])
|
||||||
|
|
||||||
def test_validation_weights_xgbclassifier():
|
def test_validation_weights_xgbclassifier():
|
||||||
from sklearn.datasets import make_hastie_10_2
|
from sklearn.datasets import make_hastie_10_2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user