Added classes_ attribute to scikit-learn wrapper.

This commit is contained in:
John Wittenauer 2015-05-15 21:19:39 -04:00
parent 9c52fc8e22
commit 4e080928a8

View File

@ -935,8 +935,8 @@ class XGBClassifier(XGBModel, XGBClassifier):
base_score, seed) base_score, seed)
def fit(self, X, y, sample_weight=None): def fit(self, X, y, sample_weight=None):
y_values = list(np.unique(y)) self.classes_ = list(np.unique(y))
self.n_classes_ = len(y_values) self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2: if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance # Switch to using a multiclass objective in the underlying XGB instance
self.objective = "multi:softprob" self.objective = "multi:softprob"