[jvm-packages] More brooming in tests (#2517)

* Deduplicated DataFrame creation in XGBoostDFSuite

* Extracted dermatology.data into MultiClassification

* Moved cache cleaning to SharedSparkContext

Cache files are prefixed with appName therefore this seems to be just the
place to delete them.

* Removed redundant JMatrix calls in xgboost4j-spark

* Slightly more readable buildDenseRDD in XGBoostGeneralSuite

* Generalized train/test DataFrame construction in XGBoostDFSuite

* Changed SharedSparkContext to setup a new context per-test

Hence the new name: PerTestSparkSession :)

* Fused Utils into PerTestSparkSession

* Whitespace fix in XGBoostDFSuite

* Ensure SparkSession is always eagerly created in PerTestSparkSession

* Renamed PerTestSparkSession->PerTest

because it was doing slightly more than creating/stopping the session.
This commit is contained in:
Sergei Lebedev
2017-07-18 22:08:48 +02:00
committed by Nan Zhu
parent ca7fc9fda3
commit 4eb255262f
11 changed files with 182 additions and 288 deletions

View File

@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
@@ -115,7 +115,7 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv)
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
val trainingMatrix = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
val trainingMatrix = new DMatrix(partitionItr, cacheFileName)
try {
if (params.contains("groupData") && params("groupData") != null) {
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
@@ -221,7 +221,7 @@ object XGBoost extends Serializable {
private def overrideParamsAccordingToTaskCPUs(
params: Map[String, Any],
sc: SparkContext): Map[String, Any] = {
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
var overridedParams = params
if (overridedParams.contains("nthread")) {
val nThread = overridedParams("nthread").toString.toInt

View File

@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
@@ -66,7 +66,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
if (testSamples.nonEmpty) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
val dMatrix = new DMatrix(testSamples)
try {
broadcastBooster.value.predictLeaf(dMatrix).iterator
} finally {
@@ -202,7 +202,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
null
}
}
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
val dMatrix = new DMatrix(testSamples, cacheFileName)
try {
broadcastBooster.value.predict(dMatrix).iterator
} finally {