Merge pull request #220 from white1033/master
*Fix XGBClassifier super()
This commit is contained in:
commit
e626b62daa
@ -26,7 +26,6 @@ except ImportError:
|
|||||||
SKLEARN_INSTALLED = False
|
SKLEARN_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
@ -632,7 +631,6 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
return bst
|
return bst
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CVPack(object):
|
class CVPack(object):
|
||||||
def __init__(self, dtrain, dtest, param):
|
def __init__(self, dtrain, dtest, param):
|
||||||
self.dtrain = dtrain
|
self.dtrain = dtrain
|
||||||
@ -778,6 +776,7 @@ class XGBModel(BaseEstimator):
|
|||||||
'silent': True if self.silent == 1 else False,
|
'silent': True if self.silent == 1 else False,
|
||||||
'objective': self.objective
|
'objective': self.objective
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective}
|
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective}
|
||||||
|
|
||||||
@ -790,9 +789,10 @@ class XGBModel(BaseEstimator):
|
|||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
return self._Booster.predict(testDmatrix)
|
return self._Booster.predict(testDmatrix)
|
||||||
|
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, ClassifierMixin):
|
class XGBClassifier(XGBModel, ClassifierMixin):
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True):
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True):
|
||||||
super().__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic")
|
super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic")
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None):
|
def fit(self, X, y, sample_weight=None):
|
||||||
y_values = list(np.unique(y))
|
y_values = list(np.unique(y))
|
||||||
@ -836,7 +836,6 @@ class XGBClassifier(XGBModel, ClassifierMixin):
|
|||||||
classzero_probs = 1.0 - classone_probs
|
classzero_probs = 1.0 - classone_probs
|
||||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
|
|
||||||
class XGBRegressor(XGBModel, RegressorMixin):
|
class XGBRegressor(XGBModel, RegressorMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user