diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7bc228513..3ad724a94 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark +import java.io.File + import scala.collection.mutable import scala.util.Random @@ -112,7 +114,6 @@ object XGBoost extends Serializable { data } val partitionedBaseMargin = partitionedData.map(_.baseMargin) - val appName = partitionedData.context.appName // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) @@ -122,17 +123,21 @@ object XGBoost extends Serializable { s"detected an empty partition in the training data, partition ID:" + s" ${TaskContext.getPartitionId()}") } - val cacheFileName = if (useExternalMemory) { - s"$appName-${TaskContext.get().stageId()}-" + - s"dtrain_cache-${TaskContext.getPartitionId()}" + val taskId = TaskContext.getPartitionId().toString + val cacheDirName = if (useExternalMemory) { + val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId") + if (!(dir.exists() || dir.mkdirs())) { + throw new XGBoostError(s"failed to create cache directory: $dir") + } + Some(dir.toString) } else { - null + None } - rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + rabitEnv.put("DMLC_TASK_ID", taskId) Rabit.init(rabitEnv) val watches = Watches(params, removeMissingValues(labeledPoints, missing), - fromBaseMarginsToArray(baseMargins), cacheFileName) + fromBaseMarginsToArray(baseMargins), cacheDirName) try { val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds") @@ -442,7 +447,10 @@ object XGBoost extends Serializable { } } -private class Watches private(val train: DMatrix, val test: DMatrix) { +private class Watches private( + val train: DMatrix, + val test: DMatrix, + private val cacheDirName: Option[String]) { def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) .filter { case (_, matrix) => matrix.rowNum > 0 } @@ -451,6 +459,13 @@ private class Watches private(val train: DMatrix, val test: DMatrix) { def delete(): Unit = { toMap.values.foreach(_.delete()) + cacheDirName.foreach { name => + for (cacheFile <- new File(name).listFiles()) { + if (!cacheFile.delete()) { + throw new IllegalStateException(s"failed to delete $cacheFile") + } + } + } } override def toString: String = toMap.toString @@ -462,7 +477,7 @@ private object Watches { params: Map[String, Any], labeledPoints: Iterator[XGBLabeledPoint], baseMarginsOpt: Option[Array[Float]], - cacheFileName: String): Watches = { + cacheDirName: Option[String]): Watches = { val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0) val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) val r = new Random(seed) @@ -475,8 +490,8 @@ private object Watches { accepted } - val trainMatrix = new DMatrix(trainPoints, cacheFileName) - val testMatrix = new DMatrix(testPoints.iterator, cacheFileName) + val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull) + val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull) r.setSeed(seed) for (baseMargins <- baseMarginsOpt) { val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) @@ -489,6 +504,6 @@ private object Watches { trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( TaskContext.getPartitionId()).toArray) } - new Watches(train = trainMatrix, test = testMatrix) + new Watches(trainMatrix, testMatrix, cacheDirName) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index d427080dd..4cdcaaf39 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes import org.scalatest.FunSuite +import org.scalatest.prop.TableDrivenPropertyChecks -class XGBoostDFSuite extends FunSuite with PerTest { +class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks { private def buildDataFrame( labeledPoints: Seq[XGBLabeledPoint], numPartitions: Int = numWorkers): DataFrame = { @@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest { test("train/test split") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "trainTestRatio" -> "0.5") - val trainingDf = buildDataFrame(Classification.train) - val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5, - nWorkers = numWorkers) - val Some(testObjectiveHistory) = model.summary.testObjectiveHistory - assert(testObjectiveHistory.length === 5) - assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) + + forAll(Table("useExternalMemory", false, true)) { useExternalMemory => + val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = useExternalMemory) + val Some(testObjectiveHistory) = model.summary.testObjectiveHistory + assert(testObjectiveHistory.length === 5) + assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) + } } }