[jvm-packages] delete all constraints from spark layer about obj and eval metrics and handle error in jvm layer (#4560)
* temp * prediction part * remove supported* * add for test * fix param name * add rabit * update rabit * return value of rabit init * eliminate compilation warnings * update rabit * shutdown * update rabit again * check sparkcontext shutdown * fix logic * sleep * fix tests * test with relaxed threshold * create new thread each time * stop for job quitting * udpate rabit * update rabit * update rabit * update git modules
This commit is contained in:
parent
dd01f7c4f5
commit
abffbe014e
@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
|
|||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
@ -153,9 +153,11 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
Rabit.init(rabitEnv)
|
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
Rabit.init(rabitEnv)
|
||||||
|
|
||||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
||||||
@ -176,6 +178,10 @@ object XGBoost extends Serializable {
|
|||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
|
} catch {
|
||||||
|
case xgbException: XGBoostError =>
|
||||||
|
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
|
||||||
|
throw xgbException
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
watches.delete()
|
watches.delete()
|
||||||
@ -467,6 +473,12 @@ object XGBoost extends Serializable {
|
|||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
}.last
|
}.last
|
||||||
|
} catch {
|
||||||
|
case t: Throwable =>
|
||||||
|
// if the job was aborted due to an exception
|
||||||
|
logger.error("the job was aborted due to ", t)
|
||||||
|
trainingData.sparkContext.stop()
|
||||||
|
throw t
|
||||||
} finally {
|
} finally {
|
||||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||||
transformedTrainingData)
|
transformedTrainingData)
|
||||||
|
|||||||
@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
if (batchCnt == 0) {
|
if (batchCnt == 0) {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Array(
|
||||||
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||||
|
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
if (batchCnt == 0) {
|
if (batchCnt == 0) {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Array(
|
||||||
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
||||||
|
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
|
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
|
||||||
* default: reg:squarederror
|
* default: reg:squarederror
|
||||||
*/
|
*/
|
||||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
final val objective = new Param[String](this, "objective",
|
||||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
"objective function used for training")
|
||||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
|
||||||
|
|
||||||
final def getObjective: String = $(objective)
|
final def getObjective: String = $(objective)
|
||||||
|
|
||||||
@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
*/
|
*/
|
||||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
||||||
"validation data, a default metric will be assigned according to objective " +
|
"validation data, a default metric will be assigned according to objective " +
|
||||||
"(rmse for regression, and error for classification, mean average precision for ranking), " +
|
"(rmse for regression, and error for classification, mean average precision for ranking)")
|
||||||
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
|
||||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
|
||||||
|
|
||||||
final def getEvalMetric: String = $(evalMetric)
|
final def getEvalMetric: String = $(evalMetric)
|
||||||
|
|
||||||
@ -106,9 +103,6 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
|
|
||||||
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
|
|
||||||
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
|
||||||
|
|
||||||
val supportedObjectiveType = HashSet("regression", "classification")
|
val supportedObjectiveType = HashSet("regression", "classification")
|
||||||
|
|
||||||
@ -116,6 +110,4 @@ private[spark] object LearningTaskParams {
|
|||||||
|
|
||||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
|
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
|
||||||
"mlogloss", "gamma-deviance")
|
"mlogloss", "gamma-deviance")
|
||||||
|
|
||||||
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
package org.apache.spark
|
package org.apache.spark
|
||||||
|
|
||||||
import java.net.URL
|
import java.net.URL
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
@ -123,18 +124,30 @@ private[spark] class TaskFailedListener extends SparkListener {
|
|||||||
case taskEndReason: TaskFailedReason =>
|
case taskEndReason: TaskFailedReason =>
|
||||||
logger.error(s"Training Task Failed during XGBoost Training: " +
|
logger.error(s"Training Task Failed during XGBoost Training: " +
|
||||||
s"$taskEndReason, stopping SparkContext")
|
s"$taskEndReason, stopping SparkContext")
|
||||||
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
|
TaskFailedListener.startedSparkContextKiller()
|
||||||
// in a separate thread
|
|
||||||
val sparkContextKiller = new Thread() {
|
|
||||||
override def run(): Unit = {
|
|
||||||
LiveListenerBus.withinListenerThread.withValue(false) {
|
|
||||||
SparkContext.getOrCreate().stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sparkContextKiller.setDaemon(true)
|
|
||||||
sparkContextKiller.start()
|
|
||||||
case _ =>
|
case _ =>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object TaskFailedListener {
|
||||||
|
|
||||||
|
var killerStarted = false
|
||||||
|
|
||||||
|
private def startedSparkContextKiller(): Unit = this.synchronized {
|
||||||
|
if (!killerStarted) {
|
||||||
|
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
|
||||||
|
// in a separate thread
|
||||||
|
val sparkContextKiller = new Thread() {
|
||||||
|
override def run(): Unit = {
|
||||||
|
LiveListenerBus.withinListenerThread.withValue(false) {
|
||||||
|
SparkContext.getOrCreate().stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkContextKiller.setDaemon(true)
|
||||||
|
sparkContextKiller.start()
|
||||||
|
killerStarted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
|
||||||
|
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||||
|
|
||||||
|
test("XGBoost and Spark parameters synchronize correctly") {
|
||||||
|
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||||
|
"objective_type" -> "classification")
|
||||||
|
// from xgboost params to spark params
|
||||||
|
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||||
|
assert(xgb.getEta === 1.0)
|
||||||
|
assert(xgb.getObjective === "binary:logistic")
|
||||||
|
assert(xgb.getObjectiveType === "classification")
|
||||||
|
// from spark to xgboost params
|
||||||
|
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||||
|
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||||
|
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||||
|
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||||
|
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||||
|
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||||
|
}
|
||||||
|
|
||||||
|
private def waitForSparkContextShutdown(): Unit = {
|
||||||
|
var totalWaitedTime = 0L
|
||||||
|
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
|
||||||
|
Thread.sleep(10000)
|
||||||
|
totalWaitedTime += 10000
|
||||||
|
}
|
||||||
|
assert(ss.sparkContext.isStopped === true)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("fail training elegantly with unsupported objective function") {
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers)
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
try {
|
||||||
|
val model = xgb.fit(trainingDF)
|
||||||
|
} catch {
|
||||||
|
case e: Throwable => // swallow anything
|
||||||
|
} finally {
|
||||||
|
waitForSparkContextShutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("fail training elegantly with unsupported eval metrics") {
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
try {
|
||||||
|
val model = xgb.fit(trainingDF)
|
||||||
|
} catch {
|
||||||
|
case e: Throwable => // swallow anything
|
||||||
|
} finally {
|
||||||
|
waitForSparkContextShutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||||
@ -50,6 +50,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
cleanExternalCache(currentSession.sparkContext.appName)
|
cleanExternalCache(currentSession.sparkContext.appName)
|
||||||
currentSession = null
|
currentSession = null
|
||||||
}
|
}
|
||||||
|
TaskFailedListener.killerStarted = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext}
|
|||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
|
||||||
class RabitSuite extends FunSuite with PerTest {
|
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("training with Scala-implemented Rabit tracker") {
|
test("training with Scala-implemented Rabit tracker") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
|
|||||||
@ -160,23 +160,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
assert(model.summary.validationObjectiveHistory.isEmpty)
|
assert(model.summary.validationObjectiveHistory.isEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("XGBoost and Spark parameters synchronize correctly") {
|
|
||||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
|
||||||
"objective_type" -> "classification")
|
|
||||||
// from xgboost params to spark params
|
|
||||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
|
||||||
assert(xgb.getEta === 1.0)
|
|
||||||
assert(xgb.getObjective === "binary:logistic")
|
|
||||||
assert(xgb.getObjectiveType === "classification")
|
|
||||||
// from spark to xgboost params
|
|
||||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
|
||||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
|
||||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
|
||||||
}
|
|
||||||
|
|
||||||
test("multi class classification") {
|
test("multi class classification") {
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||||
|
|||||||
@ -818,8 +818,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
|||||||
argv.push_back(&args[i][0]);
|
argv.push_back(&args[i][0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
RabitInit(args.size(), dmlc::BeginPtr(argv));
|
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
|
||||||
return 0;
|
return 0;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -829,8 +832,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||||
(JNIEnv *jenv, jclass jcls) {
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
RabitFinalize();
|
if (RabitFinalize()) {
|
||||||
return 0;
|
return 0;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit a429748e244f67f6f144a697f3aa1b1978705b11
|
Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1
|
||||||
Loading…
x
Reference in New Issue
Block a user