From c1a24c0fb173cc6a6009a8c9933566eb981efcc5 Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Mon, 27 Apr 2015 10:46:30 -0500 Subject: [PATCH] ENH: Add n_classes_ to fitted classifier. --- wrapper/xgboost.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index d6651c26c..65117c36c 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -872,11 +872,12 @@ class XGBClassifier(XGBModel, XGBClassifier): def fit(self, X, y, sample_weight=None): y_values = list(np.unique(y)) - if len(y_values) > 2: + self.n_classes_ = len(y_values) + if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying XGB instance self.objective = "multi:softprob" xgb_options = self.get_xgb_params() - xgb_options['num_class'] = len(y_values) + xgb_options['num_class'] = self.n_classes_ else: xgb_options = self.get_xgb_params()