Use black on more Python files. (#8137)

This commit is contained in:
Jiaming Yuan 2022-08-11 01:38:11 +08:00 committed by GitHub
parent bdb291f1c2
commit 570f8ae4ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 183 additions and 133 deletions

View File

@ -19,13 +19,14 @@ Also, see the tutorial for using XGBoost with categorical data:
""" """
from __future__ import annotations from __future__ import annotations
from time import time
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from time import time
import pandas as pd import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import xgboost as xgb import xgboost as xgb

View File

@ -16,11 +16,13 @@ categorical data.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
""" """
import pandas as pd
import numpy as np
import xgboost as xgb
from typing import Tuple from typing import Tuple
import numpy as np
import pandas as pd
import xgboost as xgb
def make_categorical( def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot: bool n_samples: int, n_features: int, n_categories: int, onehot: bool

View File

@ -1,35 +1,34 @@
''' """
Collection of examples for using xgboost.spark estimator interface Collection of examples for using xgboost.spark estimator interface
================================================================== ==================================================================
@author: Weichen Xu @author: Weichen Xu
''' """
import sklearn.datasets
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
from pyspark.sql.functions import rand from pyspark.sql.functions import rand
from pyspark.ml.linalg import Vectors
import sklearn.datasets
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator
spark = SparkSession.builder.master("local[*]").getOrCreate() spark = SparkSession.builder.master("local[*]").getOrCreate()
def create_spark_df(X, y): def create_spark_df(X, y):
return spark.createDataFrame( return spark.createDataFrame(
spark.sparkContext.parallelize([ spark.sparkContext.parallelize(
(Vectors.dense(features), float(label)) [(Vectors.dense(features), float(label)) for features, label in zip(X, y)]
for features, label in zip(X, y) ),
]), ["features", "label"],
["features", "label"]
) )
# load diabetes dataset (regression dataset) # load diabetes dataset (regression dataset)
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True) diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \ diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = train_test_split(
train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True) diabetes_X, diabetes_y, test_size=0.3, shuffle=True
)
diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train) diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test) diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
@ -38,25 +37,36 @@ diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
xgb_regressor = SparkXGBRegressor(max_depth=5) xgb_regressor = SparkXGBRegressor(max_depth=5)
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df) xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df) transformed_diabetes_test_spark_df = xgb_regressor_model.transform(
diabetes_test_spark_df
)
regressor_evaluator = RegressionEvaluator(metricName="rmse") regressor_evaluator = RegressionEvaluator(metricName="rmse")
print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}") print(
f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}"
)
diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn( diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
"validationIndicatorCol", rand(1) > 0.7 "validationIndicatorCol", rand(1) > 0.7
) )
# train xgboost regressor model with validation dataset # train xgboost regressor model with validation dataset
xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") xgb_regressor2 = SparkXGBRegressor(
max_depth=5, validation_indicator_col="validationIndicatorCol"
)
xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2) xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2)
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(
print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}") diabetes_test_spark_df
)
print(
f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}"
)
# load iris dataset (classification dataset) # load iris dataset (classification dataset)
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True) iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
iris_X_train, iris_X_test, iris_y_train, iris_y_test = \ iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(
train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True) iris_X, iris_y, test_size=0.3, shuffle=True
)
iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train) iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test) iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)
@ -74,9 +84,13 @@ iris_train_spark_df2 = iris_train_spark_df.withColumn(
) )
# train xgboost classifier model with validation dataset # train xgboost classifier model with validation dataset
xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") xgb_classifier2 = SparkXGBClassifier(
max_depth=5, validation_indicator_col="validationIndicatorCol"
)
xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2) xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2)
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}") print(
f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}"
)
spark.stop() spark.stop()

View File

@ -3,26 +3,32 @@
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
""" """
from .core import (
DMatrix,
DeviceQuantileDMatrix,
QuantileDMatrix,
Booster,
DataIter,
build_info,
_py_version,
)
from .training import train, cv
from . import rabit # noqa from . import rabit # noqa
from . import tracker # noqa from . import tracker # noqa
from .tracker import RabitTracker # noqa
from . import dask from . import dask
from .core import (
Booster,
DataIter,
DeviceQuantileDMatrix,
DMatrix,
QuantileDMatrix,
_py_version,
build_info,
)
from .tracker import RabitTracker # noqa
from .training import cv, train
try: try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker from .config import config_context, get_config, set_config
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .plotting import plot_importance, plot_tree, to_graphviz from .plotting import plot_importance, plot_tree, to_graphviz
from .config import set_config, get_config, config_context from .sklearn import (
XGBClassifier,
XGBModel,
XGBRanker,
XGBRegressor,
XGBRFClassifier,
XGBRFRegressor,
)
except ImportError: except ImportError:
pass pass

View File

@ -1,7 +1,7 @@
"""Shared typing definition.""" """Shared typing definition."""
import ctypes import ctypes
import os import os
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/ # os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack # cudf.DataFrame/cupy.array/dlpack

View File

@ -1,20 +1,21 @@
# pylint: disable= invalid-name, unused-import # pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies.""" """For compatibility and optional dependencies."""
from typing import Any, Type, Dict, Optional, List, Sequence, cast
import sys
import types
import importlib.util import importlib.util
import logging import logging
import sys
import types
from typing import Any, Dict, List, Optional, Sequence, Type, cast
import numpy as np import numpy as np
from ._typing import _T from ._typing import _T
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' assert sys.version_info[0] == 3, "Python 2 is no longer supported."
def py_str(x: bytes) -> str: def py_str(x: bytes) -> str:
"""convert c string back to python string""" """convert c string back to python string"""
return x.decode('utf-8') # type: ignore return x.decode("utf-8") # type: ignore
def lazy_isinstance(instance: Any, module: str, name: str) -> bool: def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
@ -30,8 +31,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
# pandas # pandas
try: try:
from pandas import DataFrame, Series from pandas import DataFrame, MultiIndex, Series
from pandas import MultiIndex
from pandas import concat as pandas_concat from pandas import concat as pandas_concat
PANDAS_INSTALLED = True PANDAS_INSTALLED = True
@ -45,23 +45,17 @@ except ImportError:
# sklearn # sklearn
try: try:
from sklearn.base import ( from sklearn.base import BaseEstimator as XGBModelBase
BaseEstimator as XGBModelBase, from sklearn.base import ClassifierMixin as XGBClassifierBase
RegressorMixin as XGBRegressorBase, from sklearn.base import RegressorMixin as XGBRegressorBase
ClassifierMixin as XGBClassifierBase
)
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
try: try:
from sklearn.model_selection import ( from sklearn.model_selection import KFold as XGBKFold
KFold as XGBKFold, from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
StratifiedKFold as XGBStratifiedKFold
)
except ImportError: except ImportError:
from sklearn.cross_validation import ( from sklearn.cross_validation import KFold as XGBKFold
KFold as XGBKFold, from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
StratifiedKFold as XGBStratifiedKFold
)
SKLEARN_INSTALLED = True SKLEARN_INSTALLED = True
@ -79,9 +73,10 @@ except ImportError:
class XGBoostLabelEncoder(LabelEncoder): class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.''' """Label encoder with JSON serialization methods."""
def to_json(self) -> Dict: def to_json(self) -> Dict:
'''Returns a JSON compatible dictionary''' """Returns a JSON compatible dictionary"""
meta = {} meta = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
@ -92,10 +87,10 @@ class XGBoostLabelEncoder(LabelEncoder):
def from_json(self, doc: Dict) -> None: def from_json(self, doc: Dict) -> None:
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.''' """Load the encoder back from a JSON compatible dict."""
meta = {} meta = {}
for k, v in doc.items(): for k, v in doc.items():
if k == 'classes_': if k == "classes_":
self.classes_ = np.array(v) self.classes_ = np.array(v)
continue continue
meta[k] = v meta[k] = v
@ -159,15 +154,14 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
# KIND, either express or implied. See the License for the specific language governing # KIND, either express or implied. See the License for the specific language governing
# permissions and limitations under the License. # permissions and limitations under the License.
class LazyLoader(types.ModuleType): class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies. """Lazily import a module, mainly to avoid pulling in large dependencies."""
"""
def __init__( def __init__(
self, self,
local_name: str, local_name: str,
parent_module_globals: Dict, parent_module_globals: Dict,
name: str, name: str,
warning: Optional[str] = None warning: Optional[str] = None,
) -> None: ) -> None:
self._local_name = local_name self._local_name = local_name
self._parent_module_globals = parent_module_globals self._parent_module_globals = parent_module_globals

View File

@ -4,10 +4,10 @@ import ctypes
import json import json
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Optional, Callable, Any, Dict, cast, Iterator from typing import Any, Callable, Dict, Iterator, Optional, cast
from .core import _LIB, _check_call, c_str, py_str
from ._typing import _F from ._typing import _F
from .core import _LIB, _check_call, c_str, py_str
def config_doc( def config_doc(
@ -90,22 +90,30 @@ def config_doc(
""" """
def none_to_str(value: Optional[str]) -> str: def none_to_str(value: Optional[str]) -> str:
return '' if value is None else value return "" if value is None else value
def config_doc_decorator(func: _F) -> _F: def config_doc_decorator(func: _F) -> _F:
func.__doc__ = (doc_template.format(header=none_to_str(header), func.__doc__ = (
extra_note=none_to_str(extra_note)) doc_template.format(
+ none_to_str(parameters) + none_to_str(returns) header=none_to_str(header), extra_note=none_to_str(extra_note)
+ none_to_str(common_example) + none_to_str(see_also)) )
+ none_to_str(parameters)
+ none_to_str(returns)
+ none_to_str(common_example)
+ none_to_str(see_also)
)
@wraps(func) @wraps(func)
def wrap(*args: Any, **kwargs: Any) -> Any: def wrap(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs) return func(*args, **kwargs)
return cast(_F, wrap) return cast(_F, wrap)
return config_doc_decorator return config_doc_decorator
@config_doc(header=""" @config_doc(
header="""
Set global configuration. Set global configuration.
""", """,
parameters=""" parameters="""
@ -113,7 +121,8 @@ def config_doc(
---------- ----------
new_config: Dict[str, Any] new_config: Dict[str, Any]
Keyword arguments representing the parameters and their values Keyword arguments representing the parameters and their values
""") """,
)
def set_config(**new_config: Any) -> None: def set_config(**new_config: Any) -> None:
not_none = {} not_none = {}
for k, v in new_config.items(): for k, v in new_config.items():
@ -123,7 +132,8 @@ def set_config(**new_config: Any) -> None:
_check_call(_LIB.XGBSetGlobalConfig(c_str(config))) _check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
@config_doc(header=""" @config_doc(
header="""
Get current values of the global configuration. Get current values of the global configuration.
""", """,
returns=""" returns="""
@ -131,7 +141,8 @@ def set_config(**new_config: Any) -> None:
------- -------
args: Dict[str, Any] args: Dict[str, Any]
The list of global parameters and their values The list of global parameters and their values
""") """,
)
def get_config() -> Dict[str, Any]: def get_config() -> Dict[str, Any]:
config_str = ctypes.c_char_p() config_str = ctypes.c_char_p()
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str))) _check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
@ -142,7 +153,8 @@ def get_config() -> Dict[str, Any]:
@contextmanager @contextmanager
@config_doc(header=""" @config_doc(
header="""
Context manager for global XGBoost configuration. Context manager for global XGBoost configuration.
""", """,
parameters=""" parameters="""
@ -162,7 +174,8 @@ def get_config() -> Dict[str, Any]:
-------- --------
set_config: Set global XGBoost configuration set_config: Set global XGBoost configuration
get_config: Get current values of the global configuration get_config: Get current values of the global configuration
""") """,
)
def config_context(**new_config: Any) -> Iterator[None]: def config_context(**new_config: Any) -> Iterator[None]:
old_config = get_config().copy() old_config = get_config().copy()
set_config(**new_config) set_config(**new_config)

View File

@ -399,11 +399,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
Parameters Parameters
---------- ----------
cache_prefix: cache_prefix:
Prefix to the cache files, only used in external memory. It can be either an URI Prefix to the cache files, only used in external memory. It can be either an
or a file path. URI or a file path.
""" """
_T = TypeVar("_T")
def __init__(self, cache_prefix: Optional[str] = None) -> None: def __init__(self, cache_prefix: Optional[str] = None) -> None:
self.cache_prefix = cache_prefix self.cache_prefix = cache_prefix
@ -1010,7 +1009,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Returns Returns
------- -------
number of columns : int number of columns
""" """
ret = c_bst_ulong() ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret))) _check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))

View File

@ -1,6 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module # pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name # pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines, fixme # pylint: disable=too-many-lines
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
# pylint: disable=import-error # pylint: disable=import-error
""" """
@ -227,7 +227,7 @@ class RabitContext(rabit.RabitContext):
) )
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements def dconcat(value: Sequence[_T]) -> _T:
"""Concatenate sequence of partitions.""" """Concatenate sequence of partitions."""
try: try:
return concat(value) return concat(value)
@ -253,7 +253,7 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
class DaskDMatrix: class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes # pylint: disable=too-many-instance-attributes
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a """DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
data explicitly if you want to see actual computation of constructing `DaskDMatrix`. data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
@ -486,6 +486,12 @@ class DaskDMatrix:
} }
def num_col(self) -> int: def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns
"""
return self._n_cols return self._n_cols

View File

@ -1,13 +1,15 @@
"""XGBoost Federated Learning related API.""" """XGBoost Federated Learning related API."""
from .core import _LIB, _check_call, c_str, build_info, XGBoostError from .core import _LIB, XGBoostError, _check_call, build_info, c_str
def run_federated_server(port: int, def run_federated_server(
port: int,
world_size: int, world_size: int,
server_key_path: str, server_key_path: str,
server_cert_path: str, server_cert_path: str,
client_cert_path: str) -> None: client_cert_path: str,
) -> None:
"""Run the Federated Learning server. """Run the Federated Learning server.
Parameters Parameters
@ -23,12 +25,16 @@ def run_federated_server(port: int,
client_cert_path: str client_cert_path: str
Path to the client certificate file. Path to the client certificate file.
""" """
if build_info()['USE_FEDERATED']: if build_info()["USE_FEDERATED"]:
_check_call(_LIB.XGBRunFederatedServer(port, _check_call(
_LIB.XGBRunFederatedServer(
port,
world_size, world_size,
c_str(server_key_path), c_str(server_key_path),
c_str(server_cert_path), c_str(server_cert_path),
c_str(client_cert_path))) c_str(client_cert_path),
)
)
else: else:
raise XGBoostError( raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin " "XGBoost needs to be built with the federated learning plugin "

View File

@ -112,13 +112,25 @@ if __name__ == "__main__":
if not all( if not all(
run_formatter(path) run_formatter(path)
for path in [ for path in [
# core
"python-package/xgboost/__init__.py",
"python-package/xgboost/_typing.py",
"python-package/xgboost/compat.py",
"python-package/xgboost/config.py",
"python-package/xgboost/dask.py", "python-package/xgboost/dask.py",
"python-package/xgboost/sklearn.py", "python-package/xgboost/sklearn.py",
"python-package/xgboost/spark", "python-package/xgboost/spark",
"python-package/xgboost/federated.py",
"python-package/xgboost/spark",
# tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_spark/test_data.py", "tests/python/test_spark/",
"tests/python-gpu/test_gpu_spark/test_data.py", "tests/python-gpu/test_gpu_spark/",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
# demo
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",
"demo/guide-python/spark_estimator_examples.py",
] ]
): ):
sys.exit(-1) sys.exit(-1)

View File

@ -1,11 +1,10 @@
import sys
import logging import logging
import random import random
import sys
import uuid import uuid
import numpy as np import numpy as np
import pytest import pytest
import testing as tm import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
@ -13,26 +12,27 @@ if tm.no_spark()["condition"]:
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.ml.functions import vector_to_array
from pyspark.sql import functions as spark_sql_func
from pyspark.ml import Pipeline, PipelineModel from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import ( from pyspark.ml.evaluation import (
BinaryClassificationEvaluator, BinaryClassificationEvaluator,
MulticlassClassificationEvaluator, MulticlassClassificationEvaluator,
) )
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import functions as spark_sql_func
from xgboost.spark import ( from xgboost.spark import (
SparkXGBClassifier, SparkXGBClassifier,
SparkXGBClassifierModel, SparkXGBClassifierModel,
SparkXGBRegressor, SparkXGBRegressor,
SparkXGBRegressorModel, SparkXGBRegressorModel,
) )
from .utils import SparkTestCase
from xgboost import XGBClassifier, XGBRegressor
from xgboost.spark.core import _non_booster_params from xgboost.spark.core import _non_booster_params
from xgboost import XGBClassifier, XGBRegressor
from .utils import SparkTestCase
logging.getLogger("py4j").setLevel(logging.INFO) logging.getLogger("py4j").setLevel(logging.INFO)

View File

@ -1,11 +1,11 @@
import sys
import random
import json import json
import uuid
import os import os
import random
import sys
import uuid
import pytest
import numpy as np import numpy as np
import pytest
import testing as tm import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
@ -13,10 +13,11 @@ if tm.no_spark()["condition"]:
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from .utils import SparkLocalClusterTestCase from pyspark.ml.linalg import Vectors
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
from xgboost.spark.utils import _get_max_num_concurrent_tasks from xgboost.spark.utils import _get_max_num_concurrent_tasks
from pyspark.ml.linalg import Vectors
from .utils import SparkLocalClusterTestCase
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):

View File

@ -3,22 +3,18 @@ import logging
import shutil import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
import pytest import pytest
from six import StringIO
import testing as tm import testing as tm
from six import StringIO
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.sql import SQLContext from pyspark.sql import SparkSession, SQLContext
from pyspark.sql import SparkSession
from xgboost.spark.utils import _get_default_params_from_func from xgboost.spark.utils import _get_default_params_from_func