[jvm-packages] fix executor crashing issue when transforming on xgboost4j-spark-gpu (#8025)

* [jvm-packages] fix executor crashing issue when transforming on xgboost4j-spark-gpu

the API XGBoosterSetParam is not thread-safe. Dring the phase of transforming,
XGBoost runs several transforming tasks at a time, and each of them will set
the "gpu_id" and "predictor" parameters, so if several tasks (multi-threads)
all XGBoosterSetParam simultaneously, it may cause the memory to be corrupted
and cause SIGSEGV.

This PR first get the booster from broadcast and set to the correct gpu_id
and predictor, and then all transforming taskes will use the same booster to
do the transforming.
This commit is contained in:
Bobby Wang 2022-06-24 01:18:41 +08:00 committed by GitHub
parent f0c1b842bf
commit a68580e2a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 30 deletions

View File

@ -27,7 +27,6 @@ import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor} import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -90,6 +89,11 @@ class GpuPreXGBoost extends PreXGBoostProvider {
} }
} }
class BoosterFlag extends Serializable {
// indicate if the GPU parameters are set.
var isGpuParamsSet = false
}
object GpuPreXGBoost extends PreXGBoostProvider { object GpuPreXGBoost extends PreXGBoostProvider {
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
@ -187,9 +191,9 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// predict and turn to Row // predict and turn to Row
val predictFunc = val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { (booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm) m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr, m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr) predLeafItr, predContribItr)
} }
@ -218,9 +222,9 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// predict and turn to Row // predict and turn to Row
val predictFunc = val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { (booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) = val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm) m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr,
predContribItr) predContribItr)
} }
@ -248,6 +252,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val bOrigSchema = sc.broadcast(dataset.schema) val bOrigSchema = sc.broadcast(dataset.schema)
val bRowSchema = sc.broadcast(schema) val bRowSchema = sc.broadcast(schema)
val bBooster = sc.broadcast(booster) val bBooster = sc.broadcast(booster)
val bBoosterFlag = sc.broadcast(new BoosterFlag)
// Small vars so don't need to broadcast them // Small vars so don't need to broadcast them
val isLocal = sc.isLocal val isLocal = sc.isLocal
@ -259,6 +264,31 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// UnsafeProjection is not serializable so do it on the executor side // UnsafeProjection is not serializable so do it on the executor side
val toUnsafe = UnsafeProjection.create(bOrigSchema.value) val toUnsafe = UnsafeProjection.create(bOrigSchema.value)
// booster is visible for all spark tasks in the same executor
val booster = bBooster.value
val boosterFlag = bBoosterFlag.value
synchronized {
// there are two kind of race conditions,
// 1. multi-taskes set parameters at a time
// 2. one task sets parameter and another task reads the parameter
// both of them can cause potential un-expected behavior, moreover,
// it may cause executor crash
// So add synchronized to allow only one task to set parameter if it is not set.
// and rely on BlockManager to ensure the same booster only be called once to
// set parameter.
if (!boosterFlag.isGpuParamsSet) {
// set some params of gpu related to booster
// - gpu id
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("gpu_id", gpuId.toString)
booster.setParam("predictor", "gpu_predictor")
logger.info("GPU transform on device: " + gpuId)
boosterFlag.isGpuParamsSet = true;
}
}
// Iterator on Row // Iterator on Row
new Iterator[Row] { new Iterator[Row] {
// Convert InternalRow to Row // Convert InternalRow to Row
@ -271,14 +301,6 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// Iterator on Row // Iterator on Row
var iter: Iterator[Row] = null var iter: Iterator[Row] = null
// set some params of gpu related to booster
// - gpu id
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
bBooster.value.setParam("gpu_id", gpuId.toString)
bBooster.value.setParam("predictor", "gpu_predictor")
logger.info("GPU transform on device: " + gpuId)
TaskContext.get().addTaskCompletionListener[Unit](_ => { TaskContext.get().addTaskCompletionListener[Unit](_ => {
closeCurrentBatch() // close the last ColumnarBatch closeCurrentBatch() // close the last ColumnarBatch
}) })
@ -314,7 +336,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val rowIterator = currentBatch.rowIterator().asScala val rowIterator = currentBatch.rowIterator().asScala
.map(toUnsafe) .map(toUnsafe)
.map(converter(_)) .map(converter(_))
predictFunc(bBooster, dm, rowIterator) predictFunc(booster, dm, rowIterator)
} finally { } finally {
dm.delete() dm.delete()

View File

@ -201,9 +201,9 @@ object PreXGBoost extends PreXGBoostProvider {
val (xgbInput, featuresName) = m.vectorize(dataset) val (xgbInput, featuresName) = m.vectorize(dataset)
// predict and turn to Row // predict and turn to Row
val predictFunc = val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { (booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm) m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr, m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr) predLeafItr, predContribItr)
} }
@ -231,9 +231,9 @@ object PreXGBoost extends PreXGBoostProvider {
// predict and turn to Row // predict and turn to Row
val (xgbInput, featuresName) = m.vectorize(dataset) val (xgbInput, featuresName) = m.vectorize(dataset)
val predictFunc = val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { (booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) = val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm) m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr) m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr)
} }
@ -286,7 +286,7 @@ object PreXGBoost extends PreXGBoostProvider {
cacheInfo) cacheInfo)
try { try {
predictFunc(bBooster, dm, batchRow.iterator) predictFunc(bBooster.value, dm, batchRow.iterator)
} finally { } finally {
batchCnt += 1 batchCnt += 1
dm.delete() dm.delete()

View File

@ -20,7 +20,6 @@ import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._ import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
@ -329,26 +328,26 @@ class XGBoostClassificationModel private[ml](
} }
} }
private[scala] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
Array[Iterator[Row]] = { Array[Iterator[Row]] = {
val rawPredictionItr = { val rawPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)). booster.predict(dm, outPutMargin = true, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} }
val probabilityItr = { val probabilityItr = {
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)). booster.predict(dm, outPutMargin = false, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} }
val predLeafItr = { val predLeafItr = {
if (isDefined(leafPredictionCol)) { if (isDefined(leafPredictionCol)) {
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator booster.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
} else { } else {
Iterator() Iterator()
} }
} }
val predContribItr = { val predContribItr = {
if (isDefined(contribPredictionCol)) { if (isDefined(contribPredictionCol)) {
broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator booster.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
} else { } else {
Iterator() Iterator()
} }

View File

@ -30,7 +30,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter} import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
@ -298,14 +297,14 @@ class XGBoostRegressionModel private[ml] (
} }
} }
private[scala] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix): private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
Array[Iterator[Row]] = { Array[Iterator[Row]] = {
val originalPredictionItr = { val originalPredictionItr = {
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator booster.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
} }
val predLeafItr = { val predLeafItr = {
if (isDefined(leafPredictionCol)) { if (isDefined(leafPredictionCol)) {
booster.value.predictLeaf(dm, $(treeLimit)). booster.predictLeaf(dm, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} else { } else {
Iterator() Iterator()
@ -313,7 +312,7 @@ class XGBoostRegressionModel private[ml] (
} }
val predContribItr = { val predContribItr = {
if (isDefined(contribPredictionCol)) { if (isDefined(contribPredictionCol)) {
booster.value.predictContrib(dm, $(treeLimit)). booster.predictContrib(dm, $(treeLimit)).
map(Row(_)).iterator map(Row(_)).iterator
} else { } else {
Iterator() Iterator()