[jvm-packages] More brooming in tests (#2517)
* Deduplicated DataFrame creation in XGBoostDFSuite * Extracted dermatology.data into MultiClassification * Moved cache cleaning to SharedSparkContext Cache files are prefixed with appName therefore this seems to be just the place to delete them. * Removed redundant JMatrix calls in xgboost4j-spark * Slightly more readable buildDenseRDD in XGBoostGeneralSuite * Generalized train/test DataFrame construction in XGBoostDFSuite * Changed SharedSparkContext to setup a new context per-test Hence the new name: PerTestSparkSession :) * Fused Utils into PerTestSparkSession * Whitespace fix in XGBoostDFSuite * Ensure SparkSession is always eagerly created in PerTestSparkSession * Renamed PerTestSparkSession->PerTest because it was doing slightly more than creating/stopping the session.
This commit is contained in:
parent
ca7fc9fda3
commit
4eb255262f
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
@ -115,7 +115,7 @@ object XGBoost extends Serializable {
|
|||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||||
val trainingMatrix = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
val trainingMatrix = new DMatrix(partitionItr, cacheFileName)
|
||||||
try {
|
try {
|
||||||
if (params.contains("groupData") && params("groupData") != null) {
|
if (params.contains("groupData") && params("groupData") != null) {
|
||||||
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||||
@ -221,7 +221,7 @@ object XGBoost extends Serializable {
|
|||||||
private def overrideParamsAccordingToTaskCPUs(
|
private def overrideParamsAccordingToTaskCPUs(
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
sc: SparkContext): Map[String, Any] = {
|
sc: SparkContext): Map[String, Any] = {
|
||||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
|
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
||||||
var overridedParams = params
|
var overridedParams = params
|
||||||
if (overridedParams.contains("nthread")) {
|
if (overridedParams.contains("nthread")) {
|
||||||
val nThread = overridedParams("nthread").toString.toInt
|
val nThread = overridedParams("nthread").toString.toInt
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||||
@ -66,7 +66,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
if (testSamples.nonEmpty) {
|
if (testSamples.nonEmpty) {
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
val dMatrix = new DMatrix(testSamples)
|
||||||
try {
|
try {
|
||||||
broadcastBooster.value.predictLeaf(dMatrix).iterator
|
broadcastBooster.value.predictLeaf(dMatrix).iterator
|
||||||
} finally {
|
} finally {
|
||||||
@ -202,7 +202,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
val dMatrix = new DMatrix(testSamples, cacheFileName)
|
||||||
try {
|
try {
|
||||||
broadcastBooster.value.predict(dMatrix).iterator
|
broadcastBooster.value.predict(dMatrix).iterator
|
||||||
} finally {
|
} finally {
|
||||||
|
|||||||
@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||||
|
|
||||||
|
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||||
|
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
||||||
|
|
||||||
|
@transient private var currentSession: SparkSession = _
|
||||||
|
|
||||||
|
def ss: SparkSession = getOrCreateSession
|
||||||
|
implicit def sc: SparkContext = ss.sparkContext
|
||||||
|
|
||||||
|
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
|
||||||
|
.master("local[*]")
|
||||||
|
.appName("XGBoostSuite")
|
||||||
|
.config("spark.ui.enabled", false)
|
||||||
|
.config("spark.driver.memory", "512m")
|
||||||
|
|
||||||
|
override def beforeEach(): Unit = getOrCreateSession
|
||||||
|
|
||||||
|
override def afterEach() {
|
||||||
|
synchronized {
|
||||||
|
if (currentSession != null) {
|
||||||
|
currentSession.stop()
|
||||||
|
cleanExternalCache(currentSession.sparkContext.appName)
|
||||||
|
currentSession = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def getOrCreateSession = synchronized {
|
||||||
|
if (currentSession == null) {
|
||||||
|
currentSession = sparkSessionBuilder.getOrCreate()
|
||||||
|
currentSession.sparkContext.setLogLevel("ERROR")
|
||||||
|
}
|
||||||
|
currentSession
|
||||||
|
}
|
||||||
|
|
||||||
|
private def cleanExternalCache(prefix: String): Unit = {
|
||||||
|
val dir = new File(".")
|
||||||
|
for (file <- dir.listFiles() if file.getName.startsWith(prefix)) {
|
||||||
|
file.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -23,7 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext}
|
|||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
|
||||||
class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
class RabitTrackerRobustnessSuite extends FunSuite with PerTest {
|
||||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||||
/*
|
/*
|
||||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||||
@ -33,12 +33,7 @@ class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
|||||||
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
|
||||||
that should be avoided.
|
that should be avoided.
|
||||||
*/
|
*/
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
|
||||||
implicit val sparkContext = new SparkContext(sparkConf)
|
|
||||||
sparkContext.setLogLevel("ERROR")
|
|
||||||
|
|
||||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new PyRabitTracker(numWorkers)
|
val tracker = new PyRabitTracker(numWorkers)
|
||||||
tracker.start(0)
|
tracker.start(0)
|
||||||
@ -90,16 +85,10 @@ class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
|||||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
sparkThread.start()
|
sparkThread.start()
|
||||||
assert(tracker.waitFor(0) != 0)
|
assert(tracker.waitFor(0) != 0)
|
||||||
sparkContext.stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
|
||||||
implicit val sparkContext = new SparkContext(sparkConf)
|
|
||||||
sparkContext.setLogLevel("ERROR")
|
|
||||||
|
|
||||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new ScalaRabitTracker(numWorkers)
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
tracker.start(0)
|
tracker.start(0)
|
||||||
@ -127,16 +116,10 @@ class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
|||||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||||
sparkThread.start()
|
sparkThread.start()
|
||||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
sparkContext.stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]")
|
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||||
.setAppName("XGBoostSuite").set("spark.driver.memory", "512m")
|
|
||||||
implicit val sparkContext = new SparkContext(sparkConf)
|
|
||||||
sparkContext.setLogLevel("ERROR")
|
|
||||||
|
|
||||||
val rdd = sparkContext.parallelize(1 to numWorkers, numWorkers).cache()
|
|
||||||
|
|
||||||
val tracker = new ScalaRabitTracker(numWorkers)
|
val tracker = new ScalaRabitTracker(numWorkers)
|
||||||
tracker.start(500)
|
tracker.start(500)
|
||||||
@ -164,6 +147,5 @@ class RabitTrackerRobustnessSuite extends FunSuite with Utils {
|
|||||||
sparkThread.start()
|
sparkThread.start()
|
||||||
// should fail due to connection timeout
|
// should fail due to connection timeout
|
||||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||||
sparkContext.stop()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,43 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
|
||||||
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
|
|
||||||
|
|
||||||
trait SharedSparkContext extends FunSuite with BeforeAndAfter with BeforeAndAfterAll
|
|
||||||
with Serializable {
|
|
||||||
|
|
||||||
@transient protected implicit var sc: SparkContext = _
|
|
||||||
|
|
||||||
override def beforeAll() {
|
|
||||||
val sparkConf = new SparkConf()
|
|
||||||
.setMaster("local[*]")
|
|
||||||
.setAppName("XGBoostSuite")
|
|
||||||
.set("spark.driver.memory", "512m")
|
|
||||||
.set("spark.ui.enabled", "false")
|
|
||||||
|
|
||||||
sc = new SparkContext(sparkConf)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def afterAll() {
|
|
||||||
if (sc != null) {
|
|
||||||
sc.stop()
|
|
||||||
sc = null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -56,6 +56,25 @@ object Classification extends TrainTestData {
|
|||||||
val test: Seq[MLLabeledPoint] = getLabeledPoints("/agaricus.txt.test", zeroBased = false)
|
val test: Seq[MLLabeledPoint] = getLabeledPoints("/agaricus.txt.test", zeroBased = false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object MultiClassification extends TrainTestData {
|
||||||
|
val train: Seq[MLLabeledPoint] = getLabeledPoints("/dermatology.data")
|
||||||
|
|
||||||
|
private def getLabeledPoints(resource: String): Seq[MLLabeledPoint] = {
|
||||||
|
getResourceLines(resource).map { line =>
|
||||||
|
val featuresAndLabel = line.split(",")
|
||||||
|
val label = featuresAndLabel.last.toDouble - 1
|
||||||
|
val values = new Array[Double](featuresAndLabel.length - 1)
|
||||||
|
values(values.length - 1) =
|
||||||
|
if (featuresAndLabel(featuresAndLabel.length - 2) == "?") 1 else 0
|
||||||
|
for (i <- 0 until values.length - 2) {
|
||||||
|
values(i) = featuresAndLabel(i).toDouble
|
||||||
|
}
|
||||||
|
|
||||||
|
MLLabeledPoint(label, MLVectors.dense(values.take(values.length - 1)))
|
||||||
|
}.toList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object Regression extends TrainTestData {
|
object Regression extends TrainTestData {
|
||||||
val train: Seq[MLLabeledPoint] = getLabeledPoints("/machine.txt.train", zeroBased = true)
|
val train: Seq[MLLabeledPoint] = getLabeledPoints("/machine.txt.train", zeroBased = true)
|
||||||
val test: Seq[MLLabeledPoint] = getLabeledPoints("/machine.txt.test", zeroBased = true)
|
val test: Seq[MLLabeledPoint] = getLabeledPoints("/machine.txt.test", zeroBased = true)
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
trait Utils extends Serializable {
|
|
||||||
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
|
||||||
|
|
||||||
protected def cleanExternalCache(prefix: String): Unit = {
|
|
||||||
val dir = new File(".")
|
|
||||||
for (file <- dir.listFiles() if file.getName.startsWith(prefix)) {
|
|
||||||
file.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -16,43 +16,36 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostConfigureSuite extends FunSuite with Utils {
|
class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||||
|
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||||
|
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
|
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
||||||
|
|
||||||
test("nthread configuration must be equal to spark.task.cpus") {
|
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
set("spark.task.cpus", "4")
|
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
|
||||||
customSparkContext.setLogLevel("ERROR")
|
|
||||||
// start another app
|
|
||||||
val trainingRDD = customSparkContext.parallelize(Classification.train)
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "nthread" -> 6)
|
"objective" -> "binary:logistic",
|
||||||
|
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
}
|
}
|
||||||
customSparkContext.stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("kryoSerializer test") {
|
test("kryoSerializer test") {
|
||||||
val eval = new EvalError()
|
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
|
||||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
|
||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
|
||||||
customSparkContext.setLogLevel("ERROR")
|
|
||||||
val trainingRDD = customSparkContext.parallelize(Classification.train)
|
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
// TODO write an isolated test for Booster.
|
||||||
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
|
val testSetDMatrix = new DMatrix(Classification.test.iterator, null)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic")
|
"objective" -> "binary:logistic")
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
val eval = new EvalError()
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
customSparkContext.stop()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,63 +16,42 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
import scala.collection.mutable.ListBuffer
|
|
||||||
import scala.io.Source
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.Pipeline
|
|
||||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
|
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||||
|
private def buildDataFrame(
|
||||||
|
instances: Seq[MLLabeledPoint],
|
||||||
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
|
val it = instances.iterator.zipWithIndex
|
||||||
|
.map { case (instance: MLLabeledPoint, id: Int) =>
|
||||||
|
(id, instance.label, instance.features)
|
||||||
|
}
|
||||||
|
|
||||||
private var trainingDF: DataFrame = null
|
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||||
|
.toDF("id", "label", "features")
|
||||||
after {
|
|
||||||
cleanExternalCache("XGBoostDFSuite")
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None): DataFrame = {
|
|
||||||
if (trainingDF == null) {
|
|
||||||
val labeledPointsRDD = sparkContext.getOrElse(sc)
|
|
||||||
.parallelize(Classification.train, numWorkers)
|
|
||||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
|
||||||
import sparkSession.implicits._
|
|
||||||
trainingDF = sparkSession.createDataset(labeledPointsRDD).toDF
|
|
||||||
}
|
|
||||||
trainingDF
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test consistency and order preservation of dataframe-based model") {
|
test("test consistency and order preservation of dataframe-based model") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic")
|
"objective" -> "binary:logistic")
|
||||||
val trainingItr = Classification.train.iterator
|
val trainingItr = Classification.train.iterator
|
||||||
val (testItr, auxTestItr) = Classification.test.iterator.duplicate
|
val testItr = Classification.test.iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val round = 5
|
val round = 5
|
||||||
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
|
val trainDMatrix = new DMatrix(trainingItr)
|
||||||
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
val testDMatrix = new DMatrix(testItr)
|
||||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
||||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||||
val testSetItr = auxTestItr.zipWithIndex.map {
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label)
|
|
||||||
}
|
|
||||||
val trainingDF = buildTrainingDataframe()
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = round, nWorkers = numWorkers)
|
round = round, nWorkers = numWorkers)
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
val testDF = buildDataFrame(Classification.test)
|
||||||
"id", "features", "label")
|
|
||||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
||||||
assert(testDF.count() === predResultsFromDF.size)
|
assert(testDF.count() === predResultsFromDF.size)
|
||||||
@ -89,78 +68,59 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
test("test transformLeaf") {
|
test("test transformLeaf") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic")
|
"objective" -> "binary:logistic")
|
||||||
val testItr = Classification.test.iterator
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
val trainingDF = buildTrainingDataframe()
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = numWorkers)
|
round = 5, nWorkers = numWorkers)
|
||||||
val testSetItr = testItr.zipWithIndex.map {
|
val testDF = buildDataFrame(Classification.test)
|
||||||
case (instance: LabeledPoint, id: Int) =>
|
|
||||||
(id, instance.features, instance.label)
|
|
||||||
}
|
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
|
||||||
"id", "features", "label")
|
|
||||||
xgBoostModelWithDF.transformLeaf(testDF).show()
|
xgBoostModelWithDF.transformLeaf(testDF).show()
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test schema of XGBoostRegressionModel") {
|
test("test schema of XGBoostRegressionModel") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:linear")
|
"objective" -> "reg:linear")
|
||||||
val testItr = Regression.test.iterator.zipWithIndex
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
.map { case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label) }
|
|
||||||
val trainingDF = {
|
|
||||||
val labeledPointsRDD = sc.parallelize(Regression.train, numWorkers)
|
|
||||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
|
||||||
import sparkSession.implicits._
|
|
||||||
sparkSession.createDataset(labeledPointsRDD).toDF
|
|
||||||
}
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||||
xgBoostModelWithDF.setPredictionCol("final_prediction")
|
xgBoostModelWithDF.setPredictionCol("final_prediction")
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
val testDF = buildDataFrame(Regression.test)
|
||||||
"id", "features", "label")
|
|
||||||
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||||
assert(predictionDF.columns.contains("id") === true)
|
assert(predictionDF.columns.contains("id"))
|
||||||
assert(predictionDF.columns.contains("features") === true)
|
assert(predictionDF.columns.contains("features"))
|
||||||
assert(predictionDF.columns.contains("label") === true)
|
assert(predictionDF.columns.contains("label"))
|
||||||
assert(predictionDF.columns.contains("final_prediction") === true)
|
assert(predictionDF.columns.contains("final_prediction"))
|
||||||
predictionDF.show()
|
predictionDF.show()
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test schema of XGBoostClassificationModel") {
|
test("test schema of XGBoostClassificationModel") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic")
|
"objective" -> "binary:logistic")
|
||||||
val testItr = Classification.test.iterator.
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
|
||||||
(id, instance.features, instance.label)
|
|
||||||
}
|
|
||||||
val trainingDF = buildTrainingDataframe()
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
|
||||||
"raw_prediction").setPredictionCol("final_prediction")
|
"raw_prediction").setPredictionCol("final_prediction")
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
val testDF = buildDataFrame(Classification.test)
|
||||||
"id", "features", "label")
|
|
||||||
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||||
assert(predictionDF.columns.contains("id") === true)
|
assert(predictionDF.columns.contains("id"))
|
||||||
assert(predictionDF.columns.contains("features") === true)
|
assert(predictionDF.columns.contains("features"))
|
||||||
assert(predictionDF.columns.contains("label") === true)
|
assert(predictionDF.columns.contains("label"))
|
||||||
assert(predictionDF.columns.contains("raw_prediction") === true)
|
assert(predictionDF.columns.contains("raw_prediction"))
|
||||||
assert(predictionDF.columns.contains("final_prediction") === true)
|
assert(predictionDF.columns.contains("final_prediction"))
|
||||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
|
||||||
setPredictionCol("final_prediction")
|
setPredictionCol("final_prediction")
|
||||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||||
assert(predictionDF.columns.contains("id") === true)
|
assert(predictionDF.columns.contains("id"))
|
||||||
assert(predictionDF.columns.contains("features") === true)
|
assert(predictionDF.columns.contains("features"))
|
||||||
assert(predictionDF.columns.contains("label") === true)
|
assert(predictionDF.columns.contains("label"))
|
||||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||||
assert(predictionDF.columns.contains("final_prediction") === true)
|
assert(predictionDF.columns.contains("final_prediction"))
|
||||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
|
||||||
setRawPredictionCol("raw_prediction").setPredictionCol("")
|
setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||||
assert(predictionDF.columns.contains("id") === true)
|
assert(predictionDF.columns.contains("id"))
|
||||||
assert(predictionDF.columns.contains("features") === true)
|
assert(predictionDF.columns.contains("features"))
|
||||||
assert(predictionDF.columns.contains("label") === true)
|
assert(predictionDF.columns.contains("label"))
|
||||||
assert(predictionDF.columns.contains("raw_prediction") === true)
|
assert(predictionDF.columns.contains("raw_prediction"))
|
||||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,69 +153,33 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||||
"eval_metric" -> "error")
|
"eval_metric" -> "error")
|
||||||
val testItr = Classification.test.iterator
|
val testItr = Classification.test.iterator
|
||||||
val trainingDF = buildTrainingDataframe()
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 10, nWorkers = math.min(2, numWorkers))
|
round = 10, nWorkers = math.min(2, numWorkers))
|
||||||
val error = new EvalError
|
val error = new EvalError
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
val testSetDMatrix = new DMatrix(testItr)
|
||||||
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
|
|
||||||
val intValueArray = new Array[Double](valueArray.length)
|
|
||||||
intValueArray(valueArray.length - 2) = {
|
|
||||||
if (valueArray(valueArray.length - 2) == "?") {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
intValueArray(valueArray.length - 1) = valueArray(valueArray.length - 1).toDouble - 1
|
|
||||||
for (i <- 0 until intValueArray.length - 2) {
|
|
||||||
intValueArray(i) = valueArray(i).toDouble
|
|
||||||
}
|
|
||||||
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): Seq[LabeledPoint] = {
|
|
||||||
val file = Source.fromFile(new File(filePath))
|
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
|
||||||
for (sample <- file.getLines()) {
|
|
||||||
sampleList += convertCSVPointToLabelPoint(sample.split(","))
|
|
||||||
}
|
|
||||||
sampleList
|
|
||||||
}
|
|
||||||
|
|
||||||
test("multi_class classification test") {
|
test("multi_class classification test") {
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
val spark = SparkSession.builder().getOrCreate()
|
XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||||
import spark.implicits._
|
|
||||||
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test DF use nested groupData") {
|
test("test DF use nested groupData") {
|
||||||
val testItr = Ranking.test.iterator.zipWithIndex
|
val trainingDF = buildDataFrame(Ranking.train0, 1)
|
||||||
.map { case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label) }
|
.union(buildDataFrame(Ranking.train1, 1))
|
||||||
val trainingDF = {
|
|
||||||
val labeledPointsRDD0 = sc.parallelize(Ranking.train0, numSlices = 1)
|
|
||||||
val labeledPointsRDD1 = sc.parallelize(Ranking.train1, numSlices = 1)
|
|
||||||
val labeledPointsRDD = labeledPointsRDD0.union(labeledPointsRDD1)
|
|
||||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
|
||||||
import sparkSession.implicits._
|
|
||||||
sparkSession.createDataset(labeledPointsRDD).toDF
|
|
||||||
}
|
|
||||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||||
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = 2)
|
round = 5, nWorkers = 2)
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
val testDF = buildDataFrame(Ranking.test)
|
||||||
"id", "features", "label")
|
|
||||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||||
assert(testDF.count() === predResultsFromDF.size)
|
assert(testDF.count() === predResultsFromDF.size)
|
||||||
@ -264,11 +188,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
test("params of estimator and produced model are coordinated correctly") {
|
test("params of estimator and produced model are coordinated correctly") {
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
val spark = SparkSession.builder().getOrCreate()
|
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
||||||
import spark.implicits._
|
|
||||||
val model =
|
|
||||||
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
|
||||||
assert(model.get[Double](model.eta).get == 0.1)
|
assert(model.get[Double](model.eta).get == 0.1)
|
||||||
assert(model.get[Int](model.maxDepth).get == 6)
|
assert(model.get[Int](model.maxDepth).get == 6)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import java.util.concurrent.LinkedBlockingDeque
|
|||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
|
|
||||||
@ -30,8 +30,9 @@ import org.apache.spark.SparkContext
|
|||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
|
import org.apache.spark.ml.linalg.{Vectors, Vector => SparkVector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||||
val vectorLength = 100
|
val vectorLength = 100
|
||||||
val rdd = sc.parallelize(
|
val rdd = sc.parallelize(
|
||||||
@ -84,29 +85,26 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
missing = Float.NaN, baseMargin = null)
|
missing = Float.NaN, baseMargin = null)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
cleanExternalCache("XGBoostSuite")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("training with external memory cache") {
|
test("training with external memory cache") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
nWorkers = numWorkers, useExternalMemory = true)
|
nWorkers = numWorkers, useExternalMemory = true)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
// clean
|
|
||||||
cleanExternalCache("XGBoostSuite")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("training with Scala-implemented Rabit tracker") {
|
test("training with Scala-implemented Rabit tracker") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||||
@ -120,7 +118,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
"grow_policy" -> "depthwise", "eval_metric" -> "error")
|
"grow_policy" -> "depthwise", "eval_metric" -> "error")
|
||||||
@ -135,7 +133,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
|
||||||
@ -150,7 +148,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||||
@ -166,7 +164,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||||
@ -182,7 +180,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||||
@ -196,29 +194,19 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
|
|
||||||
test("test with dense vectors containing missing value") {
|
test("test with dense vectors containing missing value") {
|
||||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||||
val nrow = 100
|
val numRows = 100
|
||||||
val ncol = 5
|
val numCols = 5
|
||||||
val data0 = Array.ofDim[Double](nrow, ncol)
|
|
||||||
// put random nums
|
val labeledPoints = (0 until numRows).map { _ =>
|
||||||
for (r <- 0 until nrow; c <- 0 until ncol) {
|
val label = Random.nextDouble()
|
||||||
data0(r)(c) = {
|
val values = Array.tabulate[Double](numCols) { c =>
|
||||||
if (c == ncol - 1) {
|
if (c == numCols - 1) -0.1 else Random.nextDouble()
|
||||||
-0.1
|
|
||||||
} else {
|
|
||||||
Random.nextDouble()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LabeledPoint(label, Vectors.dense(values))
|
||||||
}
|
}
|
||||||
// create label
|
|
||||||
val label0 = new Array[Double](nrow)
|
sc.parallelize(labeledPoints)
|
||||||
for (i <- label0.indices) {
|
|
||||||
label0(i) = Random.nextDouble()
|
|
||||||
}
|
|
||||||
val points = new ListBuffer[LabeledPoint]
|
|
||||||
for (r <- 0 until nrow) {
|
|
||||||
points += LabeledPoint(label0(r), Vectors.dense(data0(r)))
|
|
||||||
}
|
|
||||||
sc.parallelize(points)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainingRDD = buildDenseRDD().repartition(4)
|
val trainingRDD = buildDenseRDD().repartition(4)
|
||||||
@ -228,8 +216,6 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
|
||||||
useExternalMemory = true)
|
useExternalMemory = true)
|
||||||
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
||||||
// clean
|
|
||||||
cleanExternalCache("XGBoostSuite")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test consistency of prediction functions with RDD") {
|
test("test consistency of prediction functions with RDD") {
|
||||||
@ -280,7 +266,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(Classification.test.iterator, null))
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
|||||||
@ -24,11 +24,12 @@ import org.apache.spark.SparkConf
|
|||||||
import org.apache.spark.ml.feature._
|
import org.apache.spark.ml.feature._
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
|
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||||
|
|
||||||
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
class XGBoostSparkPipelinePersistence extends FunSuite with PerTest
|
||||||
|
with BeforeAndAfterAll {
|
||||||
|
|
||||||
override def afterAll(): Unit = {
|
override def afterAll(): Unit = {
|
||||||
super.afterAll()
|
|
||||||
delete(new File("./testxgbPipe"))
|
delete(new File("./testxgbPipe"))
|
||||||
delete(new File("./testxgbEst"))
|
delete(new File("./testxgbEst"))
|
||||||
delete(new File("./testxgbModel"))
|
delete(new File("./testxgbModel"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user