Update Python API doc (#3619)

* Add XGBRanker to Python API doc

* Show inherited members of XGBRegressor in API doc, since XGBRegressor uses default methods from XGBModel

* Add table of contents to Python API doc

* Skip JVM doc download if not available

* Show inherited members for XGBRegressor and XGBRanker

* Expose XGBRanker to Python XGBoost module directory

* Add docstring to XGBRegressor.predict() and XGBRanker.predict()

* Fix rendering errors in Python docstrings

* Fix lint
This commit is contained in:
Philip Hyunsu Cho 2018-08-22 18:59:30 -07:00 committed by GitHub
parent 4912c1f9c6
commit 4ed8a88240
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 168 additions and 74 deletions

View File

@ -14,6 +14,7 @@
from subprocess import call
from sh.contrib import git
import urllib.request
from urllib.error import HTTPError
from recommonmark.parser import CommonMarkParser
import sys
import re
@ -24,8 +25,11 @@ import guzzle_sphinx_theme
git_branch = [re.sub(r'origin/', '', x.lstrip(' ')) for x in str(git.branch('-r', '--contains', 'HEAD')).rstrip('\n').split('\n')]
git_branch = [x for x in git_branch if 'HEAD' not in x]
print('git_branch = {}'.format(git_branch[0]))
filename, _ = urllib.request.urlretrieve('https://s3-us-west-2.amazonaws.com/xgboost-docs/{}.tar.bz2'.format(git_branch[0]))
call('if [ -d tmp ]; then rm -rf tmp; fi; mkdir -p tmp/jvm; cd tmp/jvm; tar xvf {}'.format(filename), shell=True)
try:
filename, _ = urllib.request.urlretrieve('https://s3-us-west-2.amazonaws.com/xgboost-docs/{}.tar.bz2'.format(git_branch[0]))
call('if [ -d tmp ]; then rm -rf tmp; fi; mkdir -p tmp/jvm; cd tmp/jvm; tar xvf {}'.format(filename), shell=True)
except HTTPError:
print('JVM doc not found. Skipping...')
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the

View File

@ -274,7 +274,7 @@ and then loading the model in another session:
With regards to ML pipeline save and load, please refer the next section.
Interact with Other Bindings of XGBoost
------------------------------------
---------------------------------------
After we train a model with XGBoost4j-Spark on massive dataset, sometimes we want to do model serving in single machine or integrate it with other single node libraries for further processing. XGBoost4j-Spark supports export model to local by:
.. code-block:: scala

View File

@ -2,6 +2,10 @@ Python API Reference
====================
This page gives the Python API reference of xgboost, please also refer to Python Package Introduction for more information about python package.
.. contents::
:backlinks: none
:local:
Core Data Structure
-------------------
.. automodule:: xgboost.core
@ -29,9 +33,15 @@ Scikit-Learn API
.. automodule:: xgboost.sklearn
.. autoclass:: xgboost.XGBRegressor
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: xgboost.XGBClassifier
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: xgboost.XGBRanker
:members:
:inherited-members:
:show-inheritance:
Plotting API

View File

@ -12,7 +12,7 @@ from .core import DMatrix, Booster
from .training import train, cv
from . import rabit # noqa
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .plotting import plot_importance, plot_tree, to_graphviz
except ImportError:
pass
@ -23,5 +23,5 @@ with open(VERSION_FILE) as f:
__all__ = ['DMatrix', 'Booster',
'train', 'cv',
'XGBModel', 'XGBClassifier', 'XGBRegressor',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'plot_importance', 'plot_tree', 'to_graphviz']

View File

@ -1376,11 +1376,12 @@ class Booster(object):
def get_score(self, fmap='', importance_type='weight'):
"""Get feature importance of each feature.
Importance type can be defined as:
'weight' - the number of times a feature is used to split the data across all trees.
'gain' - the average gain across all splits the feature is used in.
'cover' - the average coverage across all splits the feature is used in.
'total_gain' - the total gain across all splits the feature is used in.
'total_cover' - the total coverage across all splits the feature is used in.
* 'weight': the number of times a feature is used to split the data across all trees.
* 'gain': the average gain across all splits the feature is used in.
* 'cover': the average coverage across all splits the feature is used in.
* 'total_gain': the total gain across all splits the feature is used in.
* 'total_cover': the total coverage across all splits the feature is used in.
Parameters
----------
@ -1496,6 +1497,7 @@ class Booster(object):
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
"""Get split value histogram of a feature
Parameters
----------
feature: str
@ -1506,7 +1508,7 @@ class Booster(object):
The maximum number of bins.
Number of bins equals number of unique split values n_unique,
if bins == None or bins > n_unique.
as_pandas : bool, default True
as_pandas: bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.

View File

@ -28,10 +28,11 @@ def plot_importance(booster, ax=None, height=0.2,
grid : bool, Turn the axes grids on or off. Default is True (On).
importance_type : str, default "weight"
How the importance is calculated: either "weight", "gain", or "cover"
"weight" is the number of times a feature appears in a tree
"gain" is the average gain of splits which use the feature
"cover" is the average coverage of splits which use the feature
where coverage is defined as the number of samples affected by the split
* "weight" is the number of times a feature appears in a tree
* "gain" is the average gain of splits which use the feature
* "cover" is the average coverage of splits which use the feature
where coverage is defined as the number of samples affected by the split
max_num_features : int, default None
Maximum number of top features displayed on plot. If None, all features will be displayed.
height : float, default 0.2

View File

@ -99,14 +99,16 @@ class XGBModel(XGBModelBase):
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
**kwargs : dict, optional
\*\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
Attempting to set a parameter via the constructor args and \*\*kwargs dict simultaneously
will result in a TypeError.
Note:
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
this argument will interact properly with Sklearn.
.. note:: \*\*kwargs unsupported by scikit-learn
\*\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
passed via this argument will interact properly with scikit-learn.
Note
----
@ -217,6 +219,7 @@ class XGBModel(XGBModelBase):
def save_model(self, fname):
"""
Save the model to a file.
Parameters
----------
fname : string
@ -227,6 +230,7 @@ class XGBModel(XGBModelBase):
def load_model(self, fname):
"""
Load the model from a file.
Parameters
----------
fname : string or a memory buffer
@ -336,6 +340,39 @@ class XGBModel(XGBModelBase):
return self
def predict(self, data, output_margin=False, ntree_limit=None):
"""
Predict with `data`.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------
data : DMatrix
The dmatrix storing the input.
output_margin : bool
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Returns
-------
prediction : numpy array
"""
# pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
# get ntree_limit to use - if none specified, default to
@ -372,10 +409,10 @@ class XGBModel(XGBModelBase):
def evals_result(self):
"""Return the evaluation results.
If eval_set is passed to the `fit` function, you can call evals_result() to
get evaluation results for all passed eval_sets. When eval_metric is also
passed to the `fit` function, the evals_result will contain the eval_metrics
passed to the `fit` function
If ``eval_set`` is passed to the `fit` function, you can call ``evals_result()`` to
get evaluation results for all passed eval_sets. When ``eval_metric`` is also
passed to the ``fit`` function, the ``evals_result`` will contain the ``eval_metrics``
passed to the ``fit`` function
Returns
-------
@ -383,20 +420,26 @@ class XGBModel(XGBModelBase):
Example
-------
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
clf = xgb.XGBModel(**param_dist)
.. code-block:: python
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
evals_result = clf.evals_result()
clf = xgb.XGBModel(**param_dist)
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
evals_result = clf.evals_result()
The variable evals_result will contain:
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
.. code-block:: none
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
"""
if self.evals_result_:
evals_result = self.evals_result_
@ -408,9 +451,11 @@ class XGBModel(XGBModelBase):
@property
def feature_importances_(self):
"""
Feature importances property
Returns
-------
feature_importances_ : array of shape = [n_features]
feature_importances_ : array of shape ``[n_features]``
"""
b = self.get_booster()
@ -422,9 +467,8 @@ class XGBModel(XGBModelBase):
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
__doc__ = """Implementation of the scikit-learn API for XGBoost classification.
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
@ -610,10 +654,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def predict_proba(self, data, ntree_limit=None):
"""
Predict the probability of each `data` example being of a given class.
NOTE: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call xgb.copy() to make copies
of model object and then call predict
.. note:: This function is not thread safe
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call predict
Parameters
----------
data : DMatrix
@ -621,6 +668,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Returns
-------
prediction : numpy array
@ -652,20 +700,26 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
Example
-------
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
clf = xgb.XGBClassifier(**param_dist)
.. code-block:: python
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
evals_result = clf.evals_result()
clf = xgb.XGBClassifier(**param_dist)
The variable evals_result will contain:
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
evals_result = clf.evals_result()
The variable ``evals_result`` will contain
.. code-block:: none
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
"""
if self.evals_result_:
evals_result = self.evals_result_
@ -677,8 +731,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
class XGBRanker(XGBModel):
@ -731,14 +785,16 @@ class XGBRanker(XGBModel):
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
**kwargs : dict, optional
\*\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
will result in a TypeError.
Note:
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
this argument will interact properly with Sklearn.
Attempting to set a parameter via the constructor args and \*\*kwargs dict
simultaneously will result in a TypeError.
.. note:: \*\*kwargs unsupported by scikit-learn
\*\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
passed via this argument will interact properly with scikit-learn.
Note
----
@ -750,16 +806,25 @@ class XGBRanker(XGBModel):
For example, if your original data look like:
+-------+-----------+---------------+
| qid | label | features |
+-------+-----------+---------------+
| 1 | 0 | x_1 |
+-------+-----------+---------------+
| 1 | 1 | x_2 |
+-------+-----------+---------------+
| 1 | 0 | x_3 |
+-------+-----------+---------------+
| 2 | 0 | x_4 |
+-------+-----------+---------------+
| 2 | 1 | x_5 |
+-------+-----------+---------------+
| 2 | 1 | x_6 |
+-------+-----------+---------------+
| 2 | 1 | x_7 |
+-------+-----------+---------------+
then your group array should be [3, 4].
then your group array should be ``[3, 4]``.
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
@ -908,3 +973,5 @@ class XGBRanker(XGBModel):
return self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
predict.__doc__ = XGBModel.predict.__doc__

View File

@ -147,18 +147,24 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
and/or num_class appears in the parameters)
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
a parameter containing ('eval_metric': 'logloss')
Returns: {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}
a parameter containing ('eval_metric': 'logloss'), the **evals_result**
returns
.. code-block:: none
{'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}
verbose_eval : bool or int
Requires at least one item in evals.
If `verbose_eval` is True then the evaluation metric on the validation set is
If **verbose_eval** is True then the evaluation metric on the validation set is
printed at each boosting stage.
If `verbose_eval` is an integer then the evaluation metric on the validation set
is printed at every given `verbose_eval` boosting stage. The last boosting stage
/ the boosting stage found by using `early_stopping_rounds` is also printed.
Example: with verbose_eval=4 and at least one item in evals, an evaluation metric
If **verbose_eval** is an integer then the evaluation metric on the validation set
is printed at every given **verbose_eval** boosting stage. The last boosting stage
/ the boosting stage found by using **early_stopping_rounds** is also printed.
Example: with ``verbose_eval=4`` and at least one item in evals, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage.
learning_rates: list or function (deprecated - use callback API instead)
List of learning rate for each boosting round
@ -328,10 +334,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
folds : a KFold or StratifiedKFold instance or list of fold indices
Sklearn KFolds or StratifiedKFolds object.
Alternatively may explicitly pass sample indices for each fold.
For `n` folds, `folds` should be a length `n` list of tuples.
Each tuple is `(in,out)` where `in` is a list of indices to be used
as the training samples for the `n`th fold and `out` is a list of
indices to be used as the testing samples for the `n`th fold.
For ``n`` folds, ``folds`` should be a length ``n`` list of tuples.
Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used
as the training samples for the ``n`` th fold and ``out`` is a list of
indices to be used as the testing samples for the ``n`` th fold.
metrics : string or list of strings
Evaluation metrics to be watched in CV.
obj : function
@ -363,8 +369,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example: [xgb.callback.reset_learning_rate(custom_rates)]
shuffle : bool
Example:
.. code-block:: none
[xgb.callback.reset_learning_rate(custom_rates)]
shuffle : bool
Shuffle data before creating folds.
Returns