[jvm-packages] delete all constraints from spark layer about obj and eval metrics and handle error in jvm layer (#4560)

* temp

* prediction part

* remove supported*

* add for test

* fix param name

* add rabit

* update rabit

* return value of rabit init

* eliminate compilation warnings

* update rabit

* shutdown

* update rabit again

* check sparkcontext shutdown

* fix logic

* sleep

* fix tests

* test with relaxed threshold

* create new thread each time

* stop for job quitting

* udpate rabit

* update rabit

* update rabit

* update git modules
This commit is contained in:
Nan Zhu
2019-06-27 08:47:37 -07:00
committed by GitHub
parent dd01f7c4f5
commit abffbe014e
11 changed files with 143 additions and 51 deletions

View File

@@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
@@ -153,9 +153,11 @@ object XGBoost extends Serializable {
}
val taskId = TaskContext.getPartitionId().toString
rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
try {
Rabit.init(rabitEnv)
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0)
val overridedParams = if (numEarlyStoppingRounds > 0 &&
@@ -176,6 +178,10 @@ object XGBoost extends Serializable {
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
watches.delete()
@@ -467,6 +473,12 @@ object XGBoost extends Serializable {
tracker.stop()
}
}.last
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
throw t
} finally {
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
transformedTrainingData)

View File

@@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava)
}

View File

@@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava)
}

View File

@@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
* default: reg:squarederror
*/
final val objective = new Param[String](this, "objective", "objective function used for " +
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
(value: String) => LearningTaskParams.supportedObjective.contains(value))
final val objective = new Param[String](this, "objective",
"objective function used for training")
final def getObjective: String = $(objective)
@@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
*/
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
"validation data, a default metric will be assigned according to objective " +
"(rmse for regression, and error for classification, mean average precision for ranking), " +
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
"(rmse for regression, and error for classification, mean average precision for ranking)")
final def getEvalMetric: String = $(evalMetric)
@@ -106,9 +103,6 @@ private[spark] trait LearningTaskParams extends Params {
}
private[spark] object LearningTaskParams {
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
val supportedObjectiveType = HashSet("regression", "classification")
@@ -116,6 +110,4 @@ private[spark] object LearningTaskParams {
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
"mlogloss", "gamma-deviance")
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
}

View File

@@ -17,6 +17,7 @@
package org.apache.spark
import java.net.URL
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.commons.logging.LogFactory
@@ -123,18 +124,30 @@ private[spark] class TaskFailedListener extends SparkListener {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext")
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
}
}
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
TaskFailedListener.startedSparkContextKiller()
case _ =>
}
}
}
object TaskFailedListener {
var killerStarted = false
private def startedSparkContextKiller(): Unit = this.synchronized {
if (!killerStarted) {
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
}
}
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
killerStarted = true
}
}
}