Revert "[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4876)" (#4965)

This reverts commit 86ed01c4bbecef66e1bc4d02fb13116bd6130fae.
This commit is contained in:
Jiaming Yuan 2019-10-18 17:02:36 -04:00 committed by Nan Zhu
parent 86ed01c4bb
commit 010b8f1428
73 changed files with 115 additions and 343 deletions

View File

@ -94,16 +94,36 @@ set_target_properties(dmlc PROPERTIES
list(APPEND LINKED_LIBRARIES_PRIVATE dmlc)
# rabit
set(RABIT_BUILD_DMLC OFF)
set(DMLC_ROOT ${xgboost_SOURCE_DIR}/dmlc-core)
set(RABIT_WITH_R_LIB ${R_LIB})
add_subdirectory(rabit)
if (RABIT_MOCK)
list(APPEND LINKED_LIBRARIES_PRIVATE rabit_mock_static)
else()
list(APPEND LINKED_LIBRARIES_PRIVATE rabit)
endif(RABIT_MOCK)
# full rabit doesn't build on windows, so we can't import it as subdirectory
if(MINGW OR R_LIB OR WIN32)
set(RABIT_SOURCES
rabit/src/engine_empty.cc
rabit/src/c_api.cc)
else ()
if(RABIT_MOCK)
set(RABIT_SOURCES
rabit/src/allreduce_base.cc
rabit/src/allreduce_robust.cc
rabit/src/engine_mock.cc
rabit/src/c_api.cc)
else()
set(RABIT_SOURCES
rabit/src/allreduce_base.cc
rabit/src/allreduce_robust.cc
rabit/src/engine.cc
rabit/src/c_api.cc)
endif(RABIT_MOCK)
endif (MINGW OR R_LIB OR WIN32)
add_library(rabit STATIC ${RABIT_SOURCES})
target_include_directories(rabit PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}/dmlc-core/include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}/rabit/include/rabit>)
set_target_properties(rabit
PROPERTIES
CXX_STANDARD 11
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
list(APPEND LINKED_LIBRARIES_PRIVATE rabit)
# Exports some R specific definitions and objects
if (R_LIB)

View File

@ -1,3 +1,2 @@
tracker.py
build.sh
build.sh

View File

@ -18,7 +18,7 @@ CONFIG = {
"USE_HDFS": "OFF",
"USE_AZURE": "OFF",
"USE_S3": "OFF",
"RABIT_MOCK": "OFF",
"USE_CUDA": "OFF",
"JVM_BINDINGS": "ON"
}
@ -68,7 +68,6 @@ def normpath(path):
if __name__ == "__main__":
CONFIG["RABIT_MOCK"] = str(sys.argv[1])
if sys.platform == "darwin":
# Enable of your compiler supports OpenMP.
CONFIG["USE_OPENMP"] = "OFF"

View File

@ -37,7 +37,6 @@
<spark.version>2.4.3</spark.version>
<scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<rabit.mock>OFF</rabit.mock>
</properties>
<repositories>
<repository>
@ -53,47 +52,6 @@
<module>xgboost4j-flink</module>
</modules>
<profiles>
<profile>
<id>dev</id>
<properties>
<rabit.mock>ON</rabit.mock>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>release</id>
<build>

View File

@ -49,7 +49,7 @@ This file is divided into 3 sections:
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
<parameters>
<parameter name="header"><![CDATA[/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -21,11 +21,11 @@ import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@ -155,7 +155,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
overridedParams
}
private[spark] def buildXGBRuntimeParams: XGBoostExecutionParams = {
def buildXGBRuntimeParams: XGBoostExecutionParams = {
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
val round = overridedParams("num_round").asInstanceOf[Int]
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
@ -221,25 +221,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
private[spark] def buildRabitParams: Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_reduce_ring_mincount", 32<<10).toString,
"rabit_reduce_buffer" ->
overridedParams.getOrElse("rabit_reduce_buffer", "256MB").toString,
"rabit_bootstrap_cache" ->
overridedParams.getOrElse("rabit_bootstrap_cache", false).toString,
"rabit_debug" ->
overridedParams.getOrElse("rabit_debug", false).toString,
"rabit_timeout" ->
overridedParams.getOrElse("rabit_timeout", false).toString,
"rabit_timeout_sec" ->
overridedParams.getOrElse("rabit_timeout_sec", 1800).toString,
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" ->
overridedParams.getOrElse("dmlc_worker_stop_process_on_error", false).toString
)
}
/**
@ -340,6 +321,7 @@ object XGBoost extends Serializable {
}
val taskId = TaskContext.getPartitionId().toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
try {
Rabit.init(rabitEnv)
@ -508,9 +490,8 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
buildXGBRuntimeParams
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
@ -529,12 +510,12 @@ object XGBoost extends Serializable {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs.asScala ++ xgbRabitParams
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv.asJava,
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv.asJava,
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class XGBoostClassifier (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbc"), xgboostParams)
XGBoost2MLlibParams(xgboostParams)
XGBoostToMLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
import org.apache.spark.ml.param.shared.HasWeightCol
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
with BoosterParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -54,7 +54,7 @@ class XGBoostRegressor (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbr"), xgboostParams)
XGBoost2MLlibParams(xgboostParams)
XGBoostToMLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -241,7 +241,7 @@ trait HasNumClass extends Params {
private[spark] trait ParamMapFuncs extends Params {
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,62 +0,0 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
private[spark] trait RabitParams extends Params {
/**
* Rabit worker configurations. These parameters were passed to Rabit.Init and decide
* rabit_reduce_ring_mincount - threshold of enable ring based allreduce/broadcast operations.
* rabit_reduce_buffer - buffer size to recv and run reduction
* rabit_bootstrap_cache - enable save allreduce cache before loadcheckpoint
* rabit_debug - enable more verbose rabit logging to stdout
* rabit_timeout - enable sidecar thread after rabit observed failures
* rabit_timeout_sec - wait interval before exit after rabit observed failures
* dmlc_worker_connect_retry - number of retrys to tracker
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
*/
final val ringReduceMin = new IntParam(this, "rabitReduceRingMincount",
"minimal counts of enable allreduce/broadcast with ring based topology",
ParamValidators.gtEq(1))
final def reduceBuffer: Param[String] = new Param[String](this, "rabitReduceBuffer",
"buffer size (MB/GB) allocated to each xgb trainner recv and run reduction",
(buf: String) => buf.contains("MB") || buf.contains("GB"))
final def bootstrapCache: BooleanParam = new BooleanParam(this, "rabitBootstrapCache",
"enable save allreduce cache before loadcheckpoint, used to allow failed task retry")
final def rabitDebug: BooleanParam = new BooleanParam(this, "rabitDebug",
"enable more verbose rabit logging to stdout")
final def rabitTimeout: BooleanParam = new BooleanParam(this, "rabitTimeout",
"enable failure timeout sidecar threads")
final def timeoutInterval: IntParam = new IntParam(this, "rabitTimeoutSec",
"timeout threshold after rabit observed failures", (interval: Int) => interval > 0)
final def connectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
"number of retry worker do before fail", ParamValidators.gtEq(1))
final def exitOnError: BooleanParam = new BooleanParam(this, "dmlcWorkerStopProcessOnError",
"exit process when rabit see assert error")
setDefault(ringReduceMin -> (32 << 10), reduceBuffer -> "256MB", bootstrapCache -> false,
rabitDebug -> false, connectRetry -> 5, rabitTimeout -> false, timeoutInterval -> 1800,
exitOnError -> false)
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,10 +16,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
@ -31,7 +28,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
test("nthread configuration must be no larger than spark.task.cpus") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
intercept[IllegalArgumentException] {
@ -43,7 +40,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
// TODO write an isolated test for Booster.
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator, null)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
@ -55,7 +52,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
ss.conf.set("spark.ssl.enabled", true)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,110 +0,0 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
test("test parity classification prediction") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val model1 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 7)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "1")
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 == p2)
}
}
test("test parity regression prediction") {
val training = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator, null)
val testDF = buildDataFrame(Classification.test)
val model1 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 7)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
if (item._1.toString == "DMLC_WORKER_STOP_PROCESS_ON_ERROR") assert(item._2 == "false")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) < 0.00001f)
}
}
test("test graceful failure handle") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
// mock rank 0 failure during 4th allreduce synchronization
Rabit.mockList = Array("0,4,0,0").toList.asJava
intercept[XGBoostError] {
new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 1,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> false)).fit(training)
}
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -66,7 +66,6 @@
<executable>python</executable>
<arguments>
<argument>create_jni.py</argument>
<argument>${rabit.mock}</argument>
</arguments>
<workingDirectory>${user.dir}</workingDirectory>
</configuration>

View File

@ -3,8 +3,6 @@ package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
@ -53,25 +51,18 @@ public class Rabit {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
// used as way to test/debug passed rabit init parameters
public static Map<String, String> rabitEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
rabitEnvs = envs;
String[] args = new String[envs.size() + mockList.size()];
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
// pass list of rabit mock strings eg mock=0,1,0,0
for(String mock : mockList) {
args[idx++] = "mock=" + mock;
}
checkCall(XGBoostJNI.RabitInit(args));
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

2
rabit

@ -1 +1 @@
Subproject commit d22e0809a890ce0bc7af8d76c3c504b333d62f49
Subproject commit 9a7ac85d7eb65b1a0b904e1fa8d5a01b910adda4

View File

@ -16,7 +16,7 @@ spark_version=$1
rm -rf build/
cd jvm-packages
mvn --no-transfer-progress package -Dspark.version=${spark_version} -Pdev
mvn --no-transfer-progress package -Dspark.version=${spark_version}
set +x
set +e

View File

@ -27,8 +27,8 @@ fi
if [ ${TASK} == "java_test" ]; then
set -e
cd jvm-packages
mvn -q clean install -DskipTests -Dmaven.test.skip -Pdev
mvn -q test -Pdev
mvn -q clean install -DskipTests -Dmaven.test.skip
mvn -q test
fi
if [ ${TASK} == "cmake_test" ]; then