Use black on more Python files. (#8137)
This commit is contained in:
parent
bdb291f1c2
commit
570f8ae4ba
@ -19,13 +19,14 @@ Also, see the tutorial for using XGBoost with categorical data:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from time import time
|
||||
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
from time import time
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
@ -16,11 +16,13 @@ categorical data.
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
def make_categorical(
|
||||
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
||||
|
||||
@ -1,35 +1,34 @@
|
||||
'''
|
||||
"""
|
||||
Collection of examples for using xgboost.spark estimator interface
|
||||
==================================================================
|
||||
|
||||
@author: Weichen Xu
|
||||
'''
|
||||
"""
|
||||
import sklearn.datasets
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.sql.functions import rand
|
||||
from pyspark.ml.linalg import Vectors
|
||||
import sklearn.datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator
|
||||
|
||||
|
||||
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
||||
|
||||
|
||||
def create_spark_df(X, y):
|
||||
return spark.createDataFrame(
|
||||
spark.sparkContext.parallelize([
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(X, y)
|
||||
]),
|
||||
["features", "label"]
|
||||
spark.sparkContext.parallelize(
|
||||
[(Vectors.dense(features), float(label)) for features, label in zip(X, y)]
|
||||
),
|
||||
["features", "label"],
|
||||
)
|
||||
|
||||
|
||||
# load diabetes dataset (regression dataset)
|
||||
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
|
||||
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \
|
||||
train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True)
|
||||
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = train_test_split(
|
||||
diabetes_X, diabetes_y, test_size=0.3, shuffle=True
|
||||
)
|
||||
|
||||
diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
|
||||
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
||||
@ -38,25 +37,36 @@ diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
||||
xgb_regressor = SparkXGBRegressor(max_depth=5)
|
||||
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)
|
||||
|
||||
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df)
|
||||
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(
|
||||
diabetes_test_spark_df
|
||||
)
|
||||
regressor_evaluator = RegressionEvaluator(metricName="rmse")
|
||||
print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}")
|
||||
print(
|
||||
f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}"
|
||||
)
|
||||
|
||||
diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
|
||||
"validationIndicatorCol", rand(1) > 0.7
|
||||
)
|
||||
|
||||
# train xgboost regressor model with validation dataset
|
||||
xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
||||
xgb_regressor2 = SparkXGBRegressor(
|
||||
max_depth=5, validation_indicator_col="validationIndicatorCol"
|
||||
)
|
||||
xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2)
|
||||
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df)
|
||||
print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}")
|
||||
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(
|
||||
diabetes_test_spark_df
|
||||
)
|
||||
print(
|
||||
f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}"
|
||||
)
|
||||
|
||||
|
||||
# load iris dataset (classification dataset)
|
||||
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
|
||||
iris_X_train, iris_X_test, iris_y_train, iris_y_test = \
|
||||
train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True)
|
||||
iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(
|
||||
iris_X, iris_y, test_size=0.3, shuffle=True
|
||||
)
|
||||
|
||||
iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
|
||||
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)
|
||||
@ -74,9 +84,13 @@ iris_train_spark_df2 = iris_train_spark_df.withColumn(
|
||||
)
|
||||
|
||||
# train xgboost classifier model with validation dataset
|
||||
xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
||||
xgb_classifier2 = SparkXGBClassifier(
|
||||
max_depth=5, validation_indicator_col="validationIndicatorCol"
|
||||
)
|
||||
xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2)
|
||||
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
|
||||
print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}")
|
||||
print(
|
||||
f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}"
|
||||
)
|
||||
|
||||
spark.stop()
|
||||
|
||||
@ -3,26 +3,32 @@
|
||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
from .core import (
|
||||
DMatrix,
|
||||
DeviceQuantileDMatrix,
|
||||
QuantileDMatrix,
|
||||
Booster,
|
||||
DataIter,
|
||||
build_info,
|
||||
_py_version,
|
||||
)
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
from .tracker import RabitTracker # noqa
|
||||
from . import dask
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
DeviceQuantileDMatrix,
|
||||
DMatrix,
|
||||
QuantileDMatrix,
|
||||
_py_version,
|
||||
build_info,
|
||||
)
|
||||
from .tracker import RabitTracker # noqa
|
||||
from .training import cv, train
|
||||
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
from .config import config_context, get_config, set_config
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
from .config import set_config, get_config, config_context
|
||||
from .sklearn import (
|
||||
XGBClassifier,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRegressor,
|
||||
XGBRFClassifier,
|
||||
XGBRFRegressor,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Shared typing definition."""
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict
|
||||
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
|
||||
|
||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||
# cudf.DataFrame/cupy.array/dlpack
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
# pylint: disable= invalid-name, unused-import
|
||||
"""For compatibility and optional dependencies."""
|
||||
from typing import Any, Type, Dict, Optional, List, Sequence, cast
|
||||
import sys
|
||||
import types
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
|
||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
|
||||
|
||||
|
||||
def py_str(x: bytes) -> str:
|
||||
"""convert c string back to python string"""
|
||||
return x.decode('utf-8') # type: ignore
|
||||
return x.decode("utf-8") # type: ignore
|
||||
|
||||
|
||||
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||
@ -30,8 +31,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame, Series
|
||||
from pandas import MultiIndex
|
||||
from pandas import DataFrame, MultiIndex, Series
|
||||
from pandas import concat as pandas_concat
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
@ -45,23 +45,17 @@ except ImportError:
|
||||
|
||||
# sklearn
|
||||
try:
|
||||
from sklearn.base import (
|
||||
BaseEstimator as XGBModelBase,
|
||||
RegressorMixin as XGBRegressorBase,
|
||||
ClassifierMixin as XGBClassifierBase
|
||||
)
|
||||
from sklearn.base import BaseEstimator as XGBModelBase
|
||||
from sklearn.base import ClassifierMixin as XGBClassifierBase
|
||||
from sklearn.base import RegressorMixin as XGBRegressorBase
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
try:
|
||||
from sklearn.model_selection import (
|
||||
KFold as XGBKFold,
|
||||
StratifiedKFold as XGBStratifiedKFold
|
||||
)
|
||||
from sklearn.model_selection import KFold as XGBKFold
|
||||
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
|
||||
except ImportError:
|
||||
from sklearn.cross_validation import (
|
||||
KFold as XGBKFold,
|
||||
StratifiedKFold as XGBStratifiedKFold
|
||||
)
|
||||
from sklearn.cross_validation import KFold as XGBKFold
|
||||
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
|
||||
|
||||
SKLEARN_INSTALLED = True
|
||||
|
||||
@ -79,9 +73,10 @@ except ImportError:
|
||||
|
||||
|
||||
class XGBoostLabelEncoder(LabelEncoder):
|
||||
'''Label encoder with JSON serialization methods.'''
|
||||
"""Label encoder with JSON serialization methods."""
|
||||
|
||||
def to_json(self) -> Dict:
|
||||
'''Returns a JSON compatible dictionary'''
|
||||
"""Returns a JSON compatible dictionary"""
|
||||
meta = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
@ -92,10 +87,10 @@ class XGBoostLabelEncoder(LabelEncoder):
|
||||
|
||||
def from_json(self, doc: Dict) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
'''Load the encoder back from a JSON compatible dict.'''
|
||||
"""Load the encoder back from a JSON compatible dict."""
|
||||
meta = {}
|
||||
for k, v in doc.items():
|
||||
if k == 'classes_':
|
||||
if k == "classes_":
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
meta[k] = v
|
||||
@ -159,15 +154,14 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
|
||||
# KIND, either express or implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
"""
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: Dict,
|
||||
name: str,
|
||||
warning: Optional[str] = None
|
||||
warning: Optional[str] = None,
|
||||
) -> None:
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
|
||||
@ -4,10 +4,10 @@ import ctypes
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Optional, Callable, Any, Dict, cast, Iterator
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, cast
|
||||
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
from ._typing import _F
|
||||
from .core import _LIB, _check_call, c_str, py_str
|
||||
|
||||
|
||||
def config_doc(
|
||||
@ -90,22 +90,30 @@ def config_doc(
|
||||
"""
|
||||
|
||||
def none_to_str(value: Optional[str]) -> str:
|
||||
return '' if value is None else value
|
||||
return "" if value is None else value
|
||||
|
||||
def config_doc_decorator(func: _F) -> _F:
|
||||
func.__doc__ = (doc_template.format(header=none_to_str(header),
|
||||
extra_note=none_to_str(extra_note))
|
||||
+ none_to_str(parameters) + none_to_str(returns)
|
||||
+ none_to_str(common_example) + none_to_str(see_also))
|
||||
func.__doc__ = (
|
||||
doc_template.format(
|
||||
header=none_to_str(header), extra_note=none_to_str(extra_note)
|
||||
)
|
||||
+ none_to_str(parameters)
|
||||
+ none_to_str(returns)
|
||||
+ none_to_str(common_example)
|
||||
+ none_to_str(see_also)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(_F, wrap)
|
||||
|
||||
return config_doc_decorator
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Set global configuration.
|
||||
""",
|
||||
parameters="""
|
||||
@ -113,7 +121,8 @@ def config_doc(
|
||||
----------
|
||||
new_config: Dict[str, Any]
|
||||
Keyword arguments representing the parameters and their values
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def set_config(**new_config: Any) -> None:
|
||||
not_none = {}
|
||||
for k, v in new_config.items():
|
||||
@ -123,7 +132,8 @@ def set_config(**new_config: Any) -> None:
|
||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||
|
||||
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Get current values of the global configuration.
|
||||
""",
|
||||
returns="""
|
||||
@ -131,7 +141,8 @@ def set_config(**new_config: Any) -> None:
|
||||
-------
|
||||
args: Dict[str, Any]
|
||||
The list of global parameters and their values
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def get_config() -> Dict[str, Any]:
|
||||
config_str = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||
@ -142,7 +153,8 @@ def get_config() -> Dict[str, Any]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
@config_doc(header="""
|
||||
@config_doc(
|
||||
header="""
|
||||
Context manager for global XGBoost configuration.
|
||||
""",
|
||||
parameters="""
|
||||
@ -162,7 +174,8 @@ def get_config() -> Dict[str, Any]:
|
||||
--------
|
||||
set_config: Set global XGBoost configuration
|
||||
get_config: Get current values of the global configuration
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def config_context(**new_config: Any) -> Iterator[None]:
|
||||
old_config = get_config().copy()
|
||||
set_config(**new_config)
|
||||
|
||||
@ -399,11 +399,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
Parameters
|
||||
----------
|
||||
cache_prefix:
|
||||
Prefix to the cache files, only used in external memory. It can be either an URI
|
||||
or a file path.
|
||||
Prefix to the cache files, only used in external memory. It can be either an
|
||||
URI or a file path.
|
||||
|
||||
"""
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
@ -1010,7 +1009,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns : int
|
||||
number of columns
|
||||
"""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||
# pylint: disable=too-many-arguments, too-many-locals
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines, fixme
|
||||
# pylint: disable=too-many-lines
|
||||
# pylint: disable=too-few-public-methods
|
||||
# pylint: disable=import-error
|
||||
"""
|
||||
@ -227,7 +227,7 @@ class RabitContext(rabit.RabitContext):
|
||||
)
|
||||
|
||||
|
||||
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
|
||||
def dconcat(value: Sequence[_T]) -> _T:
|
||||
"""Concatenate sequence of partitions."""
|
||||
try:
|
||||
return concat(value)
|
||||
@ -253,7 +253,7 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
||||
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
||||
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
@ -486,6 +486,12 @@ class DaskDMatrix:
|
||||
}
|
||||
|
||||
def num_col(self) -> int:
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns
|
||||
"""
|
||||
return self._n_cols
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
"""XGBoost Federated Learning related API."""
|
||||
|
||||
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
|
||||
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
||||
|
||||
|
||||
def run_federated_server(port: int,
|
||||
def run_federated_server(
|
||||
port: int,
|
||||
world_size: int,
|
||||
server_key_path: str,
|
||||
server_cert_path: str,
|
||||
client_cert_path: str) -> None:
|
||||
client_cert_path: str,
|
||||
) -> None:
|
||||
"""Run the Federated Learning server.
|
||||
|
||||
Parameters
|
||||
@ -23,12 +25,16 @@ def run_federated_server(port: int,
|
||||
client_cert_path: str
|
||||
Path to the client certificate file.
|
||||
"""
|
||||
if build_info()['USE_FEDERATED']:
|
||||
_check_call(_LIB.XGBRunFederatedServer(port,
|
||||
if build_info()["USE_FEDERATED"]:
|
||||
_check_call(
|
||||
_LIB.XGBRunFederatedServer(
|
||||
port,
|
||||
world_size,
|
||||
c_str(server_key_path),
|
||||
c_str(server_cert_path),
|
||||
c_str(client_cert_path)))
|
||||
c_str(client_cert_path),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise XGBoostError(
|
||||
"XGBoost needs to be built with the federated learning plugin "
|
||||
|
||||
@ -112,13 +112,25 @@ if __name__ == "__main__":
|
||||
if not all(
|
||||
run_formatter(path)
|
||||
for path in [
|
||||
# core
|
||||
"python-package/xgboost/__init__.py",
|
||||
"python-package/xgboost/_typing.py",
|
||||
"python-package/xgboost/compat.py",
|
||||
"python-package/xgboost/config.py",
|
||||
"python-package/xgboost/dask.py",
|
||||
"python-package/xgboost/sklearn.py",
|
||||
"python-package/xgboost/spark",
|
||||
"python-package/xgboost/federated.py",
|
||||
"python-package/xgboost/spark",
|
||||
# tests
|
||||
"tests/python/test_config.py",
|
||||
"tests/python/test_spark/test_data.py",
|
||||
"tests/python-gpu/test_gpu_spark/test_data.py",
|
||||
"tests/python/test_spark/",
|
||||
"tests/python-gpu/test_gpu_spark/",
|
||||
"tests/ci_build/lint_python.py",
|
||||
# demo
|
||||
"demo/guide-python/cat_in_the_dat.py",
|
||||
"demo/guide-python/categorical.py",
|
||||
"demo/guide-python/spark_estimator_examples.py",
|
||||
]
|
||||
):
|
||||
sys.exit(-1)
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import sys
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
@ -13,26 +12,27 @@ if tm.no_spark()["condition"]:
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from pyspark.ml.functions import vector_to_array
|
||||
from pyspark.sql import functions as spark_sql_func
|
||||
from pyspark.ml import Pipeline, PipelineModel
|
||||
from pyspark.ml.evaluation import (
|
||||
BinaryClassificationEvaluator,
|
||||
MulticlassClassificationEvaluator,
|
||||
)
|
||||
from pyspark.ml.functions import vector_to_array
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
|
||||
from pyspark.sql import functions as spark_sql_func
|
||||
from xgboost.spark import (
|
||||
SparkXGBClassifier,
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRegressor,
|
||||
SparkXGBRegressorModel,
|
||||
)
|
||||
from .utils import SparkTestCase
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from xgboost.spark.core import _non_booster_params
|
||||
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
|
||||
from .utils import SparkTestCase
|
||||
|
||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import sys
|
||||
import random
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
@ -13,10 +13,11 @@ if tm.no_spark()["condition"]:
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from .utils import SparkLocalClusterTestCase
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
from xgboost.spark.utils import _get_max_num_concurrent_tasks
|
||||
from pyspark.ml.linalg import Vectors
|
||||
|
||||
from .utils import SparkLocalClusterTestCase
|
||||
|
||||
|
||||
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
||||
|
||||
@ -3,22 +3,18 @@ import logging
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from six import StringIO
|
||||
|
||||
import testing as tm
|
||||
from six import StringIO
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from pyspark.sql import SQLContext
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
from pyspark.sql import SparkSession, SQLContext
|
||||
from xgboost.spark.utils import _get_default_params_from_func
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user