diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 7e3fa2dc4..46a229cd6 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -935,8 +935,8 @@ class XGBClassifier(XGBModel, XGBClassifier): base_score, seed) def fit(self, X, y, sample_weight=None): - y_values = list(np.unique(y)) - self.n_classes_ = len(y_values) + self.classes_ = list(np.unique(y)) + self.n_classes_ = len(self.classes_) if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying XGB instance self.objective = "multi:softprob"