[pyspark] support a list of feature column names (#8117)
This commit is contained in:
parent
bcc8679a05
commit
03cc3b359c
@ -2,7 +2,7 @@
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods
|
||||
from typing import Iterator, Tuple
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -26,6 +26,7 @@ from pyspark.sql.types import (
|
||||
DoubleType,
|
||||
FloatType,
|
||||
IntegerType,
|
||||
IntegralType,
|
||||
LongType,
|
||||
ShortType,
|
||||
)
|
||||
@ -43,7 +44,7 @@ from .model import (
|
||||
SparkXGBReader,
|
||||
SparkXGBWriter,
|
||||
)
|
||||
from .params import HasArbitraryParamsDict, HasBaseMarginCol
|
||||
from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols
|
||||
from .utils import (
|
||||
RabitContext,
|
||||
_get_args_from_message_list,
|
||||
@ -73,14 +74,10 @@ _pyspark_specific_params = [
|
||||
"num_workers",
|
||||
"use_gpu",
|
||||
"feature_names",
|
||||
"features_cols",
|
||||
]
|
||||
|
||||
_non_booster_params = [
|
||||
"missing",
|
||||
"n_estimators",
|
||||
"feature_types",
|
||||
"feature_weights",
|
||||
]
|
||||
_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
|
||||
|
||||
_pyspark_param_alias_map = {
|
||||
"features_col": "featuresCol",
|
||||
@ -126,6 +123,7 @@ class _SparkXGBParams(
|
||||
HasValidationIndicatorCol,
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
HasFeaturesCols,
|
||||
):
|
||||
num_workers = Param(
|
||||
Params._dummy(),
|
||||
@ -240,7 +238,6 @@ class _SparkXGBParams(
|
||||
|
||||
def _validate_params(self):
|
||||
init_model = self.getOrDefault(self.xgb_model)
|
||||
if init_model is not None:
|
||||
if init_model is not None and not isinstance(init_model, Booster):
|
||||
raise ValueError(
|
||||
"The xgb_model param must be set with a `xgboost.core.Booster` "
|
||||
@ -262,6 +259,14 @@ class _SparkXGBParams(
|
||||
"Therefore, that parameter will be ignored."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.features_cols):
|
||||
if not self.getOrDefault(self.use_gpu):
|
||||
raise ValueError("features_cols param requires enabling use_gpu.")
|
||||
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"If features_cols param set, then features_col param is ignored."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
tree_method = self.getParam("tree_method")
|
||||
if (
|
||||
@ -315,6 +320,23 @@ class _SparkXGBParams(
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, features_col_names: list
|
||||
) -> list:
|
||||
"""Values in feature columns must be integral types or float/double types"""
|
||||
feature_cols = []
|
||||
for c in features_col_names:
|
||||
if isinstance(dataset.schema[c].dataType, DoubleType):
|
||||
feature_cols.append(col(c).cast(FloatType()).alias(c))
|
||||
elif isinstance(dataset.schema[c].dataType, (FloatType, IntegralType)):
|
||||
feature_cols.append(col(c))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Values in feature columns must be integral types or float/double types."
|
||||
)
|
||||
return feature_cols
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
||||
features_col_datatype = dataset.schema[features_col_name].dataType
|
||||
features_col = col(features_col_name)
|
||||
@ -373,6 +395,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
|
||||
)
|
||||
if k in _pyspark_param_alias_map:
|
||||
if k == _inverse_pyspark_param_alias_map[
|
||||
self.featuresCol.name
|
||||
] and isinstance(v, list):
|
||||
real_k = self.features_cols.name
|
||||
k = real_k
|
||||
else:
|
||||
real_k = _pyspark_param_alias_map[k]
|
||||
k = real_k
|
||||
|
||||
@ -497,10 +525,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
self._validate_params()
|
||||
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
||||
|
||||
select_cols = [label_col]
|
||||
features_cols_names = None
|
||||
if self.getOrDefault(self.features_cols):
|
||||
features_cols_names = self.getOrDefault(self.features_cols)
|
||||
features_cols = _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, features_cols_names
|
||||
)
|
||||
select_cols.extend(features_cols)
|
||||
else:
|
||||
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
select_cols = [features_array_col, label_col]
|
||||
select_cols.append(features_array_col)
|
||||
|
||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
||||
select_cols.append(
|
||||
@ -569,10 +606,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
context = BarrierTaskContext.get()
|
||||
context.barrier()
|
||||
|
||||
gpu_id = None
|
||||
if use_gpu:
|
||||
booster_params["gpu_id"] = (
|
||||
context.partitionId() if is_local else _get_gpu_id(context)
|
||||
)
|
||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||
booster_params["gpu_id"] = gpu_id
|
||||
|
||||
# max_bin is needed for qdm
|
||||
if (
|
||||
features_cols_names is not None
|
||||
and booster_params.get("max_bin", None) is not None
|
||||
):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
_rabit_args = ""
|
||||
if context.partitionId() == 0:
|
||||
@ -583,9 +627,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
evals_result = {}
|
||||
with RabitContext(_rabit_args, context):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter,
|
||||
None,
|
||||
dmatrix_kwargs,
|
||||
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs
|
||||
)
|
||||
if dvalid is not None:
|
||||
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||
@ -685,6 +727,34 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def _transform(self, dataset):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_feature_col(self, dataset) -> (list, Optional[list]):
|
||||
"""XGBoost model trained with features_cols parameter can also predict
|
||||
vector or array feature type. But first we need to check features_cols
|
||||
and then featuresCol
|
||||
"""
|
||||
|
||||
feature_col_names = self.getOrDefault(self.features_cols)
|
||||
features_col = []
|
||||
if feature_col_names and set(feature_col_names).issubset(set(dataset.columns)):
|
||||
# The model is trained with features_cols and the predicted dataset
|
||||
# also contains all the columns specified by features_cols.
|
||||
features_col = _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, feature_col_names
|
||||
)
|
||||
else:
|
||||
# 1. The model was trained by features_cols, but the dataset doesn't contain
|
||||
# all the columns specified by features_cols, so we need to check if
|
||||
# the dataframe has the featuresCol
|
||||
# 2. The model was trained by featuresCol, and the predicted dataset must contain
|
||||
# featuresCol column.
|
||||
feature_col_names = None
|
||||
features_col.append(
|
||||
_validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
)
|
||||
return features_col, feature_col_names
|
||||
|
||||
|
||||
class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
"""
|
||||
@ -712,11 +782,17 @@ class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
alias.margin
|
||||
)
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
|
||||
@pandas_udf("double")
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = data[alias.margin].to_numpy()
|
||||
else:
|
||||
@ -730,14 +806,10 @@ class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
)
|
||||
yield pd.Series(preds)
|
||||
|
||||
features_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
|
||||
if has_base_margin:
|
||||
pred_col = predict_udf(struct(features_col, base_margin_col))
|
||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_col = predict_udf(struct(features_col))
|
||||
pred_col = predict_udf(struct(*features_col))
|
||||
|
||||
predictionColName = self.getOrDefault(self.predictionCol)
|
||||
|
||||
@ -783,6 +855,8 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
class_probs = softmax(raw_preds, axis=1)
|
||||
return raw_preds, class_probs
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
|
||||
@pandas_udf(
|
||||
"rawPrediction array<double>, prediction double, probability array<double>"
|
||||
)
|
||||
@ -791,7 +865,11 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = stack_series(data[alias.margin])
|
||||
else:
|
||||
@ -817,14 +895,10 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
}
|
||||
)
|
||||
|
||||
features_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
|
||||
if has_base_margin:
|
||||
pred_struct = predict_udf(struct(features_col, base_margin_col))
|
||||
pred_struct = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_struct = predict_udf(struct(features_col))
|
||||
pred_struct = predict_udf(struct(*features_col))
|
||||
|
||||
pred_struct_col = "_prediction_struct"
|
||||
|
||||
|
||||
@ -63,9 +63,9 @@ def cache_partitions(
|
||||
class PartIter(DataIter):
|
||||
"""Iterator for creating Quantile DMatrix from partitions."""
|
||||
|
||||
def __init__(self, data: Dict[str, List], on_device: bool) -> None:
|
||||
def __init__(self, data: Dict[str, List], device_id: Optional[int]) -> None:
|
||||
self._iter = 0
|
||||
self._cuda = on_device
|
||||
self._device_id = device_id
|
||||
self._data = data
|
||||
|
||||
super().__init__()
|
||||
@ -74,9 +74,13 @@ class PartIter(DataIter):
|
||||
if not data:
|
||||
return None
|
||||
|
||||
if self._cuda:
|
||||
if self._device_id is not None:
|
||||
import cudf # pylint: disable=import-error
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
# We must set the device after import cudf, which will change the device id to 0
|
||||
# See https://github.com/rapidsai/cudf/issues/11386
|
||||
cp.cuda.runtime.setDevice(self._device_id)
|
||||
return cudf.DataFrame(data[self._iter])
|
||||
|
||||
return data[self._iter]
|
||||
@ -100,6 +104,7 @@ class PartIter(DataIter):
|
||||
def create_dmatrix_from_partitions(
|
||||
iterator: Iterator[pd.DataFrame],
|
||||
feature_cols: Optional[Sequence[str]],
|
||||
gpu_id: Optional[int],
|
||||
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
||||
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
||||
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
||||
@ -169,7 +174,7 @@ def create_dmatrix_from_partitions(
|
||||
dtrain = make(train_data, kwargs)
|
||||
else:
|
||||
cache_partitions(iterator, append_dqm)
|
||||
it = PartIter(train_data, True)
|
||||
it = PartIter(train_data, gpu_id)
|
||||
dtrain = DeviceQuantileDMatrix(it, **kwargs)
|
||||
|
||||
dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for params."""
|
||||
# pylint: disable=too-few-public-methods
|
||||
from pyspark.ml.param import TypeConverters
|
||||
from pyspark.ml.param.shared import Param, Params
|
||||
|
||||
|
||||
@ -31,3 +32,21 @@ class HasBaseMarginCol(Params):
|
||||
"base_margin_col",
|
||||
"This stores the name for the column of the base margin",
|
||||
)
|
||||
|
||||
|
||||
class HasFeaturesCols(Params):
|
||||
"""
|
||||
Mixin for param featuresCols: a list of feature column names.
|
||||
This parameter is taken effect only when use_gpu is enabled.
|
||||
"""
|
||||
|
||||
features_cols = Param(
|
||||
Params._dummy(),
|
||||
"features_cols",
|
||||
"feature column names.",
|
||||
typeConverter=TypeConverters.toListString,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setDefault(features_cols=[])
|
||||
|
||||
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
||||
"python-package/xgboost/dask.py",
|
||||
"python-package/xgboost/spark",
|
||||
"tests/python/test_spark/test_data.py",
|
||||
"tests/python-gpu/test_spark_with_gpu/test_data.py",
|
||||
"tests/python-gpu/test_gpu_spark/test_data.py",
|
||||
"tests/ci_build/lint_python.py",
|
||||
]
|
||||
):
|
||||
@ -130,9 +130,9 @@ if __name__ == "__main__":
|
||||
"demo/guide-python/cat_in_the_dat.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python/test_spark/test_data.py",
|
||||
"tests/python-gpu/test_gpu_with_dask.py",
|
||||
"tests/python-gpu/test_gpu_with_dask/test_gpu_with_dask.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/python-gpu/test_spark_with_gpu/test_data.py",
|
||||
"tests/python-gpu/test_gpu_spark/test_data.py",
|
||||
"tests/ci_build/lint_python.py",
|
||||
]
|
||||
):
|
||||
|
||||
@ -61,8 +61,8 @@ def pytest_collection_modifyitems(config, items):
|
||||
mgpu_mark = pytest.mark.mgpu
|
||||
for item in items:
|
||||
if item.nodeid.startswith(
|
||||
"python-gpu/test_gpu_with_dask.py"
|
||||
"python-gpu/test_gpu_with_dask/test_gpu_with_dask.py"
|
||||
) or item.nodeid.startswith(
|
||||
"python-gpu/test_spark_with_gpu/test_spark_with_gpu.py"
|
||||
"python-gpu/test_gpu_spark/test_gpu_spark.py"
|
||||
):
|
||||
item.add_marker(mgpu_mark)
|
||||
|
||||
215
tests/python-gpu/test_gpu_spark/test_gpu_spark.py
Normal file
215
tests/python-gpu/test_gpu_spark/test_gpu_spark.py
Normal file
@ -0,0 +1,215 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import sklearn
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
if tm.no_dask()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
from pyspark.sql import SparkSession
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
|
||||
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh"
|
||||
executor_gpu_amount = 4
|
||||
executor_cores = 4
|
||||
num_workers = executor_gpu_amount
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def spark_session_with_gpu():
|
||||
spark_config = {
|
||||
"spark.master": f"local-cluster[1, {executor_gpu_amount}, 1024]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
"spark.cores.max": executor_cores,
|
||||
"spark.task.cpus": "1",
|
||||
"spark.executor.cores": executor_cores,
|
||||
"spark.worker.resource.gpu.amount": executor_gpu_amount,
|
||||
"spark.task.resource.gpu.amount": "1",
|
||||
"spark.executor.resource.gpu.amount": executor_gpu_amount,
|
||||
"spark.worker.resource.gpu.discoveryScript": gpu_discovery_script_path,
|
||||
}
|
||||
builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU")
|
||||
for k, v in spark_config.items():
|
||||
builder.config(k, v)
|
||||
spark = builder.getOrCreate()
|
||||
logging.getLogger("pyspark").setLevel(logging.INFO)
|
||||
# We run a dummy job so that we block until the workers have connected to the master
|
||||
spark.sparkContext.parallelize(
|
||||
range(num_workers), num_workers
|
||||
).barrier().mapPartitions(lambda _: []).collect()
|
||||
yield spark
|
||||
spark.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_iris_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_iris()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, num_workers), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, num_workers), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_iris_dataset_feature_cols(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_iris()
|
||||
train_rows = [
|
||||
(*features.tolist(), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, num_workers),
|
||||
[*data.feature_names, "label"],
|
||||
)
|
||||
test_rows = [
|
||||
(*features.tolist(), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, num_workers),
|
||||
[*data.feature_names, "label"],
|
||||
)
|
||||
return train_df, test_df, data.feature_names
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_diabetes_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_diabetes()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, num_workers), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, num_workers), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_diabetes_dataset_feature_cols(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_diabetes()
|
||||
train_rows = [
|
||||
(*features.tolist(), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, num_workers),
|
||||
[*data.feature_names, "label"],
|
||||
)
|
||||
test_rows = [
|
||||
(*features.tolist(), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, num_workers),
|
||||
[*data.feature_names, "label"],
|
||||
)
|
||||
return train_df, test_df, data.feature_names
|
||||
|
||||
|
||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers)
|
||||
train_df, test_df = spark_iris_dataset
|
||||
model = classifier.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_cols):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
)
|
||||
|
||||
model = classifier.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_cols):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
)
|
||||
grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build()
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
cv = CrossValidator(
|
||||
estimator=classifier, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3
|
||||
)
|
||||
cvModel = cv.fit(train_df)
|
||||
pred_result_df = cvModel.transform(test_df)
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
|
||||
regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers)
|
||||
train_df, test_df = spark_diabetes_dataset
|
||||
model = regressor.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = RegressionEvaluator(metricName="rmse")
|
||||
rmse = evaluator.evaluate(pred_result_df)
|
||||
assert rmse <= 65.0
|
||||
|
||||
|
||||
def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature_cols):
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
|
||||
train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols
|
||||
regressor = SparkXGBRegressor(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
)
|
||||
|
||||
model = regressor.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = RegressionEvaluator(metricName="rmse")
|
||||
rmse = evaluator.evaluate(pred_result_df)
|
||||
assert rmse <= 65.0
|
||||
@ -1,120 +0,0 @@
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
import sklearn
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
if tm.no_dask()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from xgboost.spark import SparkXGBRegressor, SparkXGBClassifier
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def spark_session_with_gpu():
|
||||
spark_config = {
|
||||
"spark.master": "local-cluster[1, 4, 1024]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
"spark.cores.max": "4",
|
||||
"spark.task.cpus": "1",
|
||||
"spark.executor.cores": "4",
|
||||
"spark.worker.resource.gpu.amount": "4",
|
||||
"spark.task.resource.gpu.amount": "1",
|
||||
"spark.executor.resource.gpu.amount": "4",
|
||||
"spark.worker.resource.gpu.discoveryScript": "tests/python-gpu/test_spark_with_gpu/discover_gpu.sh",
|
||||
}
|
||||
builder = SparkSession.builder.appName("xgboost spark python API Tests with GPU")
|
||||
for k, v in spark_config.items():
|
||||
builder.config(k, v)
|
||||
spark = builder.getOrCreate()
|
||||
logging.getLogger("pyspark").setLevel(logging.INFO)
|
||||
# We run a dummy job so that we block until the workers have connected to the master
|
||||
spark.sparkContext.parallelize(range(4), 4).barrier().mapPartitions(
|
||||
lambda _: []
|
||||
).collect()
|
||||
yield spark
|
||||
spark.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_iris_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_iris()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark_diabetes_dataset(spark_session_with_gpu):
|
||||
spark = spark_session_with_gpu
|
||||
data = sklearn.datasets.load_diabetes()
|
||||
train_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[0::2], data.target[0::2])
|
||||
]
|
||||
train_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(train_rows, 4), ["features", "label"]
|
||||
)
|
||||
test_rows = [
|
||||
(Vectors.dense(features), float(label))
|
||||
for features, label in zip(data.data[1::2], data.target[1::2])
|
||||
]
|
||||
test_df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(test_rows, 4), ["features", "label"]
|
||||
)
|
||||
return train_df, test_df
|
||||
|
||||
|
||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
use_gpu=True,
|
||||
num_workers=4,
|
||||
)
|
||||
train_df, test_df = spark_iris_dataset
|
||||
model = classifier.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
|
||||
regressor = SparkXGBRegressor(
|
||||
use_gpu=True,
|
||||
num_workers=4,
|
||||
)
|
||||
train_df, test_df = spark_diabetes_dataset
|
||||
model = regressor.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = RegressionEvaluator(metricName="rmse")
|
||||
rmse = evaluator.evaluate(pred_result_df)
|
||||
assert rmse <= 65.0
|
||||
@ -62,9 +62,11 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
kwargs = {"feature_types": feature_types}
|
||||
if is_dqm:
|
||||
cols = [f"feat-{i}" for i in range(n_features)]
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, kwargs)
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, 0, kwargs)
|
||||
else:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), None, kwargs)
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, kwargs
|
||||
)
|
||||
|
||||
assert valid_Xy is not None
|
||||
assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user