[pyspark] Support stage-level scheduling for training (#9519)

This commit is contained in:
Bobby Wang 2023-10-17 10:35:39 +08:00 committed by GitHub
parent 83191f0839
commit 4d1607eefd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 176 additions and 51 deletions

View File

@ -22,7 +22,7 @@ from typing import (
import numpy as np
import pandas as pd
from pyspark import SparkContext, cloudpickle
from pyspark import RDD, SparkContext, cloudpickle
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
@ -44,6 +44,7 @@ from pyspark.ml.util import (
MLWritable,
MLWriter,
)
from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
@ -88,6 +89,7 @@ from .utils import (
_get_rabit_args,
_get_spark_session,
_is_local,
_is_standalone_or_localcluster,
deserialize_booster,
deserialize_xgb_model,
get_class_name,
@ -342,6 +344,54 @@ class _SparkXGBParams(
predict_params[param.name] = self.getOrDefault(param)
return predict_params
def _validate_gpu_params(self) -> None:
"""Validate the gpu parameters and gpu configurations"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
ss = _get_spark_session()
sc = ss.sparkContext
if _is_local(sc):
# Support GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
"You have enabled GPU in spark local mode. Please make sure your"
" local node has at least %d GPUs",
self.getOrDefault(self.num_workers),
)
else:
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
if executor_gpus is None:
raise ValueError(
"The `spark.executor.resource.gpu.amount` is required for training"
" on GPU."
)
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
if gpu_per_task is not None:
if float(gpu_per_task) < 1.0:
raise ValueError(
"XGBoost doesn't support GPU fractional configurations. "
"Please set `spark.task.resource.gpu.amount=spark.executor"
".resource.gpu.amount`"
)
if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
)
else:
raise ValueError(
"The `spark.task.resource.gpu.amount` is required for training"
" on GPU."
)
def _validate_params(self) -> None:
# pylint: disable=too-many-branches
init_model = self.getOrDefault("xgb_model")
@ -421,53 +471,7 @@ class _SparkXGBParams(
"`pyspark.ml.linalg.Vector` type."
)
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
gpu_per_task = (
_get_spark_session()
.sparkContext.getConf()
.get("spark.task.resource.gpu.amount")
)
is_local = _is_local(_get_spark_session().sparkContext)
if is_local:
# checking spark local mode.
if gpu_per_task is not None:
raise RuntimeError(
"The spark local mode does not support gpu configuration."
"Please remove spark.executor.resource.gpu.amount and "
"spark.task.resource.gpu.amount"
)
# Support GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
"You have enabled GPU in spark local mode. Please make sure your"
" local node has at least %d GPUs",
self.getOrDefault(self.num_workers),
)
else:
# checking spark non-local mode.
if gpu_per_task is not None:
if float(gpu_per_task) < 1.0:
raise ValueError(
"XGBoost doesn't support GPU fractional configurations. "
"Please set `spark.task.resource.gpu.amount=spark.executor"
".resource.gpu.amount`"
)
if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
)
else:
raise ValueError(
"The `spark.task.resource.gpu.amount` is required for training"
" on GPU."
)
self._validate_gpu_params()
def _validate_and_convert_feature_col_as_float_col_list(
@ -592,6 +596,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
arbitrary_params_dict={},
)
self.logger = get_logger(self.__class__.__name__)
def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
"""
Set params for the estimator.
@ -894,6 +900,116 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
return booster_params, train_call_kwargs_params, dmatrix_kwargs
def _skip_stage_level_scheduling(self) -> bool:
# pylint: disable=too-many-return-statements
"""Check if stage-level scheduling is not needed,
return true to skip stage-level scheduling"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
ss = _get_spark_session()
sc = ss.sparkContext
if ss.version < "3.4.0":
self.logger.info(
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
)
return True
if not _is_standalone_or_localcluster(sc):
self.logger.info(
"Stage-level scheduling in xgboost requires spark standalone or "
"local-cluster mode"
)
return True
executor_cores = sc.getConf().get("spark.executor.cores")
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
if executor_cores is None or executor_gpus is None:
self.logger.info(
"Stage-level scheduling in xgboost requires spark.executor.cores, "
"spark.executor.resource.gpu.amount to be set."
)
return True
if int(executor_cores) == 1:
# there will be only 1 task running at any time.
self.logger.info(
"Stage-level scheduling in xgboost requires spark.executor.cores > 1 "
)
return True
if int(executor_gpus) > 1:
# For spark.executor.resource.gpu.amount > 1, we suppose user knows how to configure
# to make xgboost run successfully.
#
self.logger.info(
"Stage-level scheduling in xgboost will not work "
"when spark.executor.resource.gpu.amount>1"
)
return True
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")
if task_gpu_amount is None:
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
# but with stage-level scheduling, we can make training task grab the gpu.
return False
if float(task_gpu_amount) == float(executor_gpus):
# spark.executor.resource.gpu.amount=spark.task.resource.gpu.amount "
# results in only 1 task running at a time, which may cause perf issue.
return True
# We can enable stage-level scheduling
return False
# CPU training doesn't require stage-level scheduling
return True
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
"""Try to enable stage-level scheduling"""
if self._skip_stage_level_scheduling():
return rdd
ss = _get_spark_session()
# executor_cores will not be None
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
assert executor_cores is not None
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
# If spark-rapids is enabled, to avoid GPU OOM, we don't allow other
# ETL gpu tasks running alongside training tasks.
spark_plugins = ss.conf.get("spark.plugins", " ")
assert spark_plugins is not None
spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true")
assert spark_rapids_sql_enabled is not None
task_cores = (
int(executor_cores)
if "com.nvidia.spark.SQLPlugin" in spark_plugins
and "true" == spark_rapids_sql_enabled.lower()
else (int(executor_cores) // 2) + 1
)
# Each training task requires cpu cores > total executor cores//2 + 1 which can
# make sure the tasks be sent to different executors.
#
# Please note that we can't use GPU to limit the concurrent tasks because of
# https://issues.apache.org/jira/browse/SPARK-45527.
task_gpus = 1.0
treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
rp = ResourceProfileBuilder().require(treqs).build
self.logger.info(
"XGBoost training tasks require the resource(cores=%s, gpu=%s).",
task_cores,
task_gpus,
)
return rdd.withResources(rp)
def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
@ -994,14 +1110,16 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
)
def _run_job() -> Tuple[str, str]:
ret = (
rdd = (
dataset.mapInPandas(
_train_booster, schema="config string, booster string" # type: ignore
_train_booster, # type: ignore
schema="config string, booster string",
)
.rdd.barrier()
.mapPartitions(lambda x: x)
.collect()[0]
)
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()[0]
return ret[0], ret[1]
get_logger("XGBoost-PySpark").info(

View File

@ -129,6 +129,13 @@ def _is_local(spark_context: SparkContext) -> bool:
return spark_context._jsc.sc().isLocal()
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
master = spark_context.getConf().get("spark.master")
return master is not None and (
master.startswith("spark://") or master.startswith("local-cluster")
)
def _get_gpu_id(task_context: TaskContext) -> int:
"""Get the gpu id from the task resources"""
if task_context is None: