[jvm-packages] Scala implementation of the Rabit tracker. (#1612)
* [jvm-packages] Scala implementation of the Rabit tracker. A Scala implementation of RabitTracker that is interface-interchangable with the Java implementation, ported from `tracker.py` in the [dmlc-core project](https://github.com/dmlc/dmlc-core). * [jvm-packages] Updated Akka dependency in pom.xml. * Refactored the RabitTracker directory structure. * Fixed premature stopping of connection handler. Added a new finite state "AwaitingPortNumber" to explicitly wait for the worker to send the port, and close the connection. Stopping the actor prematurely sends a TCP RST to the worker, causing the worker to crash on AssertionError. * Added interface IRabitTracker so that user can switch implementations. * Default timeout duration changes. * Dependency for Akka tests. * Removed the main function of RabitTracker. * A skeleton for testing Akka-based Rabit tracker. * waitFor() in RabitTracker no longer throws exceptions. * Completed unit test for the 'start' command of Rabit tracker. * Preliminary support for Rabit Allreduce via JNI (no prepare function support yet.) * Fixed the default timeout duration. * Use Java container to avoid serialization issues due to intermediate wrappers. * Added tests for Allreduce/model training using Scala Rabit tracker. * Added spill-over unit test for the Scala Rabit tracker. * Fixed a typo. * Overhaul of RabitTracker interface per code review. - Removed methods start() waitFor() (no arguments) from IRabitTracker. - The timeout in start(timeout) is now worker connection timeout, as tcp socket binding timeout is less intuitive. - Dropped time unit from start(...) and waitFor(...) methods; the default time unit is millisecond. - Moved random port number generation into the RabitTrackerHandler. - Moved all Rabit-related classes to package ml.dmlc.xgboost4j.scala.rabit. * More code refactoring and comments. * Unified timeout constants. Readable tracker status code. * Add comments to indicate that allReduce is for tests only. Removed all other variants. * Removed unused imports. * Simplified signatures of training methods. - Moved TrackerConf into parameter map. - Changed GeneralParams so that TrackerConf becomes a standalone parameter. - Updated test cases accordingly. * Changed monitoring strategies. * Reverted monitoring changes. * Update test case for Rabit AllReduce. * Mix in UncaughtExceptionHandler into IRabitTracker to prevent tracker from hanging due to exceptions thrown by workers. * More comprehensive test cases for exception handling and worker connection timeout. * Handle executor loss due to unknown cause: the newly spawned executor will attempt to connect to the tracker. Interrupt tracker in such case. * Per code-review, removed training timeout from TrackerConf. Timeout logic must be implemented explicitly and externally in the driver code. * Reverted scalastyle-config changes. * Visibility scope change. Interface tweaks. * Use match pattern to handle tracker_conf parameter. * Minor clarification in JNI code. * Clearer intent in match pattern to suppress warnings. * Removed Future from constructor. Block in start() and waitFor() instead. * Revert inadvertent comment changes. * Removed debugging information. * Updated test cases that are a bit finicky. * Added comments on the reasoning behind the unit tests for testing Rabit tracker robustness.
This commit is contained in:
@@ -16,11 +16,10 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
@@ -30,6 +29,25 @@ import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
import scala.concurrent.duration.{Duration, MILLISECONDS}
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
|
||||
}
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
@@ -80,7 +98,7 @@ object XGBoost extends Serializable {
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingSet: RDD[MLLabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
@@ -92,7 +110,7 @@ object XGBoost extends Serializable {
|
||||
partitionedTrainingSet.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
Rabit.init(rabitEnv)
|
||||
var booster: Booster = null
|
||||
if (trainingSamples.hasNext) {
|
||||
val cacheFileName: String = {
|
||||
@@ -211,9 +229,21 @@ object XGBoost extends Serializable {
|
||||
overridedParams
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int): RabitTracker = {
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||
case "scala" => new RabitTracker(nWorkers)
|
||||
case "python" => new PyRabitTracker(nWorkers)
|
||||
case _ => new PyRabitTracker(nWorkers)
|
||||
}
|
||||
|
||||
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
|
||||
trackerConf.workerConnectionTimeout.toMillis
|
||||
} else {
|
||||
// 0 == Duration.Inf
|
||||
0L
|
||||
}
|
||||
|
||||
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
@@ -227,7 +257,7 @@ object XGBoost extends Serializable {
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
@@ -243,19 +273,26 @@ object XGBoost extends Serializable {
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
}
|
||||
val tracker = startTracker(nWorkers)
|
||||
val trackerConf = params.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boosters.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor()
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
isClsTask)
|
||||
|
||||
@@ -16,9 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
import scala.concurrent.duration.{Duration, NANOSECONDS}
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
@@ -69,7 +72,38 @@ trait GeneralParams extends Params {
|
||||
*/
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
*
|
||||
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||
* trackerImpl: String)
|
||||
*
|
||||
* See below for detailed explanations.
|
||||
*
|
||||
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||
* default: "python"
|
||||
*
|
||||
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||
*
|
||||
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||
* default: 0 millisecond (no timeout)
|
||||
*
|
||||
* The timeout value should take the time of data loading and pre-processing into account,
|
||||
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||
* training/testing from hanging indefinitely, possible due to network issues.
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*/
|
||||
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN)
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
|
||||
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
|
||||
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
|
||||
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
|
||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||
that should be avoided.
|
||||
*/
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
|
||||
last created worker. A cascading event chain will be triggered once the RuntimeException is
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
sparkContext.stop()
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Rabit.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
sparkContext.stop()
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
||||
implicit val sparkContext = new SparkContext(sparkConf)
|
||||
sparkContext.setLogLevel("ERROR")
|
||||
|
||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Rabit.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Rabit.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
sparkContext.stop()
|
||||
}
|
||||
}
|
||||
@@ -17,18 +17,60 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import scala.concurrent.duration._
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
|
||||
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Rabit.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Rabit.allReduce(arr, Rabit.OpType.MAX)
|
||||
Rabit.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("build RDD containing boosters with the specified worker number") {
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
@@ -36,7 +78,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
new java.util.HashMap[String, String](),
|
||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
@@ -59,6 +101,21 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
cleanExternalCache("XGBoostSuite")
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = buildTrainingRDD(sc)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||
val nrow = 100
|
||||
|
||||
Reference in New Issue
Block a user