Merge pull request #267 from jseabold/add-n-classes

Add n_classes_ to fitted XGBClassifier
This commit is contained in:
Tianqi Chen 2015-04-27 09:10:17 -07:00
commit f271af488b

View File

@ -872,11 +872,12 @@ class XGBClassifier(XGBModel, XGBClassifier):
def fit(self, X, y, sample_weight=None):
y_values = list(np.unique(y))
if len(y_values) > 2:
self.n_classes_ = len(y_values)
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance
self.objective = "multi:softprob"
xgb_options = self.get_xgb_params()
xgb_options['num_class'] = len(y_values)
xgb_options['num_class'] = self.n_classes_
else:
xgb_options = self.get_xgb_params()