PySpark XGBoost integration (#8020)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
8959622836
commit
176fec8789
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@ -141,7 +141,7 @@ jobs:
|
||||
- name: Install Python packages
|
||||
run: |
|
||||
python -m pip install wheel setuptools
|
||||
python -m pip install pylint cpplint numpy scipy scikit-learn
|
||||
python -m pip install pylint cpplint numpy scipy scikit-learn pyspark pandas cloudpickle
|
||||
- name: Run lint
|
||||
run: |
|
||||
make lint
|
||||
|
||||
1
.github/workflows/python_tests.yml
vendored
1
.github/workflows/python_tests.yml
vendored
@ -92,6 +92,7 @@ jobs:
|
||||
python-tests-on-macos:
|
||||
name: Test XGBoost Python package on ${{ matrix.config.os }}
|
||||
runs-on: ${{ matrix.config.os }}
|
||||
timeout-minutes: 90
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
|
||||
@ -351,7 +351,8 @@ if __name__ == '__main__':
|
||||
'scikit-learn': ['scikit-learn'],
|
||||
'dask': ['dask', 'pandas', 'distributed'],
|
||||
'datatable': ['datatable'],
|
||||
'plotting': ['graphviz', 'matplotlib']
|
||||
'plotting': ['graphviz', 'matplotlib'],
|
||||
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
|
||||
},
|
||||
maintainer='Hyunsu Cho',
|
||||
maintainer_email='chohyu01@cs.washington.edu',
|
||||
|
||||
22
python-package/xgboost/spark/__init__.py
Normal file
22
python-package/xgboost/spark/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# type: ignore
|
||||
"""PySpark XGBoost integration interface
|
||||
"""
|
||||
|
||||
try:
|
||||
import pyspark
|
||||
except ImportError as e:
|
||||
raise ImportError("pyspark package needs to be installed to use this module") from e
|
||||
|
||||
from .estimator import (
|
||||
SparkXGBClassifier,
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRegressor,
|
||||
SparkXGBRegressorModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SparkXGBClassifier",
|
||||
"SparkXGBClassifierModel",
|
||||
"SparkXGBRegressor",
|
||||
"SparkXGBRegressorModel",
|
||||
]
|
||||
881
python-package/xgboost/spark/core.py
Normal file
881
python-package/xgboost/spark/core.py
Normal file
@ -0,0 +1,881 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
from pyspark.ml.param.shared import (
|
||||
HasFeaturesCol,
|
||||
HasLabelCol,
|
||||
HasWeightCol,
|
||||
HasPredictionCol,
|
||||
HasProbabilityCol,
|
||||
HasRawPredictionCol,
|
||||
HasValidationIndicatorCol,
|
||||
)
|
||||
from pyspark.ml.param import Param, Params, TypeConverters
|
||||
from pyspark.ml.util import MLReadable, MLWritable
|
||||
from pyspark.sql.functions import col, pandas_udf, countDistinct, struct
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
IntegerType,
|
||||
LongType,
|
||||
ShortType,
|
||||
)
|
||||
|
||||
import xgboost
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from xgboost.core import Booster
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .data import (
|
||||
_convert_partition_data_to_dmatrix,
|
||||
)
|
||||
from .model import (
|
||||
SparkXGBReader,
|
||||
SparkXGBWriter,
|
||||
SparkXGBModelReader,
|
||||
SparkXGBModelWriter,
|
||||
)
|
||||
from .utils import (
|
||||
get_logger, _get_max_num_concurrent_tasks,
|
||||
_get_default_params_from_func,
|
||||
get_class_name,
|
||||
RabitContext,
|
||||
_get_rabit_args,
|
||||
_get_args_from_message_list,
|
||||
_get_spark_session,
|
||||
)
|
||||
from .params import (
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
)
|
||||
|
||||
# Put pyspark specific params here, they won't be passed to XGBoost.
|
||||
# like `validationIndicatorCol`, `base_margin_col`
|
||||
_pyspark_specific_params = [
|
||||
"featuresCol",
|
||||
"labelCol",
|
||||
"weightCol",
|
||||
"rawPredictionCol",
|
||||
"predictionCol",
|
||||
"probabilityCol",
|
||||
"validationIndicatorCol",
|
||||
"base_margin_col",
|
||||
"arbitrary_params_dict",
|
||||
"force_repartition",
|
||||
"num_workers",
|
||||
"use_gpu",
|
||||
"feature_names",
|
||||
]
|
||||
|
||||
_non_booster_params = [
|
||||
"missing",
|
||||
"n_estimators",
|
||||
"feature_types",
|
||||
"feature_weights",
|
||||
]
|
||||
|
||||
_pyspark_param_alias_map = {
|
||||
"features_col": "featuresCol",
|
||||
"label_col": "labelCol",
|
||||
"weight_col": "weightCol",
|
||||
"raw_prediction_ol": "rawPredictionCol",
|
||||
"prediction_col": "predictionCol",
|
||||
"probability_col": "probabilityCol",
|
||||
"validation_indicator_col": "validationIndicatorCol",
|
||||
}
|
||||
|
||||
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
|
||||
|
||||
_unsupported_xgb_params = [
|
||||
"gpu_id", # we have "use_gpu" pyspark param instead.
|
||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||
"use_label_encoder",
|
||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||
"nthread", # Ditto
|
||||
]
|
||||
|
||||
_unsupported_fit_params = {
|
||||
"sample_weight", # Supported by spark param weightCol
|
||||
# Supported by spark param weightCol # and validationIndicatorCol
|
||||
"eval_set",
|
||||
"sample_weight_eval_set",
|
||||
"base_margin", # Supported by spark param base_margin_col
|
||||
}
|
||||
|
||||
_unsupported_predict_params = {
|
||||
# for classification, we can use rawPrediction as margin
|
||||
"output_margin",
|
||||
"validate_features", # TODO
|
||||
"base_margin", # Use pyspark base_margin_col param instead.
|
||||
}
|
||||
|
||||
|
||||
class _SparkXGBParams(
|
||||
HasFeaturesCol,
|
||||
HasLabelCol,
|
||||
HasWeightCol,
|
||||
HasPredictionCol,
|
||||
HasValidationIndicatorCol,
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
):
|
||||
num_workers = Param(
|
||||
Params._dummy(),
|
||||
"num_workers",
|
||||
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
|
||||
TypeConverters.toInt,
|
||||
)
|
||||
use_gpu = Param(
|
||||
Params._dummy(),
|
||||
"use_gpu",
|
||||
"A boolean variable. Set use_gpu=true if the executors "
|
||||
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
|
||||
)
|
||||
force_repartition = Param(
|
||||
Params._dummy(),
|
||||
"force_repartition",
|
||||
"A boolean variable. Set force_repartition=true if you "
|
||||
+ "want to force the input dataset to be repartitioned before XGBoost training."
|
||||
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
|
||||
+ "to have force_repartition be True.",
|
||||
)
|
||||
feature_names = Param(
|
||||
Params._dummy(), "feature_names", "A list of str to specify feature names."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
"""
|
||||
Subclasses should override this method and
|
||||
returns an xgboost.XGBModel subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# Parameters for xgboost.XGBModel()
|
||||
@classmethod
|
||||
def _get_xgb_params_default(cls):
|
||||
xgb_model_default = cls._xgb_cls()()
|
||||
params_dict = xgb_model_default.get_params()
|
||||
filtered_params_dict = {
|
||||
k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params
|
||||
}
|
||||
return filtered_params_dict
|
||||
|
||||
def _set_xgb_params_default(self):
|
||||
filtered_params_dict = self._get_xgb_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False):
|
||||
xgb_params = {}
|
||||
non_xgb_params = (
|
||||
set(_pyspark_specific_params)
|
||||
| self._get_fit_params_default().keys()
|
||||
| self._get_predict_params_default().keys()
|
||||
)
|
||||
if not gen_xgb_sklearn_estimator_param:
|
||||
non_xgb_params |= set(_non_booster_params)
|
||||
for param in self.extractParamMap():
|
||||
if param.name not in non_xgb_params:
|
||||
xgb_params[param.name] = self.getOrDefault(param)
|
||||
|
||||
arbitrary_params_dict = self.getOrDefault(
|
||||
self.getParam("arbitrary_params_dict")
|
||||
)
|
||||
xgb_params.update(arbitrary_params_dict)
|
||||
return xgb_params
|
||||
|
||||
# Parameters for xgboost.XGBModel().fit()
|
||||
@classmethod
|
||||
def _get_fit_params_default(cls):
|
||||
fit_params = _get_default_params_from_func(
|
||||
cls._xgb_cls().fit, _unsupported_fit_params
|
||||
)
|
||||
return fit_params
|
||||
|
||||
def _set_fit_params_default(self):
|
||||
filtered_params_dict = self._get_fit_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_fit_params_dict(self):
|
||||
"""
|
||||
Returns a dict of params for .fit()
|
||||
"""
|
||||
fit_params_keys = self._get_fit_params_default().keys()
|
||||
fit_params = {}
|
||||
for param in self.extractParamMap():
|
||||
if param.name in fit_params_keys:
|
||||
fit_params[param.name] = self.getOrDefault(param)
|
||||
return fit_params
|
||||
|
||||
# Parameters for xgboost.XGBModel().predict()
|
||||
@classmethod
|
||||
def _get_predict_params_default(cls):
|
||||
predict_params = _get_default_params_from_func(
|
||||
cls._xgb_cls().predict, _unsupported_predict_params
|
||||
)
|
||||
return predict_params
|
||||
|
||||
def _set_predict_params_default(self):
|
||||
filtered_params_dict = self._get_predict_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_predict_params_dict(self):
|
||||
"""
|
||||
Returns a dict of params for .predict()
|
||||
"""
|
||||
predict_params_keys = self._get_predict_params_default().keys()
|
||||
predict_params = {}
|
||||
for param in self.extractParamMap():
|
||||
if param.name in predict_params_keys:
|
||||
predict_params[param.name] = self.getOrDefault(param)
|
||||
return predict_params
|
||||
|
||||
def _validate_params(self):
|
||||
init_model = self.getOrDefault(self.xgb_model)
|
||||
if init_model is not None:
|
||||
if init_model is not None and not isinstance(init_model, Booster):
|
||||
raise ValueError(
|
||||
"The xgb_model param must be set with a `xgboost.core.Booster` "
|
||||
"instance."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.num_workers) < 1:
|
||||
raise ValueError(
|
||||
f"Number of workers was {self.getOrDefault(self.num_workers)}."
|
||||
f"It cannot be less than 1 [Default is 1]"
|
||||
)
|
||||
|
||||
if (
|
||||
self.getOrDefault(self.force_repartition)
|
||||
and self.getOrDefault(self.num_workers) == 1
|
||||
):
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You set force_repartition to true when there is no need for a repartition."
|
||||
"Therefore, that parameter will be ignored."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
tree_method = self.getParam("tree_method")
|
||||
if (
|
||||
self.getOrDefault(tree_method) is not None
|
||||
and self.getOrDefault(tree_method) != "gpu_hist"
|
||||
):
|
||||
raise ValueError(
|
||||
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
|
||||
f"found {self.getOrDefault(tree_method)}."
|
||||
)
|
||||
|
||||
gpu_per_task = (
|
||||
_get_spark_session()
|
||||
.sparkContext.getConf()
|
||||
.get("spark.task.resource.gpu.amount")
|
||||
)
|
||||
|
||||
if not gpu_per_task or int(gpu_per_task) < 1:
|
||||
raise RuntimeError(
|
||||
"The spark cluster does not have the necessary GPU"
|
||||
+ "configuration for the spark task. Therefore, we cannot"
|
||||
+ "run xgboost training using GPU."
|
||||
)
|
||||
|
||||
if int(gpu_per_task) > 1:
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You configured %s GPU cores for each spark task, but in "
|
||||
"XGBoost training, every Spark task will only use one GPU core.",
|
||||
gpu_per_task
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
||||
features_col_datatype = dataset.schema[features_col_name].dataType
|
||||
features_col = col(features_col_name)
|
||||
if isinstance(features_col_datatype, ArrayType):
|
||||
if not isinstance(
|
||||
features_col_datatype.elementType,
|
||||
(DoubleType, FloatType, LongType, IntegerType, ShortType),
|
||||
):
|
||||
raise ValueError(
|
||||
"If feature column is array type, its elements must be number type."
|
||||
)
|
||||
features_array_col = features_col.cast(ArrayType(FloatType())).alias("values")
|
||||
elif isinstance(features_col_datatype, VectorUDT):
|
||||
features_array_col = vector_to_array(features_col, dtype="float32").alias(
|
||||
"values"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"feature column must be array type or `pyspark.ml.linalg.Vector` type, "
|
||||
"if you want to use multiple numetric columns as features, please use "
|
||||
"`pyspark.ml.transform.VectorAssembler` to assemble them into a vector "
|
||||
"type column first."
|
||||
)
|
||||
return features_array_col
|
||||
|
||||
|
||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._set_xgb_params_default()
|
||||
self._set_fit_params_default()
|
||||
self._set_predict_params_default()
|
||||
# Note: The default value for arbitrary_params_dict must always be empty dict.
|
||||
# For additional settings added into "arbitrary_params_dict" by default,
|
||||
# they are added in `setParams`.
|
||||
self._setDefault(
|
||||
num_workers=1,
|
||||
use_gpu=False,
|
||||
force_repartition=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
arbitrary_params_dict={},
|
||||
)
|
||||
|
||||
def setParams(self, **kwargs): # pylint: disable=invalid-name
|
||||
"""
|
||||
Set params for the estimator.
|
||||
"""
|
||||
_extra_params = {}
|
||||
if "arbitrary_params_dict" in kwargs:
|
||||
raise ValueError("Invalid param name: 'arbitrary_params_dict'.")
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if k in _inverse_pyspark_param_alias_map:
|
||||
raise ValueError(
|
||||
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
|
||||
)
|
||||
if k in _pyspark_param_alias_map:
|
||||
real_k = _pyspark_param_alias_map[k]
|
||||
if real_k in kwargs:
|
||||
raise ValueError(
|
||||
f"You should set only one of param '{k}' and '{real_k}'"
|
||||
)
|
||||
k = real_k
|
||||
|
||||
if self.hasParam(k):
|
||||
self._set(**{str(k): v})
|
||||
else:
|
||||
if (
|
||||
k in _unsupported_xgb_params
|
||||
or k in _unsupported_fit_params
|
||||
or k in _unsupported_predict_params
|
||||
):
|
||||
raise ValueError(f"Unsupported param '{k}'.")
|
||||
_extra_params[k] = v
|
||||
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls):
|
||||
"""
|
||||
Subclasses should override this method and
|
||||
returns a _SparkXGBModel subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _create_pyspark_model(self, xgb_model):
|
||||
return self._pyspark_model_cls()(xgb_model)
|
||||
|
||||
def _convert_to_sklearn_model(self, booster):
|
||||
xgb_sklearn_params = self._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
)
|
||||
sklearn_model = self._xgb_cls()(**xgb_sklearn_params)
|
||||
sklearn_model._Booster = booster
|
||||
return sklearn_model
|
||||
|
||||
def _query_plan_contains_valid_repartition(self, dataset):
|
||||
"""
|
||||
Returns true if the latest element in the logical plan is a valid repartition
|
||||
The logic plan string format is like:
|
||||
|
||||
== Optimized Logical Plan ==
|
||||
Repartition 4, true
|
||||
+- LogicalRDD [features#12, label#13L], false
|
||||
|
||||
i.e., the top line in the logical plan is the last operation to execute.
|
||||
so, in this method, we check the first line, if it is a "Repartition" operation,
|
||||
and the result dataframe has the same partition number with num_workers param,
|
||||
then it means the dataframe is well repartitioned and we don't need to
|
||||
repartition the dataframe again.
|
||||
"""
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
|
||||
dataset._jdf.queryExecution(), "extended"
|
||||
)
|
||||
start = query_plan.index("== Optimized Logical Plan ==")
|
||||
start += len("== Optimized Logical Plan ==") + 1
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
if (
|
||||
query_plan[start : start + len("Repartition")] == "Repartition"
|
||||
and num_workers == num_partitions
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _repartition_needed(self, dataset):
|
||||
"""
|
||||
We repartition the dataset if the number of workers is not equal to the number of
|
||||
partitions. There is also a check to make sure there was "active partitioning"
|
||||
where either Round Robin or Hash partitioning was actively used before this stage.
|
||||
"""
|
||||
if self.getOrDefault(self.force_repartition):
|
||||
return True
|
||||
try:
|
||||
if self._query_plan_contains_valid_repartition(dataset):
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
return True
|
||||
|
||||
def _get_distributed_train_params(self, dataset):
|
||||
"""
|
||||
This just gets the configuration params for distributed xgboost
|
||||
"""
|
||||
params = self._gen_xgb_params_dict()
|
||||
fit_params = self._gen_fit_params_dict()
|
||||
verbose_eval = fit_params.pop("verbose", None)
|
||||
|
||||
params.update(fit_params)
|
||||
params["verbose_eval"] = verbose_eval
|
||||
classification = self._xgb_cls() == XGBClassifier
|
||||
num_classes = int(dataset.select(countDistinct("label")).collect()[0][0])
|
||||
if classification and num_classes == 2:
|
||||
params["objective"] = "binary:logistic"
|
||||
elif classification and num_classes > 2:
|
||||
params["objective"] = "multi:softprob"
|
||||
params["num_class"] = num_classes
|
||||
else:
|
||||
params["objective"] = "reg:squarederror"
|
||||
|
||||
# TODO: support "num_parallel_tree" for random forest
|
||||
params["num_boost_round"] = self.getOrDefault(self.n_estimators)
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
params["tree_method"] = "gpu_hist"
|
||||
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def _get_xgb_train_call_args(cls, train_params):
|
||||
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
|
||||
booster_params, kwargs_params = {}, {}
|
||||
for key, value in train_params.items():
|
||||
if key in xgb_train_default_args:
|
||||
kwargs_params[key] = value
|
||||
else:
|
||||
booster_params[key] = value
|
||||
return booster_params, kwargs_params
|
||||
|
||||
def _fit(self, dataset):
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
self._validate_params()
|
||||
label_col = col(self.getOrDefault(self.labelCol)).alias("label")
|
||||
|
||||
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
select_cols = [features_array_col, label_col]
|
||||
|
||||
has_weight = False
|
||||
has_validation = False
|
||||
has_base_margin = False
|
||||
|
||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
||||
has_weight = True
|
||||
select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight"))
|
||||
|
||||
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
||||
self.validationIndicatorCol
|
||||
):
|
||||
has_validation = True
|
||||
select_cols.append(
|
||||
col(self.getOrDefault(self.validationIndicatorCol)).alias(
|
||||
"validationIndicator"
|
||||
)
|
||||
)
|
||||
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
):
|
||||
has_base_margin = True
|
||||
select_cols.append(
|
||||
col(self.getOrDefault(self.base_margin_col)).alias("baseMargin")
|
||||
)
|
||||
|
||||
dataset = dataset.select(*select_cols)
|
||||
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
sc = _get_spark_session().sparkContext
|
||||
max_concurrent_tasks = _get_max_num_concurrent_tasks(sc)
|
||||
|
||||
if num_workers > max_concurrent_tasks:
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"The num_workers %s set for xgboost distributed "
|
||||
"training is greater than current max number of concurrent "
|
||||
"spark task slots, you need wait until more task slots available "
|
||||
"or you need increase spark cluster workers.",
|
||||
num_workers
|
||||
)
|
||||
|
||||
if self._repartition_needed(dataset):
|
||||
dataset = dataset.repartition(num_workers)
|
||||
train_params = self._get_distributed_train_params(dataset)
|
||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||
train_params
|
||||
)
|
||||
|
||||
cpu_per_task = int(
|
||||
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
||||
)
|
||||
dmatrix_kwargs = {
|
||||
"nthread": cpu_per_task,
|
||||
"feature_types": self.getOrDefault(self.feature_types),
|
||||
"feature_names": self.getOrDefault(self.feature_names),
|
||||
"feature_weights": self.getOrDefault(self.feature_weights),
|
||||
"missing": self.getOrDefault(self.missing),
|
||||
}
|
||||
booster_params["nthread"] = cpu_per_task
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""
|
||||
Takes in an RDD partition and outputs a booster for that partition after going through
|
||||
the Rabit Ring protocol
|
||||
"""
|
||||
from pyspark import BarrierTaskContext
|
||||
|
||||
context = BarrierTaskContext.get()
|
||||
context.barrier()
|
||||
|
||||
if use_gpu:
|
||||
# Set booster worker to use the first GPU allocated to the spark task.
|
||||
booster_params["gpu_id"] = int(
|
||||
context._resources["gpu"].addresses[0].strip()
|
||||
)
|
||||
|
||||
_rabit_args = ""
|
||||
if context.partitionId() == 0:
|
||||
_rabit_args = str(_get_rabit_args(context, num_workers))
|
||||
|
||||
messages = context.allGather(message=str(_rabit_args))
|
||||
_rabit_args = _get_args_from_message_list(messages)
|
||||
evals_result = {}
|
||||
with RabitContext(_rabit_args, context):
|
||||
dtrain, dval = None, []
|
||||
if has_validation:
|
||||
dtrain, dval = _convert_partition_data_to_dmatrix(
|
||||
pandas_df_iter,
|
||||
has_weight,
|
||||
has_validation,
|
||||
has_base_margin,
|
||||
dmatrix_kwargs=dmatrix_kwargs,
|
||||
)
|
||||
# TODO: Question: do we need to add dtrain to dval list ?
|
||||
dval = [(dtrain, "training"), (dval, "validation")]
|
||||
else:
|
||||
dtrain = _convert_partition_data_to_dmatrix(
|
||||
pandas_df_iter,
|
||||
has_weight,
|
||||
has_validation,
|
||||
has_base_margin,
|
||||
dmatrix_kwargs=dmatrix_kwargs,
|
||||
)
|
||||
|
||||
booster = worker_train(
|
||||
params=booster_params,
|
||||
dtrain=dtrain,
|
||||
evals=dval,
|
||||
evals_result=evals_result,
|
||||
**train_call_kwargs_params,
|
||||
)
|
||||
context.barrier()
|
||||
|
||||
if context.partitionId() == 0:
|
||||
yield pd.DataFrame(data={"booster_bytes": [cloudpickle.dumps(booster)]})
|
||||
|
||||
result_ser_booster = (
|
||||
dataset.mapInPandas(_train_booster, schema="booster_bytes binary")
|
||||
.rdd.barrier()
|
||||
.mapPartitions(lambda x: x)
|
||||
.collect()[0][0]
|
||||
)
|
||||
result_xgb_model = self._convert_to_sklearn_model(
|
||||
cloudpickle.loads(result_ser_booster)
|
||||
)
|
||||
return self._copyValues(self._create_pyspark_model(result_xgb_model))
|
||||
|
||||
def write(self):
|
||||
"""
|
||||
Return the writer for saving the estimator.
|
||||
"""
|
||||
return SparkXGBWriter(self)
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
"""
|
||||
Return the reader for loading the estimator.
|
||||
"""
|
||||
return SparkXGBReader(cls)
|
||||
|
||||
|
||||
class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self, xgb_sklearn_model=None):
|
||||
super().__init__()
|
||||
self._xgb_sklearn_model = xgb_sklearn_model
|
||||
|
||||
def get_booster(self):
|
||||
"""
|
||||
Return the `xgboost.core.Booster` instance.
|
||||
"""
|
||||
return self._xgb_sklearn_model.get_booster()
|
||||
|
||||
def get_feature_importances(self, importance_type="weight"):
|
||||
"""Get feature importance of each feature.
|
||||
Importance type can be defined as:
|
||||
|
||||
* 'weight': the number of times a feature is used to split the data across all trees.
|
||||
* 'gain': the average gain across all splits the feature is used in.
|
||||
* 'cover': the average coverage across all splits the feature is used in.
|
||||
* 'total_gain': the total gain across all splits the feature is used in.
|
||||
* 'total_cover': the total coverage across all splits the feature is used in.
|
||||
|
||||
.. note:: Feature importance is defined only for tree boosters
|
||||
|
||||
Feature importance is only defined when the decision tree model is chosen as base
|
||||
learner (`booster=gbtree`). It is not defined for other base learner types, such
|
||||
as linear learners (`booster=gblinear`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
importance_type: str, default 'weight'
|
||||
One of the importance types defined above.
|
||||
"""
|
||||
return self.get_booster().get_score(importance_type=importance_type)
|
||||
|
||||
def write(self):
|
||||
"""
|
||||
Return the writer for saving the model.
|
||||
"""
|
||||
return SparkXGBModelWriter(self)
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
"""
|
||||
Return the reader for loading the model.
|
||||
"""
|
||||
return SparkXGBModelReader(cls)
|
||||
|
||||
def _transform(self, dataset):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBRegressor
|
||||
|
||||
def _transform(self, dataset):
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
|
||||
has_base_margin = False
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
"baseMargin"
|
||||
)
|
||||
|
||||
@pandas_udf("double")
|
||||
def predict_udf(input_data: pd.DataFrame) -> pd.Series:
|
||||
X = np.array(input_data["values"].tolist())
|
||||
if has_base_margin:
|
||||
base_margin = input_data["baseMargin"].to_numpy()
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
preds = xgb_sklearn_model.predict(
|
||||
X, base_margin=base_margin, validate_features=False, **predict_params
|
||||
)
|
||||
return pd.Series(preds)
|
||||
|
||||
features_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
|
||||
if has_base_margin:
|
||||
pred_col = predict_udf(struct(features_col, base_margin_col))
|
||||
else:
|
||||
pred_col = predict_udf(struct(features_col))
|
||||
|
||||
predictionColName = self.getOrDefault(self.predictionCol)
|
||||
|
||||
return dataset.withColumn(predictionColName, pred_col)
|
||||
|
||||
|
||||
class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBClassifier
|
||||
|
||||
def _transform(self, dataset):
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
|
||||
has_base_margin = False
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
"baseMargin"
|
||||
)
|
||||
|
||||
@pandas_udf(
|
||||
"rawPrediction array<double>, prediction double, probability array<double>"
|
||||
)
|
||||
def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame:
|
||||
X = np.array(input_data["values"].tolist())
|
||||
if has_base_margin:
|
||||
base_margin = input_data["baseMargin"].to_numpy()
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
margins = xgb_sklearn_model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
output_margin=True,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
if margins.ndim == 1:
|
||||
# binomial case
|
||||
classone_probs = expit(margins)
|
||||
classzero_probs = 1.0 - classone_probs
|
||||
raw_preds = np.vstack((-margins, margins)).transpose()
|
||||
class_probs = np.vstack((classzero_probs, classone_probs)).transpose()
|
||||
else:
|
||||
# multinomial case
|
||||
raw_preds = margins
|
||||
class_probs = softmax(raw_preds, axis=1)
|
||||
|
||||
# It seems that they use argmax of class probs,
|
||||
# not of margin to get the prediction (Note: scala implementation)
|
||||
preds = np.argmax(class_probs, axis=1)
|
||||
return pd.DataFrame(
|
||||
data={
|
||||
"rawPrediction": pd.Series(raw_preds.tolist()),
|
||||
"prediction": pd.Series(preds),
|
||||
"probability": pd.Series(class_probs.tolist()),
|
||||
}
|
||||
)
|
||||
|
||||
features_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
|
||||
if has_base_margin:
|
||||
pred_struct = predict_udf(struct(features_col, base_margin_col))
|
||||
else:
|
||||
pred_struct = predict_udf(struct(features_col))
|
||||
|
||||
pred_struct_col = "_prediction_struct"
|
||||
|
||||
rawPredictionColName = self.getOrDefault(self.rawPredictionCol)
|
||||
predictionColName = self.getOrDefault(self.predictionCol)
|
||||
probabilityColName = self.getOrDefault(self.probabilityCol)
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_struct)
|
||||
if rawPredictionColName:
|
||||
dataset = dataset.withColumn(
|
||||
rawPredictionColName,
|
||||
array_to_vector(col(pred_struct_col).rawPrediction),
|
||||
)
|
||||
if predictionColName:
|
||||
dataset = dataset.withColumn(
|
||||
predictionColName, col(pred_struct_col).prediction
|
||||
)
|
||||
if probabilityColName:
|
||||
dataset = dataset.withColumn(
|
||||
probabilityColName, array_to_vector(col(pred_struct_col).probability)
|
||||
)
|
||||
|
||||
return dataset.drop(pred_struct_col)
|
||||
|
||||
|
||||
def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
|
||||
params_dict = pyspark_estimator_class._get_xgb_params_default()
|
||||
|
||||
def param_value_converter(v):
|
||||
if isinstance(v, np.generic):
|
||||
# convert numpy scalar values to corresponding python scalar values
|
||||
return np.array(v).item()
|
||||
if isinstance(v, dict):
|
||||
return {k: param_value_converter(nv) for k, nv in v.items()}
|
||||
if isinstance(v, list):
|
||||
return [param_value_converter(nv) for nv in v]
|
||||
return v
|
||||
|
||||
def set_param_attrs(attr_name, param_obj_):
|
||||
param_obj_.typeConverter = param_value_converter
|
||||
setattr(pyspark_estimator_class, attr_name, param_obj_)
|
||||
setattr(pyspark_model_class, attr_name, param_obj_)
|
||||
|
||||
for name in params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of "
|
||||
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
|
||||
)
|
||||
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
fit_params_dict = pyspark_estimator_class._get_fit_params_default()
|
||||
for name in fit_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
|
||||
f".fit() for this param {name}"
|
||||
)
|
||||
if name == "callbacks":
|
||||
doc += (
|
||||
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
|
||||
"which is not a fully self-contained format. It may fail to load with "
|
||||
"different versions of dependencies."
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
predict_params_dict = pyspark_estimator_class._get_predict_params_default()
|
||||
for name in predict_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
|
||||
f".predict() for this param {name}"
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
192
python-package/xgboost/spark/data.py
Normal file
192
python-package/xgboost/spark/data.py
Normal file
@ -0,0 +1,192 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for data related functions."""
|
||||
# pylint: disable=too-many-arguments
|
||||
from typing import Iterator
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from xgboost import DMatrix
|
||||
|
||||
|
||||
def _prepare_train_val_data(
|
||||
data_iterator, has_weight, has_validation, has_fit_base_margin
|
||||
):
|
||||
def gen_data_pdf():
|
||||
for pdf in data_iterator:
|
||||
yield pdf
|
||||
|
||||
return _process_data_iter(
|
||||
gen_data_pdf(),
|
||||
train=True,
|
||||
has_weight=has_weight,
|
||||
has_validation=has_validation,
|
||||
has_fit_base_margin=has_fit_base_margin,
|
||||
has_predict_base_margin=False,
|
||||
)
|
||||
|
||||
|
||||
def _check_feature_dims(num_dims, expected_dims):
|
||||
"""
|
||||
Check all feature vectors has the same dimension
|
||||
"""
|
||||
if expected_dims is None:
|
||||
return num_dims
|
||||
if num_dims != expected_dims:
|
||||
raise ValueError(
|
||||
f"Rows contain different feature dimensions: Expecting {expected_dims}, got {num_dims}."
|
||||
)
|
||||
return expected_dims
|
||||
|
||||
|
||||
def _row_tuple_list_to_feature_matrix_y_w(
|
||||
data_iterator,
|
||||
train,
|
||||
has_weight,
|
||||
has_fit_base_margin,
|
||||
has_predict_base_margin,
|
||||
has_validation: bool = False,
|
||||
):
|
||||
"""
|
||||
Construct a feature matrix in ndarray format, label array y and weight array w
|
||||
from the row_tuple_list.
|
||||
If train == False, y and w will be None.
|
||||
If has_weight == False, w will be None.
|
||||
If has_base_margin == False, b_m will be None.
|
||||
Note: the row_tuple_list will be cleared during
|
||||
executing for reducing peak memory consumption
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
expected_feature_dims = None
|
||||
label_list, weight_list, base_margin_list = [], [], []
|
||||
label_val_list, weight_val_list, base_margin_val_list = [], [], []
|
||||
values_list, values_val_list = [], []
|
||||
|
||||
# Process rows
|
||||
for pdf in data_iterator:
|
||||
if len(pdf) == 0:
|
||||
continue
|
||||
if train and has_validation:
|
||||
pdf_val = pdf.loc[pdf["validationIndicator"], :]
|
||||
pdf = pdf.loc[~pdf["validationIndicator"], :]
|
||||
|
||||
num_feature_dims = len(pdf["values"].values[0])
|
||||
|
||||
expected_feature_dims = _check_feature_dims(
|
||||
num_feature_dims, expected_feature_dims
|
||||
)
|
||||
|
||||
# Note: each element in `pdf["values"]` is an numpy array.
|
||||
values_list.append(pdf["values"].to_list())
|
||||
if train:
|
||||
label_list.append(pdf["label"].to_numpy())
|
||||
if has_weight:
|
||||
weight_list.append(pdf["weight"].to_numpy())
|
||||
if has_fit_base_margin or has_predict_base_margin:
|
||||
base_margin_list.append(pdf["baseMargin"].to_numpy())
|
||||
if has_validation:
|
||||
values_val_list.append(pdf_val["values"].to_list())
|
||||
if train:
|
||||
label_val_list.append(pdf_val["label"].to_numpy())
|
||||
if has_weight:
|
||||
weight_val_list.append(pdf_val["weight"].to_numpy())
|
||||
if has_fit_base_margin or has_predict_base_margin:
|
||||
base_margin_val_list.append(pdf_val["baseMargin"].to_numpy())
|
||||
|
||||
# Construct feature_matrix
|
||||
if expected_feature_dims is None:
|
||||
return [], [], [], []
|
||||
|
||||
# Construct feature_matrix, y and w
|
||||
feature_matrix = np.concatenate(values_list)
|
||||
y = np.concatenate(label_list) if train else None
|
||||
w = np.concatenate(weight_list) if has_weight else None
|
||||
b_m = (
|
||||
np.concatenate(base_margin_list)
|
||||
if (has_fit_base_margin or has_predict_base_margin)
|
||||
else None
|
||||
)
|
||||
if has_validation:
|
||||
feature_matrix_val = np.concatenate(values_val_list)
|
||||
y_val = np.concatenate(label_val_list) if train else None
|
||||
w_val = np.concatenate(weight_val_list) if has_weight else None
|
||||
b_m_val = (
|
||||
np.concatenate(base_margin_val_list)
|
||||
if (has_fit_base_margin or has_predict_base_margin)
|
||||
else None
|
||||
)
|
||||
return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val
|
||||
return feature_matrix, y, w, b_m
|
||||
|
||||
|
||||
def _process_data_iter(
|
||||
data_iterator: Iterator[pd.DataFrame],
|
||||
train: bool,
|
||||
has_weight: bool,
|
||||
has_validation: bool,
|
||||
has_fit_base_margin: bool = False,
|
||||
has_predict_base_margin: bool = False,
|
||||
):
|
||||
"""
|
||||
If input is for train and has_validation=True, it will split the train data into train dataset
|
||||
and validation dataset, and return (train_X, train_y, train_w, train_b_m <-
|
||||
train base margin, val_X, val_y, val_w, val_b_m <- validation base margin)
|
||||
otherwise return (X, y, w, b_m <- base margin)
|
||||
"""
|
||||
return _row_tuple_list_to_feature_matrix_y_w(
|
||||
data_iterator,
|
||||
train,
|
||||
has_weight,
|
||||
has_fit_base_margin,
|
||||
has_predict_base_margin,
|
||||
has_validation,
|
||||
)
|
||||
|
||||
|
||||
def _convert_partition_data_to_dmatrix(
|
||||
partition_data_iter,
|
||||
has_weight,
|
||||
has_validation,
|
||||
has_base_margin,
|
||||
dmatrix_kwargs=None,
|
||||
):
|
||||
# pylint: disable=too-many-locals, unbalanced-tuple-unpacking
|
||||
dmatrix_kwargs = dmatrix_kwargs or {}
|
||||
# if we are not using external storage, we use the standard method of parsing data.
|
||||
train_val_data = _prepare_train_val_data(
|
||||
partition_data_iter, has_weight, has_validation, has_base_margin
|
||||
)
|
||||
if has_validation:
|
||||
(
|
||||
train_x,
|
||||
train_y,
|
||||
train_w,
|
||||
train_b_m,
|
||||
val_x,
|
||||
val_y,
|
||||
val_w,
|
||||
val_b_m,
|
||||
) = train_val_data
|
||||
training_dmatrix = DMatrix(
|
||||
data=train_x,
|
||||
label=train_y,
|
||||
weight=train_w,
|
||||
base_margin=train_b_m,
|
||||
**dmatrix_kwargs,
|
||||
)
|
||||
val_dmatrix = DMatrix(
|
||||
data=val_x,
|
||||
label=val_y,
|
||||
weight=val_w,
|
||||
base_margin=val_b_m,
|
||||
**dmatrix_kwargs,
|
||||
)
|
||||
return training_dmatrix, val_dmatrix
|
||||
|
||||
train_x, train_y, train_w, train_b_m = train_val_data
|
||||
training_dmatrix = DMatrix(
|
||||
data=train_x,
|
||||
label=train_y,
|
||||
weight=train_w,
|
||||
base_margin=train_b_m,
|
||||
**dmatrix_kwargs,
|
||||
)
|
||||
return training_dmatrix
|
||||
203
python-package/xgboost/spark/estimator.py
Normal file
203
python-package/xgboost/spark/estimator.py
Normal file
@ -0,0 +1,203 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for estimator API."""
|
||||
# pylint: disable=too-many-ancestors
|
||||
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from .core import (
|
||||
_SparkXGBEstimator,
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRegressorModel,
|
||||
_set_pyspark_xgb_cls_param_attrs,
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
"""
|
||||
SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
|
||||
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
||||
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.
|
||||
|
||||
SparkXGBRegressor automatically supports most of the parameters in
|
||||
`xgboost.XGBRegressor` constructor and most of the parameters used in
|
||||
`xgboost.XGBRegressor` fit and predict method (see `API docs <https://xgboost.readthedocs\
|
||||
.io/en/latest/python/python_api.html#xgboost.XGBRegressor>`_ for details).
|
||||
|
||||
SparkXGBRegressor doesn't support setting `gpu_id` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
|
||||
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support
|
||||
another param called `base_margin_col`. see doc below for more details.
|
||||
|
||||
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
||||
|
||||
callbacks:
|
||||
The export and import of the callback functions are at best effort.
|
||||
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
|
||||
validationIndicatorCol
|
||||
For params related to `xgboost.XGBRegressor` training
|
||||
with evaluation dataset's supervision, set
|
||||
:py:attr:`xgboost.spark.SparkXGBRegressor.validationIndicatorCol`
|
||||
parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor`
|
||||
fit method.
|
||||
weightCol:
|
||||
To specify the weight of the training and validation dataset, set
|
||||
:py:attr:`xgboost.spark.SparkXGBRegressor.weightCol` parameter instead of setting
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor`
|
||||
fit method.
|
||||
xgb_model:
|
||||
Set the value to be the instance returned by
|
||||
:func:`xgboost.spark.SparkXGBRegressorModel.get_booster`.
|
||||
num_workers:
|
||||
Integer that specifies the number of XGBoost workers to use.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean that specifies whether the executors are running on GPU
|
||||
instances.
|
||||
base_margin_col:
|
||||
To specify the base margins of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBRegressor.base_margin_col` parameter
|
||||
instead of setting `base_margin` and `base_margin_eval_set` in the
|
||||
`xgboost.XGBRegressor` fit method. Note: this isn't available for distributed
|
||||
training.
|
||||
|
||||
.. Note:: The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from xgboost.spark import SparkXGBRegressor
|
||||
>>> from pyspark.ml.linalg import Vectors
|
||||
>>> df_train = spark.createDataFrame([
|
||||
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
... (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
|
||||
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
|
||||
... ], ["features", "label", "isVal", "weight"])
|
||||
>>> df_test = spark.createDataFrame([
|
||||
... (Vectors.dense(1.0, 2.0, 3.0), ),
|
||||
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), )
|
||||
... ], ["features"])
|
||||
>>> xgb_regressor = SparkXGBRegressor(max_depth=5, missing=0.0,
|
||||
... validation_indicator_col='isVal', weight_col='weight',
|
||||
... early_stopping_rounds=1, eval_metric='rmse')
|
||||
>>> xgb_reg_model = xgb_regressor.fit(df_train)
|
||||
>>> xgb_reg_model.transform(df_test)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBRegressor
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls):
|
||||
return SparkXGBRegressorModel
|
||||
|
||||
|
||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
|
||||
|
||||
|
||||
class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPredictionCol):
|
||||
"""
|
||||
SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost classification
|
||||
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
||||
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.
|
||||
|
||||
SparkXGBClassifier automatically supports most of the parameters in
|
||||
`xgboost.XGBClassifier` constructor and most of the parameters used in
|
||||
`xgboost.XGBClassifier` fit and predict method (see `API docs <https://xgboost.readthedocs\
|
||||
.io/en/latest/python/python_api.html#xgboost.XGBClassifier>`_ for details).
|
||||
|
||||
SparkXGBClassifier doesn't support setting `gpu_id` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support
|
||||
another param called `base_margin_col`. see doc below for more details.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin
|
||||
from the raw prediction column. See `rawPredictionCol` param doc below for more details.
|
||||
|
||||
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callbacks:
|
||||
The export and import of the callback functions are at best effort. For
|
||||
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
|
||||
rawPredictionCol:
|
||||
The `output_margin=True` is implicitly supported by the
|
||||
`rawPredictionCol` output column, which is always returned with the predicted margin
|
||||
values.
|
||||
validationIndicatorCol:
|
||||
For params related to `xgboost.XGBClassifier` training with
|
||||
evaluation dataset's supervision,
|
||||
set :py:attr:`xgboost.spark.SparkXGBClassifier.validationIndicatorCol`
|
||||
parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier`
|
||||
fit method.
|
||||
weightCol:
|
||||
To specify the weight of the training and validation dataset, set
|
||||
:py:attr:`xgboost.spark.SparkXGBClassifier.weightCol` parameter instead of setting
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
|
||||
fit method.
|
||||
xgb_model:
|
||||
Set the value to be the instance returned by
|
||||
:func:`xgboost.spark.SparkXGBClassifierModel.get_booster`.
|
||||
num_workers:
|
||||
Integer that specifies the number of XGBoost workers to use.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean that specifies whether the executors are running on GPU
|
||||
instances.
|
||||
base_margin_col:
|
||||
To specify the base margins of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBClassifier.base_margin_col` parameter
|
||||
instead of setting `base_margin` and `base_margin_eval_set` in the
|
||||
`xgboost.XGBClassifier` fit method. Note: this isn't available for distributed
|
||||
training.
|
||||
|
||||
.. Note:: The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from xgboost.spark import SparkXGBClassifier
|
||||
>>> from pyspark.ml.linalg import Vectors
|
||||
>>> df_train = spark.createDataFrame([
|
||||
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
... ], ["features", "label", "isVal", "weight"])
|
||||
>>> df_test = spark.createDataFrame([
|
||||
... (Vectors.dense(1.0, 2.0, 3.0), ),
|
||||
... ], ["features"])
|
||||
>>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0,
|
||||
... validation_indicator_col='isVal', weight_col='weight',
|
||||
... early_stopping_rounds=1, eval_metric='logloss')
|
||||
>>> xgb_clf_model = xgb_classifier.fit(df_train)
|
||||
>>> xgb_clf_model.transform(df_test).show()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBClassifier
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls):
|
||||
return SparkXGBClassifierModel
|
||||
|
||||
|
||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
|
||||
270
python-package/xgboost/spark/model.py
Normal file
270
python-package/xgboost/spark/model.py
Normal file
@ -0,0 +1,270 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for model API."""
|
||||
# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from pyspark import cloudpickle
|
||||
from pyspark import SparkFiles
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
|
||||
from xgboost.core import Booster
|
||||
|
||||
from .utils import get_logger, get_class_name
|
||||
|
||||
|
||||
def _get_or_create_tmp_dir():
|
||||
root_dir = SparkFiles.getRootDirectory()
|
||||
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
|
||||
if not os.path.exists(xgb_tmp_dir):
|
||||
os.makedirs(xgb_tmp_dir)
|
||||
return xgb_tmp_dir
|
||||
|
||||
|
||||
def serialize_xgb_model(model):
|
||||
"""
|
||||
Serialize the input model to a string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
an xgboost.XGBModel instance, such as
|
||||
xgboost.XGBClassifier or xgboost.XGBRegressor instance
|
||||
"""
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
model.save_model(tmp_file_name)
|
||||
with open(tmp_file_name, "r", encoding="utf-8") as f:
|
||||
ser_model_string = f.read()
|
||||
return ser_model_string
|
||||
|
||||
|
||||
def deserialize_xgb_model(ser_model_string, xgb_model_creator):
|
||||
"""
|
||||
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
|
||||
"""
|
||||
xgb_model = xgb_model_creator()
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
||||
f.write(ser_model_string)
|
||||
xgb_model.load_model(tmp_file_name)
|
||||
return xgb_model
|
||||
|
||||
|
||||
def serialize_booster(booster):
|
||||
"""
|
||||
Serialize the input booster to a string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster:
|
||||
an xgboost.core.Booster instance
|
||||
"""
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
booster.save_model(tmp_file_name)
|
||||
with open(tmp_file_name, encoding="utf-8") as f:
|
||||
ser_model_string = f.read()
|
||||
return ser_model_string
|
||||
|
||||
|
||||
def deserialize_booster(ser_model_string):
|
||||
"""
|
||||
Deserialize an xgboost.core.Booster from the input ser_model_string.
|
||||
"""
|
||||
booster = Booster()
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
||||
f.write(ser_model_string)
|
||||
booster.load_model(tmp_file_name)
|
||||
return booster
|
||||
|
||||
|
||||
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
|
||||
|
||||
|
||||
def _get_spark_session():
|
||||
return SparkSession.builder.getOrCreate()
|
||||
|
||||
|
||||
class _SparkXGBSharedReadWrite:
|
||||
@staticmethod
|
||||
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
|
||||
"""
|
||||
Save the metadata of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
"""
|
||||
instance._validate_params()
|
||||
skipParams = ["callbacks", "xgb_model"]
|
||||
jsonParams = {}
|
||||
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
|
||||
if p.name not in skipParams:
|
||||
jsonParams[p.name] = v
|
||||
|
||||
extraMetadata = extraMetadata or {}
|
||||
callbacks = instance.getOrDefault(instance.callbacks)
|
||||
if callbacks is not None:
|
||||
logger.warning(
|
||||
"The callbacks parameter is saved using cloudpickle and it "
|
||||
"is not a fully self-contained format. It may fail to load "
|
||||
"with different versions of dependencies."
|
||||
)
|
||||
serialized_callbacks = base64.encodebytes(
|
||||
cloudpickle.dumps(callbacks)
|
||||
).decode("ascii")
|
||||
extraMetadata["serialized_callbacks"] = serialized_callbacks
|
||||
init_booster = instance.getOrDefault(instance.xgb_model)
|
||||
if init_booster is not None:
|
||||
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
|
||||
DefaultParamsWriter.saveMetadata(
|
||||
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
|
||||
)
|
||||
if init_booster is not None:
|
||||
ser_init_booster = serialize_booster(init_booster)
|
||||
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
|
||||
_get_spark_session().createDataFrame(
|
||||
[(ser_init_booster,)], ["init_booster"]
|
||||
).write.parquet(save_path)
|
||||
|
||||
@staticmethod
|
||||
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
|
||||
"""
|
||||
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
|
||||
:return: a tuple of (metadata, instance)
|
||||
"""
|
||||
metadata = DefaultParamsReader.loadMetadata(
|
||||
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
|
||||
)
|
||||
pyspark_xgb = pyspark_xgb_cls()
|
||||
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
|
||||
|
||||
if "serialized_callbacks" in metadata:
|
||||
serialized_callbacks = metadata["serialized_callbacks"]
|
||||
try:
|
||||
callbacks = cloudpickle.loads(
|
||||
base64.decodebytes(serialized_callbacks.encode("ascii"))
|
||||
)
|
||||
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.warning(
|
||||
f"Fails to load the callbacks param due to {e}. Please set the "
|
||||
"callbacks param manually for the loaded estimator."
|
||||
)
|
||||
|
||||
if "init_booster" in metadata:
|
||||
load_path = os.path.join(path, metadata["init_booster"])
|
||||
ser_init_booster = (
|
||||
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
|
||||
)
|
||||
init_booster = deserialize_booster(ser_init_booster)
|
||||
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
|
||||
|
||||
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
|
||||
return metadata, pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost estimator writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
save model.
|
||||
"""
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
|
||||
|
||||
class SparkXGBReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost estimator reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
load model.
|
||||
"""
|
||||
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
return pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBModelWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost model writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
Save metadata and model for a :py:class:`_SparkXGBModel`
|
||||
- save metadata to path/metadata
|
||||
- save model to path/model.json
|
||||
"""
|
||||
xgb_model = self.instance._xgb_sklearn_model
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
model_save_path = os.path.join(path, "model.json")
|
||||
ser_xgb_model = serialize_xgb_model(xgb_model)
|
||||
_get_spark_session().createDataFrame(
|
||||
[(ser_xgb_model,)], ["xgb_sklearn_model"]
|
||||
).write.parquet(model_save_path)
|
||||
|
||||
|
||||
class SparkXGBModelReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost model reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
Load metadata and model for a :py:class:`_SparkXGBModel`
|
||||
|
||||
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
|
||||
"""
|
||||
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
|
||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True)
|
||||
model_load_path = os.path.join(path, "model.json")
|
||||
|
||||
ser_xgb_model = (
|
||||
_get_spark_session()
|
||||
.read.parquet(model_load_path)
|
||||
.collect()[0]
|
||||
.xgb_sklearn_model
|
||||
)
|
||||
|
||||
def create_xgb_model():
|
||||
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
||||
|
||||
xgb_model = deserialize_xgb_model(
|
||||
ser_xgb_model, create_xgb_model
|
||||
)
|
||||
py_model._xgb_sklearn_model = xgb_model
|
||||
return py_model
|
||||
33
python-package/xgboost/spark/params.py
Normal file
33
python-package/xgboost/spark/params.py
Normal file
@ -0,0 +1,33 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for params."""
|
||||
# pylint: disable=too-few-public-methods
|
||||
from pyspark.ml.param.shared import Param, Params
|
||||
|
||||
|
||||
class HasArbitraryParamsDict(Params):
|
||||
"""
|
||||
This is a Params based class that is extended by _SparkXGBParams
|
||||
and holds the variable to store the **kwargs parts of the XGBoost
|
||||
input.
|
||||
"""
|
||||
|
||||
arbitrary_params_dict = Param(
|
||||
Params._dummy(),
|
||||
"arbitrary_params_dict",
|
||||
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
||||
"not exposed as the the XGBoost Spark estimator params but can be recognized by "
|
||||
"underlying XGBoost library. It is stored as a dictionary.",
|
||||
)
|
||||
|
||||
|
||||
class HasBaseMarginCol(Params):
|
||||
"""
|
||||
This is a Params based class that is extended by _SparkXGBParams
|
||||
and holds the variable to store the base margin column part of XGboost.
|
||||
"""
|
||||
|
||||
base_margin_col = Param(
|
||||
Params._dummy(),
|
||||
"base_margin_col",
|
||||
"This stores the name for the column of the base margin",
|
||||
)
|
||||
130
python-package/xgboost/spark/utils.py
Normal file
130
python-package/xgboost/spark/utils.py
Normal file
@ -0,0 +1,130 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for helper functions."""
|
||||
import inspect
|
||||
from threading import Thread
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import pyspark
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost.tracker import RabitTracker
|
||||
|
||||
|
||||
def get_class_name(cls):
|
||||
"""
|
||||
Return the class name.
|
||||
"""
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
|
||||
def _get_default_params_from_func(func, unsupported_set):
|
||||
"""
|
||||
Returns a dictionary of parameters and their default value of function fn.
|
||||
Only the parameters with a default value will be included.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
filtered_params_dict = {}
|
||||
for parameter in sig.parameters.values():
|
||||
# Remove parameters without a default value and those in the unsupported_set
|
||||
if (
|
||||
parameter.default is not parameter.empty
|
||||
and parameter.name not in unsupported_set
|
||||
):
|
||||
filtered_params_dict[parameter.name] = parameter.default
|
||||
return filtered_params_dict
|
||||
|
||||
|
||||
class RabitContext:
|
||||
"""
|
||||
A context controlling rabit initialization and finalization.
|
||||
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
||||
"""
|
||||
|
||||
def __init__(self, args, context):
|
||||
self.args = args
|
||||
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
|
||||
|
||||
def __enter__(self):
|
||||
rabit.init(self.args)
|
||||
|
||||
def __exit__(self, *args):
|
||||
rabit.finalize()
|
||||
|
||||
|
||||
def _start_tracker(context, n_workers):
|
||||
"""
|
||||
Start Rabit tracker with n_workers
|
||||
"""
|
||||
env = {"DMLC_NUM_WORKER": n_workers}
|
||||
host = _get_host_ip(context)
|
||||
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return env
|
||||
|
||||
|
||||
def _get_rabit_args(context, n_workers):
|
||||
"""
|
||||
Get rabit context arguments to send to each worker.
|
||||
"""
|
||||
# pylint: disable=consider-using-f-string
|
||||
env = _start_tracker(context, n_workers)
|
||||
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
|
||||
return rabit_args
|
||||
|
||||
|
||||
def _get_host_ip(context):
|
||||
"""
|
||||
Gets the hostIP for Spark. This essentially gets the IP of the first worker.
|
||||
"""
|
||||
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
|
||||
return task_ip_list[0]
|
||||
|
||||
|
||||
def _get_args_from_message_list(messages):
|
||||
"""
|
||||
A function to send/recieve messages in barrier context mode
|
||||
"""
|
||||
output = ""
|
||||
for message in messages:
|
||||
if message != "":
|
||||
output = message
|
||||
break
|
||||
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
|
||||
|
||||
|
||||
def _get_spark_session():
|
||||
"""Get or create spark session. Note: This function can only be invoked from driver side."""
|
||||
if pyspark.TaskContext.get() is not None:
|
||||
# This is a safety check.
|
||||
raise RuntimeError(
|
||||
"_get_spark_session should not be invoked from executor side."
|
||||
)
|
||||
return SparkSession.builder.getOrCreate()
|
||||
|
||||
|
||||
def get_logger(name, level="INFO"):
|
||||
"""Gets a logger by name, or creates and configures it for the first time."""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
# If the logger is configured, skip the configure
|
||||
if not logger.handlers and not logging.getLogger().handlers:
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
def _get_max_num_concurrent_tasks(spark_context):
|
||||
"""Gets the current max number of concurrent tasks."""
|
||||
# pylint: disable=protected-access
|
||||
# spark 3.1 and above has a different API for fetching max concurrent tasks
|
||||
if spark_context._jsc.sc().version() >= "3.1":
|
||||
return spark_context._jsc.sc().maxNumConcurrentTasks(
|
||||
spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
|
||||
)
|
||||
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
||||
@ -10,7 +10,7 @@ RUN \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:ubuntu-toolchain-r/test && \
|
||||
apt-get update && \
|
||||
apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 && \
|
||||
apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 openjdk-8-jdk-headless && \
|
||||
# CMake
|
||||
wget -nv -nc https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.14.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
|
||||
@ -24,6 +24,7 @@ ENV CXX=g++-8
|
||||
ENV CPP=cpp-8
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
|
||||
|
||||
# Create new Conda environment
|
||||
COPY conda_env/cpu_test.yml /scripts/
|
||||
|
||||
@ -10,7 +10,7 @@ SHELL ["/bin/bash", "-c"] # Use Bash as shell
|
||||
RUN \
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \
|
||||
apt-get update && \
|
||||
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \
|
||||
apt-get install -y wget unzip bzip2 libgomp1 build-essential openjdk-8-jdk-headless && \
|
||||
# Python
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python
|
||||
@ -19,11 +19,14 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Create new Conda environment with cuDF, Dask, and cuPy
|
||||
RUN \
|
||||
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
conda install -c conda-forge mamba && \
|
||||
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.8 cudf=22.04* rmm=22.04* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda=22.04* dask-cudf=22.04* cupy \
|
||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
|
||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
||||
pyspark cloudpickle cuda-python=11.7.0
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
|
||||
@ -28,6 +28,8 @@ dependencies:
|
||||
- llvmlite
|
||||
- cffi
|
||||
- pyarrow
|
||||
- pyspark
|
||||
- cloudpickle
|
||||
- pip:
|
||||
- shap
|
||||
- awscli
|
||||
|
||||
@ -36,6 +36,8 @@ dependencies:
|
||||
- cffi
|
||||
- pyarrow
|
||||
- protobuf<=3.20
|
||||
- pyspark
|
||||
- cloudpickle
|
||||
- pip:
|
||||
- shap
|
||||
- ipython # required by shap at import time.
|
||||
|
||||
@ -35,6 +35,8 @@ dependencies:
|
||||
- py-ubjson
|
||||
- cffi
|
||||
- pyarrow
|
||||
- pyspark
|
||||
- cloudpickle
|
||||
- pip:
|
||||
- sphinx_rtd_theme
|
||||
- datatable
|
||||
|
||||
@ -34,6 +34,18 @@ function install_xgboost {
|
||||
fi
|
||||
}
|
||||
|
||||
function setup_pyspark_envs {
|
||||
export PYSPARK_DRIVER_PYTHON=`which python`
|
||||
export PYSPARK_PYTHON=`which python`
|
||||
export SPARK_TESTING=1
|
||||
}
|
||||
|
||||
function unset_pyspark_envs {
|
||||
unset PYSPARK_DRIVER_PYTHON
|
||||
unset PYSPARK_PYTHON
|
||||
unset SPARK_TESTING
|
||||
}
|
||||
|
||||
function uninstall_xgboost {
|
||||
pip uninstall -y xgboost
|
||||
}
|
||||
@ -43,14 +55,18 @@ case "$suite" in
|
||||
gpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
setup_pyspark_envs
|
||||
pytest -v -s -rxXs --fulltrace --durations=0 -m "not mgpu" ${args} tests/python-gpu
|
||||
unset_pyspark_envs
|
||||
uninstall_xgboost
|
||||
;;
|
||||
|
||||
mgpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
setup_pyspark_envs
|
||||
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
|
||||
unset_pyspark_envs
|
||||
|
||||
cd tests/distributed
|
||||
./runtests-gpu.sh
|
||||
@ -61,7 +77,9 @@ case "$suite" in
|
||||
source activate cpu_test
|
||||
install_xgboost
|
||||
export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1
|
||||
setup_pyspark_envs
|
||||
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
|
||||
unset_pyspark_envs
|
||||
cd tests/distributed
|
||||
./runtests.sh
|
||||
uninstall_xgboost
|
||||
@ -70,7 +88,9 @@ case "$suite" in
|
||||
cpu-arm64)
|
||||
source activate aarch64_test
|
||||
install_xgboost
|
||||
setup_pyspark_envs
|
||||
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python/test_basic.py tests/python/test_basic_models.py tests/python/test_model_compatibility.py
|
||||
unset_pyspark_envs
|
||||
uninstall_xgboost
|
||||
;;
|
||||
|
||||
|
||||
@ -44,13 +44,15 @@ def pytest_addoption(parser):
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption('--use-rmm-pool'):
|
||||
if config.getoption("--use-rmm-pool"):
|
||||
blocklist = [
|
||||
'python-gpu/test_gpu_demos.py::test_dask_training',
|
||||
'python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap',
|
||||
'python-gpu/test_gpu_linear.py::TestGPULinear'
|
||||
"python-gpu/test_gpu_demos.py::test_dask_training",
|
||||
"python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap",
|
||||
"python-gpu/test_gpu_linear.py::TestGPULinear",
|
||||
]
|
||||
skip_mark = pytest.mark.skip(reason='This test is not run when --use-rmm-pool flag is active')
|
||||
skip_mark = pytest.mark.skip(
|
||||
reason="This test is not run when --use-rmm-pool flag is active"
|
||||
)
|
||||
for item in items:
|
||||
if any(item.nodeid.startswith(x) for x in blocklist):
|
||||
item.add_marker(skip_mark)
|
||||
@ -58,5 +60,9 @@ def pytest_collection_modifyitems(config, items):
|
||||
# mark dask tests as `mgpu`.
|
||||
mgpu_mark = pytest.mark.mgpu
|
||||
for item in items:
|
||||
if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py"):
|
||||
if item.nodeid.startswith(
|
||||
"python-gpu/test_gpu_with_dask.py"
|
||||
) or item.nodeid.startswith(
|
||||
"python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"
|
||||
):
|
||||
item.add_marker(mgpu_mark)
|
||||
|
||||
3
tests/python-gpu/test_spark_with_gpu/discover_gpu.sh
Executable file
3
tests/python-gpu/test_spark_with_gpu/discover_gpu.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
|
||||
120
tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py
Normal file
120
tests/python-gpu/test_spark_with_gpu/test_spark_with_gpu.py
Normal file
@ -0,0 +1,120 @@
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
import sklearn
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
if tm.no_dask()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def spark_session_with_gpu():
|
||||
spark_config = {
|
||||
"spark.master": "local-cluster[1, 4, 1024]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
"spark.cores.max": "4",
|
||||
"spark.task.cpus": "1",
|
||||
"spark.executor.cores": "4",
|
||||
"spark.worker.resource.gpu.amount": "4",
|
||||
"spark.task.resource.gpu.amount": "1",
|
||||
"spark.executor.resource.gpu.amount": "4",
|
||||
"spark.worker.resource.gpu.discoveryScript": "tests/python-gpu/test_spark_with_gpu/discover_gpu.sh",
|
||||
}
|
||||
builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU")
|
||||
for k, v in spark_config.items():
|
||||
builder.config(k, v)
|
||||
spark = builder.getOrCreate()
|
||||
logging.getLogger("pyspark").setLevel(logging.INFO)
|
||||
# We run a dummy job so that we block until the workers have connected to the master
|
||||
spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions(
|
||||
lambda _: []
|
||||
).collect()
|
||||
yield spark
|
||||
spark.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_iris_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_iris()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_diabetes_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_diabetes()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
use_gpu=True,
|
||||
num_workers=4,
|
||||
)
|
||||
train_df, test_df = spark_iris_dataset
|
||||
model = classifier.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
|
||||
regressor = SparkXGBRegressor(
|
||||
use_gpu=True,
|
||||
num_workers=4,
|
||||
)
|
||||
train_df, test_df = spark_diabetes_dataset
|
||||
model = regressor.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = RegressionEvaluator(metricName="rmse")
|
||||
rmse = evaluator.evaluate(pred_result_df)
|
||||
assert rmse <= 65.0
|
||||
0
tests/python/test_spark/__init__.py
Normal file
0
tests/python/test_spark/__init__.py
Normal file
168
tests/python/test_spark/test_data.py
Normal file
168
tests/python/test_spark/test_data.py
Normal file
@ -0,0 +1,168 @@
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from xgboost.spark.data import (
|
||||
_row_tuple_list_to_feature_matrix_y_w,
|
||||
_convert_partition_data_to_dmatrix,
|
||||
)
|
||||
|
||||
from xgboost import DMatrix, XGBClassifier
|
||||
from xgboost.training import train as worker_train
|
||||
from .utils import SparkTestCase
|
||||
import logging
|
||||
|
||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DataTest(SparkTestCase):
|
||||
def test_sparse_dense_vector(self):
|
||||
def row_tup_iter(data):
|
||||
pdf = pd.DataFrame(data)
|
||||
yield pdf
|
||||
|
||||
expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]}
|
||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
||||
list(row_tup_iter(data)),
|
||||
train=False,
|
||||
has_weight=False,
|
||||
has_fit_base_margin=False,
|
||||
has_predict_base_margin=False,
|
||||
)
|
||||
self.assertIsNone(y)
|
||||
self.assertIsNone(w)
|
||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
||||
|
||||
data["label"] = [1, 0]
|
||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
||||
row_tup_iter(data),
|
||||
train=True,
|
||||
has_weight=False,
|
||||
has_fit_base_margin=False,
|
||||
has_predict_base_margin=False,
|
||||
)
|
||||
self.assertIsNone(w)
|
||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
||||
self.assertTrue(np.array_equal(y, np.array(data["label"])))
|
||||
|
||||
data["weight"] = [0.2, 0.8]
|
||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
||||
list(row_tup_iter(data)),
|
||||
train=True,
|
||||
has_weight=True,
|
||||
has_fit_base_margin=False,
|
||||
has_predict_base_margin=False,
|
||||
)
|
||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
||||
self.assertTrue(np.array_equal(y, np.array(data["label"])))
|
||||
self.assertTrue(np.array_equal(w, np.array(data["weight"])))
|
||||
|
||||
def test_dmatrix_creator(self):
|
||||
|
||||
# This function acts as a pseudo-itertools.chain()
|
||||
def row_tup_iter(data):
|
||||
pdf = pd.DataFrame(data)
|
||||
yield pdf
|
||||
|
||||
# Standard testing DMatrix creation
|
||||
expected_features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
|
||||
expected_labels = np.array([1, 0] * 100)
|
||||
expected_dmatrix = DMatrix(data=expected_features, label=expected_labels)
|
||||
|
||||
data = {
|
||||
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
|
||||
"label": [1, 0] * 100,
|
||||
}
|
||||
output_dmatrix = _convert_partition_data_to_dmatrix(
|
||||
[pd.DataFrame(data)],
|
||||
has_weight=False,
|
||||
has_validation=False,
|
||||
has_base_margin=False,
|
||||
)
|
||||
# You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using
|
||||
# the same classifier and making sure the outputs are equal
|
||||
model = XGBClassifier()
|
||||
model.fit(expected_features, expected_labels)
|
||||
expected_preds = model.get_booster().predict(expected_dmatrix)
|
||||
output_preds = model.get_booster().predict(output_dmatrix)
|
||||
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
|
||||
|
||||
# DMatrix creation with weights
|
||||
expected_weight = np.array([0.2, 0.8] * 100)
|
||||
expected_dmatrix = DMatrix(
|
||||
data=expected_features, label=expected_labels, weight=expected_weight
|
||||
)
|
||||
|
||||
data["weight"] = [0.2, 0.8] * 100
|
||||
output_dmatrix = _convert_partition_data_to_dmatrix(
|
||||
[pd.DataFrame(data)],
|
||||
has_weight=True,
|
||||
has_validation=False,
|
||||
has_base_margin=False,
|
||||
)
|
||||
|
||||
model.fit(expected_features, expected_labels, sample_weight=expected_weight)
|
||||
expected_preds = model.get_booster().predict(expected_dmatrix)
|
||||
output_preds = model.get_booster().predict(output_dmatrix)
|
||||
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
|
||||
|
||||
def test_external_storage(self):
|
||||
# Instantiating base data (features, labels)
|
||||
features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
|
||||
labels = np.array([1, 0] * 100)
|
||||
normal_dmatrix = DMatrix(features, labels)
|
||||
test_dmatrix = DMatrix(features)
|
||||
|
||||
data = {
|
||||
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
|
||||
"label": [1, 0] * 100,
|
||||
}
|
||||
|
||||
# Creating the dmatrix based on storage
|
||||
temporary_path = tempfile.mkdtemp()
|
||||
storage_dmatrix = _convert_partition_data_to_dmatrix(
|
||||
[pd.DataFrame(data)],
|
||||
has_weight=False,
|
||||
has_validation=False,
|
||||
has_base_margin=False,
|
||||
)
|
||||
|
||||
# Testing without weights
|
||||
normal_booster = worker_train({}, normal_dmatrix)
|
||||
storage_booster = worker_train({}, storage_dmatrix)
|
||||
normal_preds = normal_booster.predict(test_dmatrix)
|
||||
storage_preds = storage_booster.predict(test_dmatrix)
|
||||
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
|
||||
shutil.rmtree(temporary_path)
|
||||
|
||||
# Testing weights
|
||||
weights = np.array([0.2, 0.8] * 100)
|
||||
normal_dmatrix = DMatrix(data=features, label=labels, weight=weights)
|
||||
data["weight"] = [0.2, 0.8] * 100
|
||||
|
||||
temporary_path = tempfile.mkdtemp()
|
||||
storage_dmatrix = _convert_partition_data_to_dmatrix(
|
||||
[pd.DataFrame(data)],
|
||||
has_weight=True,
|
||||
has_validation=False,
|
||||
has_base_margin=False,
|
||||
)
|
||||
|
||||
normal_booster = worker_train({}, normal_dmatrix)
|
||||
storage_booster = worker_train({}, storage_dmatrix)
|
||||
normal_preds = normal_booster.predict(test_dmatrix)
|
||||
storage_preds = storage_booster.predict(test_dmatrix)
|
||||
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
|
||||
shutil.rmtree(temporary_path)
|
||||
971
tests/python/test_spark/test_spark_local.py
Normal file
971
tests/python/test_spark/test_spark_local.py
Normal file
@ -0,0 +1,971 @@
|
||||
import sys
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from pyspark.ml.functions import vector_to_array
|
||||
from pyspark.sql import functions as spark_sql_func
|
||||
from pyspark.ml import Pipeline, PipelineModel
|
||||
from pyspark.ml.evaluation import (
|
||||
BinaryClassificationEvaluator,
|
||||
MulticlassClassificationEvaluator,
|
||||
)
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
|
||||
from xgboost.spark import (
|
||||
SparkXGBClassifier,
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRegressor,
|
||||
SparkXGBRegressorModel,
|
||||
)
|
||||
from .utils import SparkTestCase
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from xgboost.spark.core import _non_booster_params
|
||||
|
||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
logging.getLogger().setLevel("INFO")
|
||||
random.seed(2020)
|
||||
|
||||
# The following code use xgboost python library to train xgb model and predict.
|
||||
#
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> y = np.array([0, 1])
|
||||
# >>> reg1 = xgboost.XGBRegressor()
|
||||
# >>> reg1.fit(X, y)
|
||||
# >>> reg1.predict(X)
|
||||
# array([8.8375784e-04, 9.9911624e-01], dtype=float32)
|
||||
# >>> def custom_lr(boosting_round):
|
||||
# ... return 1.0 / (boosting_round + 1)
|
||||
# ...
|
||||
# >>> reg1.fit(X, y, callbacks=[xgboost.callback.LearningRateScheduler(custom_lr)])
|
||||
# >>> reg1.predict(X)
|
||||
# array([0.02406844, 0.9759315 ], dtype=float32)
|
||||
# >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10)
|
||||
# >>> reg2.fit(X, y)
|
||||
# >>> reg2.predict(X, ntree_limit=5)
|
||||
# array([0.22185266, 0.77814734], dtype=float32)
|
||||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
|
||||
self.reg_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
],
|
||||
["features", "label"],
|
||||
)
|
||||
self.reg_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction",
|
||||
"expected_prediction_with_params",
|
||||
"expected_prediction_with_callbacks",
|
||||
],
|
||||
)
|
||||
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> y = np.array([0, 1])
|
||||
# >>> cl1 = xgboost.XGBClassifier()
|
||||
# >>> cl1.fit(X, y)
|
||||
# >>> cl1.predict(X)
|
||||
# array([0, 0])
|
||||
# >>> cl1.predict_proba(X)
|
||||
# array([[0.5, 0.5],
|
||||
# [0.5, 0.5]], dtype=float32)
|
||||
# >>> cl2 = xgboost.XGBClassifier(max_depth=5, n_estimators=10, scale_pos_weight=4)
|
||||
# >>> cl2.fit(X, y)
|
||||
# >>> cl2.predict(X)
|
||||
# array([1, 1])
|
||||
# >>> cl2.predict_proba(X)
|
||||
# array([[0.27574146, 0.72425854 ],
|
||||
# [0.27574146, 0.72425854 ]], dtype=float32)
|
||||
self.cls_params = {"max_depth": 5, "n_estimators": 10, "scale_pos_weight": 4}
|
||||
|
||||
cls_df_train_data = [
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
]
|
||||
self.cls_df_train = self.session.createDataFrame(
|
||||
cls_df_train_data, ["features", "label"]
|
||||
)
|
||||
self.cls_df_train_large = self.session.createDataFrame(
|
||||
cls_df_train_data * 100, ["features", "label"]
|
||||
)
|
||||
self.cls_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
0,
|
||||
[0.5, 0.5],
|
||||
1,
|
||||
[0.27574146, 0.72425854],
|
||||
),
|
||||
(
|
||||
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
|
||||
0,
|
||||
[0.5, 0.5],
|
||||
1,
|
||||
[0.27574146, 0.72425854],
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction",
|
||||
"expected_probability",
|
||||
"expected_prediction_with_params",
|
||||
"expected_probability_with_params",
|
||||
],
|
||||
)
|
||||
|
||||
# kwargs test (using the above data, train, we get the same results)
|
||||
self.cls_params_kwargs = {"tree_method": "approx", "sketch_eps": 0.03}
|
||||
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]])
|
||||
# >>> y = np.array([0, 0, 1, 2])
|
||||
# >>> cl = xgboost.XGBClassifier()
|
||||
# >>> cl.fit(X, y)
|
||||
# >>> cl.predict_proba(np.array([[1.0, 2.0, 3.0]]))
|
||||
# array([[0.5374299 , 0.23128504, 0.23128504]], dtype=float32)
|
||||
multi_cls_df_train_data = [
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.dense(1.0, 2.0, 4.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
(Vectors.dense(-1.0, -2.0, 1.0), 2),
|
||||
]
|
||||
self.multi_cls_df_train = self.session.createDataFrame(
|
||||
multi_cls_df_train_data, ["features", "label"]
|
||||
)
|
||||
self.multi_cls_df_train_large = self.session.createDataFrame(
|
||||
multi_cls_df_train_data * 100, ["features", "label"]
|
||||
)
|
||||
self.multi_cls_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), [0.5374, 0.2312, 0.2312]),
|
||||
],
|
||||
["features", "expected_probability"],
|
||||
)
|
||||
|
||||
# Test regressor with weight and eval set
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
# >>> y = np.array([0, 1, 2, 3])
|
||||
# >>> reg1 = xgboost.XGBRegressor()
|
||||
# >>> reg1.fit(X, y, sample_weight=w)
|
||||
# >>> reg1.predict(X)
|
||||
# >>> array([1.0679445e-03, 1.0000550e+00, ...
|
||||
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> y_train = np.array([0, 1])
|
||||
# >>> y_val = np.array([2, 3])
|
||||
# >>> w_train = np.array([1.0, 2.0])
|
||||
# >>> w_val = np.array([1.0, 2.0])
|
||||
# >>> reg2 = xgboost.XGBRegressor()
|
||||
# >>> reg2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
|
||||
# >>> early_stopping_rounds=1, eval_metric='rmse')
|
||||
# >>> reg2.predict(X)
|
||||
# >>> array([8.8370638e-04, 9.9911624e-01, ...
|
||||
# >>> reg2.best_score
|
||||
# 2.0000002682208837
|
||||
# >>> reg3 = xgboost.XGBRegressor()
|
||||
# >>> reg3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
|
||||
# >>> sample_weight_eval_set=[w_val],
|
||||
# >>> early_stopping_rounds=1, eval_metric='rmse')
|
||||
# >>> reg3.predict(X)
|
||||
# >>> array([0.03155671, 0.98874104,...
|
||||
# >>> reg3.best_score
|
||||
# 1.9970891552124017
|
||||
self.reg_df_train_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.reg_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "rmse",
|
||||
}
|
||||
self.reg_df_test_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction_with_weight",
|
||||
"expected_prediction_with_eval",
|
||||
"expected_prediction_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.reg_with_eval_best_score = 2.0
|
||||
self.reg_with_eval_and_weight_best_score = 1.997
|
||||
|
||||
# Test classifier with weight and eval set
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
# >>> y = np.array([0, 1, 0, 1])
|
||||
# >>> cls1 = xgboost.XGBClassifier()
|
||||
# >>> cls1.fit(X, y, sample_weight=w)
|
||||
# >>> cls1.predict_proba(X)
|
||||
# array([[0.3333333, 0.6666667],...
|
||||
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> y_train = np.array([0, 1])
|
||||
# >>> y_val = np.array([0, 1])
|
||||
# >>> w_train = np.array([1.0, 2.0])
|
||||
# >>> w_val = np.array([1.0, 2.0])
|
||||
# >>> cls2 = xgboost.XGBClassifier()
|
||||
# >>> cls2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
|
||||
# >>> early_stopping_rounds=1, eval_metric='logloss')
|
||||
# >>> cls2.predict_proba(X)
|
||||
# array([[0.5, 0.5],...
|
||||
# >>> cls2.best_score
|
||||
# 0.6931
|
||||
# >>> cls3 = xgboost.XGBClassifier()
|
||||
# >>> cls3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
|
||||
# >>> sample_weight_eval_set=[w_val],
|
||||
# >>> early_stopping_rounds=1, eval_metric='logloss')
|
||||
# >>> cls3.predict_proba(X)
|
||||
# array([[0.3344962, 0.6655038],...
|
||||
# >>> cls3.best_score
|
||||
# 0.6365
|
||||
self.cls_df_train_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.cls_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "logloss",
|
||||
}
|
||||
self.cls_df_test_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
[0.3333, 0.6666],
|
||||
[0.5, 0.5],
|
||||
[0.3097, 0.6903],
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prob_with_weight",
|
||||
"expected_prob_with_eval",
|
||||
"expected_prob_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.cls_with_eval_best_score = 0.6931
|
||||
self.cls_with_eval_and_weight_best_score = 0.6378
|
||||
|
||||
# Test classifier with both base margin and without
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
# >>> y = np.array([0, 1, 0, 1])
|
||||
# >>> base_margin = np.array([1,0,0,1])
|
||||
#
|
||||
# This is without the base margin
|
||||
# >>> cls1 = xgboost.XGBClassifier()
|
||||
# >>> cls1.fit(X, y, sample_weight=w)
|
||||
# >>> cls1.predict_proba(np.array([[1.0, 2.0, 3.0]]))
|
||||
# array([[0.3333333, 0.6666667]], dtype=float32)
|
||||
# >>> cls1.predict(np.array([[1.0, 2.0, 3.0]]))
|
||||
# array([1])
|
||||
#
|
||||
# This is with the same base margin for predict
|
||||
# >>> cls2 = xgboost.XGBClassifier()
|
||||
# >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin)
|
||||
# >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
|
||||
# array([[0.44142532, 0.5585747 ]], dtype=float32)
|
||||
# >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
|
||||
# array([1])
|
||||
#
|
||||
# This is with a different base margin for predict
|
||||
# # >>> cls2 = xgboost.XGBClassifier()
|
||||
# >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin)
|
||||
# >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[1])
|
||||
# array([[0.2252, 0.7747 ]], dtype=float32)
|
||||
# >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
|
||||
# array([1])
|
||||
self.cls_df_train_without_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0),
|
||||
],
|
||||
["features", "label", "weight"],
|
||||
)
|
||||
self.cls_df_test_without_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prob_without_base_margin",
|
||||
"expected_prediction_without_base_margin",
|
||||
],
|
||||
)
|
||||
|
||||
self.cls_df_train_with_same_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1),
|
||||
],
|
||||
["features", "label", "weight", "base_margin"],
|
||||
)
|
||||
self.cls_df_test_with_same_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"base_margin",
|
||||
"expected_prob_with_base_margin",
|
||||
"expected_prediction_with_base_margin",
|
||||
],
|
||||
)
|
||||
|
||||
self.cls_df_train_with_different_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1),
|
||||
],
|
||||
["features", "label", "weight", "base_margin"],
|
||||
)
|
||||
self.cls_df_test_with_different_base_margin = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"base_margin",
|
||||
"expected_prob_with_base_margin",
|
||||
"expected_prediction_with_base_margin",
|
||||
],
|
||||
)
|
||||
|
||||
def get_local_tmp_dir(self):
|
||||
return self.tempdir + str(uuid.uuid4())
|
||||
|
||||
def test_regressor_params_basic(self):
|
||||
py_reg = SparkXGBRegressor()
|
||||
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
||||
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
||||
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
||||
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
||||
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200)
|
||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10)
|
||||
|
||||
def test_classifier_params_basic(self):
|
||||
py_cls = SparkXGBClassifier()
|
||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
||||
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
||||
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
||||
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200)
|
||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10)
|
||||
|
||||
def test_classifier_kwargs_basic(self):
|
||||
py_cls = SparkXGBClassifier(**self.cls_params_kwargs)
|
||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||
self.assertTrue(hasattr(py_cls, "arbitrary_params_dict"))
|
||||
expected_kwargs = {"sketch_eps": 0.03}
|
||||
self.assertEqual(
|
||||
py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs
|
||||
)
|
||||
|
||||
# Testing overwritten params
|
||||
py_cls = SparkXGBClassifier()
|
||||
py_cls.setParams(x=1, y=2)
|
||||
py_cls.setParams(y=3, z=4)
|
||||
xgb_params = py_cls._gen_xgb_params_dict()
|
||||
assert xgb_params["x"] == 1
|
||||
assert xgb_params["y"] == 3
|
||||
assert xgb_params["z"] == 4
|
||||
|
||||
def test_param_alias(self):
|
||||
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1")
|
||||
with pytest.raises(
|
||||
ValueError, match="Please use param name features_col instead"
|
||||
):
|
||||
SparkXGBClassifier(featuresCol="f1")
|
||||
|
||||
def test_gpu_param_setting(self):
|
||||
py_cls = SparkXGBClassifier(use_gpu=True)
|
||||
train_params = py_cls._get_distributed_train_params(self.cls_df_train)
|
||||
assert train_params["tree_method"] == "gpu_hist"
|
||||
|
||||
@staticmethod
|
||||
def test_param_value_converter():
|
||||
py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3))
|
||||
# don't check by isintance(v, float) because for numpy scalar it will also return True
|
||||
assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float"
|
||||
assert (
|
||||
py_cls.getOrDefault(py_cls.arbitrary_params_dict)[
|
||||
"sketch_eps"
|
||||
].__class__.__name__
|
||||
== "float64"
|
||||
)
|
||||
|
||||
def test_regressor_basic(self):
|
||||
regressor = SparkXGBRegressor()
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
pred_result = model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
|
||||
)
|
||||
|
||||
def test_classifier_basic(self):
|
||||
classifier = SparkXGBClassifier()
|
||||
model = classifier.fit(self.cls_df_train)
|
||||
pred_result = model.transform(self.cls_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.expected_prediction)
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
|
||||
)
|
||||
|
||||
def test_multi_classifier(self):
|
||||
classifier = SparkXGBClassifier()
|
||||
model = classifier.fit(self.multi_cls_df_train)
|
||||
pred_result = model.transform(self.multi_cls_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
|
||||
)
|
||||
|
||||
def _check_sub_dict_match(self, sub_dist, whole_dict, excluding_keys):
|
||||
for k in sub_dist:
|
||||
if k not in excluding_keys:
|
||||
self.assertTrue(k in whole_dict, f"check on {k} failed")
|
||||
self.assertEqual(sub_dist[k], whole_dict[k], f"check on {k} failed")
|
||||
|
||||
def test_regressor_with_params(self):
|
||||
regressor = SparkXGBRegressor(**self.reg_params)
|
||||
all_params = dict(
|
||||
**(regressor._gen_xgb_params_dict()),
|
||||
**(regressor._gen_fit_params_dict()),
|
||||
**(regressor._gen_predict_params_dict()),
|
||||
)
|
||||
self._check_sub_dict_match(
|
||||
self.reg_params, all_params, excluding_keys=_non_booster_params
|
||||
)
|
||||
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
all_params = dict(
|
||||
**(model._gen_xgb_params_dict()),
|
||||
**(model._gen_fit_params_dict()),
|
||||
**(model._gen_predict_params_dict()),
|
||||
)
|
||||
self._check_sub_dict_match(
|
||||
self.reg_params, all_params, excluding_keys=_non_booster_params
|
||||
)
|
||||
pred_result = model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_with_params(self):
|
||||
classifier = SparkXGBClassifier(**self.cls_params)
|
||||
all_params = dict(
|
||||
**(classifier._gen_xgb_params_dict()),
|
||||
**(classifier._gen_fit_params_dict()),
|
||||
**(classifier._gen_predict_params_dict()),
|
||||
)
|
||||
self._check_sub_dict_match(
|
||||
self.cls_params, all_params, excluding_keys=_non_booster_params
|
||||
)
|
||||
|
||||
model = classifier.fit(self.cls_df_train)
|
||||
all_params = dict(
|
||||
**(model._gen_xgb_params_dict()),
|
||||
**(model._gen_fit_params_dict()),
|
||||
**(model._gen_predict_params_dict()),
|
||||
)
|
||||
self._check_sub_dict_match(
|
||||
self.cls_params, all_params, excluding_keys=_non_booster_params
|
||||
)
|
||||
pred_result = model.transform(self.cls_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.expected_prediction_with_params)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_probability_with_params, rtol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_regressor_model_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
regressor = SparkXGBRegressor(**self.reg_params)
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
model.save(path)
|
||||
loaded_model = SparkXGBRegressorModel.load(path)
|
||||
self.assertEqual(model.uid, loaded_model.uid)
|
||||
for k, v in self.reg_params.items():
|
||||
self.assertEqual(loaded_model.getOrDefault(k), v)
|
||||
|
||||
pred_result = loaded_model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||
SparkXGBClassifierModel.load(path)
|
||||
|
||||
def test_classifier_model_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
regressor = SparkXGBClassifier(**self.cls_params)
|
||||
model = regressor.fit(self.cls_df_train)
|
||||
model.save(path)
|
||||
loaded_model = SparkXGBClassifierModel.load(path)
|
||||
self.assertEqual(model.uid, loaded_model.uid)
|
||||
for k, v in self.cls_params.items():
|
||||
self.assertEqual(loaded_model.getOrDefault(k), v)
|
||||
|
||||
pred_result = loaded_model.transform(self.cls_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_probability_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||
SparkXGBRegressorModel.load(path)
|
||||
|
||||
@staticmethod
|
||||
def _get_params_map(params_kv, estimator):
|
||||
return {getattr(estimator, k): v for k, v in params_kv.items()}
|
||||
|
||||
def test_regressor_model_pipeline_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
regressor = SparkXGBRegressor()
|
||||
pipeline = Pipeline(stages=[regressor])
|
||||
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
|
||||
model = pipeline.fit(self.reg_df_train)
|
||||
model.save(path)
|
||||
|
||||
loaded_model = PipelineModel.load(path)
|
||||
for k, v in self.reg_params.items():
|
||||
self.assertEqual(loaded_model.stages[0].getOrDefault(k), v)
|
||||
|
||||
pred_result = loaded_model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_model_pipeline_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
classifier = SparkXGBClassifier()
|
||||
pipeline = Pipeline(stages=[classifier])
|
||||
pipeline = pipeline.copy(
|
||||
extra=self._get_params_map(self.cls_params, classifier)
|
||||
)
|
||||
model = pipeline.fit(self.cls_df_train)
|
||||
model.save(path)
|
||||
|
||||
loaded_model = PipelineModel.load(path)
|
||||
for k, v in self.cls_params.items():
|
||||
self.assertEqual(loaded_model.stages[0].getOrDefault(k), v)
|
||||
|
||||
pred_result = loaded_model.transform(self.cls_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_probability_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_with_cross_validator(self):
|
||||
xgb_classifer = SparkXGBClassifier()
|
||||
paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build()
|
||||
cvBin = CrossValidator(
|
||||
estimator=xgb_classifer,
|
||||
estimatorParamMaps=paramMaps,
|
||||
evaluator=BinaryClassificationEvaluator(),
|
||||
seed=1,
|
||||
)
|
||||
cvBinModel = cvBin.fit(self.cls_df_train_large)
|
||||
cvBinModel.transform(self.cls_df_test)
|
||||
cvMulti = CrossValidator(
|
||||
estimator=xgb_classifer,
|
||||
estimatorParamMaps=paramMaps,
|
||||
evaluator=MulticlassClassificationEvaluator(),
|
||||
seed=1,
|
||||
)
|
||||
cvMultiModel = cvMulti.fit(self.multi_cls_df_train_large)
|
||||
cvMultiModel.transform(self.multi_cls_df_test)
|
||||
|
||||
def test_callbacks(self):
|
||||
from xgboost.callback import LearningRateScheduler
|
||||
|
||||
path = self.get_local_tmp_dir()
|
||||
|
||||
def custom_learning_rate(boosting_round):
|
||||
return 1.0 / (boosting_round + 1)
|
||||
|
||||
cb = [LearningRateScheduler(custom_learning_rate)]
|
||||
regressor = SparkXGBRegressor(callbacks=cb)
|
||||
|
||||
# Test the save/load of the estimator instead of the model, since
|
||||
# the callbacks param only exists in the estimator but not in the model
|
||||
regressor.save(path)
|
||||
regressor = SparkXGBRegressor.load(path)
|
||||
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
pred_result = model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_callbacks, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_train_with_initial_model(self):
|
||||
path = self.get_local_tmp_dir()
|
||||
reg1 = SparkXGBRegressor(**self.reg_params)
|
||||
model = reg1.fit(self.reg_df_train)
|
||||
init_booster = model.get_booster()
|
||||
reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster)
|
||||
model21 = reg2.fit(self.reg_df_train)
|
||||
pred_res21 = model21.transform(self.reg_df_test).collect()
|
||||
reg2.save(path)
|
||||
reg2 = SparkXGBRegressor.load(path)
|
||||
self.assertTrue(reg2.getOrDefault(reg2.xgb_model) is not None)
|
||||
model22 = reg2.fit(self.reg_df_train)
|
||||
pred_res22 = model22.transform(self.reg_df_test).collect()
|
||||
# Test the transform result is the same for original and loaded model
|
||||
for row1, row2 in zip(pred_res21, pred_res22):
|
||||
self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3))
|
||||
|
||||
def test_classifier_with_base_margin(self):
|
||||
cls_without_base_margin = SparkXGBClassifier(weight_col="weight")
|
||||
model_without_base_margin = cls_without_base_margin.fit(
|
||||
self.cls_df_train_without_base_margin
|
||||
)
|
||||
pred_result_without_base_margin = model_without_base_margin.transform(
|
||||
self.cls_df_test_without_base_margin
|
||||
).collect()
|
||||
for row in pred_result_without_base_margin:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction,
|
||||
row.expected_prediction_without_base_margin,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
row.probability, row.expected_prob_without_base_margin, atol=1e-3
|
||||
)
|
||||
|
||||
cls_with_same_base_margin = SparkXGBClassifier(
|
||||
weight_col="weight", base_margin_col="base_margin"
|
||||
)
|
||||
model_with_same_base_margin = cls_with_same_base_margin.fit(
|
||||
self.cls_df_train_with_same_base_margin
|
||||
)
|
||||
pred_result_with_same_base_margin = model_with_same_base_margin.transform(
|
||||
self.cls_df_test_with_same_base_margin
|
||||
).collect()
|
||||
for row in pred_result_with_same_base_margin:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_base_margin, atol=1e-3
|
||||
)
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
row.probability, row.expected_prob_with_base_margin, atol=1e-3
|
||||
)
|
||||
|
||||
cls_with_different_base_margin = SparkXGBClassifier(
|
||||
weight_col="weight", base_margin_col="base_margin"
|
||||
)
|
||||
model_with_different_base_margin = cls_with_different_base_margin.fit(
|
||||
self.cls_df_train_with_different_base_margin
|
||||
)
|
||||
pred_result_with_different_base_margin = (
|
||||
model_with_different_base_margin.transform(
|
||||
self.cls_df_test_with_different_base_margin
|
||||
).collect()
|
||||
)
|
||||
for row in pred_result_with_different_base_margin:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_base_margin, atol=1e-3
|
||||
)
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
row.probability, row.expected_prob_with_base_margin, atol=1e-3
|
||||
)
|
||||
|
||||
def test_regressor_with_weight_eval(self):
|
||||
# with weight
|
||||
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
|
||||
model_with_weight = regressor_with_weight.fit(
|
||||
self.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight = model_with_weight.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_weight:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
||||
)
|
||||
)
|
||||
# with eval
|
||||
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
|
||||
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_eval._xgb_sklearn_model.best_score,
|
||||
self.reg_with_eval_best_score,
|
||||
atol=1e-3,
|
||||
),
|
||||
f"Expected best score: {self.reg_with_eval_best_score}, "
|
||||
f"but get {model_with_eval._xgb_sklearn_model.best_score}",
|
||||
)
|
||||
pred_result_with_eval = model_with_eval.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_eval:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_eval, atol=1e-3
|
||||
),
|
||||
f"Expect prediction is {row.expected_prediction_with_eval},"
|
||||
f"but get {row.prediction}",
|
||||
)
|
||||
# with weight and eval
|
||||
regressor_with_weight_eval = SparkXGBRegressor(
|
||||
weight_col="weight", **self.reg_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = regressor_with_weight_eval.fit(
|
||||
self.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
self.reg_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
for row in pred_result_with_weight_eval:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction,
|
||||
row.expected_prediction_with_weight_and_eval,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_with_weight_eval(self):
|
||||
# with weight
|
||||
classifier_with_weight = SparkXGBClassifier(weight_col="weight")
|
||||
model_with_weight = classifier_with_weight.fit(
|
||||
self.cls_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight = model_with_weight.transform(
|
||||
self.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_weight:
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3)
|
||||
)
|
||||
# with eval
|
||||
classifier_with_eval = SparkXGBClassifier(**self.cls_params_with_eval)
|
||||
model_with_eval = classifier_with_eval.fit(self.cls_df_train_with_eval_weight)
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_eval._xgb_sklearn_model.best_score,
|
||||
self.cls_with_eval_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
pred_result_with_eval = model_with_eval.transform(
|
||||
self.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_eval:
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
|
||||
)
|
||||
# with weight and eval
|
||||
# Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which
|
||||
# doesn't really indicate this working correctly.
|
||||
classifier_with_weight_eval = SparkXGBClassifier(
|
||||
weight_col="weight", scale_pos_weight=4, **self.cls_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = classifier_with_weight_eval.fit(
|
||||
self.cls_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
self.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
self.cls_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
for row in pred_result_with_weight_eval:
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_num_workers_param(self):
|
||||
regressor = SparkXGBRegressor(num_workers=-1)
|
||||
self.assertRaises(ValueError, regressor._validate_params)
|
||||
classifier = SparkXGBClassifier(num_workers=0)
|
||||
self.assertRaises(ValueError, classifier._validate_params)
|
||||
|
||||
def test_use_gpu_param(self):
|
||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact")
|
||||
self.assertRaises(ValueError, classifier._validate_params)
|
||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact")
|
||||
self.assertRaises(ValueError, regressor._validate_params)
|
||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist")
|
||||
regressor = SparkXGBRegressor(use_gpu=True)
|
||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist")
|
||||
classifier = SparkXGBClassifier(use_gpu=True)
|
||||
|
||||
def test_convert_to_sklearn_model(self):
|
||||
classifier = SparkXGBClassifier(
|
||||
n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5
|
||||
)
|
||||
clf_model = classifier.fit(self.cls_df_train)
|
||||
|
||||
regressor = SparkXGBRegressor(
|
||||
n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5
|
||||
)
|
||||
reg_model = regressor.fit(self.reg_df_train)
|
||||
|
||||
# Check that regardless of what booster, _convert_to_model converts to the correct class type
|
||||
sklearn_classifier = classifier._convert_to_sklearn_model(
|
||||
clf_model.get_booster()
|
||||
)
|
||||
assert isinstance(sklearn_classifier, XGBClassifier)
|
||||
assert sklearn_classifier.n_estimators == 200
|
||||
assert sklearn_classifier.missing == 2.0
|
||||
assert sklearn_classifier.max_depth == 3
|
||||
assert sklearn_classifier.get_params()["sketch_eps"] == 0.5
|
||||
|
||||
sklearn_regressor = regressor._convert_to_sklearn_model(reg_model.get_booster())
|
||||
assert isinstance(sklearn_regressor, XGBRegressor)
|
||||
assert sklearn_regressor.n_estimators == 200
|
||||
assert sklearn_regressor.missing == 2.0
|
||||
assert sklearn_regressor.max_depth == 3
|
||||
assert sklearn_classifier.get_params()["sketch_eps"] == 0.5
|
||||
|
||||
def test_feature_importances(self):
|
||||
reg1 = SparkXGBRegressor(**self.reg_params)
|
||||
model = reg1.fit(self.reg_df_train)
|
||||
booster = model.get_booster()
|
||||
self.assertEqual(model.get_feature_importances(), booster.get_score())
|
||||
self.assertEqual(
|
||||
model.get_feature_importances(importance_type="gain"),
|
||||
booster.get_score(importance_type="gain"),
|
||||
)
|
||||
|
||||
def test_regressor_array_col_as_feature(self):
|
||||
train_dataset = self.reg_df_train.withColumn(
|
||||
"features", vector_to_array(spark_sql_func.col("features"))
|
||||
)
|
||||
test_dataset = self.reg_df_test.withColumn(
|
||||
"features", vector_to_array(spark_sql_func.col("features"))
|
||||
)
|
||||
regressor = SparkXGBRegressor()
|
||||
model = regressor.fit(train_dataset)
|
||||
pred_result = model.transform(test_dataset).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
|
||||
)
|
||||
|
||||
def test_classifier_array_col_as_feature(self):
|
||||
train_dataset = self.cls_df_train.withColumn(
|
||||
"features", vector_to_array(spark_sql_func.col("features"))
|
||||
)
|
||||
test_dataset = self.cls_df_test.withColumn(
|
||||
"features", vector_to_array(spark_sql_func.col("features"))
|
||||
)
|
||||
classifier = SparkXGBClassifier()
|
||||
model = classifier.fit(train_dataset)
|
||||
|
||||
pred_result = model.transform(test_dataset).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.expected_prediction)
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
|
||||
)
|
||||
|
||||
def test_classifier_with_feature_names_types_weights(self):
|
||||
classifier = SparkXGBClassifier(
|
||||
feature_names=["a1", "a2", "a3"],
|
||||
feature_types=["i", "int", "float"],
|
||||
feature_weights=[2.0, 5.0, 3.0],
|
||||
)
|
||||
model = classifier.fit(self.cls_df_train)
|
||||
model.transform(self.cls_df_test).collect()
|
||||
450
tests/python/test_spark/test_spark_local_cluster.py
Normal file
450
tests/python/test_spark/test_spark_local_cluster.py
Normal file
@ -0,0 +1,450 @@
|
||||
import sys
|
||||
import random
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from .utils import SparkLocalClusterTestCase
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
from xgboost.spark.utils import _get_max_num_concurrent_tasks
|
||||
from pyspark.ml.linalg import Vectors
|
||||
|
||||
|
||||
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
||||
def setUp(self):
|
||||
random.seed(2020)
|
||||
|
||||
self.n_workers = _get_max_num_concurrent_tasks(self.session.sparkContext)
|
||||
# The following code use xgboost python library to train xgb model and predict.
|
||||
#
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> y = np.array([0, 1])
|
||||
# >>> reg1 = xgboost.XGBRegressor()
|
||||
# >>> reg1.fit(X, y)
|
||||
# >>> reg1.predict(X)
|
||||
# array([8.8363886e-04, 9.9911636e-01], dtype=float32)
|
||||
# >>> def custom_lr(boosting_round, num_boost_round):
|
||||
# ... return 1.0 / (boosting_round + 1)
|
||||
# ...
|
||||
# >>> reg1.fit(X, y, callbacks=[xgboost.callback.reset_learning_rate(custom_lr)])
|
||||
# >>> reg1.predict(X)
|
||||
# array([0.02406833, 0.97593164], dtype=float32)
|
||||
# >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10)
|
||||
# >>> reg2.fit(X, y)
|
||||
# >>> reg2.predict(X, ntree_limit=5)
|
||||
# array([0.22185263, 0.77814734], dtype=float32)
|
||||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
|
||||
self.reg_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
],
|
||||
["features", "label"],
|
||||
)
|
||||
self.reg_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction",
|
||||
"expected_prediction_with_params",
|
||||
"expected_prediction_with_callbacks",
|
||||
],
|
||||
)
|
||||
|
||||
# Distributed section
|
||||
# Binary classification
|
||||
self.cls_df_train_distributed = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1),
|
||||
]
|
||||
* 100,
|
||||
["features", "label"],
|
||||
)
|
||||
self.cls_df_test_distributed = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.9949826, 0.0050174]),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0050174, 0.9949826]),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, [0.9949826, 0.0050174]),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0050174, 0.9949826]),
|
||||
],
|
||||
["features", "expected_label", "expected_probability"],
|
||||
)
|
||||
# Binary classification with different num_estimators
|
||||
self.cls_df_test_distributed_lower_estimators = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.9735, 0.0265]),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0265, 0.9735]),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, [0.9735, 0.0265]),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0265, 0.9735]),
|
||||
],
|
||||
["features", "expected_label", "expected_probability"],
|
||||
)
|
||||
|
||||
# Multiclass classification
|
||||
self.cls_df_train_distributed_multiclass = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2),
|
||||
]
|
||||
* 100,
|
||||
["features", "label"],
|
||||
)
|
||||
self.cls_df_test_distributed_multiclass = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, [4.294563, -2.449409, -2.449409]),
|
||||
(
|
||||
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
|
||||
1,
|
||||
[-2.3796105, 3.669014, -2.449409],
|
||||
),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, [4.294563, -2.449409, -2.449409]),
|
||||
(
|
||||
Vectors.sparse(3, {1: 6.0, 2: 7.5}),
|
||||
2,
|
||||
[-2.3796105, -2.449409, 3.669014],
|
||||
),
|
||||
],
|
||||
["features", "expected_label", "expected_margins"],
|
||||
)
|
||||
|
||||
# Regression
|
||||
self.reg_df_train_distributed = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2),
|
||||
]
|
||||
* 100,
|
||||
["features", "label"],
|
||||
)
|
||||
self.reg_df_test_distributed = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 1.533e-04),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.999e-01),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1.533e-04),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1.999e00),
|
||||
],
|
||||
["features", "expected_label"],
|
||||
)
|
||||
|
||||
# Adding weight and validation
|
||||
self.clf_params_with_eval_dist = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "logloss",
|
||||
}
|
||||
self.clf_params_with_weight_dist = {"weight_col": "weight"}
|
||||
self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
]
|
||||
* 100,
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.cls_df_test_distributed_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
[0.9955, 0.0044],
|
||||
[0.9904, 0.0096],
|
||||
[0.9903, 0.0097],
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prob_with_weight",
|
||||
"expected_prob_with_eval",
|
||||
"expected_prob_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.clf_best_score_eval = 0.009677
|
||||
self.clf_best_score_weight_and_eval = 0.006626
|
||||
|
||||
self.reg_params_with_eval_dist = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "rmse",
|
||||
}
|
||||
self.reg_params_with_weight_dist = {"weight_col": "weight"}
|
||||
self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
]
|
||||
* 100,
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.reg_df_test_distributed_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 4.583e-05, 5.239e-05, 6.03e-05),
|
||||
(
|
||||
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
|
||||
9.9997e-01,
|
||||
9.99947e-01,
|
||||
9.9995e-01,
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction_with_weight",
|
||||
"expected_prediction_with_eval",
|
||||
"expected_prediction_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.reg_best_score_eval = 5.239e-05
|
||||
self.reg_best_score_weight_and_eval = 4.810e-05
|
||||
|
||||
def test_regressor_basic_with_params(self):
|
||||
regressor = SparkXGBRegressor(**self.reg_params)
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
pred_result = model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_callbacks(self):
|
||||
from xgboost.callback import LearningRateScheduler
|
||||
|
||||
path = os.path.join(self.tempdir, str(uuid.uuid4()))
|
||||
|
||||
def custom_learning_rate(boosting_round):
|
||||
return 1.0 / (boosting_round + 1)
|
||||
|
||||
cb = [LearningRateScheduler(custom_learning_rate)]
|
||||
regressor = SparkXGBRegressor(callbacks=cb)
|
||||
|
||||
# Test the save/load of the estimator instead of the model, since
|
||||
# the callbacks param only exists in the estimator but not in the model
|
||||
regressor.save(path)
|
||||
regressor = SparkXGBRegressor.load(path)
|
||||
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
pred_result = model.transform(self.reg_df_test).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_callbacks, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_distributed_basic(self):
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100)
|
||||
model = classifier.fit(self.cls_df_train_distributed)
|
||||
pred_result = model.transform(self.cls_df_test_distributed).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
|
||||
self.assertTrue(
|
||||
np.allclose(row.expected_probability, row.probability, atol=1e-3)
|
||||
)
|
||||
|
||||
def test_classifier_distributed_multiclass(self):
|
||||
# There is no built-in multiclass option for external storage
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100)
|
||||
model = classifier.fit(self.cls_df_train_distributed_multiclass)
|
||||
pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
|
||||
self.assertTrue(
|
||||
np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3)
|
||||
)
|
||||
|
||||
def test_regressor_distributed_basic(self):
|
||||
regressor = SparkXGBRegressor(num_workers=self.n_workers, n_estimators=100)
|
||||
model = regressor.fit(self.reg_df_train_distributed)
|
||||
pred_result = model.transform(self.reg_df_test_distributed).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
|
||||
|
||||
def test_classifier_distributed_weight_eval(self):
|
||||
# with weight
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
**self.clf_params_with_weight_dist
|
||||
)
|
||||
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.cls_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3)
|
||||
)
|
||||
|
||||
# with eval only
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
**self.clf_params_with_eval_dist
|
||||
)
|
||||
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.cls_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
|
||||
)
|
||||
assert np.isclose(
|
||||
float(model.get_booster().attributes()["best_score"]),
|
||||
self.clf_best_score_eval,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
# with both weight and eval
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
**self.clf_params_with_eval_dist,
|
||||
**self.clf_params_with_weight_dist
|
||||
)
|
||||
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.cls_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
|
||||
)
|
||||
)
|
||||
np.isclose(
|
||||
float(model.get_booster().attributes()["best_score"]),
|
||||
self.clf_best_score_weight_and_eval,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
def test_regressor_distributed_weight_eval(self):
|
||||
# with weight
|
||||
regressor = SparkXGBRegressor(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
**self.reg_params_with_weight_dist
|
||||
)
|
||||
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.reg_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
||||
)
|
||||
)
|
||||
# with eval only
|
||||
regressor = SparkXGBRegressor(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
**self.reg_params_with_eval_dist
|
||||
)
|
||||
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.reg_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(row.prediction, row.expected_prediction_with_eval, atol=1e-3)
|
||||
)
|
||||
assert np.isclose(
|
||||
float(model.get_booster().attributes()["best_score"]),
|
||||
self.reg_best_score_eval,
|
||||
rtol=1e-3,
|
||||
)
|
||||
# with both weight and eval
|
||||
regressor = SparkXGBRegressor(
|
||||
num_workers=self.n_workers,
|
||||
n_estimators=100,
|
||||
use_external_storage=False,
|
||||
**self.reg_params_with_eval_dist,
|
||||
**self.reg_params_with_weight_dist
|
||||
)
|
||||
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
|
||||
pred_result = model.transform(
|
||||
self.reg_df_test_distributed_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction,
|
||||
row.expected_prediction_with_weight_and_eval,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
assert np.isclose(
|
||||
float(model.get_booster().attributes()["best_score"]),
|
||||
self.reg_best_score_weight_and_eval,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
def test_num_estimators(self):
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=10)
|
||||
model = classifier.fit(self.cls_df_train_distributed)
|
||||
pred_result = model.transform(
|
||||
self.cls_df_test_distributed_lower_estimators
|
||||
).collect()
|
||||
print(pred_result)
|
||||
for row in pred_result:
|
||||
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
|
||||
self.assertTrue(
|
||||
np.allclose(row.expected_probability, row.probability, atol=1e-3)
|
||||
)
|
||||
|
||||
def test_distributed_params(self):
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers, max_depth=7)
|
||||
model = classifier.fit(self.cls_df_train_distributed)
|
||||
self.assertTrue(hasattr(classifier, "max_depth"))
|
||||
self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7)
|
||||
booster_config = json.loads(model.get_booster().save_config())
|
||||
max_depth = booster_config["learner"]["gradient_booster"]["updater"][
|
||||
"grow_histmaker"
|
||||
]["train_param"]["max_depth"]
|
||||
self.assertEqual(int(max_depth), 7)
|
||||
|
||||
def test_repartition(self):
|
||||
# The following test case has a few partitioned datasets that are either
|
||||
# well partitioned relative to the number of workers that the user wants
|
||||
# or poorly partitioned. We only want to repartition when the dataset
|
||||
# is poorly partitioned so _repartition_needed is true in those instances.
|
||||
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers)
|
||||
basic = self.cls_df_train_distributed
|
||||
self.assertTrue(classifier._repartition_needed(basic))
|
||||
bad_repartitioned = basic.repartition(self.n_workers + 1)
|
||||
self.assertTrue(classifier._repartition_needed(bad_repartitioned))
|
||||
good_repartitioned = basic.repartition(self.n_workers)
|
||||
self.assertFalse(classifier._repartition_needed(good_repartitioned))
|
||||
|
||||
# Now testing if force_repartition returns True regardless of whether the data is well partitioned
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=self.n_workers, force_repartition=True
|
||||
)
|
||||
good_repartitioned = basic.repartition(self.n_workers)
|
||||
self.assertTrue(classifier._repartition_needed(good_repartitioned))
|
||||
148
tests/python/test_spark/utils.py
Normal file
148
tests/python/test_spark/utils.py
Normal file
@ -0,0 +1,148 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from six import StringIO
|
||||
|
||||
import testing as tm
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from pyspark.sql import SQLContext
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
from xgboost.spark.utils import _get_default_params_from_func
|
||||
|
||||
|
||||
class UtilsTest(unittest.TestCase):
|
||||
def test_get_default_params(self):
|
||||
class Foo:
|
||||
def func1(self, x, y, key1=None, key2="val2", key3=0, key4=None):
|
||||
pass
|
||||
|
||||
unsupported_params = {"key2", "key4"}
|
||||
expected_default_params = {
|
||||
"key1": None,
|
||||
"key3": 0,
|
||||
}
|
||||
actual_default_params = _get_default_params_from_func(
|
||||
Foo.func1, unsupported_params
|
||||
)
|
||||
self.assertEqual(
|
||||
len(expected_default_params.keys()), len(actual_default_params.keys())
|
||||
)
|
||||
for k, v in actual_default_params.items():
|
||||
self.assertEqual(expected_default_params[k], v)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_stdout():
|
||||
"""patch stdout and give an output"""
|
||||
sys_stdout = sys.stdout
|
||||
io_out = StringIO()
|
||||
sys.stdout = io_out
|
||||
try:
|
||||
yield io_out
|
||||
finally:
|
||||
sys.stdout = sys_stdout
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_logger(name):
|
||||
"""patch logger and give an output"""
|
||||
io_out = StringIO()
|
||||
log = logging.getLogger(name)
|
||||
handler = logging.StreamHandler(io_out)
|
||||
log.addHandler(handler)
|
||||
try:
|
||||
yield io_out
|
||||
finally:
|
||||
log.removeHandler(handler)
|
||||
|
||||
|
||||
class TestTempDir(object):
|
||||
@classmethod
|
||||
def make_tempdir(cls):
|
||||
"""
|
||||
:param dir: Root directory in which to create the temp directory
|
||||
"""
|
||||
cls.tempdir = tempfile.mkdtemp(prefix="sparkdl_tests")
|
||||
|
||||
@classmethod
|
||||
def remove_tempdir(cls):
|
||||
shutil.rmtree(cls.tempdir)
|
||||
|
||||
|
||||
class TestSparkContext(object):
|
||||
@classmethod
|
||||
def setup_env(cls, spark_config):
|
||||
builder = SparkSession.builder.appName("xgboost spark python API Tests")
|
||||
for k, v in spark_config.items():
|
||||
builder.config(k, v)
|
||||
spark = builder.getOrCreate()
|
||||
logging.getLogger("pyspark").setLevel(logging.INFO)
|
||||
|
||||
cls.sc = spark.sparkContext
|
||||
cls.session = spark
|
||||
|
||||
@classmethod
|
||||
def tear_down_env(cls):
|
||||
cls.session.stop()
|
||||
cls.session = None
|
||||
cls.sc.stop()
|
||||
cls.sc = None
|
||||
|
||||
|
||||
class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.setup_env(
|
||||
{
|
||||
"spark.master": "local[2]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
}
|
||||
)
|
||||
cls.make_tempdir()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.remove_tempdir()
|
||||
cls.tear_down_env()
|
||||
|
||||
|
||||
class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.setup_env(
|
||||
{
|
||||
"spark.master": "local-cluster[2, 2, 1024]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
"spark.cores.max": "4",
|
||||
"spark.task.cpus": "1",
|
||||
"spark.executor.cores": "2",
|
||||
}
|
||||
)
|
||||
cls.make_tempdir()
|
||||
# We run a dummy job so that we block until the workers have connected to the master
|
||||
cls.sc.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.remove_tempdir()
|
||||
cls.tear_down_env()
|
||||
@ -56,6 +56,15 @@ def no_dask():
|
||||
return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"}
|
||||
|
||||
|
||||
def no_spark():
|
||||
try:
|
||||
import pyspark # noqa
|
||||
SPARK_INSTALLED = True
|
||||
except ImportError:
|
||||
SPARK_INSTALLED = False
|
||||
return {"condition": not SPARK_INSTALLED, "reason": "Spark is not installed"}
|
||||
|
||||
|
||||
def no_pandas():
|
||||
return {'condition': not PANDAS_INSTALLED,
|
||||
'reason': 'Pandas is not installed.'}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user