[pyspark] Support stage-level scheduling for training (#9519)
This commit is contained in:
parent
83191f0839
commit
4d1607eefd
@ -22,7 +22,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyspark import SparkContext, cloudpickle
|
||||
from pyspark import RDD, SparkContext, cloudpickle
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
@ -44,6 +44,7 @@ from pyspark.ml.util import (
|
||||
MLWritable,
|
||||
MLWriter,
|
||||
)
|
||||
from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests
|
||||
from pyspark.sql import Column, DataFrame
|
||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||
from pyspark.sql.types import (
|
||||
@ -88,6 +89,7 @@ from .utils import (
|
||||
_get_rabit_args,
|
||||
_get_spark_session,
|
||||
_is_local,
|
||||
_is_standalone_or_localcluster,
|
||||
deserialize_booster,
|
||||
deserialize_xgb_model,
|
||||
get_class_name,
|
||||
@ -342,6 +344,54 @@ class _SparkXGBParams(
|
||||
predict_params[param.name] = self.getOrDefault(param)
|
||||
return predict_params
|
||||
|
||||
def _validate_gpu_params(self) -> None:
|
||||
"""Validate the gpu parameters and gpu configurations"""
|
||||
|
||||
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
|
||||
ss = _get_spark_session()
|
||||
sc = ss.sparkContext
|
||||
|
||||
if _is_local(sc):
|
||||
# Support GPU training in Spark local mode is just for debugging
|
||||
# purposes, so it's okay for printing the below warning instead of
|
||||
# checking the real gpu numbers and raising the exception.
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You have enabled GPU in spark local mode. Please make sure your"
|
||||
" local node has at least %d GPUs",
|
||||
self.getOrDefault(self.num_workers),
|
||||
)
|
||||
else:
|
||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
||||
if executor_gpus is None:
|
||||
raise ValueError(
|
||||
"The `spark.executor.resource.gpu.amount` is required for training"
|
||||
" on GPU."
|
||||
)
|
||||
|
||||
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
|
||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
||||
# require spark.task.resource.gpu.amount to be set explicitly
|
||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
if gpu_per_task is not None:
|
||||
if float(gpu_per_task) < 1.0:
|
||||
raise ValueError(
|
||||
"XGBoost doesn't support GPU fractional configurations. "
|
||||
"Please set `spark.task.resource.gpu.amount=spark.executor"
|
||||
".resource.gpu.amount`"
|
||||
)
|
||||
|
||||
if float(gpu_per_task) > 1.0:
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"%s GPUs for each Spark task is configured, but each "
|
||||
"XGBoost training task uses only 1 GPU.",
|
||||
gpu_per_task,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The `spark.task.resource.gpu.amount` is required for training"
|
||||
" on GPU."
|
||||
)
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
# pylint: disable=too-many-branches
|
||||
init_model = self.getOrDefault("xgb_model")
|
||||
@ -421,53 +471,7 @@ class _SparkXGBParams(
|
||||
"`pyspark.ml.linalg.Vector` type."
|
||||
)
|
||||
|
||||
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
|
||||
gpu_per_task = (
|
||||
_get_spark_session()
|
||||
.sparkContext.getConf()
|
||||
.get("spark.task.resource.gpu.amount")
|
||||
)
|
||||
|
||||
is_local = _is_local(_get_spark_session().sparkContext)
|
||||
|
||||
if is_local:
|
||||
# checking spark local mode.
|
||||
if gpu_per_task is not None:
|
||||
raise RuntimeError(
|
||||
"The spark local mode does not support gpu configuration."
|
||||
"Please remove spark.executor.resource.gpu.amount and "
|
||||
"spark.task.resource.gpu.amount"
|
||||
)
|
||||
|
||||
# Support GPU training in Spark local mode is just for debugging
|
||||
# purposes, so it's okay for printing the below warning instead of
|
||||
# checking the real gpu numbers and raising the exception.
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You have enabled GPU in spark local mode. Please make sure your"
|
||||
" local node has at least %d GPUs",
|
||||
self.getOrDefault(self.num_workers),
|
||||
)
|
||||
else:
|
||||
# checking spark non-local mode.
|
||||
if gpu_per_task is not None:
|
||||
if float(gpu_per_task) < 1.0:
|
||||
raise ValueError(
|
||||
"XGBoost doesn't support GPU fractional configurations. "
|
||||
"Please set `spark.task.resource.gpu.amount=spark.executor"
|
||||
".resource.gpu.amount`"
|
||||
)
|
||||
|
||||
if float(gpu_per_task) > 1.0:
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"%s GPUs for each Spark task is configured, but each "
|
||||
"XGBoost training task uses only 1 GPU.",
|
||||
gpu_per_task,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The `spark.task.resource.gpu.amount` is required for training"
|
||||
" on GPU."
|
||||
)
|
||||
self._validate_gpu_params()
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_float_col_list(
|
||||
@ -592,6 +596,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
arbitrary_params_dict={},
|
||||
)
|
||||
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
|
||||
def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
|
||||
"""
|
||||
Set params for the estimator.
|
||||
@ -894,6 +900,116 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||
|
||||
def _skip_stage_level_scheduling(self) -> bool:
|
||||
# pylint: disable=too-many-return-statements
|
||||
"""Check if stage-level scheduling is not needed,
|
||||
return true to skip stage-level scheduling"""
|
||||
|
||||
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
|
||||
ss = _get_spark_session()
|
||||
sc = ss.sparkContext
|
||||
|
||||
if ss.version < "3.4.0":
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
|
||||
)
|
||||
return True
|
||||
|
||||
if not _is_standalone_or_localcluster(sc):
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
||||
"local-cluster mode"
|
||||
)
|
||||
return True
|
||||
|
||||
executor_cores = sc.getConf().get("spark.executor.cores")
|
||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
||||
if executor_cores is None or executor_gpus is None:
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark.executor.cores, "
|
||||
"spark.executor.resource.gpu.amount to be set."
|
||||
)
|
||||
return True
|
||||
|
||||
if int(executor_cores) == 1:
|
||||
# there will be only 1 task running at any time.
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark.executor.cores > 1 "
|
||||
)
|
||||
return True
|
||||
|
||||
if int(executor_gpus) > 1:
|
||||
# For spark.executor.resource.gpu.amount > 1, we suppose user knows how to configure
|
||||
# to make xgboost run successfully.
|
||||
#
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost will not work "
|
||||
"when spark.executor.resource.gpu.amount>1"
|
||||
)
|
||||
return True
|
||||
|
||||
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
|
||||
if task_gpu_amount is None:
|
||||
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
||||
# but with stage-level scheduling, we can make training task grab the gpu.
|
||||
return False
|
||||
|
||||
if float(task_gpu_amount) == float(executor_gpus):
|
||||
# spark.executor.resource.gpu.amount=spark.task.resource.gpu.amount "
|
||||
# results in only 1 task running at a time, which may cause perf issue.
|
||||
return True
|
||||
|
||||
# We can enable stage-level scheduling
|
||||
return False
|
||||
|
||||
# CPU training doesn't require stage-level scheduling
|
||||
return True
|
||||
|
||||
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
|
||||
"""Try to enable stage-level scheduling"""
|
||||
|
||||
if self._skip_stage_level_scheduling():
|
||||
return rdd
|
||||
|
||||
ss = _get_spark_session()
|
||||
|
||||
# executor_cores will not be None
|
||||
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
|
||||
assert executor_cores is not None
|
||||
|
||||
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
|
||||
# If spark-rapids is enabled, to avoid GPU OOM, we don't allow other
|
||||
# ETL gpu tasks running alongside training tasks.
|
||||
spark_plugins = ss.conf.get("spark.plugins", " ")
|
||||
assert spark_plugins is not None
|
||||
spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true")
|
||||
assert spark_rapids_sql_enabled is not None
|
||||
|
||||
task_cores = (
|
||||
int(executor_cores)
|
||||
if "com.nvidia.spark.SQLPlugin" in spark_plugins
|
||||
and "true" == spark_rapids_sql_enabled.lower()
|
||||
else (int(executor_cores) // 2) + 1
|
||||
)
|
||||
|
||||
# Each training task requires cpu cores > total executor cores//2 + 1 which can
|
||||
# make sure the tasks be sent to different executors.
|
||||
#
|
||||
# Please note that we can't use GPU to limit the concurrent tasks because of
|
||||
# https://issues.apache.org/jira/browse/SPARK-45527.
|
||||
|
||||
task_gpus = 1.0
|
||||
treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
|
||||
rp = ResourceProfileBuilder().require(treqs).build
|
||||
|
||||
self.logger.info(
|
||||
"XGBoost training tasks require the resource(cores=%s, gpu=%s).",
|
||||
task_cores,
|
||||
task_gpus,
|
||||
)
|
||||
return rdd.withResources(rp)
|
||||
|
||||
def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
self._validate_params()
|
||||
@ -994,14 +1110,16 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
|
||||
def _run_job() -> Tuple[str, str]:
|
||||
ret = (
|
||||
rdd = (
|
||||
dataset.mapInPandas(
|
||||
_train_booster, schema="config string, booster string" # type: ignore
|
||||
_train_booster, # type: ignore
|
||||
schema="config string, booster string",
|
||||
)
|
||||
.rdd.barrier()
|
||||
.mapPartitions(lambda x: x)
|
||||
.collect()[0]
|
||||
)
|
||||
rdd_with_resource = self._try_stage_level_scheduling(rdd)
|
||||
ret = rdd_with_resource.collect()[0]
|
||||
return ret[0], ret[1]
|
||||
|
||||
get_logger("XGBoost-PySpark").info(
|
||||
|
||||
@ -129,6 +129,13 @@ def _is_local(spark_context: SparkContext) -> bool:
|
||||
return spark_context._jsc.sc().isLocal()
|
||||
|
||||
|
||||
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
|
||||
master = spark_context.getConf().get("spark.master")
|
||||
return master is not None and (
|
||||
master.startswith("spark://") or master.startswith("local-cluster")
|
||||
)
|
||||
|
||||
|
||||
def _get_gpu_id(task_context: TaskContext) -> int:
|
||||
"""Get the gpu id from the task resources"""
|
||||
if task_context is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user