[pyspark][doc] Test and doc for stage-level scheduling. (#9786)

This commit is contained in:
Bobby Wang 2023-11-16 18:15:59 +08:00 committed by GitHub
parent ada377c57e
commit 178cfe70a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 144 additions and 21 deletions

View File

@ -215,6 +215,22 @@ and the prediction for each instance.
Submit the application Submit the application
********************** **********************
Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
Starting from XGBoost 2.1.0, stage-level scheduling is automatically enabled. Therefore,
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
enable running multiple tasks in parallel during the ETL phase. An example configuration
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
using a XGBoost version earlier than 2.1.0 or a Spark standalone cluster version below 3.4.0,
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.
.. note::
As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.
Assuming that the application main class is "Iris" and the application jar is "iris-1.0.0.jar",` Assuming that the application main class is "Iris" and the application jar is "iris-1.0.0.jar",`
provided below is an instance demonstrating how to submit the xgboost application to an Apache provided below is an instance demonstrating how to submit the xgboost application to an Apache
Spark Standalone cluster. Spark Standalone cluster.
@ -230,9 +246,9 @@ Spark Standalone cluster.
--master $master \ --master $master \
--packages com.nvidia:rapids-4-spark_2.12:${rapids_version},ml.dmlc:xgboost4j-gpu_2.12:${xgboost_version},ml.dmlc:xgboost4j-spark-gpu_2.12:${xgboost_version} \ --packages com.nvidia:rapids-4-spark_2.12:${rapids_version},ml.dmlc:xgboost4j-gpu_2.12:${xgboost_version},ml.dmlc:xgboost4j-spark-gpu_2.12:${xgboost_version} \
--conf spark.executor.cores=12 \ --conf spark.executor.cores=12 \
--conf spark.task.cpus=12 \ --conf spark.task.cpus=1 \
--conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=0.08 \
--conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \
--conf spark.rapids.sql.hasNans=false \ --conf spark.rapids.sql.hasNans=false \
--conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \

View File

@ -22,7 +22,7 @@ from typing import (
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyspark import RDD, SparkContext, cloudpickle from pyspark import RDD, SparkConf, SparkContext, cloudpickle
from pyspark.ml import Estimator, Model from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT from pyspark.ml.linalg import VectorUDT
@ -368,7 +368,10 @@ class _SparkXGBParams(
" on GPU." " on GPU."
) )
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)): if not (
ss.version >= "3.4.0"
and _is_standalone_or_localcluster(sc.getConf())
):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't # We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly # require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount") gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
@ -907,30 +910,27 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
return booster_params, train_call_kwargs_params, dmatrix_kwargs return booster_params, train_call_kwargs_params, dmatrix_kwargs
def _skip_stage_level_scheduling(self) -> bool: def _skip_stage_level_scheduling(self, spark_version: str, conf: SparkConf) -> bool:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
"""Check if stage-level scheduling is not needed, """Check if stage-level scheduling is not needed,
return true to skip stage-level scheduling""" return true to skip stage-level scheduling"""
if self._run_on_gpu(): if self._run_on_gpu():
ss = _get_spark_session() if spark_version < "3.4.0":
sc = ss.sparkContext
if ss.version < "3.4.0":
self.logger.info( self.logger.info(
"Stage-level scheduling in xgboost requires spark version 3.4.0+" "Stage-level scheduling in xgboost requires spark version 3.4.0+"
) )
return True return True
if not _is_standalone_or_localcluster(sc): if not _is_standalone_or_localcluster(conf):
self.logger.info( self.logger.info(
"Stage-level scheduling in xgboost requires spark standalone or " "Stage-level scheduling in xgboost requires spark standalone or "
"local-cluster mode" "local-cluster mode"
) )
return True return True
executor_cores = sc.getConf().get("spark.executor.cores") executor_cores = conf.get("spark.executor.cores")
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") executor_gpus = conf.get("spark.executor.resource.gpu.amount")
if executor_cores is None or executor_gpus is None: if executor_cores is None or executor_gpus is None:
self.logger.info( self.logger.info(
"Stage-level scheduling in xgboost requires spark.executor.cores, " "Stage-level scheduling in xgboost requires spark.executor.cores, "
@ -955,7 +955,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
) )
return True return True
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount") task_gpu_amount = conf.get("spark.task.resource.gpu.amount")
if task_gpu_amount is None: if task_gpu_amount is None:
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set, # The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
@ -975,14 +975,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
"""Try to enable stage-level scheduling""" """Try to enable stage-level scheduling"""
ss = _get_spark_session()
if self._skip_stage_level_scheduling(): conf = ss.sparkContext.getConf()
if self._skip_stage_level_scheduling(ss.version, conf):
return rdd return rdd
ss = _get_spark_session()
# executor_cores will not be None # executor_cores will not be None
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores") executor_cores = conf.get("spark.executor.cores")
assert executor_cores is not None assert executor_cores is not None
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL. # Spark-rapids is a project to leverage GPUs to accelerate spark SQL.

View File

@ -10,7 +10,7 @@ from threading import Thread
from typing import Any, Callable, Dict, Optional, Set, Type from typing import Any, Callable, Dict, Optional, Set, Type
import pyspark import pyspark
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
from pyspark.sql.session import SparkSession from pyspark.sql.session import SparkSession
from xgboost import Booster, XGBModel, collective from xgboost import Booster, XGBModel, collective
@ -129,8 +129,8 @@ def _is_local(spark_context: SparkContext) -> bool:
return spark_context._jsc.sc().isLocal() return spark_context._jsc.sc().isLocal()
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool: def _is_standalone_or_localcluster(conf: SparkConf) -> bool:
master = spark_context.getConf().get("spark.master") master = conf.get("spark.master")
return master is not None and ( return master is not None and (
master.startswith("spark://") or master.startswith("local-cluster") master.startswith("spark://") or master.startswith("local-cluster")
) )

View File

@ -8,6 +8,7 @@ from typing import Generator, Sequence, Type
import numpy as np import numpy as np
import pytest import pytest
from pyspark import SparkConf
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -932,6 +933,113 @@ class TestPySparkLocal:
model_loaded.set_device("cuda") model_loaded.set_device("cuda")
assert model_loaded._run_on_gpu() assert model_loaded._run_on_gpu()
def test_skip_stage_level_scheduling(self) -> None:
conf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "0.08")
)
classifer_on_cpu = SparkXGBClassifier(use_gpu=False)
classifer_on_gpu = SparkXGBClassifier(use_gpu=True)
# the correct configurations should not skip stage-level scheduling
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf)
# spark version < 3.4.0
assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf)
# not run on GPU
assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", conf)
# spark.executor.cores is not set
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "0.08")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# spark.executor.cores=1
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "1")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "0.08")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# spark.executor.resource.gpu.amount is not set
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.task.resource.gpu.amount", "0.08")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# spark.executor.resource.gpu.amount>1
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "2")
.set("spark.task.resource.gpu.amount", "0.08")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# spark.task.resource.gpu.amount is not set
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
)
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# spark.task.resource.gpu.amount=1
badConf = (
SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "1")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# yarn
badConf = (
SparkConf()
.setMaster("yarn")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "1")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
# k8s
badConf = (
SparkConf()
.setMaster("k8s://")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "1")
)
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):
def setUp(self): def setUp(self):