[pyspark] support stage-level for yarn/k8s (#10209)
This commit is contained in:
parent
bb212bf33c
commit
8fb05c8c95
@ -347,15 +347,14 @@ class _SparkXGBParams(
|
|||||||
predict_params[param.name] = self.getOrDefault(param)
|
predict_params[param.name] = self.getOrDefault(param)
|
||||||
return predict_params
|
return predict_params
|
||||||
|
|
||||||
def _validate_gpu_params(self) -> None:
|
def _validate_gpu_params(
|
||||||
|
self, spark_version: str, conf: SparkConf, is_local: bool = False
|
||||||
|
) -> None:
|
||||||
"""Validate the gpu parameters and gpu configurations"""
|
"""Validate the gpu parameters and gpu configurations"""
|
||||||
|
|
||||||
if self._run_on_gpu():
|
if self._run_on_gpu():
|
||||||
ss = _get_spark_session()
|
if is_local:
|
||||||
sc = ss.sparkContext
|
# Supporting GPU training in Spark local mode is just for debugging
|
||||||
|
|
||||||
if _is_local(sc):
|
|
||||||
# Support GPU training in Spark local mode is just for debugging
|
|
||||||
# purposes, so it's okay for printing the below warning instead of
|
# purposes, so it's okay for printing the below warning instead of
|
||||||
# checking the real gpu numbers and raising the exception.
|
# checking the real gpu numbers and raising the exception.
|
||||||
get_logger(self.__class__.__name__).warning(
|
get_logger(self.__class__.__name__).warning(
|
||||||
@ -364,33 +363,41 @@ class _SparkXGBParams(
|
|||||||
self.getOrDefault(self.num_workers),
|
self.getOrDefault(self.num_workers),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
|
||||||
if executor_gpus is None:
|
if executor_gpus is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `spark.executor.resource.gpu.amount` is required for training"
|
"The `spark.executor.resource.gpu.amount` is required for training"
|
||||||
" on GPU."
|
" on GPU."
|
||||||
)
|
)
|
||||||
|
gpu_per_task = conf.get("spark.task.resource.gpu.amount")
|
||||||
if not (
|
if gpu_per_task is not None and float(gpu_per_task) > 1.0:
|
||||||
ss.version >= "3.4.0"
|
get_logger(self.__class__.__name__).warning(
|
||||||
and _is_standalone_or_localcluster(sc.getConf())
|
"The configuration assigns %s GPUs to each Spark task, but each "
|
||||||
|
"XGBoost training task only utilizes 1 GPU, which will lead to "
|
||||||
|
"unnecessary GPU waste",
|
||||||
|
gpu_per_task,
|
||||||
|
)
|
||||||
|
# For 3.5.1+, Spark supports task stage-level scheduling for
|
||||||
|
# Yarn/K8s/Standalone/Local cluster
|
||||||
|
# From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for
|
||||||
|
# Standalone/Local cluster
|
||||||
|
# For spark below 3.4.0, Task stage-level scheduling is not supported.
|
||||||
|
#
|
||||||
|
# With stage-level scheduling, spark.task.resource.gpu.amount is not required
|
||||||
|
# to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and
|
||||||
|
# must be set to 1.0
|
||||||
|
if spark_version < "3.4.0" or (
|
||||||
|
"3.4.0" <= spark_version < "3.5.1"
|
||||||
|
and not _is_standalone_or_localcluster(conf)
|
||||||
):
|
):
|
||||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
|
||||||
# require spark.task.resource.gpu.amount to be set explicitly
|
|
||||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
|
||||||
if gpu_per_task is not None:
|
if gpu_per_task is not None:
|
||||||
if float(gpu_per_task) < 1.0:
|
if float(gpu_per_task) < 1.0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"XGBoost doesn't support GPU fractional configurations. "
|
"XGBoost doesn't support GPU fractional configurations. Please set "
|
||||||
"Please set `spark.task.resource.gpu.amount=spark.executor"
|
"`spark.task.resource.gpu.amount=spark.executor.resource.gpu."
|
||||||
".resource.gpu.amount`"
|
"amount`. To enable GPU fractional configurations, you can try "
|
||||||
)
|
"standalone/localcluster with spark 3.4.0+ and"
|
||||||
|
"YARN/K8S with spark 3.5.1+"
|
||||||
if float(gpu_per_task) > 1.0:
|
|
||||||
get_logger(self.__class__.__name__).warning(
|
|
||||||
"%s GPUs for each Spark task is configured, but each "
|
|
||||||
"XGBoost training task uses only 1 GPU.",
|
|
||||||
gpu_per_task,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -475,7 +482,9 @@ class _SparkXGBParams(
|
|||||||
"`pyspark.ml.linalg.Vector` type."
|
"`pyspark.ml.linalg.Vector` type."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._validate_gpu_params()
|
ss = _get_spark_session()
|
||||||
|
sc = ss.sparkContext
|
||||||
|
self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc))
|
||||||
|
|
||||||
def _run_on_gpu(self) -> bool:
|
def _run_on_gpu(self) -> bool:
|
||||||
"""If train or transform on the gpu according to the parameters"""
|
"""If train or transform on the gpu according to the parameters"""
|
||||||
@ -925,10 +934,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not _is_standalone_or_localcluster(conf):
|
if (
|
||||||
|
"3.4.0" <= spark_version < "3.5.1"
|
||||||
|
and not _is_standalone_or_localcluster(conf)
|
||||||
|
):
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
"For %s, Stage-level scheduling in xgboost requires spark standalone "
|
||||||
"local-cluster mode"
|
"or local-cluster mode",
|
||||||
|
spark_version,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -980,7 +993,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
"""Try to enable stage-level scheduling"""
|
"""Try to enable stage-level scheduling"""
|
||||||
ss = _get_spark_session()
|
ss = _get_spark_session()
|
||||||
conf = ss.sparkContext.getConf()
|
conf = ss.sparkContext.getConf()
|
||||||
if self._skip_stage_level_scheduling(ss.version, conf):
|
if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling(
|
||||||
|
ss.version, conf
|
||||||
|
):
|
||||||
return rdd
|
return rdd
|
||||||
|
|
||||||
# executor_cores will not be None
|
# executor_cores will not be None
|
||||||
|
|||||||
@ -929,8 +929,127 @@ class TestPySparkLocal:
|
|||||||
model_loaded.set_device("cuda")
|
model_loaded.set_device("cuda")
|
||||||
assert model_loaded._run_on_gpu()
|
assert model_loaded._run_on_gpu()
|
||||||
|
|
||||||
|
def test_validate_gpu_params(self) -> None:
|
||||||
|
# Standalone
|
||||||
|
standalone_conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
classifer_on_cpu = SparkXGBClassifier(use_gpu=False)
|
||||||
|
classifer_on_gpu = SparkXGBClassifier(use_gpu=True)
|
||||||
|
|
||||||
|
# No exception for classifier on CPU
|
||||||
|
classifer_on_cpu._validate_gpu_params("3.4.0", standalone_conf)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="XGBoost doesn't support GPU fractional configurations"
|
||||||
|
):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.3.0", standalone_conf)
|
||||||
|
|
||||||
|
# No issues
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.0", standalone_conf)
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.1", standalone_conf)
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.0", standalone_conf)
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.1", standalone_conf)
|
||||||
|
|
||||||
|
# no spark.executor.resource.gpu.amount
|
||||||
|
standalone_bad_conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
msg_match = (
|
||||||
|
"The `spark.executor.resource.gpu.amount` is required for training on GPU"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.1", standalone_bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf)
|
||||||
|
|
||||||
|
standalone_bad_conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster("spark://foo")
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
msg_match = (
|
||||||
|
"The `spark.task.resource.gpu.amount` is required for training on GPU"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf)
|
||||||
|
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf)
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf)
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf)
|
||||||
|
|
||||||
|
# Yarn and K8s mode
|
||||||
|
for mode in ["yarn", "k8s://"]:
|
||||||
|
conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster(mode)
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="XGBoost doesn't support GPU fractional configurations",
|
||||||
|
):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.3.0", conf)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="XGBoost doesn't support GPU fractional configurations",
|
||||||
|
):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.0", conf)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="XGBoost doesn't support GPU fractional configurations",
|
||||||
|
):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.1", conf)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="XGBoost doesn't support GPU fractional configurations",
|
||||||
|
):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.0", conf)
|
||||||
|
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.1", conf)
|
||||||
|
|
||||||
|
for mode in ["yarn", "k8s://"]:
|
||||||
|
bad_conf = (
|
||||||
|
SparkConf()
|
||||||
|
.setMaster(mode)
|
||||||
|
.set("spark.executor.cores", "12")
|
||||||
|
.set("spark.task.cpus", "1")
|
||||||
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
|
)
|
||||||
|
msg_match = (
|
||||||
|
"The `spark.task.resource.gpu.amount` is required for training on GPU"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.3.0", bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.4.0", bad_conf)
|
||||||
|
with pytest.raises(ValueError, match=msg_match):
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.0", bad_conf)
|
||||||
|
|
||||||
|
classifer_on_gpu._validate_gpu_params("3.5.1", bad_conf)
|
||||||
|
|
||||||
def test_skip_stage_level_scheduling(self) -> None:
|
def test_skip_stage_level_scheduling(self) -> None:
|
||||||
conf = (
|
standalone_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "12")
|
.set("spark.executor.cores", "12")
|
||||||
@ -943,26 +1062,36 @@ class TestPySparkLocal:
|
|||||||
classifer_on_gpu = SparkXGBClassifier(use_gpu=True)
|
classifer_on_gpu = SparkXGBClassifier(use_gpu=True)
|
||||||
|
|
||||||
# the correct configurations should not skip stage-level scheduling
|
# the correct configurations should not skip stage-level scheduling
|
||||||
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf)
|
assert not classifer_on_gpu._skip_stage_level_scheduling(
|
||||||
|
"3.4.0", standalone_conf
|
||||||
|
)
|
||||||
|
assert not classifer_on_gpu._skip_stage_level_scheduling(
|
||||||
|
"3.4.1", standalone_conf
|
||||||
|
)
|
||||||
|
assert not classifer_on_gpu._skip_stage_level_scheduling(
|
||||||
|
"3.5.0", standalone_conf
|
||||||
|
)
|
||||||
|
assert not classifer_on_gpu._skip_stage_level_scheduling(
|
||||||
|
"3.5.1", standalone_conf
|
||||||
|
)
|
||||||
|
|
||||||
# spark version < 3.4.0
|
# spark version < 3.4.0
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", standalone_conf)
|
||||||
|
|
||||||
# not run on GPU
|
# not run on GPU
|
||||||
assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", conf)
|
assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", standalone_conf)
|
||||||
|
|
||||||
# spark.executor.cores is not set
|
# spark.executor.cores is not set
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.task.cpus", "1")
|
.set("spark.task.cpus", "1")
|
||||||
.set("spark.executor.resource.gpu.amount", "1")
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
.set("spark.task.resource.gpu.amount", "0.08")
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
)
|
)
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# spark.executor.cores=1
|
# spark.executor.cores=1
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "1")
|
.set("spark.executor.cores", "1")
|
||||||
@ -970,20 +1099,20 @@ class TestPySparkLocal:
|
|||||||
.set("spark.executor.resource.gpu.amount", "1")
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
.set("spark.task.resource.gpu.amount", "0.08")
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
)
|
)
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# spark.executor.resource.gpu.amount is not set
|
# spark.executor.resource.gpu.amount is not set
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "12")
|
.set("spark.executor.cores", "12")
|
||||||
.set("spark.task.cpus", "1")
|
.set("spark.task.cpus", "1")
|
||||||
.set("spark.task.resource.gpu.amount", "0.08")
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
)
|
)
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# spark.executor.resource.gpu.amount>1
|
# spark.executor.resource.gpu.amount>1
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "12")
|
.set("spark.executor.cores", "12")
|
||||||
@ -991,20 +1120,20 @@ class TestPySparkLocal:
|
|||||||
.set("spark.executor.resource.gpu.amount", "2")
|
.set("spark.executor.resource.gpu.amount", "2")
|
||||||
.set("spark.task.resource.gpu.amount", "0.08")
|
.set("spark.task.resource.gpu.amount", "0.08")
|
||||||
)
|
)
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# spark.task.resource.gpu.amount is not set
|
# spark.task.resource.gpu.amount is not set
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "12")
|
.set("spark.executor.cores", "12")
|
||||||
.set("spark.task.cpus", "1")
|
.set("spark.task.cpus", "1")
|
||||||
.set("spark.executor.resource.gpu.amount", "1")
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
)
|
)
|
||||||
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# spark.task.resource.gpu.amount=1
|
# spark.task.resource.gpu.amount=1
|
||||||
badConf = (
|
bad_conf = (
|
||||||
SparkConf()
|
SparkConf()
|
||||||
.setMaster("spark://foo")
|
.setMaster("spark://foo")
|
||||||
.set("spark.executor.cores", "12")
|
.set("spark.executor.cores", "12")
|
||||||
@ -1012,29 +1141,32 @@ class TestPySparkLocal:
|
|||||||
.set("spark.executor.resource.gpu.amount", "1")
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
.set("spark.task.resource.gpu.amount", "1")
|
.set("spark.task.resource.gpu.amount", "1")
|
||||||
)
|
)
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf)
|
||||||
|
|
||||||
# yarn
|
# For Yarn and K8S
|
||||||
badConf = (
|
for mode in ["yarn", "k8s://"]:
|
||||||
SparkConf()
|
for gpu_amount in ["0.08", "0.2", "1.0"]:
|
||||||
.setMaster("yarn")
|
conf = (
|
||||||
.set("spark.executor.cores", "12")
|
SparkConf()
|
||||||
.set("spark.task.cpus", "1")
|
.setMaster(mode)
|
||||||
.set("spark.executor.resource.gpu.amount", "1")
|
.set("spark.executor.cores", "12")
|
||||||
.set("spark.task.resource.gpu.amount", "1")
|
.set("spark.task.cpus", "1")
|
||||||
)
|
.set("spark.executor.resource.gpu.amount", "1")
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
.set("spark.task.resource.gpu.amount", gpu_amount)
|
||||||
|
)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.1", conf)
|
||||||
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.5.0", conf)
|
||||||
|
|
||||||
# k8s
|
# This will be fixed when spark 4.0.0 is released.
|
||||||
badConf = (
|
if gpu_amount == "1.0":
|
||||||
SparkConf()
|
assert classifer_on_gpu._skip_stage_level_scheduling("3.5.1", conf)
|
||||||
.setMaster("k8s://")
|
else:
|
||||||
.set("spark.executor.cores", "12")
|
# Starting from 3.5.1+, stage-level scheduling is working for Yarn and K8s
|
||||||
.set("spark.task.cpus", "1")
|
assert not classifer_on_gpu._skip_stage_level_scheduling(
|
||||||
.set("spark.executor.resource.gpu.amount", "1")
|
"3.5.1", conf
|
||||||
.set("spark.task.resource.gpu.amount", "1")
|
)
|
||||||
)
|
|
||||||
assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf)
|
|
||||||
|
|
||||||
|
|
||||||
class XgboostLocalTest(SparkTestCase):
|
class XgboostLocalTest(SparkTestCase):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user