[py] fix label encoding of eval sets in sklearn API (#1244)
This commit is contained in:
parent
197b4c6b18
commit
75d9be55de
@ -414,9 +414,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
else:
|
||||
xgb_options.update({"eval_metric": eval_metric})
|
||||
|
||||
self._le = XGBLabelEncoder().fit(y)
|
||||
training_labels = self._le.transform(y)
|
||||
|
||||
if eval_set is not None:
|
||||
# TODO: use sample_weight if given?
|
||||
evals = list(DMatrix(x[0], label=x[1], missing=self.missing) for x in eval_set)
|
||||
evals = list(
|
||||
DMatrix(x[0], label=self._le.transform(x[1]), missing=self.missing)
|
||||
for x in eval_set
|
||||
)
|
||||
nevals = len(evals)
|
||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||
evals = list(zip(evals, eval_names))
|
||||
@ -425,9 +431,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
self._features_count = X.shape[1]
|
||||
|
||||
self._le = XGBLabelEncoder().fit(y)
|
||||
training_labels = self._le.transform(y)
|
||||
|
||||
if sample_weight is not None:
|
||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||
missing=self.missing)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user