[jvm-packages] More brooming in tests (#2517)
* Deduplicated DataFrame creation in XGBoostDFSuite * Extracted dermatology.data into MultiClassification * Moved cache cleaning to SharedSparkContext Cache files are prefixed with appName therefore this seems to be just the place to delete them. * Removed redundant JMatrix calls in xgboost4j-spark * Slightly more readable buildDenseRDD in XGBoostGeneralSuite * Generalized train/test DataFrame construction in XGBoostDFSuite * Changed SharedSparkContext to setup a new context per-test Hence the new name: PerTestSparkSession :) * Fused Utils into PerTestSparkSession * Whitespace fix in XGBoostDFSuite * Ensure SparkSession is always eagerly created in PerTestSparkSession * Renamed PerTestSparkSession->PerTest because it was doing slightly more than creating/stopping the session.
This commit is contained in:
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
@@ -115,7 +115,7 @@ object XGBoost extends Serializable {
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv)
|
||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||
val trainingMatrix = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||
val trainingMatrix = new DMatrix(partitionItr, cacheFileName)
|
||||
try {
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
@@ -221,7 +221,7 @@ object XGBoost extends Serializable {
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
|
||||
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
||||
var overridedParams = params
|
||||
if (overridedParams.contains("nthread")) {
|
||||
val nThread = overridedParams("nthread").toString.toInt
|
||||
|
||||
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||
@@ -66,7 +66,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
if (testSamples.nonEmpty) {
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||
val dMatrix = new DMatrix(testSamples)
|
||||
try {
|
||||
broadcastBooster.value.predictLeaf(dMatrix).iterator
|
||||
} finally {
|
||||
@@ -202,7 +202,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||
val dMatrix = new DMatrix(testSamples, cacheFileName)
|
||||
try {
|
||||
broadcastBooster.value.predict(dMatrix).iterator
|
||||
} finally {
|
||||
|
||||
Reference in New Issue
Block a user