Assert matching length of evaluation inputs. (#5540)

This commit is contained in:
Jiaming Yuan
2020-04-18 06:52:55 +08:00
committed by GitHub
parent c69a19e2b1
commit 93df871c8c
2 changed files with 15 additions and 0 deletions

View File

@@ -514,6 +514,8 @@ class XGBModel(XGBModelBase):
raise TypeError('Unexpected input type for `eval_set`')
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
else:
assert len(eval_set) == len(sample_weight_eval_set)
evals = list(
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
@@ -792,6 +794,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
else:
assert len(sample_weight_eval_set) == len(eval_set)
evals = list(
DMatrix(eval_set[i][0],
label=self._le.transform(eval_set[i][1]),