[py] fix label encoding of eval sets in sklearn API (#1244)

This commit is contained in:
Titouan Lorieul 2016-07-11 12:29:46 +02:00 committed by Yuan (Terry) Tang
parent 197b4c6b18
commit 75d9be55de

View File

@ -414,9 +414,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else:
xgb_options.update({"eval_metric": eval_metric})
self._le = XGBLabelEncoder().fit(y)
training_labels = self._le.transform(y)
if eval_set is not None:
# TODO: use sample_weight if given?
evals = list(DMatrix(x[0], label=x[1], missing=self.missing) for x in eval_set)
evals = list(
DMatrix(x[0], label=self._le.transform(x[1]), missing=self.missing)
for x in eval_set
)
nevals = len(evals)
eval_names = ["validation_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names))
@ -425,9 +431,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
self._features_count = X.shape[1]
self._le = XGBLabelEncoder().fit(y)
training_labels = self._le.transform(y)
if sample_weight is not None:
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
missing=self.missing)