[pyspark] User guide doc and tutorials (#8082)
Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
f801d3cf15
commit
f23cc92130
82
demo/guide-python/spark_estimator_examples.py
Normal file
82
demo/guide-python/spark_estimator_examples.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
'''
|
||||||
|
Collection of examples for using xgboost.spark estimator interface
|
||||||
|
==================================================================
|
||||||
|
|
||||||
|
@author: Weichen Xu
|
||||||
|
'''
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
from pyspark.sql.functions import rand
|
||||||
|
from pyspark.ml.linalg import Vectors
|
||||||
|
import sklearn.datasets
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||||
|
from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
||||||
|
|
||||||
|
|
||||||
|
def create_spark_df(X, y):
|
||||||
|
return spark.createDataFrame(
|
||||||
|
spark.sparkContext.parallelize([
|
||||||
|
(Vectors.dense(features), float(label))
|
||||||
|
for features, label in zip(X, y)
|
||||||
|
]),
|
||||||
|
["features", "label"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# load diabetes dataset (regression dataset)
|
||||||
|
diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True)
|
||||||
|
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \
|
||||||
|
train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True)
|
||||||
|
|
||||||
|
diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train)
|
||||||
|
diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test)
|
||||||
|
|
||||||
|
# train xgboost regressor model
|
||||||
|
xgb_regressor = SparkXGBRegressor(max_depth=5)
|
||||||
|
xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df)
|
||||||
|
|
||||||
|
transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df)
|
||||||
|
regressor_evaluator = RegressionEvaluator(metricName="rmse")
|
||||||
|
print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}")
|
||||||
|
|
||||||
|
diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn(
|
||||||
|
"validationIndicatorCol", rand(1) > 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# train xgboost regressor model with validation dataset
|
||||||
|
xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
||||||
|
xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2)
|
||||||
|
transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df)
|
||||||
|
print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}")
|
||||||
|
|
||||||
|
|
||||||
|
# load iris dataset (classification dataset)
|
||||||
|
iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True)
|
||||||
|
iris_X_train, iris_X_test, iris_y_train, iris_y_test = \
|
||||||
|
train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True)
|
||||||
|
|
||||||
|
iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train)
|
||||||
|
iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test)
|
||||||
|
|
||||||
|
# train xgboost classifier model
|
||||||
|
xgb_classifier = SparkXGBClassifier(max_depth=5)
|
||||||
|
xgb_classifier_model = xgb_classifier.fit(iris_train_spark_df)
|
||||||
|
|
||||||
|
transformed_iris_test_spark_df = xgb_classifier_model.transform(iris_test_spark_df)
|
||||||
|
classifier_evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||||
|
print(f"classifier f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df)}")
|
||||||
|
|
||||||
|
iris_train_spark_df2 = iris_train_spark_df.withColumn(
|
||||||
|
"validationIndicatorCol", rand(1) > 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# train xgboost classifier model with validation dataset
|
||||||
|
xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol")
|
||||||
|
xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2)
|
||||||
|
transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df)
|
||||||
|
print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}")
|
||||||
|
|
||||||
|
spark.stop()
|
||||||
66
doc/tutorials/spark_estimator.rst
Normal file
66
doc/tutorials/spark_estimator.rst
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
###############################
|
||||||
|
Using XGBoost PySpark Estimator
|
||||||
|
###############################
|
||||||
|
Starting from version 2.0, xgboost supports pyspark estimator APIs.
|
||||||
|
The feature is still experimental and not yet ready for production use.
|
||||||
|
|
||||||
|
*****************
|
||||||
|
SparkXGBRegressor
|
||||||
|
*****************
|
||||||
|
|
||||||
|
SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification
|
||||||
|
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
||||||
|
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.
|
||||||
|
|
||||||
|
We can create a `SparkXGBRegressor` estimator like:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from xgboost.spark import SparkXGBRegressor
|
||||||
|
spark_reg_estimator = SparkXGBRegressor(num_workers=2, max_depth=5)
|
||||||
|
|
||||||
|
|
||||||
|
The above snippet create an spark estimator which can fit on a spark dataset,
|
||||||
|
and return a spark model that can transform a spark dataset and generate dataset
|
||||||
|
with prediction column. We can set almost all of xgboost sklearn estimator parameters
|
||||||
|
as `SparkXGBRegressor` parameters, but some parameter such as `nthread` is forbidden
|
||||||
|
in spark estimator, and some parameters are replaced with pyspark specific parameters
|
||||||
|
such as `weight_col`, `validation_indicator_col`, `use_gpu`, for details please see
|
||||||
|
`SparkXGBRegressor` doc.
|
||||||
|
|
||||||
|
The following code snippet shows how to train a spark xgboost regressor model,
|
||||||
|
first we need to prepare a training dataset as a spark dataframe contains
|
||||||
|
"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector`
|
||||||
|
type or spark array type.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe)
|
||||||
|
|
||||||
|
|
||||||
|
The following code snippet shows how to predict test data using a spark xgboost regressor model,
|
||||||
|
first we need to prepare a test dataset as a spark dataframe contains
|
||||||
|
"features" and "label" column, the "features" column must be `pyspark.ml.linalg.Vector`
|
||||||
|
type or spark array type.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe)
|
||||||
|
|
||||||
|
|
||||||
|
The above snippet code returns a `transformed_test_spark_dataframe` that contains the input
|
||||||
|
dataset columns and an appended column "prediction" representing the prediction results.
|
||||||
|
|
||||||
|
|
||||||
|
******************
|
||||||
|
SparkXGBClassifier
|
||||||
|
******************
|
||||||
|
|
||||||
|
|
||||||
|
`SparkXGBClassifier` estimator has similar API with `SparkXGBRegressor`, but it has some
|
||||||
|
pyspark classifier specific params, e.g. `raw_prediction_col` and `probability_col` parameters.
|
||||||
|
Correspondingly, by default, `SparkXGBClassifierModel` transforming test dataset will
|
||||||
|
generate result dataset with 3 new columns:
|
||||||
|
- "prediction": represents the predicted label.
|
||||||
|
- "raw_prediction": represents the output margin values.
|
||||||
|
- "probability": represents the prediction probability on each label.
|
||||||
@ -379,10 +379,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
if k in _pyspark_param_alias_map:
|
if k in _pyspark_param_alias_map:
|
||||||
real_k = _pyspark_param_alias_map[k]
|
real_k = _pyspark_param_alias_map[k]
|
||||||
if real_k in kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
f"You should set only one of param '{k}' and '{real_k}'"
|
|
||||||
)
|
|
||||||
k = real_k
|
k = real_k
|
||||||
|
|
||||||
if self.hasParam(k):
|
if self.hasParam(k):
|
||||||
|
|||||||
@ -31,6 +31,9 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
|||||||
|
|
||||||
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
||||||
|
|
||||||
|
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||||
|
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||||
|
|
||||||
callbacks:
|
callbacks:
|
||||||
The export and import of the callback functions are at best effort.
|
The export and import of the callback functions are at best effort.
|
||||||
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
|
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
|
||||||
@ -128,6 +131,10 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
|
|
||||||
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
||||||
|
|
||||||
|
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||||
|
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||||
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user