[PySpark] add gpu support for spark local mode (#8068)
This commit is contained in:
parent
91bb9e2cb3
commit
a33f35eecf
@ -53,6 +53,8 @@ from .utils import (
|
|||||||
_get_rabit_args,
|
_get_rabit_args,
|
||||||
_get_args_from_message_list,
|
_get_args_from_message_list,
|
||||||
_get_spark_session,
|
_get_spark_session,
|
||||||
|
_is_local,
|
||||||
|
_get_gpu_id,
|
||||||
)
|
)
|
||||||
from .params import (
|
from .params import (
|
||||||
HasArbitraryParamsDict,
|
HasArbitraryParamsDict,
|
||||||
@ -281,6 +283,29 @@ class _SparkXGBParams(
|
|||||||
.get("spark.task.resource.gpu.amount")
|
.get("spark.task.resource.gpu.amount")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_local = _is_local(
|
||||||
|
_get_spark_session()
|
||||||
|
.sparkContext
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_local:
|
||||||
|
# checking spark local mode.
|
||||||
|
if gpu_per_task:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The spark cluster does not support gpu configuration for local mode. "
|
||||||
|
"Please delete spark.executor.resource.gpu.amount and "
|
||||||
|
"spark.task.resource.gpu.amount"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Support GPU training in Spark local mode is just for debugging purposes,
|
||||||
|
# so it's okay for printing the below warning instead of checking the real
|
||||||
|
# gpu numbers and raising the exception.
|
||||||
|
get_logger(self.__class__.__name__).warning(
|
||||||
|
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
||||||
|
"has at least %d GPUs", self.getOrDefault(self.num_workers)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# checking spark non-local mode.
|
||||||
if not gpu_per_task or int(gpu_per_task) < 1:
|
if not gpu_per_task or int(gpu_per_task) < 1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The spark cluster does not have the necessary GPU"
|
"The spark cluster does not have the necessary GPU"
|
||||||
@ -547,6 +572,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
booster_params["nthread"] = cpu_per_task
|
booster_params["nthread"] = cpu_per_task
|
||||||
use_gpu = self.getOrDefault(self.use_gpu)
|
use_gpu = self.getOrDefault(self.use_gpu)
|
||||||
|
|
||||||
|
is_local = _is_local(
|
||||||
|
_get_spark_session()
|
||||||
|
.sparkContext
|
||||||
|
)
|
||||||
|
|
||||||
def _train_booster(pandas_df_iter):
|
def _train_booster(pandas_df_iter):
|
||||||
"""
|
"""
|
||||||
Takes in an RDD partition and outputs a booster for that partition after going through
|
Takes in an RDD partition and outputs a booster for that partition after going through
|
||||||
@ -558,10 +588,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
context.barrier()
|
context.barrier()
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
# Set booster worker to use the first GPU allocated to the spark task.
|
booster_params["gpu_id"] = context.partitionId() if is_local \
|
||||||
booster_params["gpu_id"] = int(
|
else _get_gpu_id(context)
|
||||||
context._resources["gpu"].addresses[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
_rabit_args = ""
|
_rabit_args = ""
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
|
|||||||
@ -128,3 +128,25 @@ def _get_max_num_concurrent_tasks(spark_context):
|
|||||||
spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
|
spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
|
||||||
)
|
)
|
||||||
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local(spark_context) -> bool:
|
||||||
|
"""Whether it is Spark local mode"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
return spark_context._jsc.sc().isLocal()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gpu_id(task_context) -> int:
|
||||||
|
"""Get the gpu id from the task resources"""
|
||||||
|
if task_context is None:
|
||||||
|
# This is a safety check.
|
||||||
|
raise RuntimeError(
|
||||||
|
"_get_gpu_id should not be invoked from driver side."
|
||||||
|
)
|
||||||
|
resources = task_context.resources()
|
||||||
|
if 'gpu' not in resources:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Couldn't get the gpu id, Please check the GPU resource configuration"
|
||||||
|
)
|
||||||
|
# return the first gpu id.
|
||||||
|
return int(resources['gpu'].addresses[0].strip())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user