[jvm-packages] delete all constraints from spark layer about obj and eval metrics and handle error in jvm layer (#4560)

* temp

* prediction part

* remove supported*

* add for test

* fix param name

* add rabit

* update rabit

* return value of rabit init

* eliminate compilation warnings

* update rabit

* shutdown

* update rabit again

* check sparkcontext shutdown

* fix logic

* sleep

* fix tests

* test with relaxed threshold

* create new thread each time

* stop for job quitting

* udpate rabit

* update rabit

* update rabit

* update git modules
This commit is contained in:
Nan Zhu 2019-06-27 08:47:37 -07:00 committed by GitHub
parent dd01f7c4f5
commit abffbe014e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 143 additions and 51 deletions

View File

@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
@ -153,9 +153,11 @@ object XGBoost extends Serializable {
} }
val taskId = TaskContext.getPartitionId().toString val taskId = TaskContext.getPartitionId().toString
rabitEnv.put("DMLC_TASK_ID", taskId) rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv) rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
try { try {
Rabit.init(rabitEnv)
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0) .map(_.toString.toInt).getOrElse(0)
val overridedParams = if (numEarlyStoppingRounds > 0 && val overridedParams = if (numEarlyStoppingRounds > 0 &&
@ -176,6 +178,10 @@ object XGBoost extends Serializable {
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster) earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
throw xgbException
} finally { } finally {
Rabit.shutdown() Rabit.shutdown()
watches.delete() watches.delete()
@ -467,6 +473,12 @@ object XGBoost extends Serializable {
tracker.stop() tracker.stop()
} }
}.last }.last
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
throw t
} finally { } finally {
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
transformedTrainingData) transformedTrainingData)

View File

@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) { if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
} }

View File

@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) { if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
} }

View File

@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma. * count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
* default: reg:squarederror * default: reg:squarederror
*/ */
final val objective = new Param[String](this, "objective", "objective function used for " + final val objective = new Param[String](this, "objective",
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}", "objective function used for training")
(value: String) => LearningTaskParams.supportedObjective.contains(value))
final def getObjective: String = $(objective) final def getObjective: String = $(objective)
@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
*/ */
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " + final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
"validation data, a default metric will be assigned according to objective " + "validation data, a default metric will be assigned according to objective " +
"(rmse for regression, and error for classification, mean average precision for ranking), " + "(rmse for regression, and error for classification, mean average precision for ranking)")
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
final def getEvalMetric: String = $(evalMetric) final def getEvalMetric: String = $(evalMetric)
@ -106,9 +103,6 @@ private[spark] trait LearningTaskParams extends Params {
} }
private[spark] object LearningTaskParams { private[spark] object LearningTaskParams {
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
val supportedObjectiveType = HashSet("regression", "classification") val supportedObjectiveType = HashSet("regression", "classification")
@ -116,6 +110,4 @@ private[spark] object LearningTaskParams {
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror", val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
"mlogloss", "gamma-deviance") "mlogloss", "gamma-deviance")
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
} }

View File

@ -17,6 +17,7 @@
package org.apache.spark package org.apache.spark
import java.net.URL import java.net.URL
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
@ -123,6 +124,18 @@ private[spark] class TaskFailedListener extends SparkListener {
case taskEndReason: TaskFailedReason => case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " + logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext") s"$taskEndReason, stopping SparkContext")
TaskFailedListener.startedSparkContextKiller()
case _ =>
}
}
}
object TaskFailedListener {
var killerStarted = false
private def startedSparkContextKiller(): Unit = this.synchronized {
if (!killerStarted) {
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it // Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread // in a separate thread
val sparkContextKiller = new Thread() { val sparkContextKiller = new Thread() {
@ -134,7 +147,7 @@ private[spark] class TaskFailedListener extends SparkListener {
} }
sparkContextKiller.setDaemon(true) sparkContextKiller.setDaemon(true)
sparkContextKiller.start() sparkContextKiller.start()
case _ => killerStarted = true
} }
} }
} }

View File

@ -0,0 +1,81 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}
private def waitForSparkContextShutdown(): Unit = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
Thread.sleep(10000)
totalWaitedTime += 10000
}
assert(ss.sparkContext.isStopped === true)
}
test("fail training elegantly with unsupported objective function") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
try {
val model = xgb.fit(trainingDF)
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
}
}
test("fail training elegantly with unsupported eval metrics") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
try {
val model = xgb.fit(trainingDF)
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
}
}
}

View File

@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.util.Random import scala.util.Random
trait PerTest extends BeforeAndAfterEach { self: FunSuite => trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
@ -50,6 +50,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
cleanExternalCache(currentSession.sparkContext.appName) cleanExternalCache(currentSession.sparkContext.appName)
currentSession = null currentSession = null
} }
TaskFailedListener.killerStarted = false
} }
} }

View File

@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite import org.scalatest.FunSuite
class RabitSuite extends FunSuite with PerTest { class RabitRobustnessSuite extends FunSuite with PerTest {
test("training with Scala-implemented Rabit tracker") { test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError() val eval = new EvalError()

View File

@ -160,23 +160,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(model.summary.validationObjectiveHistory.isEmpty) assert(model.summary.validationObjectiveHistory.isEmpty)
} }
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}
test("multi class classification") { test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,

View File

@ -818,8 +818,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
argv.push_back(&args[i][0]); argv.push_back(&args[i][0]);
} }
RabitInit(args.size(), dmlc::BeginPtr(argv)); if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
return 0; return 0;
} else {
return 1;
}
} }
/* /*
@ -829,8 +832,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
RabitFinalize(); if (RabitFinalize()) {
return 0; return 0;
} else {
return 1;
}
} }
/* /*

2
rabit

@ -1 +1 @@
Subproject commit a429748e244f67f6f144a697f3aa1b1978705b11 Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1