[PySpark] add gpu support for spark local mode (#8068)
This commit is contained in:
parent
91bb9e2cb3
commit
a33f35eecf
@ -53,6 +53,8 @@ from .utils import (
|
||||
_get_rabit_args,
|
||||
_get_args_from_message_list,
|
||||
_get_spark_session,
|
||||
_is_local,
|
||||
_get_gpu_id,
|
||||
)
|
||||
from .params import (
|
||||
HasArbitraryParamsDict,
|
||||
@ -281,6 +283,29 @@ class _SparkXGBParams(
|
||||
.get("spark.task.resource.gpu.amount")
|
||||
)
|
||||
|
||||
is_local = _is_local(
|
||||
_get_spark_session()
|
||||
.sparkContext
|
||||
)
|
||||
|
||||
if is_local:
|
||||
# checking spark local mode.
|
||||
if gpu_per_task:
|
||||
raise RuntimeError(
|
||||
"The spark cluster does not support gpu configuration for local mode. "
|
||||
"Please delete spark.executor.resource.gpu.amount and "
|
||||
"spark.task.resource.gpu.amount"
|
||||
)
|
||||
|
||||
# Support GPU training in Spark local mode is just for debugging purposes,
|
||||
# so it's okay for printing the below warning instead of checking the real
|
||||
# gpu numbers and raising the exception.
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
||||
"has at least %d GPUs", self.getOrDefault(self.num_workers)
|
||||
)
|
||||
else:
|
||||
# checking spark non-local mode.
|
||||
if not gpu_per_task or int(gpu_per_task) < 1:
|
||||
raise RuntimeError(
|
||||
"The spark cluster does not have the necessary GPU"
|
||||
@ -547,6 +572,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
booster_params["nthread"] = cpu_per_task
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
|
||||
is_local = _is_local(
|
||||
_get_spark_session()
|
||||
.sparkContext
|
||||
)
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""
|
||||
Takes in an RDD partition and outputs a booster for that partition after going through
|
||||
@ -558,10 +588,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
context.barrier()
|
||||
|
||||
if use_gpu:
|
||||
# Set booster worker to use the first GPU allocated to the spark task.
|
||||
booster_params["gpu_id"] = int(
|
||||
context._resources["gpu"].addresses[0].strip()
|
||||
)
|
||||
booster_params["gpu_id"] = context.partitionId() if is_local \
|
||||
else _get_gpu_id(context)
|
||||
|
||||
_rabit_args = ""
|
||||
if context.partitionId() == 0:
|
||||
|
||||
@ -128,3 +128,25 @@ def _get_max_num_concurrent_tasks(spark_context):
|
||||
spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
|
||||
)
|
||||
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
||||
|
||||
|
||||
def _is_local(spark_context) -> bool:
|
||||
"""Whether it is Spark local mode"""
|
||||
# pylint: disable=protected-access
|
||||
return spark_context._jsc.sc().isLocal()
|
||||
|
||||
|
||||
def _get_gpu_id(task_context) -> int:
|
||||
"""Get the gpu id from the task resources"""
|
||||
if task_context is None:
|
||||
# This is a safety check.
|
||||
raise RuntimeError(
|
||||
"_get_gpu_id should not be invoked from driver side."
|
||||
)
|
||||
resources = task_context.resources()
|
||||
if 'gpu' not in resources:
|
||||
raise RuntimeError(
|
||||
"Couldn't get the gpu id, Please check the GPU resource configuration"
|
||||
)
|
||||
# return the first gpu id.
|
||||
return int(resources['gpu'].addresses[0].strip())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user