[jvm-packages] support stage-level scheduling (#9775)

This commit is contained in:
Bobby Wang 2023-11-14 08:59:45 +08:00 committed by GitHub
parent 162da7b52b
commit 36a552ac98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 13 deletions

View File

@ -206,7 +206,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
.setDevice("cuda:1")
.fit(trainingDf)
}
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
assert(thrown.getMessage.contains("device given invalid value cuda:1"))
}
}
}

View File

@ -31,7 +31,8 @@ import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.sql.SparkSession
/**
@ -72,7 +73,8 @@ private[scala] case class XGBoostExecutionParams(
device: Option[String],
isLocal: Boolean,
featureNames: Option[Array[String]],
featureTypes: Option[Array[String]]) {
featureTypes: Option[Array[String]],
runOnGpu: Boolean) {
private var rawParamMap: Map[String, Any] = _
@ -186,14 +188,15 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.asInstanceOf[Boolean]
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
// back-compatible with "gpu_hist"
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
Some("cuda")
} else overridedParams.get("device").map(_.toString)
val device: Option[String] = overridedParams.get("device").map(_.toString)
val deviceIsGpu = device.exists(_ == "cuda")
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
// back-compatible with "gpu_hist"
val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu
val trackerConf = overridedParams.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
@ -228,7 +231,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
device,
isLocal,
featureNames,
featureTypes
featureTypes,
runOnGpu
)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
@ -253,7 +257,132 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
)
}
object XGBoost extends Serializable {
/**
* A trait to manage stage-level scheduling
*/
private[spark] trait XGBoostStageLevel extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
val master = conf.get("spark.master")
master != null && (master.startsWith("spark://") || master.startsWith("local-cluster"))
}
/**
* To determine if stage-level scheduling should be skipped according to the spark version
* and spark configurations
*
* @param sparkVersion spark version
* @param runOnGpu if xgboost training run on GPUs
* @param conf spark configurations
* @return Boolean to skip stage-level scheduling or not
*/
private[spark] def skipStageLevelScheduling(
sparkVersion: String,
runOnGpu: Boolean,
conf: SparkConf): Boolean = {
if (runOnGpu) {
if (sparkVersion < "3.4.0") {
logger.info("Stage-level scheduling in xgboost requires spark version 3.4.0+")
return true
}
if (!isStandaloneOrLocalCluster(conf)) {
logger.info("Stage-level scheduling in xgboost requires spark standalone or " +
"local-cluster mode")
return true
}
val executorCores = conf.getInt("spark.executor.cores", -1)
val executorGpus = conf.getInt("spark.executor.resource.gpu.amount", -1)
if (executorCores == -1 || executorGpus == -1) {
logger.info("Stage-level scheduling in xgboost requires spark.executor.cores, " +
"spark.executor.resource.gpu.amount to be set.")
return true
}
if (executorCores == 1) {
logger.info("Stage-level scheduling in xgboost requires spark.executor.cores > 1")
return true
}
if (executorGpus > 1) {
logger.info("Stage-level scheduling in xgboost will not work " +
"when spark.executor.resource.gpu.amount > 1")
return true
}
val taskGpuAmount = conf.getDouble("spark.task.resource.gpu.amount", -1.0).toFloat
if (taskGpuAmount == -1.0) {
// The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
// but with stage-level scheduling, we can make training task grab the gpu.
return false
}
if (taskGpuAmount == executorGpus.toFloat) {
// spark.executor.resource.gpu.amount = spark.task.resource.gpu.amount
// results in only 1 task running at a time, which may cause perf issue.
return true
}
// We can enable stage-level scheduling
false
} else true // Skip stage-level scheduling for cpu training.
}
/**
* Attempt to modify the task resources so that only one task can be executed
* on a single executor simultaneously.
*
* @param sc the spark context
* @param rdd which rdd to be applied with new resource profile
* @return the original rdd or the changed rdd
*/
private[spark] def tryStageLevelScheduling(
sc: SparkContext,
xgbExecParams: XGBoostExecutionParams,
rdd: RDD[(Booster, Map[String, Array[Float]])]
): RDD[(Booster, Map[String, Array[Float]])] = {
val conf = sc.getConf
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
return rdd
}
// Ensure executor_cores is not None
val executor_cores = conf.getInt("spark.executor.cores", -1)
if (executor_cores == -1) {
throw new RuntimeException("Wrong spark.executor.cores")
}
// Spark-rapids is a GPU-acceleration project for Spark SQL.
// When spark-rapids is enabled, we prevent concurrent execution of other ETL tasks
// that utilize GPUs alongside training tasks in order to avoid GPU out-of-memory errors.
val spark_plugins = conf.get("spark.plugins", " ")
val spark_rapids_sql_enabled = conf.get("spark.rapids.sql.enabled", "true")
// Determine the number of cores required for each task.
val task_cores = if (spark_plugins.contains("com.nvidia.spark.SQLPlugin") &&
spark_rapids_sql_enabled.toLowerCase == "true") {
executor_cores
} else {
(executor_cores / 2) + 1
}
// Each training task requires cpu cores > total executor cores//2 + 1 to
// ensure tasks are sent to different executors.
// Note: We cannot use GPUs to limit concurrent tasks
// due to https://issues.apache.org/jira/browse/SPARK-45527.
val task_gpus = 1.0
val treqs = new TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
val rp = new ResourceProfileBuilder().require(treqs).build()
logger.info(s"XGBoost training tasks require the resource(cores=$task_cores, gpu=$task_gpus).")
rdd.withResources(rp)
}
}
object XGBoost extends XGBoostStageLevel {
private val logger = LogFactory.getLog("XGBoostSpark")
def getGPUAddrFromResources: Int = {
@ -315,7 +444,7 @@ object XGBoost extends Serializable {
val externalCheckpointParams = xgbExecutionParam.checkpointParam
var params = xgbExecutionParam.toMap
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
if (xgbExecutionParam.runOnGpu) {
val gpuId = if (xgbExecutionParam.isLocal) {
// For local mode, force gpu id to primary device
0
@ -413,10 +542,12 @@ object XGBoost extends Serializable {
}}
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0)
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
if (trackerReturnVal != 0) {

View File

@ -154,11 +154,13 @@ private[spark] trait BoosterParams extends Params {
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
/**
* The device for running XGBoost algorithms, options: cpu, cuda
*/
final val device = new Param[String](
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
(value: String) => BoosterParams.supportedDevices.contains(value)
)
final def getDevice: String = $(device)
@ -288,4 +290,6 @@ private[scala] object BoosterParams {
val supportedSampleType = HashSet("uniform", "weighted")
val supportedNormalizeType = HashSet("tree", "forest")
val supportedDevices = HashSet("cpu", "cuda")
}

View File

@ -0,0 +1,150 @@
/*
Copyright (c) 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.scalatest.funsuite.AnyFunSuite
class XGBoostSuite extends AnyFunSuite with PerTest {
// Do not create spark context
override def beforeEach(): Unit = {}
test("XGBoost execution parameters") {
var xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(!xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "tree_method" -> "gpu_hist",
"num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
}
test("skip stage-level scheduling") {
val conf = new SparkConf()
.setMaster("spark://foo")
.set("spark.executor.cores", "12")
.set("spark.task.cpus", "1")
.set("spark.executor.resource.gpu.amount", "1")
.set("spark.task.resource.gpu.amount", "0.08")
// the correct configurations should not skip stage-level scheduling
assert(!XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, conf))
// spark version < 3.4.0
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.3.0", runOnGpu = true, conf))
// not run on GPU
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = false, conf))
// spark.executor.cores is not set
var badConf = conf.clone().remove("spark.executor.cores")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// spark.executor.cores=1
badConf = conf.clone().set("spark.executor.cores", "1")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// spark.executor.resource.gpu.amount is not set
badConf = conf.clone().remove("spark.executor.resource.gpu.amount")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// spark.executor.resource.gpu.amount>1
badConf = conf.clone().set("spark.executor.resource.gpu.amount", "2")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// spark.task.resource.gpu.amount is not set
badConf = conf.clone().remove("spark.task.resource.gpu.amount")
assert(!XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// spark.task.resource.gpu.amount=1
badConf = conf.clone().set("spark.task.resource.gpu.amount", "1")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// yarn
badConf = conf.clone().setMaster("yarn")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
// k8s
badConf = conf.clone().setMaster("k8s://")
assert(XGBoost.skipStageLevelScheduling(sparkVersion = "3.4.0", runOnGpu = true, badConf))
}
object FakedXGBoost extends XGBoostStageLevel {
// Do not skip stage-level scheduling for testing purposes.
override private[spark] def skipStageLevelScheduling(
sparkVersion: String,
runOnGpu: Boolean,
conf: SparkConf) = false
}
test("try stage-level scheduling without spark-rapids") {
val builder = SparkSession.builder()
.master(s"local-cluster[1, 4, 1024]")
.appName("XGBoostSuite")
.config("spark.ui.enabled", false)
.config("spark.driver.memory", "512m")
.config("spark.barrier.sync.timeout", 10)
.config("spark.task.cpus", 1)
.config("spark.executor.cores", 4)
.config("spark.executor.resource.gpu.amount", 1)
.config("spark.task.resource.gpu.amount", 0.25)
val ss = builder.getOrCreate()
try {
val df = ss.range(1, 10)
val rdd = df.rdd
val xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, xgbExecutionParams,
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
val taskResources = finalRDD.getResourceProfile().taskResources
assert(taskResources.contains("cpus"))
assert(taskResources.get("cpus").get.amount == 3)
assert(taskResources.contains("gpu"))
assert(taskResources.get("gpu").get.amount == 1.0)
} finally {
ss.stop()
}
}
}