*Fix Sklearn.grid_search error
This commit is contained in:
parent
e626b62daa
commit
b4545df0e3
@ -763,26 +763,30 @@ class XGBModel(BaseEstimator):
|
|||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise Exception('sklearn needs to be installed in order to use this module')
|
raise Exception('sklearn needs to be installed in order to use this module')
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.eta = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.silent = 1 if silent else 0
|
self.silent = silent
|
||||||
self.n_rounds = n_estimators
|
self.n_estimators = n_estimators
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self._Booster = Booster()
|
self._Booster = Booster()
|
||||||
|
|
||||||
def get_params(self, deep=True):
|
def get_params(self, deep=True):
|
||||||
return {'max_depth': self.max_depth,
|
return {'max_depth': self.max_depth,
|
||||||
'learning_rate': self.eta,
|
'learning_rate': self.learning_rate,
|
||||||
'n_estimators': self.n_rounds,
|
'n_estimators': self.n_estimators,
|
||||||
'silent': True if self.silent == 1 else False,
|
'silent': self.silent,
|
||||||
'objective': self.objective
|
'objective': self.objective
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective}
|
return {'eta': self.learning_rate,
|
||||||
|
'max_depth': self.max_depth,
|
||||||
|
'silent': 1 if self.silent else 0,
|
||||||
|
'objective': self.objective
|
||||||
|
}
|
||||||
|
|
||||||
def fit(self, X, y):
|
def fit(self, X, y):
|
||||||
trainDmatrix = DMatrix(X, label=y)
|
trainDmatrix = DMatrix(X, label=y)
|
||||||
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds)
|
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_estimators)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
@ -791,8 +795,8 @@ class XGBModel(BaseEstimator):
|
|||||||
|
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, ClassifierMixin):
|
class XGBClassifier(XGBModel, ClassifierMixin):
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True):
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="binary:logistic"):
|
||||||
super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic")
|
super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None):
|
def fit(self, X, y, sample_weight=None):
|
||||||
y_values = list(np.unique(y))
|
y_values = list(np.unique(y))
|
||||||
@ -812,7 +816,7 @@ class XGBClassifier(XGBModel, ClassifierMixin):
|
|||||||
else:
|
else:
|
||||||
trainDmatrix = DMatrix(X, label=training_labels)
|
trainDmatrix = DMatrix(X, label=training_labels)
|
||||||
|
|
||||||
self._Booster = train(xgb_options, trainDmatrix, self.n_rounds)
|
self._Booster = train(xgb_options, trainDmatrix, self.n_estimators)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user