Use black on more Python files. (#8137)
This commit is contained in:
parent
bdb291f1c2
commit
570f8ae4ba
@ -19,13 +19,14 @@ Also, see the tutorial for using XGBoost with categorical data:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from time import time
|
|
||||||
import os
|
import os
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
from time import time
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import roc_auc_score
|
from sklearn.metrics import roc_auc_score
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|||||||
@ -16,11 +16,13 @@ categorical data.
|
|||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|
||||||
def make_categorical(
|
def make_categorical(
|
||||||
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
||||||
|
|||||||
@ -1,35 +1,34 @@
|
|||||||
'''
|
"""
|
||||||
Collection of examples for using xgboost.spark estimator interface
|
Collection of examples for using xgboost.spark estimator interface
|
||||||
==================================================================
|
==================================================================
|
||||||
|
|
||||||
@author: Weichen Xu
|
@author: Weichen Xu
|
||||||
'''
|
"""
|
||||||
|
import sklearn.datasets
|
||||||
|
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator
|
||||||
|
from pyspark.ml.linalg import Vectors
|
||||||
from pyspark.sql import SparkSession
|
from pyspark.sql import SparkSession
|
||||||
from pyspark.sql.functions import rand
|
from pyspark.sql.functions import rand
|
||||||
from pyspark.ml.linalg import Vectors
|
|
||||||
import sklearn.datasets
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||||
from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator
|
|
||||||
|
|
||||||
|
|
||||||
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
||||||
|
|
||||||
|
|
||||||
def create_spark_df(X, y):
|
def create_spark_df(X, y):
|
||||||
return spark.createDataFrame(
|
return spark.createDataFrame(
|
||||||
spark.sparkContext.parallelize([
|
spark.sparkContext.parallelize(
|
||||||
(Vectors.dense(features), float(label))
|
[(Vectors.dense(features), float(label)) for features, label in zip(X, y)]
|
||||||
for features, label in zip(X, y)
|
),
|
||||||
]),
|
["features", "label"],
|
||||||
["features", "label"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# load diabetes dataset (regression dataset)
|
# load diabetes dataset (regression dataset)
|
||||||
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
|
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
|
||||||
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \
|
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = train_test_split(
|
||||||
train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True)
|
diabetes_X, diabetes_y, test_size=0.3, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
|
diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
|
||||||
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
||||||
@ -38,25 +37,36 @@ diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
|||||||
xgb_regressor = SparkXGBRegressor(max_depth=5)
|
xgb_regressor = SparkXGBRegressor(max_depth=5)
|
||||||
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)
|
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)
|
||||||
|
|
||||||
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df)
|
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(
|
||||||
|
diabetes_test_spark_df
|
||||||
|
)
|
||||||
regressor_evaluator = RegressionEvaluator(metricName="rmse")
|
regressor_evaluator = RegressionEvaluator(metricName="rmse")
|
||||||
print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}")
|
print(
|
||||||
|
f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}"
|
||||||
|
)
|
||||||
|
|
||||||
diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
|
diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
|
||||||
"validationIndicatorCol", rand(1) > 0.7
|
"validationIndicatorCol", rand(1) > 0.7
|
||||||
)
|
)
|
||||||
|
|
||||||
# train xgboost regressor model with validation dataset
|
# train xgboost regressor model with validation dataset
|
||||||
xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
xgb_regressor2 = SparkXGBRegressor(
|
||||||
|
max_depth=5, validation_indicator_col="validationIndicatorCol"
|
||||||
|
)
|
||||||
xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2)
|
xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2)
|
||||||
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df)
|
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(
|
||||||
print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}")
|
diabetes_test_spark_df
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# load iris dataset (classification dataset)
|
# load iris dataset (classification dataset)
|
||||||
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
|
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
|
||||||
iris_X_train, iris_X_test, iris_y_train, iris_y_test = \
|
iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(
|
||||||
train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True)
|
iris_X, iris_y, test_size=0.3, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
|
iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
|
||||||
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)
|
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)
|
||||||
@ -74,9 +84,13 @@ iris_train_spark_df2 = iris_train_spark_df.withColumn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# train xgboost classifier model with validation dataset
|
# train xgboost classifier model with validation dataset
|
||||||
xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
xgb_classifier2 = SparkXGBClassifier(
|
||||||
|
max_depth=5, validation_indicator_col="validationIndicatorCol"
|
||||||
|
)
|
||||||
xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2)
|
xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2)
|
||||||
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
|
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
|
||||||
print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}")
|
print(
|
||||||
|
f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}"
|
||||||
|
)
|
||||||
|
|
||||||
spark.stop()
|
spark.stop()
|
||||||
|
|||||||
@ -3,26 +3,32 @@
|
|||||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .core import (
|
|
||||||
DMatrix,
|
|
||||||
DeviceQuantileDMatrix,
|
|
||||||
QuantileDMatrix,
|
|
||||||
Booster,
|
|
||||||
DataIter,
|
|
||||||
build_info,
|
|
||||||
_py_version,
|
|
||||||
)
|
|
||||||
from .training import train, cv
|
|
||||||
from . import rabit # noqa
|
from . import rabit # noqa
|
||||||
from . import tracker # noqa
|
from . import tracker # noqa
|
||||||
from .tracker import RabitTracker # noqa
|
|
||||||
from . import dask
|
from . import dask
|
||||||
|
from .core import (
|
||||||
|
Booster,
|
||||||
|
DataIter,
|
||||||
|
DeviceQuantileDMatrix,
|
||||||
|
DMatrix,
|
||||||
|
QuantileDMatrix,
|
||||||
|
_py_version,
|
||||||
|
build_info,
|
||||||
|
)
|
||||||
|
from .tracker import RabitTracker # noqa
|
||||||
|
from .training import cv, train
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
from .config import config_context, get_config, set_config
|
||||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
|
||||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||||
from .config import set_config, get_config, config_context
|
from .sklearn import (
|
||||||
|
XGBClassifier,
|
||||||
|
XGBModel,
|
||||||
|
XGBRanker,
|
||||||
|
XGBRegressor,
|
||||||
|
XGBRFClassifier,
|
||||||
|
XGBRFRegressor,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Shared typing definition."""
|
"""Shared typing definition."""
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict
|
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
|
||||||
|
|
||||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||||
# cudf.DataFrame/cupy.array/dlpack
|
# cudf.DataFrame/cupy.array/dlpack
|
||||||
|
|||||||
@ -1,20 +1,21 @@
|
|||||||
# pylint: disable= invalid-name, unused-import
|
# pylint: disable= invalid-name, unused-import
|
||||||
"""For compatibility and optional dependencies."""
|
"""For compatibility and optional dependencies."""
|
||||||
from typing import Any, Type, Dict, Optional, List, Sequence, cast
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Type, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ._typing import _T
|
from ._typing import _T
|
||||||
|
|
||||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
|
||||||
|
|
||||||
|
|
||||||
def py_str(x: bytes) -> str:
|
def py_str(x: bytes) -> str:
|
||||||
"""convert c string back to python string"""
|
"""convert c string back to python string"""
|
||||||
return x.decode('utf-8') # type: ignore
|
return x.decode("utf-8") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||||
@ -30,8 +31,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
|||||||
|
|
||||||
# pandas
|
# pandas
|
||||||
try:
|
try:
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, MultiIndex, Series
|
||||||
from pandas import MultiIndex
|
|
||||||
from pandas import concat as pandas_concat
|
from pandas import concat as pandas_concat
|
||||||
|
|
||||||
PANDAS_INSTALLED = True
|
PANDAS_INSTALLED = True
|
||||||
@ -45,23 +45,17 @@ except ImportError:
|
|||||||
|
|
||||||
# sklearn
|
# sklearn
|
||||||
try:
|
try:
|
||||||
from sklearn.base import (
|
from sklearn.base import BaseEstimator as XGBModelBase
|
||||||
BaseEstimator as XGBModelBase,
|
from sklearn.base import ClassifierMixin as XGBClassifierBase
|
||||||
RegressorMixin as XGBRegressorBase,
|
from sklearn.base import RegressorMixin as XGBRegressorBase
|
||||||
ClassifierMixin as XGBClassifierBase
|
|
||||||
)
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sklearn.model_selection import (
|
from sklearn.model_selection import KFold as XGBKFold
|
||||||
KFold as XGBKFold,
|
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
|
||||||
StratifiedKFold as XGBStratifiedKFold
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from sklearn.cross_validation import (
|
from sklearn.cross_validation import KFold as XGBKFold
|
||||||
KFold as XGBKFold,
|
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
|
||||||
StratifiedKFold as XGBStratifiedKFold
|
|
||||||
)
|
|
||||||
|
|
||||||
SKLEARN_INSTALLED = True
|
SKLEARN_INSTALLED = True
|
||||||
|
|
||||||
@ -79,9 +73,10 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class XGBoostLabelEncoder(LabelEncoder):
|
class XGBoostLabelEncoder(LabelEncoder):
|
||||||
'''Label encoder with JSON serialization methods.'''
|
"""Label encoder with JSON serialization methods."""
|
||||||
|
|
||||||
def to_json(self) -> Dict:
|
def to_json(self) -> Dict:
|
||||||
'''Returns a JSON compatible dictionary'''
|
"""Returns a JSON compatible dictionary"""
|
||||||
meta = {}
|
meta = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
@ -92,10 +87,10 @@ class XGBoostLabelEncoder(LabelEncoder):
|
|||||||
|
|
||||||
def from_json(self, doc: Dict) -> None:
|
def from_json(self, doc: Dict) -> None:
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
'''Load the encoder back from a JSON compatible dict.'''
|
"""Load the encoder back from a JSON compatible dict."""
|
||||||
meta = {}
|
meta = {}
|
||||||
for k, v in doc.items():
|
for k, v in doc.items():
|
||||||
if k == 'classes_':
|
if k == "classes_":
|
||||||
self.classes_ = np.array(v)
|
self.classes_ = np.array(v)
|
||||||
continue
|
continue
|
||||||
meta[k] = v
|
meta[k] = v
|
||||||
@ -159,15 +154,14 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
|
|||||||
# KIND, either express or implied. See the License for the specific language governing
|
# KIND, either express or implied. See the License for the specific language governing
|
||||||
# permissions and limitations under the License.
|
# permissions and limitations under the License.
|
||||||
class LazyLoader(types.ModuleType):
|
class LazyLoader(types.ModuleType):
|
||||||
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
"""Lazily import a module, mainly to avoid pulling in large dependencies."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
local_name: str,
|
local_name: str,
|
||||||
parent_module_globals: Dict,
|
parent_module_globals: Dict,
|
||||||
name: str,
|
name: str,
|
||||||
warning: Optional[str] = None
|
warning: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._local_name = local_name
|
self._local_name = local_name
|
||||||
self._parent_module_globals = parent_module_globals
|
self._parent_module_globals = parent_module_globals
|
||||||
|
|||||||
@ -4,10 +4,10 @@ import ctypes
|
|||||||
import json
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Callable, Any, Dict, cast, Iterator
|
from typing import Any, Callable, Dict, Iterator, Optional, cast
|
||||||
|
|
||||||
from .core import _LIB, _check_call, c_str, py_str
|
|
||||||
from ._typing import _F
|
from ._typing import _F
|
||||||
|
from .core import _LIB, _check_call, c_str, py_str
|
||||||
|
|
||||||
|
|
||||||
def config_doc(
|
def config_doc(
|
||||||
@ -90,22 +90,30 @@ def config_doc(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def none_to_str(value: Optional[str]) -> str:
|
def none_to_str(value: Optional[str]) -> str:
|
||||||
return '' if value is None else value
|
return "" if value is None else value
|
||||||
|
|
||||||
def config_doc_decorator(func: _F) -> _F:
|
def config_doc_decorator(func: _F) -> _F:
|
||||||
func.__doc__ = (doc_template.format(header=none_to_str(header),
|
func.__doc__ = (
|
||||||
extra_note=none_to_str(extra_note))
|
doc_template.format(
|
||||||
+ none_to_str(parameters) + none_to_str(returns)
|
header=none_to_str(header), extra_note=none_to_str(extra_note)
|
||||||
+ none_to_str(common_example) + none_to_str(see_also))
|
)
|
||||||
|
+ none_to_str(parameters)
|
||||||
|
+ none_to_str(returns)
|
||||||
|
+ none_to_str(common_example)
|
||||||
|
+ none_to_str(see_also)
|
||||||
|
)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return cast(_F, wrap)
|
return cast(_F, wrap)
|
||||||
|
|
||||||
return config_doc_decorator
|
return config_doc_decorator
|
||||||
|
|
||||||
|
|
||||||
@config_doc(header="""
|
@config_doc(
|
||||||
|
header="""
|
||||||
Set global configuration.
|
Set global configuration.
|
||||||
""",
|
""",
|
||||||
parameters="""
|
parameters="""
|
||||||
@ -113,7 +121,8 @@ def config_doc(
|
|||||||
----------
|
----------
|
||||||
new_config: Dict[str, Any]
|
new_config: Dict[str, Any]
|
||||||
Keyword arguments representing the parameters and their values
|
Keyword arguments representing the parameters and their values
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
def set_config(**new_config: Any) -> None:
|
def set_config(**new_config: Any) -> None:
|
||||||
not_none = {}
|
not_none = {}
|
||||||
for k, v in new_config.items():
|
for k, v in new_config.items():
|
||||||
@ -123,7 +132,8 @@ def set_config(**new_config: Any) -> None:
|
|||||||
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))
|
||||||
|
|
||||||
|
|
||||||
@config_doc(header="""
|
@config_doc(
|
||||||
|
header="""
|
||||||
Get current values of the global configuration.
|
Get current values of the global configuration.
|
||||||
""",
|
""",
|
||||||
returns="""
|
returns="""
|
||||||
@ -131,7 +141,8 @@ def set_config(**new_config: Any) -> None:
|
|||||||
-------
|
-------
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
The list of global parameters and their values
|
The list of global parameters and their values
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
def get_config() -> Dict[str, Any]:
|
def get_config() -> Dict[str, Any]:
|
||||||
config_str = ctypes.c_char_p()
|
config_str = ctypes.c_char_p()
|
||||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||||
@ -142,7 +153,8 @@ def get_config() -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@config_doc(header="""
|
@config_doc(
|
||||||
|
header="""
|
||||||
Context manager for global XGBoost configuration.
|
Context manager for global XGBoost configuration.
|
||||||
""",
|
""",
|
||||||
parameters="""
|
parameters="""
|
||||||
@ -162,7 +174,8 @@ def get_config() -> Dict[str, Any]:
|
|||||||
--------
|
--------
|
||||||
set_config: Set global XGBoost configuration
|
set_config: Set global XGBoost configuration
|
||||||
get_config: Get current values of the global configuration
|
get_config: Get current values of the global configuration
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
def config_context(**new_config: Any) -> Iterator[None]:
|
def config_context(**new_config: Any) -> Iterator[None]:
|
||||||
old_config = get_config().copy()
|
old_config = get_config().copy()
|
||||||
set_config(**new_config)
|
set_config(**new_config)
|
||||||
|
|||||||
@ -399,11 +399,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
cache_prefix:
|
cache_prefix:
|
||||||
Prefix to the cache files, only used in external memory. It can be either an URI
|
Prefix to the cache files, only used in external memory. It can be either an
|
||||||
or a file path.
|
URI or a file path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||||
self.cache_prefix = cache_prefix
|
self.cache_prefix = cache_prefix
|
||||||
@ -1010,7 +1009,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
number of columns : int
|
number of columns
|
||||||
"""
|
"""
|
||||||
ret = c_bst_ulong()
|
ret = c_bst_ulong()
|
||||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
# pylint: disable=too-many-arguments, too-many-locals
|
||||||
# pylint: disable=missing-class-docstring, invalid-name
|
# pylint: disable=missing-class-docstring, invalid-name
|
||||||
# pylint: disable=too-many-lines, fixme
|
# pylint: disable=too-many-lines
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
"""
|
"""
|
||||||
@ -227,7 +227,7 @@ class RabitContext(rabit.RabitContext):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
|
def dconcat(value: Sequence[_T]) -> _T:
|
||||||
"""Concatenate sequence of partitions."""
|
"""Concatenate sequence of partitions."""
|
||||||
try:
|
try:
|
||||||
return concat(value)
|
return concat(value)
|
||||||
@ -253,7 +253,7 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
|||||||
|
|
||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes
|
||||||
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
||||||
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||||
@ -486,6 +486,12 @@ class DaskDMatrix:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def num_col(self) -> int:
|
def num_col(self) -> int:
|
||||||
|
"""Get the number of columns (features) in the DMatrix.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
number of columns
|
||||||
|
"""
|
||||||
return self._n_cols
|
return self._n_cols
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
"""XGBoost Federated Learning related API."""
|
"""XGBoost Federated Learning related API."""
|
||||||
|
|
||||||
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
|
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
||||||
|
|
||||||
|
|
||||||
def run_federated_server(port: int,
|
def run_federated_server(
|
||||||
|
port: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
server_key_path: str,
|
server_key_path: str,
|
||||||
server_cert_path: str,
|
server_cert_path: str,
|
||||||
client_cert_path: str) -> None:
|
client_cert_path: str,
|
||||||
|
) -> None:
|
||||||
"""Run the Federated Learning server.
|
"""Run the Federated Learning server.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -23,12 +25,16 @@ def run_federated_server(port: int,
|
|||||||
client_cert_path: str
|
client_cert_path: str
|
||||||
Path to the client certificate file.
|
Path to the client certificate file.
|
||||||
"""
|
"""
|
||||||
if build_info()['USE_FEDERATED']:
|
if build_info()["USE_FEDERATED"]:
|
||||||
_check_call(_LIB.XGBRunFederatedServer(port,
|
_check_call(
|
||||||
|
_LIB.XGBRunFederatedServer(
|
||||||
|
port,
|
||||||
world_size,
|
world_size,
|
||||||
c_str(server_key_path),
|
c_str(server_key_path),
|
||||||
c_str(server_cert_path),
|
c_str(server_cert_path),
|
||||||
c_str(client_cert_path)))
|
c_str(client_cert_path),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
"XGBoost needs to be built with the federated learning plugin "
|
"XGBoost needs to be built with the federated learning plugin "
|
||||||
|
|||||||
@ -112,13 +112,25 @@ if __name__ == "__main__":
|
|||||||
if not all(
|
if not all(
|
||||||
run_formatter(path)
|
run_formatter(path)
|
||||||
for path in [
|
for path in [
|
||||||
|
# core
|
||||||
|
"python-package/xgboost/__init__.py",
|
||||||
|
"python-package/xgboost/_typing.py",
|
||||||
|
"python-package/xgboost/compat.py",
|
||||||
|
"python-package/xgboost/config.py",
|
||||||
"python-package/xgboost/dask.py",
|
"python-package/xgboost/dask.py",
|
||||||
"python-package/xgboost/sklearn.py",
|
"python-package/xgboost/sklearn.py",
|
||||||
"python-package/xgboost/spark",
|
"python-package/xgboost/spark",
|
||||||
|
"python-package/xgboost/federated.py",
|
||||||
|
"python-package/xgboost/spark",
|
||||||
|
# tests
|
||||||
"tests/python/test_config.py",
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_spark/test_data.py",
|
"tests/python/test_spark/",
|
||||||
"tests/python-gpu/test_gpu_spark/test_data.py",
|
"tests/python-gpu/test_gpu_spark/",
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
|
# demo
|
||||||
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
"demo/guide-python/categorical.py",
|
||||||
|
"demo/guide-python/spark_estimator_examples.py",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
|
||||||
if tm.no_spark()["condition"]:
|
if tm.no_spark()["condition"]:
|
||||||
@ -13,26 +12,27 @@ if tm.no_spark()["condition"]:
|
|||||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
from pyspark.ml.functions import vector_to_array
|
|
||||||
from pyspark.sql import functions as spark_sql_func
|
|
||||||
from pyspark.ml import Pipeline, PipelineModel
|
from pyspark.ml import Pipeline, PipelineModel
|
||||||
from pyspark.ml.evaluation import (
|
from pyspark.ml.evaluation import (
|
||||||
BinaryClassificationEvaluator,
|
BinaryClassificationEvaluator,
|
||||||
MulticlassClassificationEvaluator,
|
MulticlassClassificationEvaluator,
|
||||||
)
|
)
|
||||||
|
from pyspark.ml.functions import vector_to_array
|
||||||
from pyspark.ml.linalg import Vectors
|
from pyspark.ml.linalg import Vectors
|
||||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||||
|
from pyspark.sql import functions as spark_sql_func
|
||||||
from xgboost.spark import (
|
from xgboost.spark import (
|
||||||
SparkXGBClassifier,
|
SparkXGBClassifier,
|
||||||
SparkXGBClassifierModel,
|
SparkXGBClassifierModel,
|
||||||
SparkXGBRegressor,
|
SparkXGBRegressor,
|
||||||
SparkXGBRegressorModel,
|
SparkXGBRegressorModel,
|
||||||
)
|
)
|
||||||
from .utils import SparkTestCase
|
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
|
||||||
from xgboost.spark.core import _non_booster_params
|
from xgboost.spark.core import _non_booster_params
|
||||||
|
|
||||||
|
from xgboost import XGBClassifier, XGBRegressor
|
||||||
|
|
||||||
|
from .utils import SparkTestCase
|
||||||
|
|
||||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import sys
|
|
||||||
import random
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
|
||||||
if tm.no_spark()["condition"]:
|
if tm.no_spark()["condition"]:
|
||||||
@ -13,10 +13,11 @@ if tm.no_spark()["condition"]:
|
|||||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
from .utils import SparkLocalClusterTestCase
|
from pyspark.ml.linalg import Vectors
|
||||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||||
from xgboost.spark.utils import _get_max_num_concurrent_tasks
|
from xgboost.spark.utils import _get_max_num_concurrent_tasks
|
||||||
from pyspark.ml.linalg import Vectors
|
|
||||||
|
from .utils import SparkLocalClusterTestCase
|
||||||
|
|
||||||
|
|
||||||
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
||||||
|
|||||||
@ -3,22 +3,18 @@ import logging
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from six import StringIO
|
|
||||||
|
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
from six import StringIO
|
||||||
|
|
||||||
if tm.no_spark()["condition"]:
|
if tm.no_spark()["condition"]:
|
||||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
from pyspark.sql import SQLContext
|
from pyspark.sql import SparkSession, SQLContext
|
||||||
from pyspark.sql import SparkSession
|
|
||||||
|
|
||||||
from xgboost.spark.utils import _get_default_params_from_func
|
from xgboost.spark.utils import _get_default_params_from_func
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user