Merge pull request #220 from white1033/master
*Fix XGBClassifier super()
This commit is contained in:
commit
e626b62daa
@ -26,7 +26,6 @@ except ImportError:
|
|||||||
SKLEARN_INSTALLED = False
|
SKLEARN_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
@ -552,20 +551,20 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
early_stopping_rounds: int
|
early_stopping_rounds: int
|
||||||
Activates early stopping. Validation error needs to decrease at least
|
Activates early stopping. Validation error needs to decrease at least
|
||||||
every <early_stopping_rounds> round(s) to continue training.
|
every <early_stopping_rounds> round(s) to continue training.
|
||||||
Requires at least one item in evals.
|
Requires at least one item in evals.
|
||||||
If there's more than one, will use the last.
|
If there's more than one, will use the last.
|
||||||
Returns the model from the last iteration (not the best one).
|
Returns the model from the last iteration (not the best one).
|
||||||
If early stopping occurs, the model will have two additional fields:
|
If early stopping occurs, the model will have two additional fields:
|
||||||
bst.best_score and bst.best_iteration.
|
bst.best_score and bst.best_iteration.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
booster : a trained booster model
|
booster : a trained booster model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
if not early_stopping_rounds:
|
if not early_stopping_rounds:
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
@ -576,15 +575,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
else:
|
else:
|
||||||
sys.stderr.write(bst_eval_set.decode() + '\n')
|
sys.stderr.write(bst_eval_set.decode() + '\n')
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# early stopping
|
# early stopping
|
||||||
|
|
||||||
if len(evals) < 1:
|
if len(evals) < 1:
|
||||||
raise ValueError('For early stopping you need at least on set in evals.')
|
raise ValueError('For early stopping you need at least on set in evals.')
|
||||||
|
|
||||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
|
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
|
||||||
|
|
||||||
# is params a list of tuples? are we using multiple eval metrics?
|
# is params a list of tuples? are we using multiple eval metrics?
|
||||||
if type(params) == list:
|
if type(params) == list:
|
||||||
if len(params) != len(dict(params).items()):
|
if len(params) != len(dict(params).items()):
|
||||||
@ -597,29 +596,29 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||||
if filter(lambda x: params['eval_metric'].startswith(x), maximize_metrics):
|
if filter(lambda x: params['eval_metric'].startswith(x), maximize_metrics):
|
||||||
maximize_score = True
|
maximize_score = True
|
||||||
|
|
||||||
if maximize_score:
|
if maximize_score:
|
||||||
best_score = 0.0
|
best_score = 0.0
|
||||||
else:
|
else:
|
||||||
best_score = float('inf')
|
best_score = float('inf')
|
||||||
|
|
||||||
best_msg = ''
|
best_msg = ''
|
||||||
best_score_i = 0
|
best_score_i = 0
|
||||||
|
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
|
||||||
if isinstance(bst_eval_set, string_types):
|
if isinstance(bst_eval_set, string_types):
|
||||||
msg = bst_eval_set
|
msg = bst_eval_set
|
||||||
else:
|
else:
|
||||||
msg = bst_eval_set.decode()
|
msg = bst_eval_set.decode()
|
||||||
|
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
score = float(msg.rsplit(':', 1)[1])
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
(not maximize_score and score < best_score):
|
(not maximize_score and score < best_score):
|
||||||
best_score = score
|
best_score = score
|
||||||
best_score_i = i
|
best_score_i = i
|
||||||
best_msg = msg
|
best_msg = msg
|
||||||
@ -628,10 +627,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
bst.best_score = best_score
|
bst.best_score = best_score
|
||||||
bst.best_iteration = best_score_i
|
bst.best_iteration = best_score_i
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CVPack(object):
|
class CVPack(object):
|
||||||
def __init__(self, dtrain, dtest, param):
|
def __init__(self, dtrain, dtest, param):
|
||||||
@ -770,7 +768,7 @@ class XGBModel(BaseEstimator):
|
|||||||
self.n_rounds = n_estimators
|
self.n_rounds = n_estimators
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self._Booster = Booster()
|
self._Booster = Booster()
|
||||||
|
|
||||||
def get_params(self, deep=True):
|
def get_params(self, deep=True):
|
||||||
return {'max_depth': self.max_depth,
|
return {'max_depth': self.max_depth,
|
||||||
'learning_rate': self.eta,
|
'learning_rate': self.eta,
|
||||||
@ -778,22 +776,24 @@ class XGBModel(BaseEstimator):
|
|||||||
'silent': True if self.silent == 1 else False,
|
'silent': True if self.silent == 1 else False,
|
||||||
'objective': self.objective
|
'objective': self.objective
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective}
|
return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective}
|
||||||
|
|
||||||
def fit(self, X, y):
|
def fit(self, X, y):
|
||||||
trainDmatrix = DMatrix(X, label=y)
|
trainDmatrix = DMatrix(X, label=y)
|
||||||
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds)
|
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
return self._Booster.predict(testDmatrix)
|
return self._Booster.predict(testDmatrix)
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, ClassifierMixin):
|
|
||||||
|
class XGBClassifier(XGBModel, ClassifierMixin):
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True):
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True):
|
||||||
super().__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic")
|
super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic")
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None):
|
def fit(self, X, y, sample_weight=None):
|
||||||
y_values = list(np.unique(y))
|
y_values = list(np.unique(y))
|
||||||
if len(y_values) > 2:
|
if len(y_values) > 2:
|
||||||
@ -803,19 +803,19 @@ class XGBClassifier(XGBModel, ClassifierMixin):
|
|||||||
xgb_options['num_class'] = len(y_values)
|
xgb_options['num_class'] = len(y_values)
|
||||||
else:
|
else:
|
||||||
xgb_options = self.get_xgb_params()
|
xgb_options = self.get_xgb_params()
|
||||||
|
|
||||||
self._le = LabelEncoder().fit(y)
|
self._le = LabelEncoder().fit(y)
|
||||||
training_labels = self._le.transform(y)
|
training_labels = self._le.transform(y)
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
trainDmatrix = DMatrix(X, label=training_labels, weight=sample_weight)
|
trainDmatrix = DMatrix(X, label=training_labels, weight=sample_weight)
|
||||||
else:
|
else:
|
||||||
trainDmatrix = DMatrix(X, label=training_labels)
|
trainDmatrix = DMatrix(X, label=training_labels)
|
||||||
|
|
||||||
self._Booster = train(xgb_options, trainDmatrix, self.n_rounds)
|
self._Booster = train(xgb_options, trainDmatrix, self.n_rounds)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
class_probs = self._Booster.predict(testDmatrix)
|
class_probs = self._Booster.predict(testDmatrix)
|
||||||
@ -825,7 +825,7 @@ class XGBClassifier(XGBModel, ClassifierMixin):
|
|||||||
column_indexes = np.repeat(0, X.shape[0])
|
column_indexes = np.repeat(0, X.shape[0])
|
||||||
column_indexes[class_probs > 0.5] = 1
|
column_indexes[class_probs > 0.5] = 1
|
||||||
return self._le.inverse_transform(column_indexes)
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
|
||||||
def predict_proba(self, X):
|
def predict_proba(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
class_probs = self._Booster.predict(testDmatrix)
|
class_probs = self._Booster.predict(testDmatrix)
|
||||||
@ -834,9 +834,8 @@ class XGBClassifier(XGBModel, ClassifierMixin):
|
|||||||
else:
|
else:
|
||||||
classone_probs = class_probs
|
classone_probs = class_probs
|
||||||
classzero_probs = 1.0 - classone_probs
|
classzero_probs = 1.0 - classone_probs
|
||||||
return np.vstack((classzero_probs,classone_probs)).transpose()
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
class XGBRegressor(XGBModel, RegressorMixin):
|
|
||||||
|
class XGBRegressor(XGBModel, RegressorMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user