[jvm-packages] Fixed test/train persistence (#2949)

* [jvm-packages] Fixed test/train persistence

Prior to this patch both data sets were persisted in the same directory,
i.e. the test data replaced the training one which led to

* training on less data (since usually test < train) and
* test loss being exactly equal to the training loss.

Closes #2945.

* Cleanup file cache after the training

* Addressed review comments
This commit is contained in:
Sergei Lebedev 2017-12-19 16:11:48 +01:00 committed by Nan Zhu
parent 7759ab99ee
commit 7c6673cb9e
2 changed files with 37 additions and 20 deletions

View File

@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import scala.collection.mutable import scala.collection.mutable
import scala.util.Random import scala.util.Random
@ -112,7 +114,6 @@ object XGBoost extends Serializable {
data data
} }
val partitionedBaseMargin = partitionedData.map(_.baseMargin) val partitionedBaseMargin = partitionedData.map(_.baseMargin)
val appName = partitionedData.context.appName
// to workaround the empty partitions in training dataset, // to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see // this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277) // (https://github.com/dmlc/xgboost/issues/1277)
@ -122,17 +123,21 @@ object XGBoost extends Serializable {
s"detected an empty partition in the training data, partition ID:" + s"detected an empty partition in the training data, partition ID:" +
s" ${TaskContext.getPartitionId()}") s" ${TaskContext.getPartitionId()}")
} }
val cacheFileName = if (useExternalMemory) { val taskId = TaskContext.getPartitionId().toString
s"$appName-${TaskContext.get().stageId()}-" + val cacheDirName = if (useExternalMemory) {
s"dtrain_cache-${TaskContext.getPartitionId()}" val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId")
if (!(dir.exists() || dir.mkdirs())) {
throw new XGBoostError(s"failed to create cache directory: $dir")
}
Some(dir.toString)
} else { } else {
null None
} }
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv) Rabit.init(rabitEnv)
val watches = Watches(params, val watches = Watches(params,
removeMissingValues(labeledPoints, missing), removeMissingValues(labeledPoints, missing),
fromBaseMarginsToArray(baseMargins), cacheFileName) fromBaseMarginsToArray(baseMargins), cacheDirName)
try { try {
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds") val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
@ -442,7 +447,10 @@ object XGBoost extends Serializable {
} }
} }
private class Watches private(val train: DMatrix, val test: DMatrix) { private class Watches private(
val train: DMatrix,
val test: DMatrix,
private val cacheDirName: Option[String]) {
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
.filter { case (_, matrix) => matrix.rowNum > 0 } .filter { case (_, matrix) => matrix.rowNum > 0 }
@ -451,6 +459,13 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
def delete(): Unit = { def delete(): Unit = {
toMap.values.foreach(_.delete()) toMap.values.foreach(_.delete())
cacheDirName.foreach { name =>
for (cacheFile <- new File(name).listFiles()) {
if (!cacheFile.delete()) {
throw new IllegalStateException(s"failed to delete $cacheFile")
}
}
}
} }
override def toString: String = toMap.toString override def toString: String = toMap.toString
@ -462,7 +477,7 @@ private object Watches {
params: Map[String, Any], params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint], labeledPoints: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]], baseMarginsOpt: Option[Array[Float]],
cacheFileName: String): Watches = { cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0) val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed) val r = new Random(seed)
@ -475,8 +490,8 @@ private object Watches {
accepted accepted
} }
val trainMatrix = new DMatrix(trainPoints, cacheFileName) val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val testMatrix = new DMatrix(testPoints.iterator, cacheFileName) val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
r.setSeed(seed) r.setSeed(seed)
for (baseMargins <- baseMarginsOpt) { for (baseMargins <- baseMarginsOpt) {
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
@ -489,6 +504,6 @@ private object Watches {
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
TaskContext.getPartitionId()).toArray) TaskContext.getPartitionId()).toArray)
} }
new Watches(train = trainMatrix, test = testMatrix) new Watches(trainMatrix, testMatrix, cacheDirName)
} }
} }

View File

@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.scalatest.prop.TableDrivenPropertyChecks
class XGBoostDFSuite extends FunSuite with PerTest { class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
private def buildDataFrame( private def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint], labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = { numPartitions: Int = numWorkers): DataFrame = {
@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest {
test("train/test split") { test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5") "objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
val trainingDf = buildDataFrame(Classification.train) val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers) forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
assert(testObjectiveHistory.length === 5) nWorkers = numWorkers, useExternalMemory = useExternalMemory)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
} }
} }