Use black on more Python files. (#8137)
This commit is contained in:
@@ -3,26 +3,32 @@
|
||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
from .core import (
|
||||
DMatrix,
|
||||
DeviceQuantileDMatrix,
|
||||
QuantileDMatrix,
|
||||
Booster,
|
||||
DataIter,
|
||||
build_info,
|
||||
_py_version,
|
||||
)
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
from .tracker import RabitTracker # noqa
|
||||
from . import dask
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
DeviceQuantileDMatrix,
|
||||
DMatrix,
|
||||
QuantileDMatrix,
|
||||
_py_version,
|
||||
build_info,
|
||||
)
|
||||
from .tracker import RabitTracker # noqa
|
||||
from .training import cv, train
|
||||
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
from .config import config_context, get_config, set_config
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
from .config import set_config, get_config, config_context
|
||||
from .sklearn import (
|
||||
XGBClassifier,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRegressor,
|
||||
XGBRFClassifier,
|
||||
XGBRFRegressor,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Shared typing definition."""
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict
|
||||
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
|
||||
|
||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||
# cudf.DataFrame/cupy.array/dlpack
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
# pylint: disable= invalid-name, unused-import
|
||||
"""For compatibility and optional dependencies."""
|
||||
from typing import Any, Type, Dict, Optional, List, Sequence, cast
|
||||
import sys
|
||||
import types
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
|
||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
|
||||
|
||||
|
||||
def py_str(x: bytes) -> str:
|
||||
"""convert c string back to python string"""
|
||||
return x.decode('utf-8') # type: ignore
|
||||
return x.decode("utf-8") # type: ignore
|
||||
|
||||
|
||||
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||
@@ -30,8 +31,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame, Series
|
||||
from pandas import MultiIndex
|
||||
from pandas import DataFrame, MultiIndex, Series
|
||||
from pandas import concat as pandas_concat
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
@@ -45,23 +45,17 @@ except ImportError:
|
||||
|
||||
# sklearn
|
||||
try:
|
||||
from sklearn.base import (
|
||||
BaseEstimator as XGBModelBase,
|
||||
RegressorMixin as XGBRegressorBase,
|
||||
ClassifierMixin as XGBClassifierBase
|
||||
)
|
||||
from sklearn.base import BaseEstimator as XGBModelBase
|
||||
from sklearn.base import ClassifierMixin as XGBClassifierBase
|
||||
from sklearn.base import RegressorMixin as XGBRegressorBase
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
try:
|
||||
from sklearn.model_selection import (
|
||||
KFold as XGBKFold,
|
||||
StratifiedKFold as XGBStratifiedKFold
|
||||
)
|
||||
from sklearn.model_selection import KFold as XGBKFold
|
||||
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
|
||||
except ImportError:
|
||||
from sklearn.cross_validation import (
|
||||
KFold as XGBKFold,
|
||||
StratifiedKFold as XGBStratifiedKFold
|
||||
)
|
||||
from sklearn.cross_validation import KFold as XGBKFold
|
||||
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
|
||||
|
||||
SKLEARN_INSTALLED = True
|
||||
|
||||
@@ -79,9 +73,10 @@ except ImportError:
|
||||
|
||||
|
||||
class XGBoostLabelEncoder(LabelEncoder):
|
||||
'''Label encoder with JSON serialization methods.'''
|
||||
"""Label encoder with JSON serialization methods."""
|
||||
|
||||
def to_json(self) -> Dict:
|
||||
'''Returns a JSON compatible dictionary'''
|
||||
"""Returns a JSON compatible dictionary"""
|
||||
meta = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
@@ -92,10 +87,10 @@ class XGBoostLabelEncoder(LabelEncoder):
|
||||
|
||||
def from_json(self, doc: Dict) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
'''Load the encoder back from a JSON compatible dict.'''
|
||||
"""Load the encoder back from a JSON compatible dict."""
|
||||
meta = {}
|
||||
for k, v in doc.items():
|
||||
if k == 'classes_':
|
||||
if k == "classes_":
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
meta[k] = v
|
||||
@@ -159,15 +154,14 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
|
||||
# KIND, either express or implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
"""
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: Dict,
|
||||
name: str,
|
||||
warning: Optional[str] = None
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: Dict,
|
||||
name: str,
|
||||
warning: Optional[str] = None,
|
||||
) -> None:
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
|
||||
@@ -4,10 +4,10 @@ import ctypes
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Optional, Callable, Any, Dict, cast, Iterator
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, cast
|
||||
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
from ._typing import _F
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
|
||||
|
||||
def config_doc(
|
||||
@@ -90,30 +90,39 @@ def config_doc(
|
||||
"""
|
||||
|
||||
def none_to_str(value: Optional[str]) -> str:
|
||||
return '' if value is None else value
|
||||
return "" if value is None else value
|
||||
|
||||
def config_doc_decorator(func: _F) -> _F:
|
||||
func.__doc__ = (doc_template.format(header=none_to_str(header),
|
||||
extra_note=none_to_str(extra_note))
|
||||
+ none_to_str(parameters) + none_to_str(returns)
|
||||
+ none_to_str(common_example) + none_to_str(see_also))
|
||||
func.__doc__ = (
|
||||
doc_template.format(
|
||||
header=none_to_str(header), extra_note=none_to_str(extra_note)
|
||||
)
|
||||
+ none_to_str(parameters)
|
||||
+ none_to_str(returns)
|
||||
+ none_to_str(common_example)
|
||||
+ none_to_str(see_also)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(_F, wrap)
|
||||
|
||||
return config_doc_decorator
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Set global configuration.
|
||||
""",
|
||||
parameters="""
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def set_config(**new_config: Any) -> None:
|
||||
not_none = {}
|
||||
for k, v in new_config.items():
|
||||
@@ -123,15 +132,17 @@ def set_config(**new_config: Any) -> None:
|
||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Get current values of the global configuration.
|
||||
""",
|
||||
returns="""
|
||||
returns="""
|
||||
Returns
|
||||
-------
|
||||
args: Dict[str, Any]
|
||||
The list of global parameters and their values
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def get_config() -> Dict[str, Any]:
|
||||
config_str = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||
@@ -142,27 +153,29 @@ def get_config() -> Dict[str, Any]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Context manager for global XGBoost configuration.
|
||||
""",
|
||||
parameters="""
|
||||
parameters="""
|
||||
Parameters
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""",
|
||||
extra_note="""
|
||||
extra_note="""
|
||||
.. note::
|
||||
|
||||
All settings, not just those presently modified, will be returned to their
|
||||
previous values when the context manager is exited. This is not thread-safe.
|
||||
""",
|
||||
see_also="""
|
||||
see_also="""
|
||||
See Also
|
||||
--------
|
||||
set_config: Set global XGBoost configuration
|
||||
get_config: Get current values of the global configuration
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def config_context(**new_config: Any) -> Iterator[None]:
|
||||
old_config = get_config().copy()
|
||||
set_config(**new_config)
|
||||
|
||||
@@ -399,11 +399,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
Parameters
|
||||
----------
|
||||
cache_prefix:
|
||||
Prefix to the cache files, only used in external memory. It can be either an URI
|
||||
or a file path.
|
||||
Prefix to the cache files, only used in external memory. It can be either an
|
||||
URI or a file path.
|
||||
|
||||
"""
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
@@ -1010,7 +1009,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns : int
|
||||
number of columns
|
||||
"""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||
# pylint: disable=too-many-arguments, too-many-locals
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines, fixme
|
||||
# pylint: disable=too-many-lines
|
||||
# pylint: disable=too-few-public-methods
|
||||
# pylint: disable=import-error
|
||||
"""
|
||||
@@ -227,7 +227,7 @@ class RabitContext(rabit.RabitContext):
|
||||
)
|
||||
|
||||
|
||||
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
|
||||
def dconcat(value: Sequence[_T]) -> _T:
|
||||
"""Concatenate sequence of partitions."""
|
||||
try:
|
||||
return concat(value)
|
||||
@@ -253,7 +253,7 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
||||
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
||||
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
@@ -486,6 +486,12 @@ class DaskDMatrix:
|
||||
}
|
||||
|
||||
def num_col(self) -> int:
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns
|
||||
"""
|
||||
return self._n_cols
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""XGBoost Federated Learning related API."""
|
||||
|
||||
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
|
||||
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
||||
|
||||
|
||||
def run_federated_server(port: int,
|
||||
world_size: int,
|
||||
server_key_path: str,
|
||||
server_cert_path: str,
|
||||
client_cert_path: str) -> None:
|
||||
def run_federated_server(
|
||||
port: int,
|
||||
world_size: int,
|
||||
server_key_path: str,
|
||||
server_cert_path: str,
|
||||
client_cert_path: str,
|
||||
) -> None:
|
||||
"""Run the Federated Learning server.
|
||||
|
||||
Parameters
|
||||
@@ -23,12 +25,16 @@ def run_federated_server(port: int,
|
||||
client_cert_path: str
|
||||
Path to the client certificate file.
|
||||
"""
|
||||
if build_info()['USE_FEDERATED']:
|
||||
_check_call(_LIB.XGBRunFederatedServer(port,
|
||||
world_size,
|
||||
c_str(server_key_path),
|
||||
c_str(server_cert_path),
|
||||
c_str(client_cert_path)))
|
||||
if build_info()["USE_FEDERATED"]:
|
||||
_check_call(
|
||||
_LIB.XGBRunFederatedServer(
|
||||
port,
|
||||
world_size,
|
||||
c_str(server_key_path),
|
||||
c_str(server_cert_path),
|
||||
c_str(client_cert_path),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise XGBoostError(
|
||||
"XGBoost needs to be built with the federated learning plugin "
|
||||
|
||||
Reference in New Issue
Block a user