[pyspark][doc] Test and doc for stage-level scheduling. (#9786)
This commit is contained in:
parent
ada377c57e
commit
178cfe70a8
@ -215,6 +215,22 @@ and the prediction for each instance.
|
|||||||
Submit the application
|
Submit the application
|
||||||
**********************
|
**********************
|
||||||
|
|
||||||
|
Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
|
||||||
|
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
|
||||||
|
|
||||||
|
Starting from XGBoost 2.1.0, stage-level scheduling is automatically enabled. Therefore,
|
||||||
|
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
|
||||||
|
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
|
||||||
|
enable running multiple tasks in parallel during the ETL phase. An example configuration
|
||||||
|
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
|
||||||
|
using a XGBoost version earlier than 2.1.0 or a Spark standalone cluster version below 3.4.0,
|
||||||
|
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
|
||||||
|
However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.
|
||||||
|
|
||||||
Assuming that the application main class is "Iris" and the application jar is "iris-1.0.0.jar",`
|
Assuming that the application main class is "Iris" and the application jar is "iris-1.0.0.jar",`
|
||||||
provided below is an instance demonstrating how to submit the xgboost application to an Apache
|
provided below is an instance demonstrating how to submit the xgboost application to an Apache
|
||||||
Spark Standalone cluster.
|
Spark Standalone cluster.
|
||||||
@ -230,9 +246,9 @@ Spark Standalone cluster.
|
|||||||
--master $master \
|
--master $master \
|
||||||
--packages com.nvidia:rapids-4-spark_2.12:${rapids_version},ml.dmlc:xgboost4j-gpu_2.12:${xgboost_version},ml.dmlc:xgboost4j-spark-gpu_2.12:${xgboost_version} \
|
--packages com.nvidia:rapids-4-spark_2.12:${rapids_version},ml.dmlc:xgboost4j-gpu_2.12:${xgboost_version},ml.dmlc:xgboost4j-spark-gpu_2.12:${xgboost_version} \
|
||||||
--conf spark.executor.cores=12 \
|
--conf spark.executor.cores=12 \
|
||||||
--conf spark.task.cpus=12 \
|
--conf spark.task.cpus=1 \
|
||||||
--conf spark.executor.resource.gpu.amount=1 \
|
--conf spark.executor.resource.gpu.amount=1 \
|
||||||
--conf spark.task.resource.gpu.amount=1 \
|
--conf spark.task.resource.gpu.amount=0.08 \
|
||||||
--conf spark.rapids.sql.csv.read.double.enabled=true \
|
--conf spark.rapids.sql.csv.read.double.enabled=true \
|
||||||
--conf spark.rapids.sql.hasNans=false \
|
--conf spark.rapids.sql.hasNans=false \
|
||||||
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyspark import RDD, SparkContext, cloudpickle
|
from pyspark import RDD, SparkConf, SparkContext, cloudpickle
|
||||||
from pyspark.ml import Estimator, Model
|
from pyspark.ml import Estimator, Model
|
||||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||||
from pyspark.ml.linalg import VectorUDT
|
from pyspark.ml.linalg import VectorUDT
|
||||||
@ -368,7 +368,10 @@ class _SparkXGBParams(
|
|||||||
" on GPU."
|
" on GPU."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
|
if not (
|
||||||
|
ss.version >= "3.4.0"
|
||||||
|
and _is_standalone_or_localcluster(sc.getConf())
|
||||||
|
):
|
||||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
||||||
# require spark.task.resource.gpu.amount to be set explicitly
|
# require spark.task.resource.gpu.amount to be set explicitly
|
||||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||||
@ -907,30 +910,27 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||||
|
|
||||||
def _skip_stage_level_scheduling(self) -> bool:
|
def _skip_stage_level_scheduling(self, spark_version: str, conf: SparkConf) -> bool:
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
"""Check if stage-level scheduling is not needed,
|
"""Check if stage-level scheduling is not needed,
|
||||||
return true to skip stage-level scheduling"""
|
return true to skip stage-level scheduling"""
|
||||||
|
|
||||||
if self._run_on_gpu():
|
if self._run_on_gpu():
|
||||||
ss = _get_spark_session()
|
if spark_version < "3.4.0":
|
||||||
sc = ss.sparkContext
|
|
||||||
|
|
||||||
if ss.version < "3.4.0":
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
|
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not _is_standalone_or_localcluster(sc):
|
if not _is_standalone_or_localcluster(conf):
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
"Stage-level scheduling in xgboost requires spark standalone or "
|
||||||
"local-cluster mode"
|
"local-cluster mode"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
executor_cores = sc.getConf().get("spark.executor.cores")
|
executor_cores = conf.get("spark.executor.cores")
|
||||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
|
||||||
if executor_cores is None or executor_gpus is None:
|
if executor_cores is None or executor_gpus is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Stage-level scheduling in xgboost requires spark.executor.cores, "
|
"Stage-level scheduling in xgboost requires spark.executor.cores, "
|
||||||
@ -955,7 +955,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")
|
task_gpu_amount = conf.get("spark.task.resource.gpu.amount")
|
||||||
|
|
||||||
if task_gpu_amount is None:
|
if task_gpu_amount is None:
|
||||||
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
||||||
@ -975,14 +975,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
|
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
|
||||||
"""Try to enable stage-level scheduling"""
|
"""Try to enable stage-level scheduling"""
|
||||||
|
ss = _get_spark_session()
|
||||||
if self._skip_stage_level_scheduling():
|
conf = ss.sparkContext.getConf()
|
||||||
|
if self._skip_stage_level_scheduling(ss.version, conf):
|
||||||
return rdd
|
return rdd
|
||||||
|
|
||||||
ss = _get_spark_session()
|
|
||||||
|
|
||||||
# executor_cores will not be None
|
# executor_cores will not be None
|
||||||
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
|
executor_cores = conf.get("spark.executor.cores")
|
||||||
assert executor_cores is not None
|
assert executor_cores is not None
|
||||||
|
|
||||||
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
|
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from threading import Thread
|
|||||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext
|
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
|
|
||||||
from xgboost import Booster, XGBModel, collective
|
from xgboost import Booster, XGBModel, collective
|
||||||
@ -129,8 +129,8 @@ def _is_local(spark_context: SparkContext) -> bool:
|
|||||||
return spark_context._jsc.sc().isLocal()
|
return spark_context._jsc.sc().isLocal()
|
||||||
|
|
||||||
|
|
||||||
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
|
def _is_standalone_or_localcluster(conf: SparkConf) -> bool:
|
||||||
master = spark_context.getConf().get("spark.master")
|
master = conf.get("spark.master")
|
||||||
return master is not None and (
|
return master is not None and (
|
||||||
master.startswith("spark://") or master.startswith("local-cluster")
|
master.startswith("spark://") or master.startswith("local-cluster")
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Generator, Sequence, Type
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from pyspark import SparkConf
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
@ -932,6 +933,113 @@ class TestPySparkLocal:
|
|||||||
model_loaded.set_device("cuda")
|
model_loaded.set_device("cuda")
|
||||||
assert model_loaded._run_on_gpu()
|
assert model_loaded._run_on_gpu()
|
||||||
|
|
||||||
|
def test_skip_stage_level_scheduling(self) -> None:
|
||||||
|
conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
|
||||||
|
classifer_on_cpu = SparkXGBClassifier(use_gpu=False)
|
||||||
|
classifer_on_gpu = SparkXGBClassifier(use_gpu=True)
|
||||||
|
|
||||||
|
# the correct configurations should not skip stage-level scheduling
|
||||||
|
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf)
|
||||||
|
|
||||||
|
# spark version < 3.4.0
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf)
|
||||||
|
|
||||||
|
# not run on GPU
|
||||||
|
assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", conf)
|
||||||
|
|
||||||
|
# spark.executor.cores is not set
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# spark.executor.cores=1
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "1")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# spark.executor.resource.gpu.amount is not set
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# spark.executor.resource.gpu.amount>1
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "2")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# spark.task.resource.gpu.amount is not set
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# spark.task.resource.gpu.amount=1
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# yarn
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("yarn")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
# k8s
|
||||||
|
badConf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("k8s://")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
||||||
|
|
||||||
|
|
||||||
class XgboostLocalTest(SparkTestCase):
|
class XgboostLocalTest(SparkTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user