[jvm-packages] delete all constraints from spark layer about obj and eval metrics and handle error in jvm layer (#4560)
* temp * prediction part * remove supported* * add for test * fix param name * add rabit * update rabit * return value of rabit init * eliminate compilation warnings * update rabit * shutdown * update rabit again * check sparkcontext shutdown * fix logic * sleep * fix tests * test with relaxed threshold * create new thread each time * stop for job quitting * udpate rabit * update rabit * update rabit * update git modules
This commit is contained in:
@@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
@@ -153,9 +153,11 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
Rabit.init(rabitEnv)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
||||
@@ -176,6 +178,10 @@ object XGBoost extends Serializable {
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} catch {
|
||||
case xgbException: XGBoostError =>
|
||||
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
watches.delete()
|
||||
@@ -467,6 +473,12 @@ object XGBoost extends Serializable {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
logger.error("the job was aborted due to ", t)
|
||||
trainingData.sparkContext.stop()
|
||||
throw t
|
||||
} finally {
|
||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||
transformedTrainingData)
|
||||
|
||||
@@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
|
||||
@@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
|
||||
* default: reg:squarederror
|
||||
*/
|
||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||
final val objective = new Param[String](this, "objective",
|
||||
"objective function used for training")
|
||||
|
||||
final def getObjective: String = $(objective)
|
||||
|
||||
@@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
*/
|
||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
||||
"validation data, a default metric will be assigned according to objective " +
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking), " +
|
||||
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking)")
|
||||
|
||||
final def getEvalMetric: String = $(evalMetric)
|
||||
|
||||
@@ -106,9 +103,6 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
|
||||
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
|
||||
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
@@ -116,6 +110,4 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
|
||||
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.apache.spark
|
||||
|
||||
import java.net.URL
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
@@ -123,18 +124,30 @@ private[spark] class TaskFailedListener extends SparkListener {
|
||||
case taskEndReason: TaskFailedReason =>
|
||||
logger.error(s"Training Task Failed during XGBoost Training: " +
|
||||
s"$taskEndReason, stopping SparkContext")
|
||||
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
|
||||
// in a separate thread
|
||||
val sparkContextKiller = new Thread() {
|
||||
override def run(): Unit = {
|
||||
LiveListenerBus.withinListenerThread.withValue(false) {
|
||||
SparkContext.getOrCreate().stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
sparkContextKiller.setDaemon(true)
|
||||
sparkContextKiller.start()
|
||||
TaskFailedListener.startedSparkContextKiller()
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TaskFailedListener {
|
||||
|
||||
var killerStarted = false
|
||||
|
||||
private def startedSparkContextKiller(): Unit = this.synchronized {
|
||||
if (!killerStarted) {
|
||||
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
|
||||
// in a separate thread
|
||||
val sparkContextKiller = new Thread() {
|
||||
override def run(): Unit = {
|
||||
LiveListenerBus.withinListenerThread.withValue(false) {
|
||||
SparkContext.getOrCreate().stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
sparkContextKiller.setDaemon(true)
|
||||
sparkContextKiller.start()
|
||||
killerStarted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user