Use config_context in sklearn interface. (#8141)
This commit is contained in:
parent
03cc3b359c
commit
9ae547f994
@ -1,3 +1,5 @@
|
|||||||
|
.. _dask-examples:
|
||||||
|
|
||||||
XGBoost Dask Feature Walkthrough
|
XGBoost Dask Feature Walkthrough
|
||||||
================================
|
================================
|
||||||
|
|
||||||
|
|||||||
@ -126,7 +126,7 @@ master_doc = 'index'
|
|||||||
#
|
#
|
||||||
# This is also used if you do content translation via gettext catalogs.
|
# This is also used if you do content translation via gettext catalogs.
|
||||||
# Usually you set "language" from the command line for these cases.
|
# Usually you set "language" from the command line for these cases.
|
||||||
language = None
|
language = "en"
|
||||||
|
|
||||||
autoclass_content = 'both'
|
autoclass_content = 'both'
|
||||||
|
|
||||||
|
|||||||
@ -115,7 +115,7 @@ Alternatively, XGBoost also implements the Scikit-Learn interface with
|
|||||||
:py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is
|
:py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is
|
||||||
similar to the single node Scikit-Learn interface in xgboost, with dask collection as
|
similar to the single node Scikit-Learn interface in xgboost, with dask collection as
|
||||||
inputs and has an additional ``client`` attribute. See following sections and
|
inputs and has an additional ``client`` attribute. See following sections and
|
||||||
:ref:`sphx_glr_python_dask-examples` for more examples.
|
:ref:`dask-examples` for more examples.
|
||||||
|
|
||||||
|
|
||||||
******************
|
******************
|
||||||
|
|||||||
@ -16,6 +16,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
|||||||
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
|
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
|
||||||
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
|
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
|
||||||
dask
|
dask
|
||||||
|
spark_estimator
|
||||||
ray
|
ray
|
||||||
dart
|
dart
|
||||||
monotonic
|
monotonic
|
||||||
|
|||||||
@ -70,6 +70,23 @@ def config_doc(
|
|||||||
# Suppress warning caused by model generated with XGBoost version < 1.0.0
|
# Suppress warning caused by model generated with XGBoost version < 1.0.0
|
||||||
bst = xgb.Booster(model_file='./old_model.bin')
|
bst = xgb.Booster(model_file='./old_model.bin')
|
||||||
assert xgb.get_config()['verbosity'] == 2 # old value restored
|
assert xgb.get_config()['verbosity'] == 2 # old value restored
|
||||||
|
|
||||||
|
Nested configuration context is also supported:
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
with xgb.config_context(verbosity=3):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
|
with xgb.config_context(verbosity=2):
|
||||||
|
assert xgb.get_config()["verbosity"] == 2
|
||||||
|
|
||||||
|
xgb.set_config(verbosity=2)
|
||||||
|
assert xgb.get_config()["verbosity"] == 2
|
||||||
|
with xgb.config_context(verbosity=3):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def none_to_str(value: Optional[str]) -> str:
|
def none_to_str(value: Optional[str]) -> str:
|
||||||
@ -98,7 +115,11 @@ def config_doc(
|
|||||||
Keyword arguments representing the parameters and their values
|
Keyword arguments representing the parameters and their values
|
||||||
""")
|
""")
|
||||||
def set_config(**new_config: Any) -> None:
|
def set_config(**new_config: Any) -> None:
|
||||||
config = json.dumps(new_config)
|
not_none = {}
|
||||||
|
for k, v in new_config.items():
|
||||||
|
if v is not None:
|
||||||
|
not_none[k] = v
|
||||||
|
config = json.dumps(not_none)
|
||||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -113,7 +113,9 @@ if __name__ == "__main__":
|
|||||||
run_formatter(path)
|
run_formatter(path)
|
||||||
for path in [
|
for path in [
|
||||||
"python-package/xgboost/dask.py",
|
"python-package/xgboost/dask.py",
|
||||||
|
"python-package/xgboost/sklearn.py",
|
||||||
"python-package/xgboost/spark",
|
"python-package/xgboost/spark",
|
||||||
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_spark/test_data.py",
|
"tests/python/test_spark/test_data.py",
|
||||||
"tests/python-gpu/test_gpu_spark/test_data.py",
|
"tests/python-gpu/test_gpu_spark/test_data.py",
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
import multiprocessing
|
||||||
import xgboost as xgb
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import testing as tm
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
|
@pytest.mark.parametrize("verbosity_level", [0, 1, 2, 3])
|
||||||
def test_global_config_verbosity(verbosity_level):
|
def test_global_config_verbosity(verbosity_level):
|
||||||
def get_current_verbosity():
|
def get_current_verbosity():
|
||||||
return xgb.get_config()['verbosity']
|
return xgb.get_config()["verbosity"]
|
||||||
|
|
||||||
old_verbosity = get_current_verbosity()
|
old_verbosity = get_current_verbosity()
|
||||||
with xgb.config_context(verbosity=verbosity_level):
|
with xgb.config_context(verbosity=verbosity_level):
|
||||||
@ -16,13 +18,48 @@ def test_global_config_verbosity(verbosity_level):
|
|||||||
assert old_verbosity == get_current_verbosity()
|
assert old_verbosity == get_current_verbosity()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('use_rmm', [False, True])
|
@pytest.mark.parametrize("use_rmm", [False, True])
|
||||||
def test_global_config_use_rmm(use_rmm):
|
def test_global_config_use_rmm(use_rmm):
|
||||||
def get_current_use_rmm_flag():
|
def get_current_use_rmm_flag():
|
||||||
return xgb.get_config()['use_rmm']
|
return xgb.get_config()["use_rmm"]
|
||||||
|
|
||||||
old_use_rmm_flag = get_current_use_rmm_flag()
|
old_use_rmm_flag = get_current_use_rmm_flag()
|
||||||
with xgb.config_context(use_rmm=use_rmm):
|
with xgb.config_context(use_rmm=use_rmm):
|
||||||
new_use_rmm_flag = get_current_use_rmm_flag()
|
new_use_rmm_flag = get_current_use_rmm_flag()
|
||||||
assert new_use_rmm_flag == use_rmm
|
assert new_use_rmm_flag == use_rmm
|
||||||
assert old_use_rmm_flag == get_current_use_rmm_flag()
|
assert old_use_rmm_flag == get_current_use_rmm_flag()
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_config():
|
||||||
|
with xgb.config_context(verbosity=3):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
|
with xgb.config_context(verbosity=2):
|
||||||
|
assert xgb.get_config()["verbosity"] == 2
|
||||||
|
with xgb.config_context(verbosity=1):
|
||||||
|
assert xgb.get_config()["verbosity"] == 1
|
||||||
|
assert xgb.get_config()["verbosity"] == 2
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
|
|
||||||
|
with xgb.config_context(verbosity=3):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
|
with xgb.config_context(verbosity=None):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3 # None has no effect
|
||||||
|
|
||||||
|
verbosity = xgb.get_config()["verbosity"]
|
||||||
|
xgb.set_config(verbosity=2)
|
||||||
|
assert xgb.get_config()["verbosity"] == 2
|
||||||
|
with xgb.config_context(verbosity=3):
|
||||||
|
assert xgb.get_config()["verbosity"] == 3
|
||||||
|
xgb.set_config(verbosity=verbosity) # reset
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safty():
|
||||||
|
n_threads = multiprocessing.cpu_count()
|
||||||
|
futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=n_threads) as executor:
|
||||||
|
for i in range(256):
|
||||||
|
f = executor.submit(test_nested_config)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user