Merge pull request #3 from dmlc/master
Getting latest version from dmlc
This commit is contained in:
commit
f116722e68
@ -50,3 +50,4 @@ List of Contributors
|
|||||||
* [Hongliang Liu](https://github.com/phunterlau)
|
* [Hongliang Liu](https://github.com/phunterlau)
|
||||||
- Hongliang is the maintainer of xgboost python PyPI package for pip installation.
|
- Hongliang is the maintainer of xgboost python PyPI package for pip installation.
|
||||||
* [Huayi Zhang](https://github.com/irachex)
|
* [Huayi Zhang](https://github.com/irachex)
|
||||||
|
* [Johan Manders](https://github.com/johanmanders)
|
||||||
|
|||||||
@ -9,4 +9,6 @@ XGBoost Python Feature Walkthrough
|
|||||||
* [Predicting leaf indices](predict_leaf_indices.py)
|
* [Predicting leaf indices](predict_leaf_indices.py)
|
||||||
* [Sklearn Wrapper](sklearn_examples.py)
|
* [Sklearn Wrapper](sklearn_examples.py)
|
||||||
* [Sklearn Parallel](sklearn_parallel.py)
|
* [Sklearn Parallel](sklearn_parallel.py)
|
||||||
|
* [Sklearn access evals result](sklearn_evals_result.py)
|
||||||
|
* [Access evals result](evals_result.py)
|
||||||
* [External Memory](external_memory.py)
|
* [External Memory](external_memory.py)
|
||||||
|
|||||||
30
demo/guide-python/evals_result.py
Normal file
30
demo/guide-python/evals_result.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
##
|
||||||
|
# This script demonstrate how to access the eval metrics in xgboost
|
||||||
|
##
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
dtrain = xgb.DMatrix('../data/agaricus.txt.train', silent=True)
|
||||||
|
dtest = xgb.DMatrix('../data/agaricus.txt.test', silent=True)
|
||||||
|
|
||||||
|
param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
|
||||||
|
|
||||||
|
num_round = 2
|
||||||
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
|
|
||||||
|
evals_result = {}
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
|
||||||
|
|
||||||
|
print('Access logloss metric directly from evals_result:')
|
||||||
|
print(evals_result['eval']['logloss'])
|
||||||
|
|
||||||
|
print('')
|
||||||
|
print('Access metrics through a loop:')
|
||||||
|
for e_name, e_mtrs in evals_result.items():
|
||||||
|
print('- {}'.format(e_name))
|
||||||
|
for e_mtr_name, e_mtr_vals in e_mtrs.items():
|
||||||
|
print(' - {}'.format(e_mtr_name))
|
||||||
|
print(' - {}'.format(e_mtr_vals))
|
||||||
|
|
||||||
|
print('')
|
||||||
|
print('Access complete dictionary:')
|
||||||
|
print(evals_result)
|
||||||
43
demo/guide-python/sklearn_evals_result.py
Normal file
43
demo/guide-python/sklearn_evals_result.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
##
|
||||||
|
# This script demonstrate how to access the xgboost eval metrics by using sklearn
|
||||||
|
##
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import make_hastie_10_2
|
||||||
|
|
||||||
|
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
|
||||||
|
|
||||||
|
# Map labels from {-1, 1} to {0, 1}
|
||||||
|
labels, y = np.unique(y, return_inverse=True)
|
||||||
|
|
||||||
|
X_train, X_test = X[:1600], X[1600:]
|
||||||
|
y_train, y_test = y[:1600], y[1600:]
|
||||||
|
|
||||||
|
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
|
||||||
|
|
||||||
|
clf = xgb.XGBModel(**param_dist)
|
||||||
|
# Or you can use: clf = xgb.XGBClassifier(**param_dist)
|
||||||
|
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
# Load evals result by calling the evals_result() function
|
||||||
|
evals_result = clf.evals_result()
|
||||||
|
|
||||||
|
print('Access logloss metric directly from validation_0:')
|
||||||
|
print(evals_result['validation_0']['logloss'])
|
||||||
|
|
||||||
|
print('')
|
||||||
|
print('Access metrics through a loop:')
|
||||||
|
for e_name, e_mtrs in evals_result.items():
|
||||||
|
print('- {}'.format(e_name))
|
||||||
|
for e_mtr_name, e_mtr_vals in e_mtrs.items():
|
||||||
|
print(' - {}'.format(e_mtr_name))
|
||||||
|
print(' - {}'.format(e_mtr_vals))
|
||||||
|
|
||||||
|
print('')
|
||||||
|
print('Access complete dict:')
|
||||||
|
print(evals_result)
|
||||||
@ -165,7 +165,7 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
|
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
|
||||||
|
|
||||||
eval_results = {}
|
evals_result = {}
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
||||||
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
||||||
@ -185,13 +185,14 @@ class XGBModel(XGBModelBase):
|
|||||||
self._Booster = train(params, trainDmatrix,
|
self._Booster = train(params, trainDmatrix,
|
||||||
self.n_estimators, evals=evals,
|
self.n_estimators, evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=eval_results, feval=feval,
|
evals_result=evals_result, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
if eval_results:
|
|
||||||
eval_results = {k: np.array(v, dtype=float)
|
if evals_result:
|
||||||
for k, v in eval_results.items()}
|
for val in evals_result.items():
|
||||||
eval_results = {k: np.array(v) for k, v in eval_results.items()}
|
evals_result_key = val[1].keys()[0]
|
||||||
self.eval_results = eval_results
|
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||||
|
self.evals_result_ = evals_result
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
self.best_score = self._Booster.best_score
|
self.best_score = self._Booster.best_score
|
||||||
@ -203,6 +204,42 @@ class XGBModel(XGBModelBase):
|
|||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
return self.booster().predict(test_dmatrix)
|
return self.booster().predict(test_dmatrix)
|
||||||
|
|
||||||
|
def evals_result(self):
|
||||||
|
"""Return the evaluation results.
|
||||||
|
|
||||||
|
If eval_set is passed to the `fit` function, you can call evals_result() to
|
||||||
|
get evaluation results for all passed eval_sets. When eval_metric is also
|
||||||
|
passed to the `fit` function, the evals_result will contain the eval_metrics
|
||||||
|
passed to the `fit` function
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evals_result : dictionary
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
|
||||||
|
|
||||||
|
clf = xgb.XGBModel(**param_dist)
|
||||||
|
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
evals_result = clf.evals_result()
|
||||||
|
|
||||||
|
The variable evals_result will contain:
|
||||||
|
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
||||||
|
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
||||||
|
"""
|
||||||
|
if self.evals_result_:
|
||||||
|
evals_result = self.evals_result_
|
||||||
|
else:
|
||||||
|
raise XGBoostError('No results.')
|
||||||
|
|
||||||
|
return evals_result
|
||||||
|
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
@ -259,7 +296,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
"""
|
"""
|
||||||
eval_results = {}
|
evals_result = {}
|
||||||
self.classes_ = list(np.unique(y))
|
self.classes_ = list(np.unique(y))
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
@ -299,13 +336,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=eval_results, feval=feval,
|
evals_result=evals_result, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
|
|
||||||
if eval_results:
|
if evals_result:
|
||||||
eval_results = {k: np.array(v, dtype=float)
|
for val in evals_result.items():
|
||||||
for k, v in eval_results.items()}
|
evals_result_key = val[1].keys()[0]
|
||||||
self.eval_results = eval_results
|
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||||
|
self.evals_result_ = evals_result
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
self.best_score = self._Booster.best_score
|
self.best_score = self._Booster.best_score
|
||||||
@ -333,6 +371,42 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
classzero_probs = 1.0 - classone_probs
|
classzero_probs = 1.0 - classone_probs
|
||||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
|
def evals_result(self):
|
||||||
|
"""Return the evaluation results.
|
||||||
|
|
||||||
|
If eval_set is passed to the `fit` function, you can call evals_result() to
|
||||||
|
get evaluation results for all passed eval_sets. When eval_metric is also
|
||||||
|
passed to the `fit` function, the evals_result will contain the eval_metrics
|
||||||
|
passed to the `fit` function
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evals_result : dictionary
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(**param_dist)
|
||||||
|
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
evals_result = clf.evals_result()
|
||||||
|
|
||||||
|
The variable evals_result will contain:
|
||||||
|
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
||||||
|
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
||||||
|
"""
|
||||||
|
if self.evals_result_:
|
||||||
|
evals_result = self.evals_result_
|
||||||
|
else:
|
||||||
|
raise XGBoostError('No results.')
|
||||||
|
|
||||||
|
return evals_result
|
||||||
|
|
||||||
class XGBRegressor(XGBModel, XGBRegressorBase):
|
class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
|
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
|
||||||
|
|||||||
@ -38,7 +38,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
If early stopping occurs, the model will have two additional fields:
|
If early stopping occurs, the model will have two additional fields:
|
||||||
bst.best_score and bst.best_iteration.
|
bst.best_score and bst.best_iteration.
|
||||||
evals_result: dict
|
evals_result: dict
|
||||||
This dictionary stores the evaluation results of all the items in watchlist
|
This dictionary stores the evaluation results of all the items in watchlist.
|
||||||
|
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
|
||||||
|
and a paramater containing ('eval_metric', 'logloss')
|
||||||
|
Returns: {'train': {'logloss': ['0.48253', '0.35953']},
|
||||||
|
'eval': {'logloss': ['0.480385', '0.357756']}}
|
||||||
verbose_eval : bool
|
verbose_eval : bool
|
||||||
If `verbose_eval` then the evaluation metric on the validation set, if
|
If `verbose_eval` then the evaluation metric on the validation set, if
|
||||||
given, is printed at each boosting stage.
|
given, is printed at each boosting stage.
|
||||||
@ -317,4 +321,3 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
|||||||
results = np.array(results)
|
results = np.array(results)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user