Update Python documents. (#6376)
This commit is contained in:
parent
c5645180a6
commit
c90f968d92
@ -14,15 +14,15 @@ print('running cross validation')
|
|||||||
# std_value is standard deviation of the metric
|
# std_value is standard deviation of the metric
|
||||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||||
metrics={'error'}, seed=0,
|
metrics={'error'}, seed=0,
|
||||||
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
|
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
|
||||||
|
|
||||||
print('running cross validation, disable standard deviation display')
|
print('running cross validation, disable standard deviation display')
|
||||||
# do cross validation, this will print result out as
|
# do cross validation, this will print result out as
|
||||||
# [iteration] metric_name:mean_value
|
# [iteration] metric_name:mean_value
|
||||||
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
|
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
|
||||||
metrics={'error'}, seed=0,
|
metrics={'error'}, seed=0,
|
||||||
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
|
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
|
||||||
xgb.callback.early_stop(3)])
|
xgb.callback.EarlyStopping(3)])
|
||||||
print(res)
|
print(res)
|
||||||
print('running cross validation, with preprocessing function')
|
print('running cross validation, with preprocessing function')
|
||||||
# define the preprocessing function
|
# define the preprocessing function
|
||||||
|
|||||||
@ -69,13 +69,15 @@ Plotting API
|
|||||||
|
|
||||||
Callback API
|
Callback API
|
||||||
------------
|
------------
|
||||||
.. autofunction:: xgboost.callback.print_evaluation
|
.. autofunction:: xgboost.callback.TrainingCallback
|
||||||
|
|
||||||
.. autofunction:: xgboost.callback.record_evaluation
|
.. autofunction:: xgboost.callback.EvaluationMonitor
|
||||||
|
|
||||||
.. autofunction:: xgboost.callback.reset_learning_rate
|
.. autofunction:: xgboost.callback.EarlyStopping
|
||||||
|
|
||||||
.. autofunction:: xgboost.callback.early_stop
|
.. autofunction:: xgboost.callback.LearningRateScheduler
|
||||||
|
|
||||||
|
.. autofunction:: xgboost.callback.TrainingCheckPoint
|
||||||
|
|
||||||
.. _dask_api:
|
.. _dask_api:
|
||||||
|
|
||||||
@ -91,6 +93,8 @@ Dask API
|
|||||||
|
|
||||||
.. autofunction:: xgboost.dask.predict
|
.. autofunction:: xgboost.dask.predict
|
||||||
|
|
||||||
|
.. autofunction:: xgboost.dask.inplace_predict
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
||||||
|
|||||||
@ -510,7 +510,8 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=True)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.n_features_in_ = X.shape[1]
|
self.n_features_in_ = X.shape[1]
|
||||||
@ -1249,7 +1250,8 @@ class XGBRanker(XGBModel):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=True)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check if group information is provided
|
# check if group information is provided
|
||||||
|
|||||||
@ -123,9 +123,10 @@ class TestCallbacks(unittest.TestCase):
|
|||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
early_stopping_rounds = 5
|
early_stopping_rounds = 5
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds)
|
||||||
cls.fit(X, y, eval_set=[(X, y)],
|
cls.fit(X, y, eval_set=[(X, y)],
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
eval_metric=tm.eval_error_metric,
|
||||||
eval_metric=tm.eval_error_metric)
|
callbacks=[early_stop])
|
||||||
booster = cls.get_booster()
|
booster = cls.get_booster()
|
||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user