PySpark XGBoost integration (#8020)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
WeichenXu 2022-07-13 13:11:18 +08:00 committed by GitHub
parent 8959622836
commit 176fec8789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 3650 additions and 12 deletions

View File

@ -141,7 +141,7 @@ jobs:
- name: Install Python packages - name: Install Python packages
run: | run: |
python -m pip install wheel setuptools python -m pip install wheel setuptools
python -m pip install pylint cpplint numpy scipy scikit-learn python -m pip install pylint cpplint numpy scipy scikit-learn pyspark pandas cloudpickle
- name: Run lint - name: Run lint
run: | run: |
make lint make lint

View File

@ -92,6 +92,7 @@ jobs:
python-tests-on-macos: python-tests-on-macos:
name: Test XGBoost Python package on ${{ matrix.config.os }} name: Test XGBoost Python package on ${{ matrix.config.os }}
runs-on: ${{ matrix.config.os }} runs-on: ${{ matrix.config.os }}
timeout-minutes: 90
strategy: strategy:
matrix: matrix:
config: config:

View File

@ -351,7 +351,8 @@ if __name__ == '__main__':
'scikit-learn': ['scikit-learn'], 'scikit-learn': ['scikit-learn'],
'dask': ['dask', 'pandas', 'distributed'], 'dask': ['dask', 'pandas', 'distributed'],
'datatable': ['datatable'], 'datatable': ['datatable'],
'plotting': ['graphviz', 'matplotlib'] 'plotting': ['graphviz', 'matplotlib'],
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
}, },
maintainer='Hyunsu Cho', maintainer='Hyunsu Cho',
maintainer_email='chohyu01@cs.washington.edu', maintainer_email='chohyu01@cs.washington.edu',

View File

@ -0,0 +1,22 @@
# type: ignore
"""PySpark XGBoost integration interface
"""
try:
import pyspark
except ImportError as e:
raise ImportError("pyspark package needs to be installed to use this module") from e
from .estimator import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
__all__ = [
"SparkXGBClassifier",
"SparkXGBClassifierModel",
"SparkXGBRegressor",
"SparkXGBRegressorModel",
]

View File

@ -0,0 +1,881 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods
import cloudpickle
import numpy as np
import pandas as pd
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml import Estimator, Model
from pyspark.ml.linalg import VectorUDT
from pyspark.ml.param.shared import (
HasFeaturesCol,
HasLabelCol,
HasWeightCol,
HasPredictionCol,
HasProbabilityCol,
HasRawPredictionCol,
HasValidationIndicatorCol,
)
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql.functions import col, pandas_udf, countDistinct, struct
from pyspark.sql.types import (
ArrayType,
DoubleType,
FloatType,
IntegerType,
LongType,
ShortType,
)
import xgboost
from xgboost import XGBClassifier, XGBRegressor
from xgboost.core import Booster
from xgboost.training import train as worker_train
from .data import (
_convert_partition_data_to_dmatrix,
)
from .model import (
SparkXGBReader,
SparkXGBWriter,
SparkXGBModelReader,
SparkXGBModelWriter,
)
from .utils import (
get_logger, _get_max_num_concurrent_tasks,
_get_default_params_from_func,
get_class_name,
RabitContext,
_get_rabit_args,
_get_args_from_message_list,
_get_spark_session,
)
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
)
# Put pyspark specific params here, they won't be passed to XGBoost.
# like `validationIndicatorCol`, `base_margin_col`
_pyspark_specific_params = [
"featuresCol",
"labelCol",
"weightCol",
"rawPredictionCol",
"predictionCol",
"probabilityCol",
"validationIndicatorCol",
"base_margin_col",
"arbitrary_params_dict",
"force_repartition",
"num_workers",
"use_gpu",
"feature_names",
]
_non_booster_params = [
"missing",
"n_estimators",
"feature_types",
"feature_weights",
]
_pyspark_param_alias_map = {
"features_col": "featuresCol",
"label_col": "labelCol",
"weight_col": "weightCol",
"raw_prediction_ol": "rawPredictionCol",
"prediction_col": "predictionCol",
"probability_col": "probabilityCol",
"validation_indicator_col": "validationIndicatorCol",
}
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
_unsupported_xgb_params = [
"gpu_id", # we have "use_gpu" pyspark param instead.
"enable_categorical", # Use feature_types param to specify categorical feature instead
"use_label_encoder",
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
"nthread", # Ditto
]
_unsupported_fit_params = {
"sample_weight", # Supported by spark param weightCol
# Supported by spark param weightCol # and validationIndicatorCol
"eval_set",
"sample_weight_eval_set",
"base_margin", # Supported by spark param base_margin_col
}
_unsupported_predict_params = {
# for classification, we can use rawPrediction as margin
"output_margin",
"validate_features", # TODO
"base_margin", # Use pyspark base_margin_col param instead.
}
class _SparkXGBParams(
HasFeaturesCol,
HasLabelCol,
HasWeightCol,
HasPredictionCol,
HasValidationIndicatorCol,
HasArbitraryParamsDict,
HasBaseMarginCol,
):
num_workers = Param(
Params._dummy(),
"num_workers",
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
TypeConverters.toInt,
)
use_gpu = Param(
Params._dummy(),
"use_gpu",
"A boolean variable. Set use_gpu=true if the executors "
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
)
force_repartition = Param(
Params._dummy(),
"force_repartition",
"A boolean variable. Set force_repartition=true if you "
+ "want to force the input dataset to be repartitioned before XGBoost training."
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
+ "to have force_repartition be True.",
)
feature_names = Param(
Params._dummy(), "feature_names", "A list of str to specify feature names."
)
@classmethod
def _xgb_cls(cls):
"""
Subclasses should override this method and
returns an xgboost.XGBModel subclass
"""
raise NotImplementedError()
# Parameters for xgboost.XGBModel()
@classmethod
def _get_xgb_params_default(cls):
xgb_model_default = cls._xgb_cls()()
params_dict = xgb_model_default.get_params()
filtered_params_dict = {
k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params
}
return filtered_params_dict
def _set_xgb_params_default(self):
filtered_params_dict = self._get_xgb_params_default()
self._setDefault(**filtered_params_dict)
def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False):
xgb_params = {}
non_xgb_params = (
set(_pyspark_specific_params)
| self._get_fit_params_default().keys()
| self._get_predict_params_default().keys()
)
if not gen_xgb_sklearn_estimator_param:
non_xgb_params |= set(_non_booster_params)
for param in self.extractParamMap():
if param.name not in non_xgb_params:
xgb_params[param.name] = self.getOrDefault(param)
arbitrary_params_dict = self.getOrDefault(
self.getParam("arbitrary_params_dict")
)
xgb_params.update(arbitrary_params_dict)
return xgb_params
# Parameters for xgboost.XGBModel().fit()
@classmethod
def _get_fit_params_default(cls):
fit_params = _get_default_params_from_func(
cls._xgb_cls().fit, _unsupported_fit_params
)
return fit_params
def _set_fit_params_default(self):
filtered_params_dict = self._get_fit_params_default()
self._setDefault(**filtered_params_dict)
def _gen_fit_params_dict(self):
"""
Returns a dict of params for .fit()
"""
fit_params_keys = self._get_fit_params_default().keys()
fit_params = {}
for param in self.extractParamMap():
if param.name in fit_params_keys:
fit_params[param.name] = self.getOrDefault(param)
return fit_params
# Parameters for xgboost.XGBModel().predict()
@classmethod
def _get_predict_params_default(cls):
predict_params = _get_default_params_from_func(
cls._xgb_cls().predict, _unsupported_predict_params
)
return predict_params
def _set_predict_params_default(self):
filtered_params_dict = self._get_predict_params_default()
self._setDefault(**filtered_params_dict)
def _gen_predict_params_dict(self):
"""
Returns a dict of params for .predict()
"""
predict_params_keys = self._get_predict_params_default().keys()
predict_params = {}
for param in self.extractParamMap():
if param.name in predict_params_keys:
predict_params[param.name] = self.getOrDefault(param)
return predict_params
def _validate_params(self):
init_model = self.getOrDefault(self.xgb_model)
if init_model is not None:
if init_model is not None and not isinstance(init_model, Booster):
raise ValueError(
"The xgb_model param must be set with a `xgboost.core.Booster` "
"instance."
)
if self.getOrDefault(self.num_workers) < 1:
raise ValueError(
f"Number of workers was {self.getOrDefault(self.num_workers)}."
f"It cannot be less than 1 [Default is 1]"
)
if (
self.getOrDefault(self.force_repartition)
and self.getOrDefault(self.num_workers) == 1
):
get_logger(self.__class__.__name__).warning(
"You set force_repartition to true when there is no need for a repartition."
"Therefore, that parameter will be ignored."
)
if self.getOrDefault(self.use_gpu):
tree_method = self.getParam("tree_method")
if (
self.getOrDefault(tree_method) is not None
and self.getOrDefault(tree_method) != "gpu_hist"
):
raise ValueError(
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
f"found {self.getOrDefault(tree_method)}."
)
gpu_per_task = (
_get_spark_session()
.sparkContext.getConf()
.get("spark.task.resource.gpu.amount")
)
if not gpu_per_task or int(gpu_per_task) < 1:
raise RuntimeError(
"The spark cluster does not have the necessary GPU"
+ "configuration for the spark task. Therefore, we cannot"
+ "run xgboost training using GPU."
)
if int(gpu_per_task) > 1:
get_logger(self.__class__.__name__).warning(
"You configured %s GPU cores for each spark task, but in "
"XGBoost training, every Spark task will only use one GPU core.",
gpu_per_task
)
def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
features_col_datatype = dataset.schema[features_col_name].dataType
features_col = col(features_col_name)
if isinstance(features_col_datatype, ArrayType):
if not isinstance(
features_col_datatype.elementType,
(DoubleType, FloatType, LongType, IntegerType, ShortType),
):
raise ValueError(
"If feature column is array type, its elements must be number type."
)
features_array_col = features_col.cast(ArrayType(FloatType())).alias("values")
elif isinstance(features_col_datatype, VectorUDT):
features_array_col = vector_to_array(features_col, dtype="float32").alias(
"values"
)
else:
raise ValueError(
"feature column must be array type or `pyspark.ml.linalg.Vector` type, "
"if you want to use multiple numetric columns as features, please use "
"`pyspark.ml.transform.VectorAssembler` to assemble them into a vector "
"type column first."
)
return features_array_col
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self):
super().__init__()
self._set_xgb_params_default()
self._set_fit_params_default()
self._set_predict_params_default()
# Note: The default value for arbitrary_params_dict must always be empty dict.
# For additional settings added into "arbitrary_params_dict" by default,
# they are added in `setParams`.
self._setDefault(
num_workers=1,
use_gpu=False,
force_repartition=False,
feature_names=None,
feature_types=None,
arbitrary_params_dict={},
)
def setParams(self, **kwargs): # pylint: disable=invalid-name
"""
Set params for the estimator.
"""
_extra_params = {}
if "arbitrary_params_dict" in kwargs:
raise ValueError("Invalid param name: 'arbitrary_params_dict'.")
for k, v in kwargs.items():
if k in _inverse_pyspark_param_alias_map:
raise ValueError(
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
)
if k in _pyspark_param_alias_map:
real_k = _pyspark_param_alias_map[k]
if real_k in kwargs:
raise ValueError(
f"You should set only one of param '{k}' and '{real_k}'"
)
k = real_k
if self.hasParam(k):
self._set(**{str(k): v})
else:
if (
k in _unsupported_xgb_params
or k in _unsupported_fit_params
or k in _unsupported_predict_params
):
raise ValueError(f"Unsupported param '{k}'.")
_extra_params[k] = v
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
@classmethod
def _pyspark_model_cls(cls):
"""
Subclasses should override this method and
returns a _SparkXGBModel subclass
"""
raise NotImplementedError()
def _create_pyspark_model(self, xgb_model):
return self._pyspark_model_cls()(xgb_model)
def _convert_to_sklearn_model(self, booster):
xgb_sklearn_params = self._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
sklearn_model = self._xgb_cls()(**xgb_sklearn_params)
sklearn_model._Booster = booster
return sklearn_model
def _query_plan_contains_valid_repartition(self, dataset):
"""
Returns true if the latest element in the logical plan is a valid repartition
The logic plan string format is like:
== Optimized Logical Plan ==
Repartition 4, true
+- LogicalRDD [features#12, label#13L], false
i.e., the top line in the logical plan is the last operation to execute.
so, in this method, we check the first line, if it is a "Repartition" operation,
and the result dataframe has the same partition number with num_workers param,
then it means the dataframe is well repartitioned and we don't need to
repartition the dataframe again.
"""
num_partitions = dataset.rdd.getNumPartitions()
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
dataset._jdf.queryExecution(), "extended"
)
start = query_plan.index("== Optimized Logical Plan ==")
start += len("== Optimized Logical Plan ==") + 1
num_workers = self.getOrDefault(self.num_workers)
if (
query_plan[start : start + len("Repartition")] == "Repartition"
and num_workers == num_partitions
):
return True
return False
def _repartition_needed(self, dataset):
"""
We repartition the dataset if the number of workers is not equal to the number of
partitions. There is also a check to make sure there was "active partitioning"
where either Round Robin or Hash partitioning was actively used before this stage.
"""
if self.getOrDefault(self.force_repartition):
return True
try:
if self._query_plan_contains_valid_repartition(dataset):
return False
except Exception: # pylint: disable=broad-except
pass
return True
def _get_distributed_train_params(self, dataset):
"""
This just gets the configuration params for distributed xgboost
"""
params = self._gen_xgb_params_dict()
fit_params = self._gen_fit_params_dict()
verbose_eval = fit_params.pop("verbose", None)
params.update(fit_params)
params["verbose_eval"] = verbose_eval
classification = self._xgb_cls() == XGBClassifier
num_classes = int(dataset.select(countDistinct("label")).collect()[0][0])
if classification and num_classes == 2:
params["objective"] = "binary:logistic"
elif classification and num_classes > 2:
params["objective"] = "multi:softprob"
params["num_class"] = num_classes
else:
params["objective"] = "reg:squarederror"
# TODO: support "num_parallel_tree" for random forest
params["num_boost_round"] = self.getOrDefault(self.n_estimators)
if self.getOrDefault(self.use_gpu):
params["tree_method"] = "gpu_hist"
return params
@classmethod
def _get_xgb_train_call_args(cls, train_params):
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
booster_params, kwargs_params = {}, {}
for key, value in train_params.items():
if key in xgb_train_default_args:
kwargs_params[key] = value
else:
booster_params[key] = value
return booster_params, kwargs_params
def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
label_col = col(self.getOrDefault(self.labelCol)).alias("label")
features_array_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
select_cols = [features_array_col, label_col]
has_weight = False
has_validation = False
has_base_margin = False
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
has_weight = True
select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight"))
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
self.validationIndicatorCol
):
has_validation = True
select_cols.append(
col(self.getOrDefault(self.validationIndicatorCol)).alias(
"validationIndicator"
)
)
if self.isDefined(self.base_margin_col) and self.getOrDefault(
self.base_margin_col
):
has_base_margin = True
select_cols.append(
col(self.getOrDefault(self.base_margin_col)).alias("baseMargin")
)
dataset = dataset.select(*select_cols)
num_workers = self.getOrDefault(self.num_workers)
sc = _get_spark_session().sparkContext
max_concurrent_tasks = _get_max_num_concurrent_tasks(sc)
if num_workers > max_concurrent_tasks:
get_logger(self.__class__.__name__).warning(
"The num_workers %s set for xgboost distributed "
"training is greater than current max number of concurrent "
"spark task slots, you need wait until more task slots available "
"or you need increase spark cluster workers.",
num_workers
)
if self._repartition_needed(dataset):
dataset = dataset.repartition(num_workers)
train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
)
cpu_per_task = int(
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
)
dmatrix_kwargs = {
"nthread": cpu_per_task,
"feature_types": self.getOrDefault(self.feature_types),
"feature_names": self.getOrDefault(self.feature_names),
"feature_weights": self.getOrDefault(self.feature_weights),
"missing": self.getOrDefault(self.missing),
}
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
def _train_booster(pandas_df_iter):
"""
Takes in an RDD partition and outputs a booster for that partition after going through
the Rabit Ring protocol
"""
from pyspark import BarrierTaskContext
context = BarrierTaskContext.get()
context.barrier()
if use_gpu:
# Set booster worker to use the first GPU allocated to the spark task.
booster_params["gpu_id"] = int(
context._resources["gpu"].addresses[0].strip()
)
_rabit_args = ""
if context.partitionId() == 0:
_rabit_args = str(_get_rabit_args(context, num_workers))
messages = context.allGather(message=str(_rabit_args))
_rabit_args = _get_args_from_message_list(messages)
evals_result = {}
with RabitContext(_rabit_args, context):
dtrain, dval = None, []
if has_validation:
dtrain, dval = _convert_partition_data_to_dmatrix(
pandas_df_iter,
has_weight,
has_validation,
has_base_margin,
dmatrix_kwargs=dmatrix_kwargs,
)
# TODO: Question: do we need to add dtrain to dval list ?
dval = [(dtrain, "training"), (dval, "validation")]
else:
dtrain = _convert_partition_data_to_dmatrix(
pandas_df_iter,
has_weight,
has_validation,
has_base_margin,
dmatrix_kwargs=dmatrix_kwargs,
)
booster = worker_train(
params=booster_params,
dtrain=dtrain,
evals=dval,
evals_result=evals_result,
**train_call_kwargs_params,
)
context.barrier()
if context.partitionId() == 0:
yield pd.DataFrame(data={"booster_bytes": [cloudpickle.dumps(booster)]})
result_ser_booster = (
dataset.mapInPandas(_train_booster, schema="booster_bytes binary")
.rdd.barrier()
.mapPartitions(lambda x: x)
.collect()[0][0]
)
result_xgb_model = self._convert_to_sklearn_model(
cloudpickle.loads(result_ser_booster)
)
return self._copyValues(self._create_pyspark_model(result_xgb_model))
def write(self):
"""
Return the writer for saving the estimator.
"""
return SparkXGBWriter(self)
@classmethod
def read(cls):
"""
Return the reader for loading the estimator.
"""
return SparkXGBReader(cls)
class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self, xgb_sklearn_model=None):
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model
def get_booster(self):
"""
Return the `xgboost.core.Booster` instance.
"""
return self._xgb_sklearn_model.get_booster()
def get_feature_importances(self, importance_type="weight"):
"""Get feature importance of each feature.
Importance type can be defined as:
* 'weight': the number of times a feature is used to split the data across all trees.
* 'gain': the average gain across all splits the feature is used in.
* 'cover': the average coverage across all splits the feature is used in.
* 'total_gain': the total gain across all splits the feature is used in.
* 'total_cover': the total coverage across all splits the feature is used in.
.. note:: Feature importance is defined only for tree boosters
Feature importance is only defined when the decision tree model is chosen as base
learner (`booster=gbtree`). It is not defined for other base learner types, such
as linear learners (`booster=gblinear`).
Parameters
----------
importance_type: str, default 'weight'
One of the importance types defined above.
"""
return self.get_booster().get_score(importance_type=importance_type)
def write(self):
"""
Return the writer for saving the model.
"""
return SparkXGBModelWriter(self)
@classmethod
def read(cls):
"""
Return the reader for loading the model.
"""
return SparkXGBModelReader(cls)
def _transform(self, dataset):
raise NotImplementedError()
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRegressor
def _transform(self, dataset):
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
has_base_margin = False
if self.isDefined(self.base_margin_col) and self.getOrDefault(
self.base_margin_col
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
"baseMargin"
)
@pandas_udf("double")
def predict_udf(input_data: pd.DataFrame) -> pd.Series:
X = np.array(input_data["values"].tolist())
if has_base_margin:
base_margin = input_data["baseMargin"].to_numpy()
else:
base_margin = None
preds = xgb_sklearn_model.predict(
X, base_margin=base_margin, validate_features=False, **predict_params
)
return pd.Series(preds)
features_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
if has_base_margin:
pred_col = predict_udf(struct(features_col, base_margin_col))
else:
pred_col = predict_udf(struct(features_col))
predictionColName = self.getOrDefault(self.predictionCol)
return dataset.withColumn(predictionColName, pred_col)
class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol):
"""
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBClassifier
def _transform(self, dataset):
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
has_base_margin = False
if self.isDefined(self.base_margin_col) and self.getOrDefault(
self.base_margin_col
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
"baseMargin"
)
@pandas_udf(
"rawPrediction array<double>, prediction double, probability array<double>"
)
def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame:
X = np.array(input_data["values"].tolist())
if has_base_margin:
base_margin = input_data["baseMargin"].to_numpy()
else:
base_margin = None
margins = xgb_sklearn_model.predict(
X,
base_margin=base_margin,
output_margin=True,
validate_features=False,
**predict_params,
)
if margins.ndim == 1:
# binomial case
classone_probs = expit(margins)
classzero_probs = 1.0 - classone_probs
raw_preds = np.vstack((-margins, margins)).transpose()
class_probs = np.vstack((classzero_probs, classone_probs)).transpose()
else:
# multinomial case
raw_preds = margins
class_probs = softmax(raw_preds, axis=1)
# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1)
return pd.DataFrame(
data={
"rawPrediction": pd.Series(raw_preds.tolist()),
"prediction": pd.Series(preds),
"probability": pd.Series(class_probs.tolist()),
}
)
features_col = _validate_and_convert_feature_col_as_array_col(
dataset, self.getOrDefault(self.featuresCol)
)
if has_base_margin:
pred_struct = predict_udf(struct(features_col, base_margin_col))
else:
pred_struct = predict_udf(struct(features_col))
pred_struct_col = "_prediction_struct"
rawPredictionColName = self.getOrDefault(self.rawPredictionCol)
predictionColName = self.getOrDefault(self.predictionCol)
probabilityColName = self.getOrDefault(self.probabilityCol)
dataset = dataset.withColumn(pred_struct_col, pred_struct)
if rawPredictionColName:
dataset = dataset.withColumn(
rawPredictionColName,
array_to_vector(col(pred_struct_col).rawPrediction),
)
if predictionColName:
dataset = dataset.withColumn(
predictionColName, col(pred_struct_col).prediction
)
if probabilityColName:
dataset = dataset.withColumn(
probabilityColName, array_to_vector(col(pred_struct_col).probability)
)
return dataset.drop(pred_struct_col)
def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
params_dict = pyspark_estimator_class._get_xgb_params_default()
def param_value_converter(v):
if isinstance(v, np.generic):
# convert numpy scalar values to corresponding python scalar values
return np.array(v).item()
if isinstance(v, dict):
return {k: param_value_converter(nv) for k, nv in v.items()}
if isinstance(v, list):
return [param_value_converter(nv) for nv in v]
return v
def set_param_attrs(attr_name, param_obj_):
param_obj_.typeConverter = param_value_converter
setattr(pyspark_estimator_class, attr_name, param_obj_)
setattr(pyspark_model_class, attr_name, param_obj_)
for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
fit_params_dict = pyspark_estimator_class._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".fit() for this param {name}"
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
predict_params_dict = pyspark_estimator_class._get_predict_params_default()
for name in predict_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".predict() for this param {name}"
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)

View File

@ -0,0 +1,192 @@
# type: ignore
"""Xgboost pyspark integration submodule for data related functions."""
# pylint: disable=too-many-arguments
from typing import Iterator
import numpy as np
import pandas as pd
from xgboost import DMatrix
def _prepare_train_val_data(
data_iterator, has_weight, has_validation, has_fit_base_margin
):
def gen_data_pdf():
for pdf in data_iterator:
yield pdf
return _process_data_iter(
gen_data_pdf(),
train=True,
has_weight=has_weight,
has_validation=has_validation,
has_fit_base_margin=has_fit_base_margin,
has_predict_base_margin=False,
)
def _check_feature_dims(num_dims, expected_dims):
"""
Check all feature vectors has the same dimension
"""
if expected_dims is None:
return num_dims
if num_dims != expected_dims:
raise ValueError(
f"Rows contain different feature dimensions: Expecting {expected_dims}, got {num_dims}."
)
return expected_dims
def _row_tuple_list_to_feature_matrix_y_w(
data_iterator,
train,
has_weight,
has_fit_base_margin,
has_predict_base_margin,
has_validation: bool = False,
):
"""
Construct a feature matrix in ndarray format, label array y and weight array w
from the row_tuple_list.
If train == False, y and w will be None.
If has_weight == False, w will be None.
If has_base_margin == False, b_m will be None.
Note: the row_tuple_list will be cleared during
executing for reducing peak memory consumption
"""
# pylint: disable=too-many-locals
expected_feature_dims = None
label_list, weight_list, base_margin_list = [], [], []
label_val_list, weight_val_list, base_margin_val_list = [], [], []
values_list, values_val_list = [], []
# Process rows
for pdf in data_iterator:
if len(pdf) == 0:
continue
if train and has_validation:
pdf_val = pdf.loc[pdf["validationIndicator"], :]
pdf = pdf.loc[~pdf["validationIndicator"], :]
num_feature_dims = len(pdf["values"].values[0])
expected_feature_dims = _check_feature_dims(
num_feature_dims, expected_feature_dims
)
# Note: each element in `pdf["values"]` is an numpy array.
values_list.append(pdf["values"].to_list())
if train:
label_list.append(pdf["label"].to_numpy())
if has_weight:
weight_list.append(pdf["weight"].to_numpy())
if has_fit_base_margin or has_predict_base_margin:
base_margin_list.append(pdf["baseMargin"].to_numpy())
if has_validation:
values_val_list.append(pdf_val["values"].to_list())
if train:
label_val_list.append(pdf_val["label"].to_numpy())
if has_weight:
weight_val_list.append(pdf_val["weight"].to_numpy())
if has_fit_base_margin or has_predict_base_margin:
base_margin_val_list.append(pdf_val["baseMargin"].to_numpy())
# Construct feature_matrix
if expected_feature_dims is None:
return [], [], [], []
# Construct feature_matrix, y and w
feature_matrix = np.concatenate(values_list)
y = np.concatenate(label_list) if train else None
w = np.concatenate(weight_list) if has_weight else None
b_m = (
np.concatenate(base_margin_list)
if (has_fit_base_margin or has_predict_base_margin)
else None
)
if has_validation:
feature_matrix_val = np.concatenate(values_val_list)
y_val = np.concatenate(label_val_list) if train else None
w_val = np.concatenate(weight_val_list) if has_weight else None
b_m_val = (
np.concatenate(base_margin_val_list)
if (has_fit_base_margin or has_predict_base_margin)
else None
)
return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val
return feature_matrix, y, w, b_m
def _process_data_iter(
data_iterator: Iterator[pd.DataFrame],
train: bool,
has_weight: bool,
has_validation: bool,
has_fit_base_margin: bool = False,
has_predict_base_margin: bool = False,
):
"""
If input is for train and has_validation=True, it will split the train data into train dataset
and validation dataset, and return (train_X, train_y, train_w, train_b_m <-
train base margin, val_X, val_y, val_w, val_b_m <- validation base margin)
otherwise return (X, y, w, b_m <- base margin)
"""
return _row_tuple_list_to_feature_matrix_y_w(
data_iterator,
train,
has_weight,
has_fit_base_margin,
has_predict_base_margin,
has_validation,
)
def _convert_partition_data_to_dmatrix(
partition_data_iter,
has_weight,
has_validation,
has_base_margin,
dmatrix_kwargs=None,
):
# pylint: disable=too-many-locals, unbalanced-tuple-unpacking
dmatrix_kwargs = dmatrix_kwargs or {}
# if we are not using external storage, we use the standard method of parsing data.
train_val_data = _prepare_train_val_data(
partition_data_iter, has_weight, has_validation, has_base_margin
)
if has_validation:
(
train_x,
train_y,
train_w,
train_b_m,
val_x,
val_y,
val_w,
val_b_m,
) = train_val_data
training_dmatrix = DMatrix(
data=train_x,
label=train_y,
weight=train_w,
base_margin=train_b_m,
**dmatrix_kwargs,
)
val_dmatrix = DMatrix(
data=val_x,
label=val_y,
weight=val_w,
base_margin=val_b_m,
**dmatrix_kwargs,
)
return training_dmatrix, val_dmatrix
train_x, train_y, train_w, train_b_m = train_val_data
training_dmatrix = DMatrix(
data=train_x,
label=train_y,
weight=train_w,
base_margin=train_b_m,
**dmatrix_kwargs,
)
return training_dmatrix

View File

@ -0,0 +1,203 @@
# type: ignore
"""Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRegressor
from .core import (
_SparkXGBEstimator,
SparkXGBClassifierModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
)
class SparkXGBRegressor(_SparkXGBEstimator):
"""
SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.
SparkXGBRegressor automatically supports most of the parameters in
`xgboost.XGBRegressor` constructor and most of the parameters used in
`xgboost.XGBRegressor` fit and predict method (see `API docs <https://xgboost.readthedocs\
.io/en/latest/python/python_api.html#xgboost.XGBRegressor>`_ for details).
SparkXGBRegressor doesn't support setting `gpu_id` but support another param `use_gpu`,
see doc below for more details.
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support
another param called `base_margin_col`. see doc below for more details.
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
callbacks:
The export and import of the callback functions are at best effort.
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
validationIndicatorCol
For params related to `xgboost.XGBRegressor` training
with evaluation dataset's supervision, set
:py:attr:`xgboost.spark.SparkXGBRegressor.validationIndicatorCol`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor`
fit method.
weightCol:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBRegressor.weightCol` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor`
fit method.
xgb_model:
Set the value to be the instance returned by
:func:`xgboost.spark.SparkXGBRegressorModel.get_booster`.
num_workers:
Integer that specifies the number of XGBoost workers to use.
Each XGBoost worker corresponds to one spark task.
use_gpu:
Boolean that specifies whether the executors are running on GPU
instances.
base_margin_col:
To specify the base margins of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBRegressor.base_margin_col` parameter
instead of setting `base_margin` and `base_margin_eval_set` in the
`xgboost.XGBRegressor` fit method. Note: this isn't available for distributed
training.
.. Note:: The Parameters chart above contains parameters that need special handling.
For a full list of parameters, see entries with `Param(parent=...` below.
.. Note:: This API is experimental.
**Examples**
>>> from xgboost.spark import SparkXGBRegressor
>>> from pyspark.ml.linalg import Vectors
>>> df_train = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
... (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
... ], ["features", "label", "isVal", "weight"])
>>> df_test = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), ),
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), )
... ], ["features"])
>>> xgb_regressor = SparkXGBRegressor(max_depth=5, missing=0.0,
... validation_indicator_col='isVal', weight_col='weight',
... early_stopping_rounds=1, eval_metric='rmse')
>>> xgb_reg_model = xgb_regressor.fit(df_train)
>>> xgb_reg_model.transform(df_test)
"""
def __init__(self, **kwargs):
super().__init__()
self.setParams(**kwargs)
@classmethod
def _xgb_cls(cls):
return XGBRegressor
@classmethod
def _pyspark_model_cls(cls):
return SparkXGBRegressorModel
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPredictionCol):
"""
SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost classification
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.
SparkXGBClassifier automatically supports most of the parameters in
`xgboost.XGBClassifier` constructor and most of the parameters used in
`xgboost.XGBClassifier` fit and predict method (see `API docs <https://xgboost.readthedocs\
.io/en/latest/python/python_api.html#xgboost.XGBClassifier>`_ for details).
SparkXGBClassifier doesn't support setting `gpu_id` but support another param `use_gpu`,
see doc below for more details.
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support
another param called `base_margin_col`. see doc below for more details.
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin
from the raw prediction column. See `rawPredictionCol` param doc below for more details.
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
Parameters
----------
callbacks:
The export and import of the callback functions are at best effort. For
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
rawPredictionCol:
The `output_margin=True` is implicitly supported by the
`rawPredictionCol` output column, which is always returned with the predicted margin
values.
validationIndicatorCol:
For params related to `xgboost.XGBClassifier` training with
evaluation dataset's supervision,
set :py:attr:`xgboost.spark.SparkXGBClassifier.validationIndicatorCol`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier`
fit method.
weightCol:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBClassifier.weightCol` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
fit method.
xgb_model:
Set the value to be the instance returned by
:func:`xgboost.spark.SparkXGBClassifierModel.get_booster`.
num_workers:
Integer that specifies the number of XGBoost workers to use.
Each XGBoost worker corresponds to one spark task.
use_gpu:
Boolean that specifies whether the executors are running on GPU
instances.
base_margin_col:
To specify the base margins of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBClassifier.base_margin_col` parameter
instead of setting `base_margin` and `base_margin_eval_set` in the
`xgboost.XGBClassifier` fit method. Note: this isn't available for distributed
training.
.. Note:: The Parameters chart above contains parameters that need special handling.
For a full list of parameters, see entries with `Param(parent=...` below.
.. Note:: This API is experimental.
**Examples**
>>> from xgboost.spark import SparkXGBClassifier
>>> from pyspark.ml.linalg import Vectors
>>> df_train = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
... ], ["features", "label", "isVal", "weight"])
>>> df_test = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), ),
... ], ["features"])
>>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0,
... validation_indicator_col='isVal', weight_col='weight',
... early_stopping_rounds=1, eval_metric='logloss')
>>> xgb_clf_model = xgb_classifier.fit(df_train)
>>> xgb_clf_model.transform(df_test).show()
"""
def __init__(self, **kwargs):
super().__init__()
self.setParams(**kwargs)
@classmethod
def _xgb_cls(cls):
return XGBClassifier
@classmethod
def _pyspark_model_cls(cls):
return SparkXGBClassifierModel
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)

View File

@ -0,0 +1,270 @@
# type: ignore
"""Xgboost pyspark integration submodule for model API."""
# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
import base64
import os
import uuid
from pyspark import cloudpickle
from pyspark import SparkFiles
from pyspark.sql import SparkSession
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
from xgboost.core import Booster
from .utils import get_logger, get_class_name
def _get_or_create_tmp_dir():
root_dir = SparkFiles.getRootDirectory()
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
if not os.path.exists(xgb_tmp_dir):
os.makedirs(xgb_tmp_dir)
return xgb_tmp_dir
def serialize_xgb_model(model):
"""
Serialize the input model to a string.
Parameters
----------
model:
an xgboost.XGBModel instance, such as
xgboost.XGBClassifier or xgboost.XGBRegressor instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
model.save_model(tmp_file_name)
with open(tmp_file_name, "r", encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_xgb_model(ser_model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
"""
xgb_model = xgb_model_creator()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
xgb_model.load_model(tmp_file_name)
return xgb_model
def serialize_booster(booster):
"""
Serialize the input booster to a string.
Parameters
----------
booster:
an xgboost.core.Booster instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
booster.save_model(tmp_file_name)
with open(tmp_file_name, encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_booster(ser_model_string):
"""
Deserialize an xgboost.core.Booster from the input ser_model_string.
"""
booster = Booster()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
booster.load_model(tmp_file_name)
return booster
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
def _get_spark_session():
return SparkSession.builder.getOrCreate()
class _SparkXGBSharedReadWrite:
@staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v
extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks)
if callbacks is not None:
logger.warning(
"The callbacks parameter is saved using cloudpickle and it "
"is not a fully self-contained format. It may fail to load "
"with different versions of dependencies."
)
serialized_callbacks = base64.encodebytes(
cloudpickle.dumps(callbacks)
).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model)
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
if init_booster is not None:
ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)
@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb
class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
return pyspark_xgb
class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model.json")
ser_xgb_model = serialize_xgb_model(xgb_model)
_get_spark_session().createDataFrame(
[(ser_xgb_model,)], ["xgb_sklearn_model"]
).write.parquet(model_save_path)
class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True)
model_load_path = os.path.join(path, "model.json")
ser_xgb_model = (
_get_spark_session()
.read.parquet(model_load_path)
.collect()[0]
.xgb_sklearn_model
)
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(
ser_xgb_model, create_xgb_model
)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -0,0 +1,33 @@
# type: ignore
"""Xgboost pyspark integration submodule for params."""
# pylint: disable=too-few-public-methods
from pyspark.ml.param.shared import Param, Params
class HasArbitraryParamsDict(Params):
"""
This is a Params based class that is extended by _SparkXGBParams
and holds the variable to store the **kwargs parts of the XGBoost
input.
"""
arbitrary_params_dict = Param(
Params._dummy(),
"arbitrary_params_dict",
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
"not exposed as the the XGBoost Spark estimator params but can be recognized by "
"underlying XGBoost library. It is stored as a dictionary.",
)
class HasBaseMarginCol(Params):
"""
This is a Params based class that is extended by _SparkXGBParams
and holds the variable to store the base margin column part of XGboost.
"""
base_margin_col = Param(
Params._dummy(),
"base_margin_col",
"This stores the name for the column of the base margin",
)

View File

@ -0,0 +1,130 @@
# type: ignore
"""Xgboost pyspark integration submodule for helper functions."""
import inspect
from threading import Thread
import sys
import logging
import pyspark
from pyspark.sql.session import SparkSession
from xgboost import rabit
from xgboost.tracker import RabitTracker
def get_class_name(cls):
"""
Return the class name.
"""
return f"{cls.__module__}.{cls.__name__}"
def _get_default_params_from_func(func, unsupported_set):
"""
Returns a dictionary of parameters and their default value of function fn.
Only the parameters with a default value will be included.
"""
sig = inspect.signature(func)
filtered_params_dict = {}
for parameter in sig.parameters.values():
# Remove parameters without a default value and those in the unsupported_set
if (
parameter.default is not parameter.empty
and parameter.name not in unsupported_set
):
filtered_params_dict[parameter.name] = parameter.default
return filtered_params_dict
class RabitContext:
"""
A context controlling rabit initialization and finalization.
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
"""
def __init__(self, args, context):
self.args = args
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
def __enter__(self):
rabit.init(self.args)
def __exit__(self, *args):
rabit.finalize()
def _start_tracker(context, n_workers):
"""
Start Rabit tracker with n_workers
"""
env = {"DMLC_NUM_WORKER": n_workers}
host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env
def _get_rabit_args(context, n_workers):
"""
Get rabit context arguments to send to each worker.
"""
# pylint: disable=consider-using-f-string
env = _start_tracker(context, n_workers)
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
return rabit_args
def _get_host_ip(context):
"""
Gets the hostIP for Spark. This essentially gets the IP of the first worker.
"""
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
return task_ip_list[0]
def _get_args_from_message_list(messages):
"""
A function to send/recieve messages in barrier context mode
"""
output = ""
for message in messages:
if message != "":
output = message
break
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
def _get_spark_session():
"""Get or create spark session. Note: This function can only be invoked from driver side."""
if pyspark.TaskContext.get() is not None:
# This is a safety check.
raise RuntimeError(
"_get_spark_session should not be invoked from executor side."
)
return SparkSession.builder.getOrCreate()
def get_logger(name, level="INFO"):
"""Gets a logger by name, or creates and configures it for the first time."""
logger = logging.getLogger(name)
logger.setLevel(level)
# If the logger is configured, skip the configure
if not logger.handlers and not logging.getLogger().handlers:
handler = logging.StreamHandler(sys.stderr)
logger.addHandler(handler)
return logger
def _get_max_num_concurrent_tasks(spark_context):
"""Gets the current max number of concurrent tasks."""
# pylint: disable=protected-access
# spark 3.1 and above has a different API for fetching max concurrent tasks
if spark_context._jsc.sc().version() >= "3.1":
return spark_context._jsc.sc().maxNumConcurrentTasks(
spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
)
return spark_context._jsc.sc().maxNumConcurrentTasks()

View File

@ -10,7 +10,7 @@ RUN \
apt-get install -y software-properties-common && \ apt-get install -y software-properties-common && \
add-apt-repository ppa:ubuntu-toolchain-r/test && \ add-apt-repository ppa:ubuntu-toolchain-r/test && \
apt-get update && \ apt-get update && \
apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 && \ apt-get install -y tar unzip wget git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 openjdk-8-jdk-headless && \
# CMake # CMake
wget -nv -nc https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh --no-check-certificate && \ wget -nv -nc https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.14.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ bash cmake-3.14.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
@ -24,6 +24,7 @@ ENV CXX=g++-8
ENV CPP=cpp-8 ENV CPP=cpp-8
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
# Create new Conda environment # Create new Conda environment
COPY conda_env/cpu_test.yml /scripts/ COPY conda_env/cpu_test.yml /scripts/

View File

@ -10,7 +10,7 @@ SHELL ["/bin/bash", "-c"] # Use Bash as shell
RUN \ RUN \
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \
apt-get update && \ apt-get update && \
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \ apt-get install -y wget unzip bzip2 libgomp1 build-essential openjdk-8-jdk-headless && \
# Python # Python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python bash Miniconda3.sh -b -p /opt/python
@ -19,11 +19,14 @@ ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with cuDF, Dask, and cuPy # Create new Conda environment with cuDF, Dask, and cuPy
RUN \ RUN \
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ conda install -c conda-forge mamba && \
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.8 cudf=22.04* rmm=22.04* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda=22.04* dask-cudf=22.04* cupy \ python=3.8 cudf=22.04* rmm=22.04* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda=22.04* dask-cudf=22.04* cupy \
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
pyspark cloudpickle cuda-python=11.7.0
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
# Install lightweight sudo (not bound to TTY) # Install lightweight sudo (not bound to TTY)
RUN set -ex; \ RUN set -ex; \

View File

@ -28,6 +28,8 @@ dependencies:
- llvmlite - llvmlite
- cffi - cffi
- pyarrow - pyarrow
- pyspark
- cloudpickle
- pip: - pip:
- shap - shap
- awscli - awscli

View File

@ -36,6 +36,8 @@ dependencies:
- cffi - cffi
- pyarrow - pyarrow
- protobuf<=3.20 - protobuf<=3.20
- pyspark
- cloudpickle
- pip: - pip:
- shap - shap
- ipython # required by shap at import time. - ipython # required by shap at import time.

View File

@ -35,6 +35,8 @@ dependencies:
- py-ubjson - py-ubjson
- cffi - cffi
- pyarrow - pyarrow
- pyspark
- cloudpickle
- pip: - pip:
- sphinx_rtd_theme - sphinx_rtd_theme
- datatable - datatable

View File

@ -34,6 +34,18 @@ function install_xgboost {
fi fi
} }
function setup_pyspark_envs {
export PYSPARK_DRIVER_PYTHON=`which python`
export PYSPARK_PYTHON=`which python`
export SPARK_TESTING=1
}
function unset_pyspark_envs {
unset PYSPARK_DRIVER_PYTHON
unset PYSPARK_PYTHON
unset SPARK_TESTING
}
function uninstall_xgboost { function uninstall_xgboost {
pip uninstall -y xgboost pip uninstall -y xgboost
} }
@ -43,14 +55,18 @@ case "$suite" in
gpu) gpu)
source activate gpu_test source activate gpu_test
install_xgboost install_xgboost
setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 -m "not mgpu" ${args} tests/python-gpu pytest -v -s -rxXs --fulltrace --durations=0 -m "not mgpu" ${args} tests/python-gpu
unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
;; ;;
mgpu) mgpu)
source activate gpu_test source activate gpu_test
install_xgboost install_xgboost
setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
unset_pyspark_envs
cd tests/distributed cd tests/distributed
./runtests-gpu.sh ./runtests-gpu.sh
@ -61,7 +77,9 @@ case "$suite" in
source activate cpu_test source activate cpu_test
install_xgboost install_xgboost
export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1 export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1
setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
unset_pyspark_envs
cd tests/distributed cd tests/distributed
./runtests.sh ./runtests.sh
uninstall_xgboost uninstall_xgboost
@ -70,7 +88,9 @@ case "$suite" in
cpu-arm64) cpu-arm64)
source activate aarch64_test source activate aarch64_test
install_xgboost install_xgboost
setup_pyspark_envs
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python/test_basic.py tests/python/test_basic_models.py tests/python/test_model_compatibility.py pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python/test_basic.py tests/python/test_basic_models.py tests/python/test_model_compatibility.py
unset_pyspark_envs
uninstall_xgboost uninstall_xgboost
;; ;;

View File

@ -44,13 +44,15 @@ def pytest_addoption(parser):
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
if config.getoption('--use-rmm-pool'): if config.getoption("--use-rmm-pool"):
blocklist = [ blocklist = [
'python-gpu/test_gpu_demos.py::test_dask_training', "python-gpu/test_gpu_demos.py::test_dask_training",
'python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap', "python-gpu/test_gpu_prediction.py::TestGPUPredict::test_shap",
'python-gpu/test_gpu_linear.py::TestGPULinear' "python-gpu/test_gpu_linear.py::TestGPULinear",
] ]
skip_mark = pytest.mark.skip(reason='This test is not run when --use-rmm-pool flag is active') skip_mark = pytest.mark.skip(
reason="This test is not run when --use-rmm-pool flag is active"
)
for item in items: for item in items:
if any(item.nodeid.startswith(x) for x in blocklist): if any(item.nodeid.startswith(x) for x in blocklist):
item.add_marker(skip_mark) item.add_marker(skip_mark)
@ -58,5 +60,9 @@ def pytest_collection_modifyitems(config, items):
# mark dask tests as `mgpu`. # mark dask tests as `mgpu`.
mgpu_mark = pytest.mark.mgpu mgpu_mark = pytest.mark.mgpu
for item in items: for item in items:
if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py"): if item.nodeid.startswith(
"python-gpu/test_gpu_with_dask.py"
) or item.nodeid.startswith(
"python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"
):
item.add_marker(mgpu_mark) item.add_marker(mgpu_mark)

View File

@ -0,0 +1,3 @@
#!/bin/bash
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"

View File

@ -0,0 +1,120 @@
import sys
import logging
import pytest
import sklearn
sys.path.append("tests/python")
import testing as tm
if tm.no_dask()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier
@pytest.fixture(scope="module", autouse=True)
def spark_session_with_gpu():
spark_config = {
"spark.master": "local-cluster[1, 4, 1024]",
"spark.python.worker.reuse": "false",
"spark.driver.host": "127.0.0.1",
"spark.task.maxFailures": "1",
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
"spark.cores.max": "4",
"spark.task.cpus": "1",
"spark.executor.cores": "4",
"spark.worker.resource.gpu.amount": "4",
"spark.task.resource.gpu.amount": "1",
"spark.executor.resource.gpu.amount": "4",
"spark.worker.resource.gpu.discoveryScript": "tests/python-gpu/test_spark_with_gpu/discover_gpu.sh",
}
builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU")
for k, v in spark_config.items():
builder.config(k, v)
spark = builder.getOrCreate()
logging.getLogger("pyspark").setLevel(logging.INFO)
# We run a dummy job so that we block until the workers have connected to the master
spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions(
lambda _: []
).collect()
yield spark
spark.stop()
@pytest.fixture
def spark_iris_dataset(spark_session_with_gpu):
spark = spark_session_with_gpu
data = sklearn.datasets.load_iris()
train_rows = [
(Vectors.dense(features), float(label))
for features, label in zip(data.data[0::2], data.target[0::2])
]
train_df = spark.createDataFrame(
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
)
test_rows = [
(Vectors.dense(features), float(label))
for features, label in zip(data.data[1::2], data.target[1::2])
]
test_df = spark.createDataFrame(
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
)
return train_df, test_df
@pytest.fixture
def spark_diabetes_dataset(spark_session_with_gpu):
spark = spark_session_with_gpu
data = sklearn.datasets.load_diabetes()
train_rows = [
(Vectors.dense(features), float(label))
for features, label in zip(data.data[0::2], data.target[0::2])
]
train_df = spark.createDataFrame(
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
)
test_rows = [
(Vectors.dense(features), float(label))
for features, label in zip(data.data[1::2], data.target[1::2])
]
test_df = spark.createDataFrame(
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
)
return train_df, test_df
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
classifier = SparkXGBClassifier(
use_gpu=True,
num_workers=4,
)
train_df, test_df = spark_iris_dataset
model = classifier.fit(train_df)
pred_result_df = model.transform(test_df)
evaluator = MulticlassClassificationEvaluator(metricName="f1")
f1 = evaluator.evaluate(pred_result_df)
assert f1 >= 0.97
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
from pyspark.ml.evaluation import RegressionEvaluator
regressor = SparkXGBRegressor(
use_gpu=True,
num_workers=4,
)
train_df, test_df = spark_diabetes_dataset
model = regressor.fit(train_df)
pred_result_df = model.transform(test_df)
evaluator = RegressionEvaluator(metricName="rmse")
rmse = evaluator.evaluate(pred_result_df)
assert rmse <= 65.0

View File

View File

@ -0,0 +1,168 @@
import sys
import tempfile
import shutil
import pytest
import numpy as np
import pandas as pd
import testing as tm
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from xgboost.spark.data import (
_row_tuple_list_to_feature_matrix_y_w,
_convert_partition_data_to_dmatrix,
)
from xgboost import DMatrix, XGBClassifier
from xgboost.training import train as worker_train
from .utils import SparkTestCase
import logging
logging.getLogger("py4j").setLevel(logging.INFO)
class DataTest(SparkTestCase):
def test_sparse_dense_vector(self):
def row_tup_iter(data):
pdf = pd.DataFrame(data)
yield pdf
expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]}
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
list(row_tup_iter(data)),
train=False,
has_weight=False,
has_fit_base_margin=False,
has_predict_base_margin=False,
)
self.assertIsNone(y)
self.assertIsNone(w)
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
data["label"] = [1, 0]
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
row_tup_iter(data),
train=True,
has_weight=False,
has_fit_base_margin=False,
has_predict_base_margin=False,
)
self.assertIsNone(w)
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
self.assertTrue(np.array_equal(y, np.array(data["label"])))
data["weight"] = [0.2, 0.8]
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
list(row_tup_iter(data)),
train=True,
has_weight=True,
has_fit_base_margin=False,
has_predict_base_margin=False,
)
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
self.assertTrue(np.array_equal(y, np.array(data["label"])))
self.assertTrue(np.array_equal(w, np.array(data["weight"])))
def test_dmatrix_creator(self):
# This function acts as a pseudo-itertools.chain()
def row_tup_iter(data):
pdf = pd.DataFrame(data)
yield pdf
# Standard testing DMatrix creation
expected_features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
expected_labels = np.array([1, 0] * 100)
expected_dmatrix = DMatrix(data=expected_features, label=expected_labels)
data = {
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
"label": [1, 0] * 100,
}
output_dmatrix = _convert_partition_data_to_dmatrix(
[pd.DataFrame(data)],
has_weight=False,
has_validation=False,
has_base_margin=False,
)
# You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using
# the same classifier and making sure the outputs are equal
model = XGBClassifier()
model.fit(expected_features, expected_labels)
expected_preds = model.get_booster().predict(expected_dmatrix)
output_preds = model.get_booster().predict(output_dmatrix)
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
# DMatrix creation with weights
expected_weight = np.array([0.2, 0.8] * 100)
expected_dmatrix = DMatrix(
data=expected_features, label=expected_labels, weight=expected_weight
)
data["weight"] = [0.2, 0.8] * 100
output_dmatrix = _convert_partition_data_to_dmatrix(
[pd.DataFrame(data)],
has_weight=True,
has_validation=False,
has_base_margin=False,
)
model.fit(expected_features, expected_labels, sample_weight=expected_weight)
expected_preds = model.get_booster().predict(expected_dmatrix)
output_preds = model.get_booster().predict(output_dmatrix)
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
def test_external_storage(self):
# Instantiating base data (features, labels)
features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
labels = np.array([1, 0] * 100)
normal_dmatrix = DMatrix(features, labels)
test_dmatrix = DMatrix(features)
data = {
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
"label": [1, 0] * 100,
}
# Creating the dmatrix based on storage
temporary_path = tempfile.mkdtemp()
storage_dmatrix = _convert_partition_data_to_dmatrix(
[pd.DataFrame(data)],
has_weight=False,
has_validation=False,
has_base_margin=False,
)
# Testing without weights
normal_booster = worker_train({}, normal_dmatrix)
storage_booster = worker_train({}, storage_dmatrix)
normal_preds = normal_booster.predict(test_dmatrix)
storage_preds = storage_booster.predict(test_dmatrix)
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
shutil.rmtree(temporary_path)
# Testing weights
weights = np.array([0.2, 0.8] * 100)
normal_dmatrix = DMatrix(data=features, label=labels, weight=weights)
data["weight"] = [0.2, 0.8] * 100
temporary_path = tempfile.mkdtemp()
storage_dmatrix = _convert_partition_data_to_dmatrix(
[pd.DataFrame(data)],
has_weight=True,
has_validation=False,
has_base_margin=False,
)
normal_booster = worker_train({}, normal_dmatrix)
storage_booster = worker_train({}, storage_dmatrix)
normal_preds = normal_booster.predict(test_dmatrix)
storage_preds = storage_booster.predict(test_dmatrix)
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
shutil.rmtree(temporary_path)

View File

@ -0,0 +1,971 @@
import sys
import logging
import random
import uuid
import numpy as np
import pytest
import testing as tm
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.ml.functions import vector_to_array
from pyspark.sql import functions as spark_sql_func
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
MulticlassClassificationEvaluator,
)
from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from xgboost.spark import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
from .utils import SparkTestCase
from xgboost import XGBClassifier, XGBRegressor
from xgboost.spark.core import _non_booster_params
logging.getLogger("py4j").setLevel(logging.INFO)
class XgboostLocalTest(SparkTestCase):
def setUp(self):
logging.getLogger().setLevel("INFO")
random.seed(2020)
# The following code use xgboost python library to train xgb model and predict.
#
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> y = np.array([0, 1])
# >>> reg1 = xgboost.XGBRegressor()
# >>> reg1.fit(X, y)
# >>> reg1.predict(X)
# array([8.8375784e-04, 9.9911624e-01], dtype=float32)
# >>> def custom_lr(boosting_round):
# ... return 1.0 / (boosting_round + 1)
# ...
# >>> reg1.fit(X, y, callbacks=[xgboost.callback.LearningRateScheduler(custom_lr)])
# >>> reg1.predict(X)
# array([0.02406844, 0.9759315 ], dtype=float32)
# >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10)
# >>> reg2.fit(X, y)
# >>> reg2.predict(X, ntree_limit=5)
# array([0.22185266, 0.77814734], dtype=float32)
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
self.reg_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
],
["features", "label"],
)
self.reg_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759),
],
[
"features",
"expected_prediction",
"expected_prediction_with_params",
"expected_prediction_with_callbacks",
],
)
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> y = np.array([0, 1])
# >>> cl1 = xgboost.XGBClassifier()
# >>> cl1.fit(X, y)
# >>> cl1.predict(X)
# array([0, 0])
# >>> cl1.predict_proba(X)
# array([[0.5, 0.5],
# [0.5, 0.5]], dtype=float32)
# >>> cl2 = xgboost.XGBClassifier(max_depth=5, n_estimators=10, scale_pos_weight=4)
# >>> cl2.fit(X, y)
# >>> cl2.predict(X)
# array([1, 1])
# >>> cl2.predict_proba(X)
# array([[0.27574146, 0.72425854 ],
# [0.27574146, 0.72425854 ]], dtype=float32)
self.cls_params = {"max_depth": 5, "n_estimators": 10, "scale_pos_weight": 4}
cls_df_train_data = [
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
]
self.cls_df_train = self.session.createDataFrame(
cls_df_train_data, ["features", "label"]
)
self.cls_df_train_large = self.session.createDataFrame(
cls_df_train_data * 100, ["features", "label"]
)
self.cls_df_test = self.session.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
0,
[0.5, 0.5],
1,
[0.27574146, 0.72425854],
),
(
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
0,
[0.5, 0.5],
1,
[0.27574146, 0.72425854],
),
],
[
"features",
"expected_prediction",
"expected_probability",
"expected_prediction_with_params",
"expected_probability_with_params",
],
)
# kwargs test (using the above data, train, we get the same results)
self.cls_params_kwargs = {"tree_method": "approx", "sketch_eps": 0.03}
# >>> X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]])
# >>> y = np.array([0, 0, 1, 2])
# >>> cl = xgboost.XGBClassifier()
# >>> cl.fit(X, y)
# >>> cl.predict_proba(np.array([[1.0, 2.0, 3.0]]))
# array([[0.5374299 , 0.23128504, 0.23128504]], dtype=float32)
multi_cls_df_train_data = [
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.dense(1.0, 2.0, 4.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
(Vectors.dense(-1.0, -2.0, 1.0), 2),
]
self.multi_cls_df_train = self.session.createDataFrame(
multi_cls_df_train_data, ["features", "label"]
)
self.multi_cls_df_train_large = self.session.createDataFrame(
multi_cls_df_train_data * 100, ["features", "label"]
)
self.multi_cls_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), [0.5374, 0.2312, 0.2312]),
],
["features", "expected_probability"],
)
# Test regressor with weight and eval set
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
# >>> y = np.array([0, 1, 2, 3])
# >>> reg1 = xgboost.XGBRegressor()
# >>> reg1.fit(X, y, sample_weight=w)
# >>> reg1.predict(X)
# >>> array([1.0679445e-03, 1.0000550e+00, ...
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> y_train = np.array([0, 1])
# >>> y_val = np.array([2, 3])
# >>> w_train = np.array([1.0, 2.0])
# >>> w_val = np.array([1.0, 2.0])
# >>> reg2 = xgboost.XGBRegressor()
# >>> reg2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
# >>> early_stopping_rounds=1, eval_metric='rmse')
# >>> reg2.predict(X)
# >>> array([8.8370638e-04, 9.9911624e-01, ...
# >>> reg2.best_score
# 2.0000002682208837
# >>> reg3 = xgboost.XGBRegressor()
# >>> reg3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
# >>> sample_weight_eval_set=[w_val],
# >>> early_stopping_rounds=1, eval_metric='rmse')
# >>> reg3.predict(X)
# >>> array([0.03155671, 0.98874104,...
# >>> reg3.best_score
# 1.9970891552124017
self.reg_df_train_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
self.reg_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "rmse",
}
self.reg_df_test_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887),
],
[
"features",
"expected_prediction_with_weight",
"expected_prediction_with_eval",
"expected_prediction_with_weight_and_eval",
],
)
self.reg_with_eval_best_score = 2.0
self.reg_with_eval_and_weight_best_score = 1.997
# Test classifier with weight and eval set
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
# >>> y = np.array([0, 1, 0, 1])
# >>> cls1 = xgboost.XGBClassifier()
# >>> cls1.fit(X, y, sample_weight=w)
# >>> cls1.predict_proba(X)
# array([[0.3333333, 0.6666667],...
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> y_train = np.array([0, 1])
# >>> y_val = np.array([0, 1])
# >>> w_train = np.array([1.0, 2.0])
# >>> w_val = np.array([1.0, 2.0])
# >>> cls2 = xgboost.XGBClassifier()
# >>> cls2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
# >>> early_stopping_rounds=1, eval_metric='logloss')
# >>> cls2.predict_proba(X)
# array([[0.5, 0.5],...
# >>> cls2.best_score
# 0.6931
# >>> cls3 = xgboost.XGBClassifier()
# >>> cls3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
# >>> sample_weight_eval_set=[w_val],
# >>> early_stopping_rounds=1, eval_metric='logloss')
# >>> cls3.predict_proba(X)
# array([[0.3344962, 0.6655038],...
# >>> cls3.best_score
# 0.6365
self.cls_df_train_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
self.cls_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "logloss",
}
self.cls_df_test_with_eval_weight = self.session.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
[0.3333, 0.6666],
[0.5, 0.5],
[0.3097, 0.6903],
),
],
[
"features",
"expected_prob_with_weight",
"expected_prob_with_eval",
"expected_prob_with_weight_and_eval",
],
)
self.cls_with_eval_best_score = 0.6931
self.cls_with_eval_and_weight_best_score = 0.6378
# Test classifier with both base margin and without
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
# >>> y = np.array([0, 1, 0, 1])
# >>> base_margin = np.array([1,0,0,1])
#
# This is without the base margin
# >>> cls1 = xgboost.XGBClassifier()
# >>> cls1.fit(X, y, sample_weight=w)
# >>> cls1.predict_proba(np.array([[1.0, 2.0, 3.0]]))
# array([[0.3333333, 0.6666667]], dtype=float32)
# >>> cls1.predict(np.array([[1.0, 2.0, 3.0]]))
# array([1])
#
# This is with the same base margin for predict
# >>> cls2 = xgboost.XGBClassifier()
# >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin)
# >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
# array([[0.44142532, 0.5585747 ]], dtype=float32)
# >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
# array([1])
#
# This is with a different base margin for predict
# # >>> cls2 = xgboost.XGBClassifier()
# >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin)
# >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[1])
# array([[0.2252, 0.7747 ]], dtype=float32)
# >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0])
# array([1])
self.cls_df_train_without_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0),
],
["features", "label", "weight"],
)
self.cls_df_test_without_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1),
],
[
"features",
"expected_prob_without_base_margin",
"expected_prediction_without_base_margin",
],
)
self.cls_df_train_with_same_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1),
],
["features", "label", "weight", "base_margin"],
)
self.cls_df_test_with_same_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1),
],
[
"features",
"base_margin",
"expected_prob_with_base_margin",
"expected_prediction_with_base_margin",
],
)
self.cls_df_train_with_different_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1),
],
["features", "label", "weight", "base_margin"],
)
self.cls_df_test_with_different_base_margin = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1),
],
[
"features",
"base_margin",
"expected_prob_with_base_margin",
"expected_prediction_with_base_margin",
],
)
def get_local_tmp_dir(self):
return self.tempdir + str(uuid.uuid4())
def test_regressor_params_basic(self):
py_reg = SparkXGBRegressor()
self.assertTrue(hasattr(py_reg, "n_estimators"))
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
self.assertFalse(hasattr(py_reg, "gpu_id"))
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
py_reg2 = SparkXGBRegressor(n_estimators=200)
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200)
self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10)
def test_classifier_params_basic(self):
py_cls = SparkXGBClassifier()
self.assertTrue(hasattr(py_cls, "n_estimators"))
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
self.assertFalse(hasattr(py_cls, "gpu_id"))
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
py_cls2 = SparkXGBClassifier(n_estimators=200)
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200)
self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10)
def test_classifier_kwargs_basic(self):
py_cls = SparkXGBClassifier(**self.cls_params_kwargs)
self.assertTrue(hasattr(py_cls, "n_estimators"))
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
self.assertFalse(hasattr(py_cls, "gpu_id"))
self.assertTrue(hasattr(py_cls, "arbitrary_params_dict"))
expected_kwargs = {"sketch_eps": 0.03}
self.assertEqual(
py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs
)
# Testing overwritten params
py_cls = SparkXGBClassifier()
py_cls.setParams(x=1, y=2)
py_cls.setParams(y=3, z=4)
xgb_params = py_cls._gen_xgb_params_dict()
assert xgb_params["x"] == 1
assert xgb_params["y"] == 3
assert xgb_params["z"] == 4
def test_param_alias(self):
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1")
with pytest.raises(
ValueError, match="Please use param name features_col instead"
):
SparkXGBClassifier(featuresCol="f1")
def test_gpu_param_setting(self):
py_cls = SparkXGBClassifier(use_gpu=True)
train_params = py_cls._get_distributed_train_params(self.cls_df_train)
assert train_params["tree_method"] == "gpu_hist"
@staticmethod
def test_param_value_converter():
py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3))
# don't check by isintance(v, float) because for numpy scalar it will also return True
assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float"
assert (
py_cls.getOrDefault(py_cls.arbitrary_params_dict)[
"sketch_eps"
].__class__.__name__
== "float64"
)
def test_regressor_basic(self):
regressor = SparkXGBRegressor()
model = regressor.fit(self.reg_df_train)
pred_result = model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
)
def test_classifier_basic(self):
classifier = SparkXGBClassifier()
model = classifier.fit(self.cls_df_train)
pred_result = model.transform(self.cls_df_test).collect()
for row in pred_result:
self.assertEqual(row.prediction, row.expected_prediction)
self.assertTrue(
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
)
def test_multi_classifier(self):
classifier = SparkXGBClassifier()
model = classifier.fit(self.multi_cls_df_train)
pred_result = model.transform(self.multi_cls_df_test).collect()
for row in pred_result:
self.assertTrue(
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
)
def _check_sub_dict_match(self, sub_dist, whole_dict, excluding_keys):
for k in sub_dist:
if k not in excluding_keys:
self.assertTrue(k in whole_dict, f"check on {k} failed")
self.assertEqual(sub_dist[k], whole_dict[k], f"check on {k} failed")
def test_regressor_with_params(self):
regressor = SparkXGBRegressor(**self.reg_params)
all_params = dict(
**(regressor._gen_xgb_params_dict()),
**(regressor._gen_fit_params_dict()),
**(regressor._gen_predict_params_dict()),
)
self._check_sub_dict_match(
self.reg_params, all_params, excluding_keys=_non_booster_params
)
model = regressor.fit(self.reg_df_train)
all_params = dict(
**(model._gen_xgb_params_dict()),
**(model._gen_fit_params_dict()),
**(model._gen_predict_params_dict()),
)
self._check_sub_dict_match(
self.reg_params, all_params, excluding_keys=_non_booster_params
)
pred_result = model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
def test_classifier_with_params(self):
classifier = SparkXGBClassifier(**self.cls_params)
all_params = dict(
**(classifier._gen_xgb_params_dict()),
**(classifier._gen_fit_params_dict()),
**(classifier._gen_predict_params_dict()),
)
self._check_sub_dict_match(
self.cls_params, all_params, excluding_keys=_non_booster_params
)
model = classifier.fit(self.cls_df_train)
all_params = dict(
**(model._gen_xgb_params_dict()),
**(model._gen_fit_params_dict()),
**(model._gen_predict_params_dict()),
)
self._check_sub_dict_match(
self.cls_params, all_params, excluding_keys=_non_booster_params
)
pred_result = model.transform(self.cls_df_test).collect()
for row in pred_result:
self.assertEqual(row.prediction, row.expected_prediction_with_params)
self.assertTrue(
np.allclose(
row.probability, row.expected_probability_with_params, rtol=1e-3
)
)
def test_regressor_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
regressor = SparkXGBRegressor(**self.reg_params)
model = regressor.fit(self.reg_df_train)
model.save(path)
loaded_model = SparkXGBRegressorModel.load(path)
self.assertEqual(model.uid, loaded_model.uid)
for k, v in self.reg_params.items():
self.assertEqual(loaded_model.getOrDefault(k), v)
pred_result = loaded_model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBClassifierModel.load(path)
def test_classifier_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
regressor = SparkXGBClassifier(**self.cls_params)
model = regressor.fit(self.cls_df_train)
model.save(path)
loaded_model = SparkXGBClassifierModel.load(path)
self.assertEqual(model.uid, loaded_model.uid)
for k, v in self.cls_params.items():
self.assertEqual(loaded_model.getOrDefault(k), v)
pred_result = loaded_model.transform(self.cls_df_test).collect()
for row in pred_result:
self.assertTrue(
np.allclose(
row.probability, row.expected_probability_with_params, atol=1e-3
)
)
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBRegressorModel.load(path)
@staticmethod
def _get_params_map(params_kv, estimator):
return {getattr(estimator, k): v for k, v in params_kv.items()}
def test_regressor_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
regressor = SparkXGBRegressor()
pipeline = Pipeline(stages=[regressor])
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
model = pipeline.fit(self.reg_df_train)
model.save(path)
loaded_model = PipelineModel.load(path)
for k, v in self.reg_params.items():
self.assertEqual(loaded_model.stages[0].getOrDefault(k), v)
pred_result = loaded_model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
def test_classifier_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
classifier = SparkXGBClassifier()
pipeline = Pipeline(stages=[classifier])
pipeline = pipeline.copy(
extra=self._get_params_map(self.cls_params, classifier)
)
model = pipeline.fit(self.cls_df_train)
model.save(path)
loaded_model = PipelineModel.load(path)
for k, v in self.cls_params.items():
self.assertEqual(loaded_model.stages[0].getOrDefault(k), v)
pred_result = loaded_model.transform(self.cls_df_test).collect()
for row in pred_result:
self.assertTrue(
np.allclose(
row.probability, row.expected_probability_with_params, atol=1e-3
)
)
def test_classifier_with_cross_validator(self):
xgb_classifer = SparkXGBClassifier()
paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build()
cvBin = CrossValidator(
estimator=xgb_classifer,
estimatorParamMaps=paramMaps,
evaluator=BinaryClassificationEvaluator(),
seed=1,
)
cvBinModel = cvBin.fit(self.cls_df_train_large)
cvBinModel.transform(self.cls_df_test)
cvMulti = CrossValidator(
estimator=xgb_classifer,
estimatorParamMaps=paramMaps,
evaluator=MulticlassClassificationEvaluator(),
seed=1,
)
cvMultiModel = cvMulti.fit(self.multi_cls_df_train_large)
cvMultiModel.transform(self.multi_cls_df_test)
def test_callbacks(self):
from xgboost.callback import LearningRateScheduler
path = self.get_local_tmp_dir()
def custom_learning_rate(boosting_round):
return 1.0 / (boosting_round + 1)
cb = [LearningRateScheduler(custom_learning_rate)]
regressor = SparkXGBRegressor(callbacks=cb)
# Test the save/load of the estimator instead of the model, since
# the callbacks param only exists in the estimator but not in the model
regressor.save(path)
regressor = SparkXGBRegressor.load(path)
model = regressor.fit(self.reg_df_train)
pred_result = model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_callbacks, atol=1e-3
)
)
def test_train_with_initial_model(self):
path = self.get_local_tmp_dir()
reg1 = SparkXGBRegressor(**self.reg_params)
model = reg1.fit(self.reg_df_train)
init_booster = model.get_booster()
reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster)
model21 = reg2.fit(self.reg_df_train)
pred_res21 = model21.transform(self.reg_df_test).collect()
reg2.save(path)
reg2 = SparkXGBRegressor.load(path)
self.assertTrue(reg2.getOrDefault(reg2.xgb_model) is not None)
model22 = reg2.fit(self.reg_df_train)
pred_res22 = model22.transform(self.reg_df_test).collect()
# Test the transform result is the same for original and loaded model
for row1, row2 in zip(pred_res21, pred_res22):
self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3))
def test_classifier_with_base_margin(self):
cls_without_base_margin = SparkXGBClassifier(weight_col="weight")
model_without_base_margin = cls_without_base_margin.fit(
self.cls_df_train_without_base_margin
)
pred_result_without_base_margin = model_without_base_margin.transform(
self.cls_df_test_without_base_margin
).collect()
for row in pred_result_without_base_margin:
self.assertTrue(
np.isclose(
row.prediction,
row.expected_prediction_without_base_margin,
atol=1e-3,
)
)
np.testing.assert_allclose(
row.probability, row.expected_prob_without_base_margin, atol=1e-3
)
cls_with_same_base_margin = SparkXGBClassifier(
weight_col="weight", base_margin_col="base_margin"
)
model_with_same_base_margin = cls_with_same_base_margin.fit(
self.cls_df_train_with_same_base_margin
)
pred_result_with_same_base_margin = model_with_same_base_margin.transform(
self.cls_df_test_with_same_base_margin
).collect()
for row in pred_result_with_same_base_margin:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_base_margin, atol=1e-3
)
)
np.testing.assert_allclose(
row.probability, row.expected_prob_with_base_margin, atol=1e-3
)
cls_with_different_base_margin = SparkXGBClassifier(
weight_col="weight", base_margin_col="base_margin"
)
model_with_different_base_margin = cls_with_different_base_margin.fit(
self.cls_df_train_with_different_base_margin
)
pred_result_with_different_base_margin = (
model_with_different_base_margin.transform(
self.cls_df_test_with_different_base_margin
).collect()
)
for row in pred_result_with_different_base_margin:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_base_margin, atol=1e-3
)
)
np.testing.assert_allclose(
row.probability, row.expected_prob_with_base_margin, atol=1e-3
)
def test_regressor_with_weight_eval(self):
# with weight
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
model_with_weight = regressor_with_weight.fit(
self.reg_df_train_with_eval_weight
)
pred_result_with_weight = model_with_weight.transform(
self.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_weight:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_weight, atol=1e-3
)
)
# with eval
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
self.assertTrue(
np.isclose(
model_with_eval._xgb_sklearn_model.best_score,
self.reg_with_eval_best_score,
atol=1e-3,
),
f"Expected best score: {self.reg_with_eval_best_score}, "
f"but get {model_with_eval._xgb_sklearn_model.best_score}",
)
pred_result_with_eval = model_with_eval.transform(
self.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_eval:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_eval, atol=1e-3
),
f"Expect prediction is {row.expected_prediction_with_eval},"
f"but get {row.prediction}",
)
# with weight and eval
regressor_with_weight_eval = SparkXGBRegressor(
weight_col="weight", **self.reg_params_with_eval
)
model_with_weight_eval = regressor_with_weight_eval.fit(
self.reg_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
self.reg_df_test_with_eval_weight
).collect()
self.assertTrue(
np.isclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
self.reg_with_eval_and_weight_best_score,
atol=1e-3,
)
)
for row in pred_result_with_weight_eval:
self.assertTrue(
np.isclose(
row.prediction,
row.expected_prediction_with_weight_and_eval,
atol=1e-3,
)
)
def test_classifier_with_weight_eval(self):
# with weight
classifier_with_weight = SparkXGBClassifier(weight_col="weight")
model_with_weight = classifier_with_weight.fit(
self.cls_df_train_with_eval_weight
)
pred_result_with_weight = model_with_weight.transform(
self.cls_df_test_with_eval_weight
).collect()
for row in pred_result_with_weight:
self.assertTrue(
np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3)
)
# with eval
classifier_with_eval = SparkXGBClassifier(**self.cls_params_with_eval)
model_with_eval = classifier_with_eval.fit(self.cls_df_train_with_eval_weight)
self.assertTrue(
np.isclose(
model_with_eval._xgb_sklearn_model.best_score,
self.cls_with_eval_best_score,
atol=1e-3,
)
)
pred_result_with_eval = model_with_eval.transform(
self.cls_df_test_with_eval_weight
).collect()
for row in pred_result_with_eval:
self.assertTrue(
np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
)
# with weight and eval
# Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which
# doesn't really indicate this working correctly.
classifier_with_weight_eval = SparkXGBClassifier(
weight_col="weight", scale_pos_weight=4, **self.cls_params_with_eval
)
model_with_weight_eval = classifier_with_weight_eval.fit(
self.cls_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
self.cls_df_test_with_eval_weight
).collect()
self.assertTrue(
np.isclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
self.cls_with_eval_and_weight_best_score,
atol=1e-3,
)
)
for row in pred_result_with_weight_eval:
self.assertTrue(
np.allclose(
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
)
)
def test_num_workers_param(self):
regressor = SparkXGBRegressor(num_workers=-1)
self.assertRaises(ValueError, regressor._validate_params)
classifier = SparkXGBClassifier(num_workers=0)
self.assertRaises(ValueError, classifier._validate_params)
def test_use_gpu_param(self):
classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact")
self.assertRaises(ValueError, classifier._validate_params)
regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact")
self.assertRaises(ValueError, regressor._validate_params)
regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist")
regressor = SparkXGBRegressor(use_gpu=True)
classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist")
classifier = SparkXGBClassifier(use_gpu=True)
def test_convert_to_sklearn_model(self):
classifier = SparkXGBClassifier(
n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5
)
clf_model = classifier.fit(self.cls_df_train)
regressor = SparkXGBRegressor(
n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5
)
reg_model = regressor.fit(self.reg_df_train)
# Check that regardless of what booster, _convert_to_model converts to the correct class type
sklearn_classifier = classifier._convert_to_sklearn_model(
clf_model.get_booster()
)
assert isinstance(sklearn_classifier, XGBClassifier)
assert sklearn_classifier.n_estimators == 200
assert sklearn_classifier.missing == 2.0
assert sklearn_classifier.max_depth == 3
assert sklearn_classifier.get_params()["sketch_eps"] == 0.5
sklearn_regressor = regressor._convert_to_sklearn_model(reg_model.get_booster())
assert isinstance(sklearn_regressor, XGBRegressor)
assert sklearn_regressor.n_estimators == 200
assert sklearn_regressor.missing == 2.0
assert sklearn_regressor.max_depth == 3
assert sklearn_classifier.get_params()["sketch_eps"] == 0.5
def test_feature_importances(self):
reg1 = SparkXGBRegressor(**self.reg_params)
model = reg1.fit(self.reg_df_train)
booster = model.get_booster()
self.assertEqual(model.get_feature_importances(), booster.get_score())
self.assertEqual(
model.get_feature_importances(importance_type="gain"),
booster.get_score(importance_type="gain"),
)
def test_regressor_array_col_as_feature(self):
train_dataset = self.reg_df_train.withColumn(
"features", vector_to_array(spark_sql_func.col("features"))
)
test_dataset = self.reg_df_test.withColumn(
"features", vector_to_array(spark_sql_func.col("features"))
)
regressor = SparkXGBRegressor()
model = regressor.fit(train_dataset)
pred_result = model.transform(test_dataset).collect()
for row in pred_result:
self.assertTrue(
np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
)
def test_classifier_array_col_as_feature(self):
train_dataset = self.cls_df_train.withColumn(
"features", vector_to_array(spark_sql_func.col("features"))
)
test_dataset = self.cls_df_test.withColumn(
"features", vector_to_array(spark_sql_func.col("features"))
)
classifier = SparkXGBClassifier()
model = classifier.fit(train_dataset)
pred_result = model.transform(test_dataset).collect()
for row in pred_result:
self.assertEqual(row.prediction, row.expected_prediction)
self.assertTrue(
np.allclose(row.probability, row.expected_probability, rtol=1e-3)
)
def test_classifier_with_feature_names_types_weights(self):
classifier = SparkXGBClassifier(
feature_names=["a1", "a2", "a3"],
feature_types=["i", "int", "float"],
feature_weights=[2.0, 5.0, 3.0],
)
model = classifier.fit(self.cls_df_train)
model.transform(self.cls_df_test).collect()

View File

@ -0,0 +1,450 @@
import sys
import random
import json
import uuid
import os
import pytest
import numpy as np
import testing as tm
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from .utils import SparkLocalClusterTestCase
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
from xgboost.spark.utils import _get_max_num_concurrent_tasks
from pyspark.ml.linalg import Vectors
class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
def setUp(self):
random.seed(2020)
self.n_workers = _get_max_num_concurrent_tasks(self.session.sparkContext)
# The following code use xgboost python library to train xgb model and predict.
#
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> y = np.array([0, 1])
# >>> reg1 = xgboost.XGBRegressor()
# >>> reg1.fit(X, y)
# >>> reg1.predict(X)
# array([8.8363886e-04, 9.9911636e-01], dtype=float32)
# >>> def custom_lr(boosting_round, num_boost_round):
# ... return 1.0 / (boosting_round + 1)
# ...
# >>> reg1.fit(X, y, callbacks=[xgboost.callback.reset_learning_rate(custom_lr)])
# >>> reg1.predict(X)
# array([0.02406833, 0.97593164], dtype=float32)
# >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10)
# >>> reg2.fit(X, y)
# >>> reg2.predict(X, ntree_limit=5)
# array([0.22185263, 0.77814734], dtype=float32)
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
self.reg_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
],
["features", "label"],
)
self.reg_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759),
],
[
"features",
"expected_prediction",
"expected_prediction_with_params",
"expected_prediction_with_callbacks",
],
)
# Distributed section
# Binary classification
self.cls_df_train_distributed = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
(Vectors.dense(4.0, 5.0, 6.0), 0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1),
]
* 100,
["features", "label"],
)
self.cls_df_test_distributed = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.9949826, 0.0050174]),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0050174, 0.9949826]),
(Vectors.dense(4.0, 5.0, 6.0), 0, [0.9949826, 0.0050174]),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0050174, 0.9949826]),
],
["features", "expected_label", "expected_probability"],
)
# Binary classification with different num_estimators
self.cls_df_test_distributed_lower_estimators = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, [0.9735, 0.0265]),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, [0.0265, 0.9735]),
(Vectors.dense(4.0, 5.0, 6.0), 0, [0.9735, 0.0265]),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, [0.0265, 0.9735]),
],
["features", "expected_label", "expected_probability"],
)
# Multiclass classification
self.cls_df_train_distributed_multiclass = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
(Vectors.dense(4.0, 5.0, 6.0), 0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2),
]
* 100,
["features", "label"],
)
self.cls_df_test_distributed_multiclass = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, [4.294563, -2.449409, -2.449409]),
(
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
1,
[-2.3796105, 3.669014, -2.449409],
),
(Vectors.dense(4.0, 5.0, 6.0), 0, [4.294563, -2.449409, -2.449409]),
(
Vectors.sparse(3, {1: 6.0, 2: 7.5}),
2,
[-2.3796105, -2.449409, 3.669014],
),
],
["features", "expected_label", "expected_margins"],
)
# Regression
self.reg_df_train_distributed = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1),
(Vectors.dense(4.0, 5.0, 6.0), 0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 2),
]
* 100,
["features", "label"],
)
self.reg_df_test_distributed = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 1.533e-04),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 9.999e-01),
(Vectors.dense(4.0, 5.0, 6.0), 1.533e-04),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1.999e00),
],
["features", "expected_label"],
)
# Adding weight and validation
self.clf_params_with_eval_dist = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "logloss",
}
self.clf_params_with_weight_dist = {"weight_col": "weight"}
self.cls_df_train_distributed_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
]
* 100,
["features", "label", "isVal", "weight"],
)
self.cls_df_test_distributed_with_eval_weight = self.session.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
[0.9955, 0.0044],
[0.9904, 0.0096],
[0.9903, 0.0097],
),
],
[
"features",
"expected_prob_with_weight",
"expected_prob_with_eval",
"expected_prob_with_weight_and_eval",
],
)
self.clf_best_score_eval = 0.009677
self.clf_best_score_weight_and_eval = 0.006626
self.reg_params_with_eval_dist = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "rmse",
}
self.reg_params_with_weight_dist = {"weight_col": "weight"}
self.reg_df_train_distributed_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
]
* 100,
["features", "label", "isVal", "weight"],
)
self.reg_df_test_distributed_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 4.583e-05, 5.239e-05, 6.03e-05),
(
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
9.9997e-01,
9.99947e-01,
9.9995e-01,
),
],
[
"features",
"expected_prediction_with_weight",
"expected_prediction_with_eval",
"expected_prediction_with_weight_and_eval",
],
)
self.reg_best_score_eval = 5.239e-05
self.reg_best_score_weight_and_eval = 4.810e-05
def test_regressor_basic_with_params(self):
regressor = SparkXGBRegressor(**self.reg_params)
model = regressor.fit(self.reg_df_train)
pred_result = model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
def test_callbacks(self):
from xgboost.callback import LearningRateScheduler
path = os.path.join(self.tempdir, str(uuid.uuid4()))
def custom_learning_rate(boosting_round):
return 1.0 / (boosting_round + 1)
cb = [LearningRateScheduler(custom_learning_rate)]
regressor = SparkXGBRegressor(callbacks=cb)
# Test the save/load of the estimator instead of the model, since
# the callbacks param only exists in the estimator but not in the model
regressor.save(path)
regressor = SparkXGBRegressor.load(path)
model = regressor.fit(self.reg_df_train)
pred_result = model.transform(self.reg_df_test).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_callbacks, atol=1e-3
)
)
def test_classifier_distributed_basic(self):
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100)
model = classifier.fit(self.cls_df_train_distributed)
pred_result = model.transform(self.cls_df_test_distributed).collect()
for row in pred_result:
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
self.assertTrue(
np.allclose(row.expected_probability, row.probability, atol=1e-3)
)
def test_classifier_distributed_multiclass(self):
# There is no built-in multiclass option for external storage
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100)
model = classifier.fit(self.cls_df_train_distributed_multiclass)
pred_result = model.transform(self.cls_df_test_distributed_multiclass).collect()
for row in pred_result:
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
self.assertTrue(
np.allclose(row.expected_margins, row.rawPrediction, atol=1e-3)
)
def test_regressor_distributed_basic(self):
regressor = SparkXGBRegressor(num_workers=self.n_workers, n_estimators=100)
model = regressor.fit(self.reg_df_train_distributed)
pred_result = model.transform(self.reg_df_test_distributed).collect()
for row in pred_result:
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
def test_classifier_distributed_weight_eval(self):
# with weight
classifier = SparkXGBClassifier(
num_workers=self.n_workers,
n_estimators=100,
**self.clf_params_with_weight_dist
)
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.cls_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.allclose(row.probability, row.expected_prob_with_weight, atol=1e-3)
)
# with eval only
classifier = SparkXGBClassifier(
num_workers=self.n_workers,
n_estimators=100,
**self.clf_params_with_eval_dist
)
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.cls_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
)
assert np.isclose(
float(model.get_booster().attributes()["best_score"]),
self.clf_best_score_eval,
rtol=1e-3,
)
# with both weight and eval
classifier = SparkXGBClassifier(
num_workers=self.n_workers,
n_estimators=100,
**self.clf_params_with_eval_dist,
**self.clf_params_with_weight_dist
)
model = classifier.fit(self.cls_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.cls_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.allclose(
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
)
)
np.isclose(
float(model.get_booster().attributes()["best_score"]),
self.clf_best_score_weight_and_eval,
rtol=1e-3,
)
def test_regressor_distributed_weight_eval(self):
# with weight
regressor = SparkXGBRegressor(
num_workers=self.n_workers,
n_estimators=100,
**self.reg_params_with_weight_dist
)
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.reg_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_weight, atol=1e-3
)
)
# with eval only
regressor = SparkXGBRegressor(
num_workers=self.n_workers,
n_estimators=100,
**self.reg_params_with_eval_dist
)
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.reg_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.isclose(row.prediction, row.expected_prediction_with_eval, atol=1e-3)
)
assert np.isclose(
float(model.get_booster().attributes()["best_score"]),
self.reg_best_score_eval,
rtol=1e-3,
)
# with both weight and eval
regressor = SparkXGBRegressor(
num_workers=self.n_workers,
n_estimators=100,
use_external_storage=False,
**self.reg_params_with_eval_dist,
**self.reg_params_with_weight_dist
)
model = regressor.fit(self.reg_df_train_distributed_with_eval_weight)
pred_result = model.transform(
self.reg_df_test_distributed_with_eval_weight
).collect()
for row in pred_result:
self.assertTrue(
np.isclose(
row.prediction,
row.expected_prediction_with_weight_and_eval,
atol=1e-3,
)
)
assert np.isclose(
float(model.get_booster().attributes()["best_score"]),
self.reg_best_score_weight_and_eval,
rtol=1e-3,
)
def test_num_estimators(self):
classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=10)
model = classifier.fit(self.cls_df_train_distributed)
pred_result = model.transform(
self.cls_df_test_distributed_lower_estimators
).collect()
print(pred_result)
for row in pred_result:
self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3))
self.assertTrue(
np.allclose(row.expected_probability, row.probability, atol=1e-3)
)
def test_distributed_params(self):
classifier = SparkXGBClassifier(num_workers=self.n_workers, max_depth=7)
model = classifier.fit(self.cls_df_train_distributed)
self.assertTrue(hasattr(classifier, "max_depth"))
self.assertEqual(classifier.getOrDefault(classifier.max_depth), 7)
booster_config = json.loads(model.get_booster().save_config())
max_depth = booster_config["learner"]["gradient_booster"]["updater"][
"grow_histmaker"
]["train_param"]["max_depth"]
self.assertEqual(int(max_depth), 7)
def test_repartition(self):
# The following test case has a few partitioned datasets that are either
# well partitioned relative to the number of workers that the user wants
# or poorly partitioned. We only want to repartition when the dataset
# is poorly partitioned so _repartition_needed is true in those instances.
classifier = SparkXGBClassifier(num_workers=self.n_workers)
basic = self.cls_df_train_distributed
self.assertTrue(classifier._repartition_needed(basic))
bad_repartitioned = basic.repartition(self.n_workers + 1)
self.assertTrue(classifier._repartition_needed(bad_repartitioned))
good_repartitioned = basic.repartition(self.n_workers)
self.assertFalse(classifier._repartition_needed(good_repartitioned))
# Now testing if force_repartition returns True regardless of whether the data is well partitioned
classifier = SparkXGBClassifier(
num_workers=self.n_workers, force_repartition=True
)
good_repartitioned = basic.repartition(self.n_workers)
self.assertTrue(classifier._repartition_needed(good_repartitioned))

View File

@ -0,0 +1,148 @@
import contextlib
import logging
import shutil
import sys
import tempfile
import unittest
import pytest
from six import StringIO
import testing as tm
if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from xgboost.spark.utils import _get_default_params_from_func
class UtilsTest(unittest.TestCase):
def test_get_default_params(self):
class Foo:
def func1(self, x, y, key1=None, key2="val2", key3=0, key4=None):
pass
unsupported_params = {"key2", "key4"}
expected_default_params = {
"key1": None,
"key3": 0,
}
actual_default_params = _get_default_params_from_func(
Foo.func1, unsupported_params
)
self.assertEqual(
len(expected_default_params.keys()), len(actual_default_params.keys())
)
for k, v in actual_default_params.items():
self.assertEqual(expected_default_params[k], v)
@contextlib.contextmanager
def patch_stdout():
"""patch stdout and give an output"""
sys_stdout = sys.stdout
io_out = StringIO()
sys.stdout = io_out
try:
yield io_out
finally:
sys.stdout = sys_stdout
@contextlib.contextmanager
def patch_logger(name):
"""patch logger and give an output"""
io_out = StringIO()
log = logging.getLogger(name)
handler = logging.StreamHandler(io_out)
log.addHandler(handler)
try:
yield io_out
finally:
log.removeHandler(handler)
class TestTempDir(object):
@classmethod
def make_tempdir(cls):
"""
:param dir: Root directory in which to create the temp directory
"""
cls.tempdir = tempfile.mkdtemp(prefix="sparkdl_tests")
@classmethod
def remove_tempdir(cls):
shutil.rmtree(cls.tempdir)
class TestSparkContext(object):
@classmethod
def setup_env(cls, spark_config):
builder = SparkSession.builder.appName("xgboost spark python API Tests")
for k, v in spark_config.items():
builder.config(k, v)
spark = builder.getOrCreate()
logging.getLogger("pyspark").setLevel(logging.INFO)
cls.sc = spark.sparkContext
cls.session = spark
@classmethod
def tear_down_env(cls):
cls.session.stop()
cls.session = None
cls.sc.stop()
cls.sc = None
class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.setup_env(
{
"spark.master": "local[2]",
"spark.python.worker.reuse": "false",
"spark.driver.host": "127.0.0.1",
"spark.task.maxFailures": "1",
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
}
)
cls.make_tempdir()
@classmethod
def tearDownClass(cls):
cls.remove_tempdir()
cls.tear_down_env()
class SparkLocalClusterTestCase(TestSparkContext, TestTempDir, unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.setup_env(
{
"spark.master": "local-cluster[2, 2, 1024]",
"spark.python.worker.reuse": "false",
"spark.driver.host": "127.0.0.1",
"spark.task.maxFailures": "1",
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
"spark.cores.max": "4",
"spark.task.cpus": "1",
"spark.executor.cores": "2",
}
)
cls.make_tempdir()
# We run a dummy job so that we block until the workers have connected to the master
cls.sc.parallelize(range(4), 4).barrier().mapPartitions(lambda _: []).collect()
@classmethod
def tearDownClass(cls):
cls.remove_tempdir()
cls.tear_down_env()

View File

@ -56,6 +56,15 @@ def no_dask():
return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"} return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"}
def no_spark():
try:
import pyspark # noqa
SPARK_INSTALLED = True
except ImportError:
SPARK_INSTALLED = False
return {"condition": not SPARK_INSTALLED, "reason": "Spark is not installed"}
def no_pandas(): def no_pandas():
return {'condition': not PANDAS_INSTALLED, return {'condition': not PANDAS_INSTALLED,
'reason': 'Pandas is not installed.'} 'reason': 'Pandas is not installed.'}