Added classes_ attribute to scikit-learn wrapper.
This commit is contained in:
parent
9c52fc8e22
commit
4e080928a8
@ -935,8 +935,8 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
||||
base_score, seed)
|
||||
|
||||
def fit(self, X, y, sample_weight=None):
|
||||
y_values = list(np.unique(y))
|
||||
self.n_classes_ = len(y_values)
|
||||
self.classes_ = list(np.unique(y))
|
||||
self.n_classes_ = len(self.classes_)
|
||||
if self.n_classes_ > 2:
|
||||
# Switch to using a multiclass objective in the underlying XGB instance
|
||||
self.objective = "multi:softprob"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user