ENH: Add n_classes_ to fitted classifier.
This commit is contained in:
parent
8ac89b290e
commit
c1a24c0fb1
@ -872,11 +872,12 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
|||||||
|
|
||||||
def fit(self, X, y, sample_weight=None):
|
def fit(self, X, y, sample_weight=None):
|
||||||
y_values = list(np.unique(y))
|
y_values = list(np.unique(y))
|
||||||
if len(y_values) > 2:
|
self.n_classes_ = len(y_values)
|
||||||
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying XGB instance
|
# Switch to using a multiclass objective in the underlying XGB instance
|
||||||
self.objective = "multi:softprob"
|
self.objective = "multi:softprob"
|
||||||
xgb_options = self.get_xgb_params()
|
xgb_options = self.get_xgb_params()
|
||||||
xgb_options['num_class'] = len(y_values)
|
xgb_options['num_class'] = self.n_classes_
|
||||||
else:
|
else:
|
||||||
xgb_options = self.get_xgb_params()
|
xgb_options = self.get_xgb_params()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user