[py] fix label encoding of eval sets in sklearn API (#1244)
This commit is contained in:
parent
197b4c6b18
commit
75d9be55de
@ -414,9 +414,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
xgb_options.update({"eval_metric": eval_metric})
|
xgb_options.update({"eval_metric": eval_metric})
|
||||||
|
|
||||||
|
self._le = XGBLabelEncoder().fit(y)
|
||||||
|
training_labels = self._le.transform(y)
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
# TODO: use sample_weight if given?
|
# TODO: use sample_weight if given?
|
||||||
evals = list(DMatrix(x[0], label=x[1], missing=self.missing) for x in eval_set)
|
evals = list(
|
||||||
|
DMatrix(x[0], label=self._le.transform(x[1]), missing=self.missing)
|
||||||
|
for x in eval_set
|
||||||
|
)
|
||||||
nevals = len(evals)
|
nevals = len(evals)
|
||||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||||
evals = list(zip(evals, eval_names))
|
evals = list(zip(evals, eval_names))
|
||||||
@ -425,9 +431,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
self._features_count = X.shape[1]
|
self._features_count = X.shape[1]
|
||||||
|
|
||||||
self._le = XGBLabelEncoder().fit(y)
|
|
||||||
training_labels = self._le.transform(y)
|
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||||
missing=self.missing)
|
missing=self.missing)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user