checkin all python
This commit is contained in:
parent
c2fec29bfa
commit
60217a2c02
7
.gitignore
vendored
7
.gitignore
vendored
@ -48,10 +48,9 @@ Debug
|
|||||||
*.cpage.col
|
*.cpage.col
|
||||||
*.cpage
|
*.cpage
|
||||||
*.Rproj
|
*.Rproj
|
||||||
xgboost
|
./xgboost
|
||||||
xgboost.mpi
|
./xgboost.mpi
|
||||||
xgboost.mock
|
./xgboost.mock
|
||||||
train*
|
|
||||||
rabit
|
rabit
|
||||||
#.Rbuildignore
|
#.Rbuildignore
|
||||||
R-package.Rproj
|
R-package.Rproj
|
||||||
|
|||||||
12
python-package/xgboost/__init__.py
Normal file
12
python-package/xgboost/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
"""XGBoost: eXtreme Gradient Boosting library.
|
||||||
|
|
||||||
|
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .core import DMatrix, Booster
|
||||||
|
from .training import train, cv
|
||||||
|
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
||||||
|
|
||||||
|
__version__ = '0.4'
|
||||||
341
python-package/xgboost/sklearn.py
Normal file
341
python-package/xgboost/sklearn.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
|
||||||
|
"""Scikit-Learn Wrapper interface for XGBoost."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
|
from .training import train
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.base import BaseEstimator
|
||||||
|
from sklearn.base import RegressorMixin, ClassifierMixin
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
SKLEARN_INSTALLED = True
|
||||||
|
except ImportError:
|
||||||
|
SKLEARN_INSTALLED = False
|
||||||
|
|
||||||
|
# used for compatiblity without sklearn
|
||||||
|
XGBModelBase = object
|
||||||
|
XGBClassifierBase = object
|
||||||
|
XGBRegressorBase = object
|
||||||
|
|
||||||
|
if SKLEARN_INSTALLED:
|
||||||
|
XGBModelBase = BaseEstimator
|
||||||
|
XGBRegressorBase = RegressorMixin
|
||||||
|
XGBClassifierBase = ClassifierMixin
|
||||||
|
|
||||||
|
class XGBModel(XGBModelBase):
|
||||||
|
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
|
||||||
|
"""Implementation of the Scikit-Learn API for XGBoost.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_depth : int
|
||||||
|
Maximum tree depth for base learners.
|
||||||
|
learning_rate : float
|
||||||
|
Boosting learning rate (xgb's "eta")
|
||||||
|
n_estimators : int
|
||||||
|
Number of boosted trees to fit.
|
||||||
|
silent : boolean
|
||||||
|
Whether to print messages while running boosting.
|
||||||
|
objective : string
|
||||||
|
Specify the learning task and the corresponding learning objective.
|
||||||
|
|
||||||
|
nthread : int
|
||||||
|
Number of parallel threads used to run xgboost.
|
||||||
|
gamma : float
|
||||||
|
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
|
min_child_weight : int
|
||||||
|
Minimum sum of instance weight(hessian) needed in a child.
|
||||||
|
max_delta_step : int
|
||||||
|
Maximum delta step we allow each tree's weight estimation to be.
|
||||||
|
subsample : float
|
||||||
|
Subsample ratio of the training instance.
|
||||||
|
colsample_bytree : float
|
||||||
|
Subsample ratio of columns when constructing each tree.
|
||||||
|
|
||||||
|
base_score:
|
||||||
|
The initial prediction score of all instances, global bias.
|
||||||
|
seed : int
|
||||||
|
Random number seed.
|
||||||
|
missing : float, optional
|
||||||
|
Value in the data which needs to be present as a missing value. If
|
||||||
|
None, defaults to np.nan.
|
||||||
|
"""
|
||||||
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
|
silent=True, objective="reg:linear",
|
||||||
|
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
|
subsample=1, colsample_bytree=1,
|
||||||
|
base_score=0.5, seed=0, missing=None):
|
||||||
|
if not SKLEARN_INSTALLED:
|
||||||
|
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.n_estimators = n_estimators
|
||||||
|
self.silent = silent
|
||||||
|
self.objective = objective
|
||||||
|
|
||||||
|
self.nthread = nthread
|
||||||
|
self.gamma = gamma
|
||||||
|
self.min_child_weight = min_child_weight
|
||||||
|
self.max_delta_step = max_delta_step
|
||||||
|
self.subsample = subsample
|
||||||
|
self.colsample_bytree = colsample_bytree
|
||||||
|
|
||||||
|
self.base_score = base_score
|
||||||
|
self.seed = seed
|
||||||
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self._Booster = None
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# backward compatiblity code
|
||||||
|
# load booster from raw if it is raw
|
||||||
|
# the booster now support pickle
|
||||||
|
bst = state["_Booster"]
|
||||||
|
if bst is not None and not isinstance(bst, Booster):
|
||||||
|
state["_Booster"] = Booster(model_file=bst)
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def booster(self):
|
||||||
|
"""Get the underlying xgboost Booster of this model.
|
||||||
|
|
||||||
|
This will raise an exception when fit was not called
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
booster : a xgboost booster of underlying model
|
||||||
|
"""
|
||||||
|
if self._Booster is None:
|
||||||
|
raise XGBoostError('need to call fit beforehand')
|
||||||
|
return self._Booster
|
||||||
|
|
||||||
|
def get_params(self, deep=False):
|
||||||
|
"""Get parameter.s"""
|
||||||
|
params = super(XGBModel, self).get_params(deep=deep)
|
||||||
|
if params['missing'] is np.nan:
|
||||||
|
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
||||||
|
if not params.get('eval_metric', True):
|
||||||
|
del params['eval_metric'] # don't give as None param to Booster
|
||||||
|
return params
|
||||||
|
|
||||||
|
def get_xgb_params(self):
|
||||||
|
"""Get xgboost type parameters."""
|
||||||
|
xgb_params = self.get_params()
|
||||||
|
|
||||||
|
xgb_params['silent'] = 1 if self.silent else 0
|
||||||
|
|
||||||
|
if self.nthread <= 0:
|
||||||
|
xgb_params.pop('nthread', None)
|
||||||
|
return xgb_params
|
||||||
|
|
||||||
|
def fit(self, X, y, eval_set=None, eval_metric=None,
|
||||||
|
early_stopping_rounds=None, verbose=True):
|
||||||
|
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||||
|
"""
|
||||||
|
Fit the gradient boosting model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : array_like
|
||||||
|
Feature matrix
|
||||||
|
y : array_like
|
||||||
|
Labels
|
||||||
|
eval_set : list, optional
|
||||||
|
A list of (X, y) tuple pairs to use as a validation set for
|
||||||
|
early-stopping
|
||||||
|
eval_metric : str, callable, optional
|
||||||
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
|
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||||
|
signature is func(y_predicted, y_true) where y_true will be a
|
||||||
|
DMatrix object such that you may need to call the get_label
|
||||||
|
method. It must return a str, value pair where the str is a name
|
||||||
|
for the evaluation and value is the value of the evaluation
|
||||||
|
function. This objective is always minimized.
|
||||||
|
early_stopping_rounds : int
|
||||||
|
Activates early stopping. Validation error needs to decrease at
|
||||||
|
least every <early_stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals. If there's more than one,
|
||||||
|
will use the last. Returns the model from the last iteration
|
||||||
|
(not the best one). If early stopping occurs, the model will
|
||||||
|
have two additional fields: bst.best_score and bst.best_iteration.
|
||||||
|
verbose : bool
|
||||||
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
|
metric measured on the validation set to stderr.
|
||||||
|
"""
|
||||||
|
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
|
||||||
|
|
||||||
|
eval_results = {}
|
||||||
|
if eval_set is not None:
|
||||||
|
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
||||||
|
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
||||||
|
range(len(evals))]))
|
||||||
|
else:
|
||||||
|
evals = ()
|
||||||
|
|
||||||
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
|
if eval_metric is not None:
|
||||||
|
if callable(eval_metric):
|
||||||
|
eval_metric = None
|
||||||
|
else:
|
||||||
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
|
self._Booster = train(params, trainDmatrix,
|
||||||
|
self.n_estimators, evals=evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
evals_result=eval_results, feval=feval,
|
||||||
|
verbose_eval=verbose)
|
||||||
|
if eval_results:
|
||||||
|
eval_results = {k: np.array(v, dtype=float)
|
||||||
|
for k, v in eval_results.items()}
|
||||||
|
eval_results = {k: np.array(v) for k, v in eval_results.items()}
|
||||||
|
self.eval_results = eval_results
|
||||||
|
|
||||||
|
if early_stopping_rounds is not None:
|
||||||
|
self.best_score = self._Booster.best_score
|
||||||
|
self.best_iteration = self._Booster.best_iteration
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
|
return self.booster().predict(test_dmatrix)
|
||||||
|
|
||||||
|
|
||||||
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
|
__doc__ = """
|
||||||
|
Implementation of the scikit-learn API for XGBoost classification
|
||||||
|
""" + "\n".join(XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
|
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||||
|
n_estimators=100, silent=True,
|
||||||
|
objective="binary:logistic",
|
||||||
|
nthread=-1, gamma=0, min_child_weight=1,
|
||||||
|
max_delta_step=0, subsample=1, colsample_bytree=1,
|
||||||
|
base_score=0.5, seed=0, missing=None):
|
||||||
|
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||||
|
n_estimators, silent, objective,
|
||||||
|
nthread, gamma, min_child_weight,
|
||||||
|
max_delta_step, subsample,
|
||||||
|
colsample_bytree,
|
||||||
|
base_score, seed, missing)
|
||||||
|
|
||||||
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
|
early_stopping_rounds=None, verbose=True):
|
||||||
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
|
"""
|
||||||
|
Fit gradient boosting classifier
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : array_like
|
||||||
|
Feature matrix
|
||||||
|
y : array_like
|
||||||
|
Labels
|
||||||
|
sample_weight : array_like
|
||||||
|
Weight for each instance
|
||||||
|
eval_set : list, optional
|
||||||
|
A list of (X, y) pairs to use as a validation set for
|
||||||
|
early-stopping
|
||||||
|
eval_metric : str, callable, optional
|
||||||
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
|
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||||
|
signature is func(y_predicted, y_true) where y_true will be a
|
||||||
|
DMatrix object such that you may need to call the get_label
|
||||||
|
method. It must return a str, value pair where the str is a name
|
||||||
|
for the evaluation and value is the value of the evaluation
|
||||||
|
function. This objective is always minimized.
|
||||||
|
early_stopping_rounds : int, optional
|
||||||
|
Activates early stopping. Validation error needs to decrease at
|
||||||
|
least every <early_stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals. If there's more than one,
|
||||||
|
will use the last. Returns the model from the last iteration
|
||||||
|
(not the best one). If early stopping occurs, the model will
|
||||||
|
have two additional fields: bst.best_score and bst.best_iteration.
|
||||||
|
verbose : bool
|
||||||
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
|
metric measured on the validation set to stderr.
|
||||||
|
"""
|
||||||
|
eval_results = {}
|
||||||
|
self.classes_ = list(np.unique(y))
|
||||||
|
self.n_classes_ = len(self.classes_)
|
||||||
|
if self.n_classes_ > 2:
|
||||||
|
# Switch to using a multiclass objective in the underlying XGB instance
|
||||||
|
self.objective = "multi:softprob"
|
||||||
|
xgb_options = self.get_xgb_params()
|
||||||
|
xgb_options['num_class'] = self.n_classes_
|
||||||
|
else:
|
||||||
|
xgb_options = self.get_xgb_params()
|
||||||
|
|
||||||
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
|
if eval_metric is not None:
|
||||||
|
if callable(eval_metric):
|
||||||
|
eval_metric = None
|
||||||
|
else:
|
||||||
|
xgb_options.update({"eval_metric": eval_metric})
|
||||||
|
|
||||||
|
if eval_set is not None:
|
||||||
|
# TODO: use sample_weight if given?
|
||||||
|
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
||||||
|
nevals = len(evals)
|
||||||
|
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||||
|
evals = list(zip(evals, eval_names))
|
||||||
|
else:
|
||||||
|
evals = ()
|
||||||
|
|
||||||
|
self._le = LabelEncoder().fit(y)
|
||||||
|
training_labels = self._le.transform(y)
|
||||||
|
|
||||||
|
if sample_weight is not None:
|
||||||
|
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||||
|
missing=self.missing)
|
||||||
|
else:
|
||||||
|
train_dmatrix = DMatrix(X, label=training_labels,
|
||||||
|
missing=self.missing)
|
||||||
|
|
||||||
|
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||||
|
evals=evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
evals_result=eval_results, feval=feval,
|
||||||
|
verbose_eval=verbose)
|
||||||
|
|
||||||
|
if eval_results:
|
||||||
|
eval_results = {k: np.array(v, dtype=float)
|
||||||
|
for k, v in eval_results.items()}
|
||||||
|
self.eval_results = eval_results
|
||||||
|
|
||||||
|
if early_stopping_rounds is not None:
|
||||||
|
self.best_score = self._Booster.best_score
|
||||||
|
self.best_iteration = self._Booster.best_iteration
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
|
class_probs = self.booster().predict(test_dmatrix)
|
||||||
|
if len(class_probs.shape) > 1:
|
||||||
|
column_indexes = np.argmax(class_probs, axis=1)
|
||||||
|
else:
|
||||||
|
column_indexes = np.repeat(0, data.shape[0])
|
||||||
|
column_indexes[class_probs > 0.5] = 1
|
||||||
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
|
||||||
|
def predict_proba(self, data):
|
||||||
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
|
class_probs = self.booster().predict(test_dmatrix)
|
||||||
|
if self.objective == "multi:softprob":
|
||||||
|
return class_probs
|
||||||
|
else:
|
||||||
|
classone_probs = class_probs
|
||||||
|
classzero_probs = 1.0 - classone_probs
|
||||||
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
|
class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
__doc__ = """
|
||||||
|
Implementation of the scikit-learn API for XGBoost regression
|
||||||
|
""" + "\n".join(XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
252
python-package/xgboost/training.py
Normal file
252
python-package/xgboost/training.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||||
|
"""Training Library containing training routines."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from .core import Booster, STRING_TYPES
|
||||||
|
|
||||||
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
|
early_stopping_rounds=None, evals_result=None, verbose_eval=True):
|
||||||
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
params : dict
|
||||||
|
Booster params.
|
||||||
|
dtrain : DMatrix
|
||||||
|
Data to be trained.
|
||||||
|
num_boost_round: int
|
||||||
|
Number of boosting iterations.
|
||||||
|
watchlist (evals): list of pairs (DMatrix, string)
|
||||||
|
List of items to be evaluated during training, this allows user to watch
|
||||||
|
performance on the validation set.
|
||||||
|
obj : function
|
||||||
|
Customized objective function.
|
||||||
|
feval : function
|
||||||
|
Customized evaluation function.
|
||||||
|
early_stopping_rounds: int
|
||||||
|
Activates early stopping. Validation error needs to decrease at least
|
||||||
|
every <early_stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals.
|
||||||
|
If there's more than one, will use the last.
|
||||||
|
Returns the model from the last iteration (not the best one).
|
||||||
|
If early stopping occurs, the model will have two additional fields:
|
||||||
|
bst.best_score and bst.best_iteration.
|
||||||
|
evals_result: dict
|
||||||
|
This dictionary stores the evaluation results of all the items in watchlist
|
||||||
|
verbose_eval : bool
|
||||||
|
If `verbose_eval` then the evaluation metric on the validation set, if
|
||||||
|
given, is printed at each boosting stage.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
booster : a trained booster model
|
||||||
|
"""
|
||||||
|
evals = list(evals)
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
|
if evals_result is not None:
|
||||||
|
if not isinstance(evals_result, dict):
|
||||||
|
raise TypeError('evals_result has to be a dictionary')
|
||||||
|
else:
|
||||||
|
evals_name = [d[1] for d in evals]
|
||||||
|
evals_result.clear()
|
||||||
|
evals_result.update({key: [] for key in evals_name})
|
||||||
|
|
||||||
|
if not early_stopping_rounds:
|
||||||
|
for i in range(num_boost_round):
|
||||||
|
bst.update(dtrain, i, obj)
|
||||||
|
if len(evals) != 0:
|
||||||
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
|
msg = bst_eval_set
|
||||||
|
else:
|
||||||
|
msg = bst_eval_set.decode()
|
||||||
|
|
||||||
|
if verbose_eval:
|
||||||
|
sys.stderr.write(msg + '\n')
|
||||||
|
if evals_result is not None:
|
||||||
|
res = re.findall(":-?([0-9.]+).", msg)
|
||||||
|
for key, val in zip(evals_name, res):
|
||||||
|
evals_result[key].append(val)
|
||||||
|
return bst
|
||||||
|
|
||||||
|
else:
|
||||||
|
# early stopping
|
||||||
|
if len(evals) < 1:
|
||||||
|
raise ValueError('For early stopping you need at least one set in evals.')
|
||||||
|
|
||||||
|
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(\
|
||||||
|
evals[-1][1], early_stopping_rounds))
|
||||||
|
|
||||||
|
# is params a list of tuples? are we using multiple eval metrics?
|
||||||
|
if isinstance(params, list):
|
||||||
|
if len(params) != len(dict(params).items()):
|
||||||
|
raise ValueError('Check your params.'\
|
||||||
|
'Early stopping works with single eval metric only.')
|
||||||
|
params = dict(params)
|
||||||
|
|
||||||
|
# either minimize loss or maximize AUC/MAP/NDCG
|
||||||
|
maximize_score = False
|
||||||
|
if 'eval_metric' in params:
|
||||||
|
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||||
|
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
|
||||||
|
maximize_score = True
|
||||||
|
|
||||||
|
if maximize_score:
|
||||||
|
best_score = 0.0
|
||||||
|
else:
|
||||||
|
best_score = float('inf')
|
||||||
|
|
||||||
|
best_msg = ''
|
||||||
|
best_score_i = 0
|
||||||
|
|
||||||
|
for i in range(num_boost_round):
|
||||||
|
bst.update(dtrain, i, obj)
|
||||||
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
|
||||||
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
|
msg = bst_eval_set
|
||||||
|
else:
|
||||||
|
msg = bst_eval_set.decode()
|
||||||
|
|
||||||
|
if verbose_eval:
|
||||||
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
|
if evals_result is not None:
|
||||||
|
res = re.findall(":-([0-9.]+).", msg)
|
||||||
|
for key, val in zip(evals_name, res):
|
||||||
|
evals_result[key].append(val)
|
||||||
|
|
||||||
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
|
if (maximize_score and score > best_score) or \
|
||||||
|
(not maximize_score and score < best_score):
|
||||||
|
best_score = score
|
||||||
|
best_score_i = i
|
||||||
|
best_msg = msg
|
||||||
|
elif i - best_score_i >= early_stopping_rounds:
|
||||||
|
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||||
|
bst.best_score = best_score
|
||||||
|
bst.best_iteration = best_score_i
|
||||||
|
break
|
||||||
|
bst.best_score = best_score
|
||||||
|
bst.best_iteration = best_score_i
|
||||||
|
return bst
|
||||||
|
|
||||||
|
|
||||||
|
class CVPack(object):
|
||||||
|
""""Auxiliary datastruct to hold one fold of CV."""
|
||||||
|
def __init__(self, dtrain, dtest, param):
|
||||||
|
""""Initialize the CVPack"""
|
||||||
|
self.dtrain = dtrain
|
||||||
|
self.dtest = dtest
|
||||||
|
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
||||||
|
self.bst = Booster(param, [dtrain, dtest])
|
||||||
|
|
||||||
|
def update(self, iteration, fobj):
|
||||||
|
""""Update the boosters for one iteration"""
|
||||||
|
self.bst.update(self.dtrain, iteration, fobj)
|
||||||
|
|
||||||
|
def eval(self, iteration, feval):
|
||||||
|
""""Evaluate the CVPack for one iteration."""
|
||||||
|
return self.bst.eval_set(self.watchlist, iteration, feval)
|
||||||
|
|
||||||
|
|
||||||
|
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
|
||||||
|
"""
|
||||||
|
Make an n-fold list of CVPack from random indices.
|
||||||
|
"""
|
||||||
|
evals = list(evals)
|
||||||
|
np.random.seed(seed)
|
||||||
|
randidx = np.random.permutation(dall.num_row())
|
||||||
|
kstep = len(randidx) / nfold
|
||||||
|
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
||||||
|
ret = []
|
||||||
|
for k in range(nfold):
|
||||||
|
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
|
||||||
|
dtest = dall.slice(idset[k])
|
||||||
|
# run preprocessing on the data set if needed
|
||||||
|
if fpreproc is not None:
|
||||||
|
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||||
|
else:
|
||||||
|
tparam = param
|
||||||
|
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
|
||||||
|
ret.append(CVPack(dtrain, dtest, plst))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def aggcv(rlist, show_stdv=True):
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
"""
|
||||||
|
Aggregate cross-validation results.
|
||||||
|
"""
|
||||||
|
cvmap = {}
|
||||||
|
ret = rlist[0].split()[0]
|
||||||
|
for line in rlist:
|
||||||
|
arr = line.split()
|
||||||
|
assert ret == arr[0]
|
||||||
|
for it in arr[1:]:
|
||||||
|
if not isinstance(it, STRING_TYPES):
|
||||||
|
it = it.decode()
|
||||||
|
k, v = it.split(':')
|
||||||
|
if k not in cvmap:
|
||||||
|
cvmap[k] = []
|
||||||
|
cvmap[k].append(float(v))
|
||||||
|
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
||||||
|
v = np.array(v)
|
||||||
|
if not isinstance(ret, STRING_TYPES):
|
||||||
|
ret = ret.decode()
|
||||||
|
if show_stdv:
|
||||||
|
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
||||||
|
else:
|
||||||
|
ret += '\tcv-%s:%f' % (k, np.mean(v))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||||
|
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
|
||||||
|
# pylint: disable = invalid-name
|
||||||
|
"""Cross-validation with given paramaters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
params : dict
|
||||||
|
Booster params.
|
||||||
|
dtrain : DMatrix
|
||||||
|
Data to be trained.
|
||||||
|
num_boost_round : int
|
||||||
|
Number of boosting iterations.
|
||||||
|
nfold : int
|
||||||
|
Number of folds in CV.
|
||||||
|
metrics : list of strings
|
||||||
|
Evaluation metrics to be watched in CV.
|
||||||
|
obj : function
|
||||||
|
Custom objective function.
|
||||||
|
feval : function
|
||||||
|
Custom evaluation function.
|
||||||
|
fpreproc : function
|
||||||
|
Preprocessing function that takes (dtrain, dtest, param) and returns
|
||||||
|
transformed versions of those.
|
||||||
|
show_stdv : bool
|
||||||
|
Whether to display the standard deviation.
|
||||||
|
seed : int
|
||||||
|
Seed used to generate the folds (passed to numpy.random.seed).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evaluation history : list(string)
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
|
||||||
|
for i in range(num_boost_round):
|
||||||
|
for fold in cvfolds:
|
||||||
|
fold.update(i, obj)
|
||||||
|
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
|
||||||
|
sys.stderr.write(res + '\n')
|
||||||
|
results.append(res)
|
||||||
|
return results
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user