From 4e080928a8b6d62e57f5a08e7637790f26b76ace Mon Sep 17 00:00:00 2001 From: John Wittenauer Date: Fri, 15 May 2015 21:19:39 -0400 Subject: [PATCH] Added classes_ attribute to scikit-learn wrapper. --- wrapper/xgboost.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 7e3fa2dc4..46a229cd6 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -935,8 +935,8 @@ class XGBClassifier(XGBModel, XGBClassifier): base_score, seed) def fit(self, X, y, sample_weight=None): - y_values = list(np.unique(y)) - self.n_classes_ = len(y_values) + self.classes_ = list(np.unique(y)) + self.n_classes_ = len(self.classes_) if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying XGB instance self.objective = "multi:softprob"