[jvm-packages] support stage-level scheduling (#9775)
This commit is contained in:
parent
162da7b52b
commit
36a552ac98
@ -206,7 +206,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
.setDevice("cuda:1")
|
||||
.fit(trainingDf)
|
||||
}
|
||||
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
|
||||
assert(thrown.getMessage.contains("device given invalid value cuda:1"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,7 +31,8 @@ import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
|
||||
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
/**
|
||||
@ -72,7 +73,8 @@ private[scala] case class XGBoostExecutionParams(
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
featureNames: Option[Array[String]],
|
||||
featureTypes: Option[Array[String]]) {
|
||||
featureTypes: Option[Array[String]],
|
||||
runOnGpu: Boolean) {
|
||||
|
||||
private var rawParamMap: Map[String, Any] = _
|
||||
|
||||
@ -186,14 +188,15 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
.asInstanceOf[Boolean]
|
||||
|
||||
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
|
||||
// back-compatible with "gpu_hist"
|
||||
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||
Some("cuda")
|
||||
} else overridedParams.get("device").map(_.toString)
|
||||
val device: Option[String] = overridedParams.get("device").map(_.toString)
|
||||
val deviceIsGpu = device.exists(_ == "cuda")
|
||||
|
||||
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
|
||||
require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
|
||||
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
|
||||
|
||||
// back-compatible with "gpu_hist"
|
||||
val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu
|
||||
|
||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
@ -228,7 +231,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
device,
|
||||
isLocal,
|
||||
featureNames,
|
||||
featureTypes
|
||||
featureTypes,
|
||||
runOnGpu
|
||||
)
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
@ -253,7 +257,132 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
)
|
||||
}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
/**
|
||||
* A trait to manage stage-level scheduling
|
||||
*/
|
||||
private[spark] trait XGBoostStageLevel extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
|
||||
val master = conf.get("spark.master")
|
||||
master != null && (master.startsWith("spark://") || master.startsWith("local-cluster"))
|
||||
}
|
||||
|
||||
/**
|
||||
* To determine if stage-level scheduling should be skipped according to the spark version
|
||||
* and spark configurations
|
||||
*
|
||||
* @param sparkVersion spark version
|
||||
* @param runOnGpu if xgboost training run on GPUs
|
||||
* @param conf spark configurations
|
||||
* @return Boolean to skip stage-level scheduling or not
|
||||
*/
|
||||
private[spark] def skipStageLevelScheduling(
|
||||
sparkVersion: String,
|
||||
runOnGpu: Boolean,
|
||||
conf: SparkConf): Boolean = {
|
||||
if (runOnGpu) {
|
||||
if (sparkVersion < "3.4.0") {
|
||||
logger.info("Stage-level scheduling in xgboost requires spark version 3.4.0+")
|
||||
return true
|
||||
}
|
||||
|
||||
if (!isStandaloneOrLocalCluster(conf)) {
|
||||
logger.info("Stage-level scheduling in xgboost requires spark standalone or " +
|
||||
"local-cluster mode")
|
||||
return true
|
||||
}
|
||||
|
||||
val executorCores = conf.getInt("spark.executor.cores", -1)
|
||||
val executorGpus = conf.getInt("spark.executor.resource.gpu.amount", -1)
|
||||
if (executorCores == -1 || executorGpus == -1) {
|
||||
logger.info("Stage-level scheduling in xgboost requires spark.executor.cores, " +
|
||||
"spark.executor.resource.gpu.amount to be set.")
|
||||
return true
|
||||
}
|
||||
|
||||
if (executorCores == 1) {
|
||||
logger.info("Stage-level scheduling in xgboost requires spark.executor.cores > 1")
|
||||
return true
|
||||
}
|
||||
|
||||
if (executorGpus > 1) {
|
||||
logger.info("Stage-level scheduling in xgboost will not work " +
|
||||
"when spark.executor.resource.gpu.amount > 1")
|
||||
return true
|
||||
}
|
||||
|
||||
val taskGpuAmount = conf.getDouble("spark.task.resource.gpu.amount", -1.0).toFloat
|
||||
|
||||
if (taskGpuAmount == -1.0) {
|
||||
// The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
||||
// but with stage-level scheduling, we can make training task grab the gpu.
|
||||
return false
|
||||
}
|
||||
|
||||
if (taskGpuAmount == executorGpus.toFloat) {
|
||||
// spark.executor.resource.gpu.amount = spark.task.resource.gpu.amount
|
||||
// results in only 1 task running at a time, which may cause perf issue.
|
||||
return true
|
||||
}
|
||||
// We can enable stage-level scheduling
|
||||
false
|
||||
} else true // Skip stage-level scheduling for cpu training.
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to modify the task resources so that only one task can be executed
|
||||
* on a single executor simultaneously.
|
||||
*
|
||||
* @param sc the spark context
|
||||
* @param rdd which rdd to be applied with new resource profile
|
||||
* @return the original rdd or the changed rdd
|
||||
*/
|
||||
private[spark] def tryStageLevelScheduling(
|
||||
sc: SparkContext,
|
||||
xgbExecParams: XGBoostExecutionParams,
|
||||
rdd: RDD[(Booster, Map[String, Array[Float]])]
|
||||
): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
|
||||
val conf = sc.getConf
|
||||
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
|
||||
return rdd
|
||||
}
|
||||
|
||||
// Ensure executor_cores is not None
|
||||
val executor_cores = conf.getInt("spark.executor.cores", -1)
|
||||
if (executor_cores == -1) {
|
||||
throw new RuntimeException("Wrong spark.executor.cores")
|
||||
}
|
||||
|
||||
// Spark-rapids is a GPU-acceleration project for Spark SQL.
|
||||
// When spark-rapids is enabled, we prevent concurrent execution of other ETL tasks
|
||||
// that utilize GPUs alongside training tasks in order to avoid GPU out-of-memory errors.
|
||||
val spark_plugins = conf.get("spark.plugins", " ")
|
||||
val spark_rapids_sql_enabled = conf.get("spark.rapids.sql.enabled", "true")
|
||||
|
||||
// Determine the number of cores required for each task.
|
||||
val task_cores = if (spark_plugins.contains("com.nvidia.spark.SQLPlugin") &&
|
||||
spark_rapids_sql_enabled.toLowerCase == "true") {
|
||||
executor_cores
|
||||
} else {
|
||||
(executor_cores / 2) + 1
|
||||
}
|
||||
|
||||
// Each training task requires cpu cores > total executor cores//2 + 1 to
|
||||
// ensure tasks are sent to different executors.
|
||||
// Note: We cannot use GPUs to limit concurrent tasks
|
||||
// due to https://issues.apache.org/jira/browse/SPARK-45527.
|
||||
val task_gpus = 1.0
|
||||
val treqs = new TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
|
||||
val rp = new ResourceProfileBuilder().require(treqs).build()
|
||||
|
||||
logger.info(s"XGBoost training tasks require the resource(cores=$task_cores, gpu=$task_gpus).")
|
||||
rdd.withResources(rp)
|
||||
}
|
||||
}
|
||||
|
||||
object XGBoost extends XGBoostStageLevel {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
def getGPUAddrFromResources: Int = {
|
||||
@ -315,7 +444,7 @@ object XGBoost extends Serializable {
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
var params = xgbExecutionParam.toMap
|
||||
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
|
||||
if (xgbExecutionParam.runOnGpu) {
|
||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||
// For local mode, force gpu id to primary device
|
||||
0
|
||||
@ -413,10 +542,12 @@ object XGBoost extends Serializable {
|
||||
|
||||
}}
|
||||
|
||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
|
||||
boostersAndMetrics)
|
||||
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0)
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
if (trackerReturnVal != 0) {
|
||||
|
||||
@ -154,11 +154,13 @@ private[spark] trait BoosterParams extends Params {
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
|
||||
/**
|
||||
* The device for running XGBoost algorithms, options: cpu, cuda
|
||||
*/
|
||||
final val device = new Param[String](
|
||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
|
||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
|
||||
(value: String) => BoosterParams.supportedDevices.contains(value)
|
||||
)
|
||||
|
||||
final def getDevice: String = $(device)
|
||||
@ -288,4 +290,6 @@ private[scala] object BoosterParams {
|
||||
val supportedSampleType = HashSet("uniform", "weighted")
|
||||
|
||||
val supportedNormalizeType = HashSet("tree", "forest")
|
||||
|
||||
val supportedDevices = HashSet("cpu", "cuda")
|
||||
}
|
||||
|
||||
@ -0,0 +1,150 @@
|
||||
/*
|
||||
Copyright (c) 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class XGBoostSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
// Do not create spark context
|
||||
override def beforeEach(): Unit = {}
|
||||
|
||||
test("XGBoost execution parameters") {
|
||||
var xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
||||
Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1), sc)
|
||||
.buildXGBRuntimeParams
|
||||
assert(!xgbExecutionParams.runOnGpu)
|
||||
|
||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
||||
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
|
||||
.buildXGBRuntimeParams
|
||||
assert(xgbExecutionParams.runOnGpu)
|
||||
|
||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
||||
Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1), sc)
|
||||
.buildXGBRuntimeParams
|
||||
assert(xgbExecutionParams.runOnGpu)
|
||||
|
||||
xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
||||
Map("device" -> "cuda", "tree_method" -> "gpu_hist",
|
||||
"num_workers" -> 1, "num_round" -> 1), sc)
|
||||
.buildXGBRuntimeParams
|
||||
assert(xgbExecutionParams.runOnGpu)
|
||||
}
|
||||
|
||||
test("skip stage-level scheduling") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("spark://foo")
|
||||
.set("spark.executor.cores", "12")
|
||||
.set("spark.task.cpus", "1")
|
||||
.set("spark.executor.resource.gpu.amount", "1")
|
||||
.set("spark.task.resource.gpu.amount", "0.08")
|
||||
|
||||
// the correct configurations should not skip stage-level scheduling
|
||||
assert(!XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, conf))
|
||||
|
||||
// spark version < 3.4.0
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.3.0", runOnGpu = true, conf))
|
||||
|
||||
// not run on GPU
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = false, conf))
|
||||
|
||||
// spark.executor.cores is not set
|
||||
var badConf = conf.clone().remove("spark.executor.cores")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// spark.executor.cores=1
|
||||
badConf = conf.clone().set("spark.executor.cores", "1")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// spark.executor.resource.gpu.amount is not set
|
||||
badConf = conf.clone().remove("spark.executor.resource.gpu.amount")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// spark.executor.resource.gpu.amount>1
|
||||
badConf = conf.clone().set("spark.executor.resource.gpu.amount", "2")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// spark.task.resource.gpu.amount is not set
|
||||
badConf = conf.clone().remove("spark.task.resource.gpu.amount")
|
||||
assert(!XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// spark.task.resource.gpu.amount=1
|
||||
badConf = conf.clone().set("spark.task.resource.gpu.amount", "1")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// yarn
|
||||
badConf = conf.clone().setMaster("yarn")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
|
||||
// k8s
|
||||
badConf = conf.clone().setMaster("k8s://")
|
||||
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
|
||||
}
|
||||
|
||||
|
||||
object FakedXGBoost extends XGBoostStageLevel {
|
||||
|
||||
// Do not skip stage-level scheduling for testing purposes.
|
||||
override private[spark] def skipStageLevelScheduling(
|
||||
sparkVersion: String,
|
||||
runOnGpu: Boolean,
|
||||
conf: SparkConf) = false
|
||||
}
|
||||
|
||||
test("try stage-level scheduling without spark-rapids") {
|
||||
|
||||
val builder = SparkSession.builder()
|
||||
.master(s"local-cluster[1, 4, 1024]")
|
||||
.appName("XGBoostSuite")
|
||||
.config("spark.ui.enabled", false)
|
||||
.config("spark.driver.memory", "512m")
|
||||
.config("spark.barrier.sync.timeout", 10)
|
||||
.config("spark.task.cpus", 1)
|
||||
.config("spark.executor.cores", 4)
|
||||
.config("spark.executor.resource.gpu.amount", 1)
|
||||
.config("spark.task.resource.gpu.amount", 0.25)
|
||||
|
||||
val ss = builder.getOrCreate()
|
||||
|
||||
try {
|
||||
val df = ss.range(1, 10)
|
||||
val rdd = df.rdd
|
||||
|
||||
val xgbExecutionParams = new XGBoostExecutionParamsFactory(
|
||||
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
|
||||
.buildXGBRuntimeParams
|
||||
assert(xgbExecutionParams.runOnGpu)
|
||||
|
||||
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, xgbExecutionParams,
|
||||
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
|
||||
|
||||
val taskResources = finalRDD.getResourceProfile().taskResources
|
||||
assert(taskResources.contains("cpus"))
|
||||
assert(taskResources.get("cpus").get.amount == 3)
|
||||
|
||||
assert(taskResources.contains("gpu"))
|
||||
assert(taskResources.get("gpu").get.amount == 1.0)
|
||||
} finally {
|
||||
ss.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user