Use config_context in sklearn interface. (#8141)

This commit is contained in:
Jiaming Yuan 2022-08-09 14:48:54 +08:00 committed by GitHub
parent 03cc3b359c
commit 9ae547f994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 560 additions and 438 deletions

View File

@ -1,3 +1,5 @@
.. _dask-examples:
XGBoost Dask Feature Walkthrough XGBoost Dask Feature Walkthrough
================================ ================================

View File

@ -126,7 +126,7 @@ master_doc = 'index'
# #
# This is also used if you do content translation via gettext catalogs. # This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases. # Usually you set "language" from the command line for these cases.
language = None language = "en"
autoclass_content = 'both' autoclass_content = 'both'

View File

@ -115,7 +115,7 @@ Alternatively, XGBoost also implements the Scikit-Learn interface with
:py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is :py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is
similar to the single node Scikit-Learn interface in xgboost, with dask collection as similar to the single node Scikit-Learn interface in xgboost, with dask collection as
inputs and has an additional ``client`` attribute. See following sections and inputs and has an additional ``client`` attribute. See following sections and
:ref:`sphx_glr_python_dask-examples` for more examples. :ref:`dask-examples` for more examples.
****************** ******************

View File

@ -16,6 +16,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html> Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html> Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
dask dask
spark_estimator
ray ray
dart dart
monotonic monotonic

View File

@ -70,6 +70,23 @@ def config_doc(
# Suppress warning caused by model generated with XGBoost version < 1.0.0 # Suppress warning caused by model generated with XGBoost version < 1.0.0
bst = xgb.Booster(model_file='./old_model.bin') bst = xgb.Booster(model_file='./old_model.bin')
assert xgb.get_config()['verbosity'] == 2 # old value restored assert xgb.get_config()['verbosity'] == 2 # old value restored
Nested configuration context is also supported:
Example
-------
.. code-block:: python
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=2):
assert xgb.get_config()["verbosity"] == 2
xgb.set_config(verbosity=2)
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
""" """
def none_to_str(value: Optional[str]) -> str: def none_to_str(value: Optional[str]) -> str:
@ -98,7 +115,11 @@ def config_doc(
Keyword arguments representing the parameters and their values Keyword arguments representing the parameters and their values
""") """)
def set_config(**new_config: Any) -> None: def set_config(**new_config: Any) -> None:
config = json.dumps(new_config) not_none = {}
for k, v in new_config.items():
if v is not None:
not_none[k] = v
config = json.dumps(not_none)
_check_call(_LIB.XGBSetGlobalConfig(c_str(config))) _check_call(_LIB.XGBSetGlobalConfig(c_str(config)))

File diff suppressed because it is too large Load Diff

View File

@ -113,7 +113,9 @@ if __name__ == "__main__":
run_formatter(path) run_formatter(path)
for path in [ for path in [
"python-package/xgboost/dask.py", "python-package/xgboost/dask.py",
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark", "python-package/xgboost/spark",
"tests/python/test_config.py",
"tests/python/test_spark/test_data.py", "tests/python/test_spark/test_data.py",
"tests/python-gpu/test_gpu_spark/test_data.py", "tests/python-gpu/test_gpu_spark/test_data.py",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",

View File

@ -1,13 +1,15 @@
# -*- coding: utf-8 -*- import multiprocessing
import xgboost as xgb from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
import testing as tm
import xgboost as xgb
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3]) @pytest.mark.parametrize("verbosity_level", [0, 1, 2, 3])
def test_global_config_verbosity(verbosity_level): def test_global_config_verbosity(verbosity_level):
def get_current_verbosity(): def get_current_verbosity():
return xgb.get_config()['verbosity'] return xgb.get_config()["verbosity"]
old_verbosity = get_current_verbosity() old_verbosity = get_current_verbosity()
with xgb.config_context(verbosity=verbosity_level): with xgb.config_context(verbosity=verbosity_level):
@ -16,13 +18,48 @@ def test_global_config_verbosity(verbosity_level):
assert old_verbosity == get_current_verbosity() assert old_verbosity == get_current_verbosity()
@pytest.mark.parametrize('use_rmm', [False, True]) @pytest.mark.parametrize("use_rmm", [False, True])
def test_global_config_use_rmm(use_rmm): def test_global_config_use_rmm(use_rmm):
def get_current_use_rmm_flag(): def get_current_use_rmm_flag():
return xgb.get_config()['use_rmm'] return xgb.get_config()["use_rmm"]
old_use_rmm_flag = get_current_use_rmm_flag() old_use_rmm_flag = get_current_use_rmm_flag()
with xgb.config_context(use_rmm=use_rmm): with xgb.config_context(use_rmm=use_rmm):
new_use_rmm_flag = get_current_use_rmm_flag() new_use_rmm_flag = get_current_use_rmm_flag()
assert new_use_rmm_flag == use_rmm assert new_use_rmm_flag == use_rmm
assert old_use_rmm_flag == get_current_use_rmm_flag() assert old_use_rmm_flag == get_current_use_rmm_flag()
def test_nested_config():
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=2):
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=1):
assert xgb.get_config()["verbosity"] == 1
assert xgb.get_config()["verbosity"] == 2
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=None):
assert xgb.get_config()["verbosity"] == 3 # None has no effect
verbosity = xgb.get_config()["verbosity"]
xgb.set_config(verbosity=2)
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
xgb.set_config(verbosity=verbosity) # reset
def test_thread_safty():
n_threads = multiprocessing.cpu_count()
futures = []
with ThreadPoolExecutor(max_workers=n_threads) as executor:
for i in range(256):
f = executor.submit(test_nested_config)
futures.append(f)
for f in futures:
f.result()