[jvm-packages] delete all constraints from spark layer about obj and eval metrics and handle error in jvm layer (#4560)
* temp * prediction part * remove supported* * add for test * fix param name * add rabit * update rabit * return value of rabit init * eliminate compilation warnings * update rabit * shutdown * update rabit again * check sparkcontext shutdown * fix logic * sleep * fix tests * test with relaxed threshold * create new thread each time * stop for job quitting * udpate rabit * update rabit * update rabit * update git modules
This commit is contained in:
parent
dd01f7c4f5
commit
abffbe014e
@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
@ -153,9 +153,11 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
Rabit.init(rabitEnv)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
||||
@ -176,6 +178,10 @@ object XGBoost extends Serializable {
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} catch {
|
||||
case xgbException: XGBoostError =>
|
||||
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
watches.delete()
|
||||
@ -467,6 +473,12 @@ object XGBoost extends Serializable {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
logger.error("the job was aborted due to ", t)
|
||||
trainingData.sparkContext.stop()
|
||||
throw t
|
||||
} finally {
|
||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||
transformedTrainingData)
|
||||
|
||||
@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
|
||||
@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
|
||||
@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
|
||||
* default: reg:squarederror
|
||||
*/
|
||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||
final val objective = new Param[String](this, "objective",
|
||||
"objective function used for training")
|
||||
|
||||
final def getObjective: String = $(objective)
|
||||
|
||||
@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
*/
|
||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
||||
"validation data, a default metric will be assigned according to objective " +
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking), " +
|
||||
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking)")
|
||||
|
||||
final def getEvalMetric: String = $(evalMetric)
|
||||
|
||||
@ -106,9 +103,6 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
|
||||
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
|
||||
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
@ -116,6 +110,4 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
|
||||
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
package org.apache.spark
|
||||
|
||||
import java.net.URL
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
@ -123,6 +124,18 @@ private[spark] class TaskFailedListener extends SparkListener {
|
||||
case taskEndReason: TaskFailedReason =>
|
||||
logger.error(s"Training Task Failed during XGBoost Training: " +
|
||||
s"$taskEndReason, stopping SparkContext")
|
||||
TaskFailedListener.startedSparkContextKiller()
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TaskFailedListener {
|
||||
|
||||
var killerStarted = false
|
||||
|
||||
private def startedSparkContextKiller(): Unit = this.synchronized {
|
||||
if (!killerStarted) {
|
||||
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
|
||||
// in a separate thread
|
||||
val sparkContextKiller = new Thread() {
|
||||
@ -134,7 +147,7 @@ private[spark] class TaskFailedListener extends SparkListener {
|
||||
}
|
||||
sparkContextKiller.setDaemon(true)
|
||||
sparkContextKiller.start()
|
||||
case _ =>
|
||||
killerStarted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,81 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
assert(xgb.getObjectiveType === "classification")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
private def waitForSparkContextShutdown(): Unit = {
|
||||
var totalWaitedTime = 0L
|
||||
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
|
||||
Thread.sleep(10000)
|
||||
totalWaitedTime += 10000
|
||||
}
|
||||
assert(ss.sparkContext.isStopped === true)
|
||||
}
|
||||
|
||||
test("fail training elegantly with unsupported objective function") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
try {
|
||||
val model = xgb.fit(trainingDF)
|
||||
} catch {
|
||||
case e: Throwable => // swallow anything
|
||||
} finally {
|
||||
waitForSparkContextShutdown()
|
||||
}
|
||||
}
|
||||
|
||||
test("fail training elegantly with unsupported eval metrics") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
try {
|
||||
val model = xgb.fit(trainingDF)
|
||||
} catch {
|
||||
case e: Throwable => // swallow anything
|
||||
} finally {
|
||||
waitForSparkContextShutdown()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
@ -50,6 +50,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
cleanExternalCache(currentSession.sparkContext.appName)
|
||||
currentSession = null
|
||||
}
|
||||
TaskFailedListener.killerStarted = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
|
||||
class RabitSuite extends FunSuite with PerTest {
|
||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
|
||||
@ -160,23 +160,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
assert(xgb.getObjectiveType === "classification")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
test("multi class classification") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||
|
||||
@ -818,8 +818,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
argv.push_back(&args[i][0]);
|
||||
}
|
||||
|
||||
RabitInit(args.size(), dmlc::BeginPtr(argv));
|
||||
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
@ -829,8 +832,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
RabitFinalize();
|
||||
if (RabitFinalize()) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit a429748e244f67f6f144a697f3aa1b1978705b11
|
||||
Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1
|
||||
Loading…
x
Reference in New Issue
Block a user