[pyspark][doc] Test and doc for stage-level scheduling. (#9786)
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyspark import RDD, SparkContext, cloudpickle
|
||||
from pyspark import RDD, SparkConf, SparkContext, cloudpickle
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
@@ -368,7 +368,10 @@ class _SparkXGBParams(
|
||||
" on GPU."
|
||||
)
|
||||
|
||||
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
|
||||
if not (
|
||||
ss.version >= "3.4.0"
|
||||
and _is_standalone_or_localcluster(sc.getConf())
|
||||
):
|
||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
||||
# require spark.task.resource.gpu.amount to be set explicitly
|
||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
@@ -907,30 +910,27 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||
|
||||
def _skip_stage_level_scheduling(self) -> bool:
|
||||
def _skip_stage_level_scheduling(self, spark_version: str, conf: SparkConf) -> bool:
|
||||
# pylint: disable=too-many-return-statements
|
||||
"""Check if stage-level scheduling is not needed,
|
||||
return true to skip stage-level scheduling"""
|
||||
|
||||
if self._run_on_gpu():
|
||||
ss = _get_spark_session()
|
||||
sc = ss.sparkContext
|
||||
|
||||
if ss.version < "3.4.0":
|
||||
if spark_version < "3.4.0":
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
|
||||
)
|
||||
return True
|
||||
|
||||
if not _is_standalone_or_localcluster(sc):
|
||||
if not _is_standalone_or_localcluster(conf):
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
||||
"local-cluster mode"
|
||||
)
|
||||
return True
|
||||
|
||||
executor_cores = sc.getConf().get("spark.executor.cores")
|
||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
||||
executor_cores = conf.get("spark.executor.cores")
|
||||
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
|
||||
if executor_cores is None or executor_gpus is None:
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark.executor.cores, "
|
||||
@@ -955,7 +955,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
return True
|
||||
|
||||
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
task_gpu_amount = conf.get("spark.task.resource.gpu.amount")
|
||||
|
||||
if task_gpu_amount is None:
|
||||
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
||||
@@ -975,14 +975,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
|
||||
"""Try to enable stage-level scheduling"""
|
||||
|
||||
if self._skip_stage_level_scheduling():
|
||||
ss = _get_spark_session()
|
||||
conf = ss.sparkContext.getConf()
|
||||
if self._skip_stage_level_scheduling(ss.version, conf):
|
||||
return rdd
|
||||
|
||||
ss = _get_spark_session()
|
||||
|
||||
# executor_cores will not be None
|
||||
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
|
||||
executor_cores = conf.get("spark.executor.cores")
|
||||
assert executor_cores is not None
|
||||
|
||||
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
|
||||
|
||||
@@ -10,7 +10,7 @@ from threading import Thread
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import Booster, XGBModel, collective
|
||||
@@ -129,8 +129,8 @@ def _is_local(spark_context: SparkContext) -> bool:
|
||||
return spark_context._jsc.sc().isLocal()
|
||||
|
||||
|
||||
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
|
||||
master = spark_context.getConf().get("spark.master")
|
||||
def _is_standalone_or_localcluster(conf: SparkConf) -> bool:
|
||||
master = conf.get("spark.master")
|
||||
return master is not None and (
|
||||
master.startswith("spark://") or master.startswith("local-cluster")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user