Remove STRING_TYPES. (#7827)

This commit is contained in:
Jiaming Yuan 2022-04-22 19:07:51 +08:00 committed by GitHub
parent c13a2a3114
commit f0f76259c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 14 deletions

View File

@ -16,7 +16,6 @@ import numpy
from . import rabit from . import rabit
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .compat import STRING_TYPES
__all__ = [ __all__ = [
@ -82,7 +81,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
results = [] results = []
for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]): for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
as_arr = numpy.array(s) as_arr = numpy.array(s)
if not isinstance(msg, STRING_TYPES): if not isinstance(msg, str):
msg = msg.decode() msg = msg.decode()
mean, std = numpy.mean(as_arr), numpy.std(as_arr) mean, std = numpy.mean(as_arr), numpy.std(as_arr)
results.extend([(name, mean, std)]) results.extend([(name, mean, std)])

View File

@ -10,9 +10,6 @@ import numpy as np
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = (str,)
def py_str(x): def py_str(x):
"""convert c string back to python string""" """convert c string back to python string"""

View File

@ -17,7 +17,7 @@ from inspect import signature, Parameter
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED from .compat import DataFrame, py_str, PANDAS_INSTALLED
from .libpath import find_lib_path from .libpath import find_lib_path
from ._typing import ( from ._typing import (
CStrPptr, CStrPptr,
@ -1387,7 +1387,7 @@ class Booster:
_check_call( _check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length)) _LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state) self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os.PathLike, bytearray)): elif isinstance(model_file, (str, os.PathLike, bytearray)):
self.load_model(model_file) self.load_model(model_file)
elif model_file is None: elif model_file is None:
pass pass
@ -1629,7 +1629,7 @@ class Booster:
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if value is not None: if value is not None:
if not isinstance(value, STRING_TYPES): if not isinstance(value, str):
raise ValueError("Set Attr only accepts string values") raise ValueError("Set Attr only accepts string values")
value = c_str(str(value)) value = c_str(str(value))
_check_call(_LIB.XGBoosterSetAttr( _check_call(_LIB.XGBoosterSetAttr(
@ -1705,7 +1705,7 @@ class Booster:
""" """
if isinstance(params, Mapping): if isinstance(params, Mapping):
params = params.items() params = params.items()
elif isinstance(params, STRING_TYPES) and value is not None: elif isinstance(params, str) and value is not None:
params = [(params, value)] params = [(params, value)]
for key, val in params: for key, val in params:
if val is not None: if val is not None:
@ -1796,7 +1796,7 @@ class Booster:
for d in evals: for d in evals:
if not isinstance(d[0], DMatrix): if not isinstance(d[0], DMatrix):
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}") raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], STRING_TYPES): if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}") raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_features(d[0]) self._validate_features(d[0])
@ -2192,7 +2192,7 @@ class Booster:
Output file name Output file name
""" """
if isinstance(fname, (STRING_TYPES, os.PathLike)): # assume file name if isinstance(fname, (str, os.PathLike)): # assume file name
fname = os.fspath(os.path.expanduser(fname)) fname = os.fspath(os.path.expanduser(fname))
_check_call(_LIB.XGBoosterSaveModel( _check_call(_LIB.XGBoosterSaveModel(
self.handle, c_str(fname))) self.handle, c_str(fname)))
@ -2301,7 +2301,7 @@ class Booster:
dump_format : string, optional dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'. Format of model dump file. Can be 'text' or 'json'.
""" """
if isinstance(fout, (STRING_TYPES, os.PathLike)): if isinstance(fout, (str, os.PathLike)):
fout = os.fspath(os.path.expanduser(fout)) fout = os.fspath(os.path.expanduser(fout))
# pylint: disable=consider-using-with # pylint: disable=consider-using-with
fout_obj = open(fout, 'w', encoding="utf-8") fout_obj = open(fout, 'w', encoding="utf-8")

View File

@ -6,7 +6,7 @@ from typing import Any, TypeVar, Callable, Optional, cast, List, Union
import numpy as np import numpy as np
from .core import _LIB, c_str, STRING_TYPES, _check_call from .core import _LIB, c_str, _check_call
def _init_rabit() -> None: def _init_rabit() -> None:
@ -73,7 +73,7 @@ def tracker_print(msg: Any) -> None:
msg : str msg : str
The message to be printed to tracker. The message to be printed to tracker.
""" """
if not isinstance(msg, STRING_TYPES): if not isinstance(msg, str):
msg = str(msg) msg = str(msg)
is_dist = _LIB.RabitIsDistributed() is_dist = _LIB.RabitIsDistributed()
if is_dist != 0: if is_dist != 0: