Use black on more Python files. (#8137)

This commit is contained in:
Jiaming Yuan
2022-08-11 01:38:11 +08:00
committed by GitHub
parent bdb291f1c2
commit 570f8ae4ba
14 changed files with 183 additions and 133 deletions

View File

@@ -3,26 +3,32 @@
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
"""
from .core import (
DMatrix,
DeviceQuantileDMatrix,
QuantileDMatrix,
Booster,
DataIter,
build_info,
_py_version,
)
from .training import train, cv
from . import rabit # noqa
from . import tracker # noqa
from .tracker import RabitTracker # noqa
from . import dask
from .core import (
Booster,
DataIter,
DeviceQuantileDMatrix,
DMatrix,
QuantileDMatrix,
_py_version,
build_info,
)
from .tracker import RabitTracker # noqa
from .training import cv, train
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .config import config_context, get_config, set_config
from .plotting import plot_importance, plot_tree, to_graphviz
from .config import set_config, get_config, config_context
from .sklearn import (
XGBClassifier,
XGBModel,
XGBRanker,
XGBRegressor,
XGBRFClassifier,
XGBRFRegressor,
)
except ImportError:
pass

View File

@@ -1,7 +1,7 @@
"""Shared typing definition."""
import ctypes
import os
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack

View File

@@ -1,20 +1,21 @@
# pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies."""
from typing import Any, Type, Dict, Optional, List, Sequence, cast
import sys
import types
import importlib.util
import logging
import sys
import types
from typing import Any, Dict, List, Optional, Sequence, Type, cast
import numpy as np
from ._typing import _T
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
def py_str(x: bytes) -> str:
"""convert c string back to python string"""
return x.decode('utf-8') # type: ignore
return x.decode("utf-8") # type: ignore
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
@@ -30,8 +31,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
# pandas
try:
from pandas import DataFrame, Series
from pandas import MultiIndex
from pandas import DataFrame, MultiIndex, Series
from pandas import concat as pandas_concat
PANDAS_INSTALLED = True
@@ -45,23 +45,17 @@ except ImportError:
# sklearn
try:
from sklearn.base import (
BaseEstimator as XGBModelBase,
RegressorMixin as XGBRegressorBase,
ClassifierMixin as XGBClassifierBase
)
from sklearn.base import BaseEstimator as XGBModelBase
from sklearn.base import ClassifierMixin as XGBClassifierBase
from sklearn.base import RegressorMixin as XGBRegressorBase
from sklearn.preprocessing import LabelEncoder
try:
from sklearn.model_selection import (
KFold as XGBKFold,
StratifiedKFold as XGBStratifiedKFold
)
from sklearn.model_selection import KFold as XGBKFold
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
except ImportError:
from sklearn.cross_validation import (
KFold as XGBKFold,
StratifiedKFold as XGBStratifiedKFold
)
from sklearn.cross_validation import KFold as XGBKFold
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
SKLEARN_INSTALLED = True
@@ -79,9 +73,10 @@ except ImportError:
class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.'''
"""Label encoder with JSON serialization methods."""
def to_json(self) -> Dict:
'''Returns a JSON compatible dictionary'''
"""Returns a JSON compatible dictionary"""
meta = {}
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
@@ -92,10 +87,10 @@ class XGBoostLabelEncoder(LabelEncoder):
def from_json(self, doc: Dict) -> None:
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
"""Load the encoder back from a JSON compatible dict."""
meta = {}
for k, v in doc.items():
if k == 'classes_':
if k == "classes_":
self.classes_ = np.array(v)
continue
meta[k] = v
@@ -159,15 +154,14 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
# KIND, either express or implied. See the License for the specific language governing
# permissions and limitations under the License.
class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
"""
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
def __init__(
self,
local_name: str,
parent_module_globals: Dict,
name: str,
warning: Optional[str] = None
self,
local_name: str,
parent_module_globals: Dict,
name: str,
warning: Optional[str] = None,
) -> None:
self._local_name = local_name
self._parent_module_globals = parent_module_globals

View File

@@ -4,10 +4,10 @@ import ctypes
import json
from contextlib import contextmanager
from functools import wraps
from typing import Optional, Callable, Any, Dict, cast, Iterator
from typing import Any, Callable, Dict, Iterator, Optional, cast
from .core import _LIB, _check_call, c_str, py_str
from ._typing import _F
from .core import _LIB, _check_call, c_str, py_str
def config_doc(
@@ -90,30 +90,39 @@ def config_doc(
"""
def none_to_str(value: Optional[str]) -> str:
return '' if value is None else value
return "" if value is None else value
def config_doc_decorator(func: _F) -> _F:
func.__doc__ = (doc_template.format(header=none_to_str(header),
extra_note=none_to_str(extra_note))
+ none_to_str(parameters) + none_to_str(returns)
+ none_to_str(common_example) + none_to_str(see_also))
func.__doc__ = (
doc_template.format(
header=none_to_str(header), extra_note=none_to_str(extra_note)
)
+ none_to_str(parameters)
+ none_to_str(returns)
+ none_to_str(common_example)
+ none_to_str(see_also)
)
@wraps(func)
def wrap(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
return cast(_F, wrap)
return config_doc_decorator
@config_doc(header="""
@config_doc(
header="""
Set global configuration.
""",
parameters="""
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""")
""",
)
def set_config(**new_config: Any) -> None:
not_none = {}
for k, v in new_config.items():
@@ -123,15 +132,17 @@ def set_config(**new_config: Any) -> None:
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
@config_doc(header="""
@config_doc(
header="""
Get current values of the global configuration.
""",
returns="""
returns="""
Returns
-------
args: Dict[str, Any]
The list of global parameters and their values
""")
""",
)
def get_config() -> Dict[str, Any]:
config_str = ctypes.c_char_p()
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
@@ -142,27 +153,29 @@ def get_config() -> Dict[str, Any]:
@contextmanager
@config_doc(header="""
@config_doc(
header="""
Context manager for global XGBoost configuration.
""",
parameters="""
parameters="""
Parameters
----------
new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values
""",
extra_note="""
extra_note="""
.. note::
All settings, not just those presently modified, will be returned to their
previous values when the context manager is exited. This is not thread-safe.
""",
see_also="""
see_also="""
See Also
--------
set_config: Set global XGBoost configuration
get_config: Get current values of the global configuration
""")
""",
)
def config_context(**new_config: Any) -> Iterator[None]:
old_config = get_config().copy()
set_config(**new_config)

View File

@@ -399,11 +399,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
Parameters
----------
cache_prefix:
Prefix to the cache files, only used in external memory. It can be either an URI
or a file path.
Prefix to the cache files, only used in external memory. It can be either an
URI or a file path.
"""
_T = TypeVar("_T")
def __init__(self, cache_prefix: Optional[str] = None) -> None:
self.cache_prefix = cache_prefix
@@ -1010,7 +1009,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Returns
-------
number of columns : int
number of columns
"""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))

View File

@@ -1,6 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines, fixme
# pylint: disable=too-many-lines
# pylint: disable=too-few-public-methods
# pylint: disable=import-error
"""
@@ -227,7 +227,7 @@ class RabitContext(rabit.RabitContext):
)
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
def dconcat(value: Sequence[_T]) -> _T:
"""Concatenate sequence of partitions."""
try:
return concat(value)
@@ -253,7 +253,7 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
# pylint: disable=too-many-instance-attributes
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
@@ -486,6 +486,12 @@ class DaskDMatrix:
}
def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns
"""
return self._n_cols

View File

@@ -1,13 +1,15 @@
"""XGBoost Federated Learning related API."""
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
def run_federated_server(port: int,
world_size: int,
server_key_path: str,
server_cert_path: str,
client_cert_path: str) -> None:
def run_federated_server(
port: int,
world_size: int,
server_key_path: str,
server_cert_path: str,
client_cert_path: str,
) -> None:
"""Run the Federated Learning server.
Parameters
@@ -23,12 +25,16 @@ def run_federated_server(port: int,
client_cert_path: str
Path to the client certificate file.
"""
if build_info()['USE_FEDERATED']:
_check_call(_LIB.XGBRunFederatedServer(port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path)))
if build_info()["USE_FEDERATED"]:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "