[jvm-packages] Fixed test/train persistence (#2949)
* [jvm-packages] Fixed test/train persistence Prior to this patch both data sets were persisted in the same directory, i.e. the test data replaced the training one which led to * training on less data (since usually test < train) and * test loss being exactly equal to the training loss. Closes #2945. * Cleanup file cache after the training * Addressed review comments
This commit is contained in:
parent
7759ab99ee
commit
7c6673cb9e
@ -16,6 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
@ -112,7 +114,6 @@ object XGBoost extends Serializable {
|
||||
data
|
||||
}
|
||||
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
||||
val appName = partitionedData.context.appName
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
@ -122,17 +123,21 @@ object XGBoost extends Serializable {
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val cacheFileName = if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-" +
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val cacheDirName = if (useExternalMemory) {
|
||||
val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
if (!(dir.exists() || dir.mkdirs())) {
|
||||
throw new XGBoostError(s"failed to create cache directory: $dir")
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Some(dir.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
Rabit.init(rabitEnv)
|
||||
val watches = Watches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
||||
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||
@ -442,7 +447,10 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||
private class Watches private(
|
||||
val train: DMatrix,
|
||||
val test: DMatrix,
|
||||
private val cacheDirName: Option[String]) {
|
||||
|
||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
@ -451,6 +459,13 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||
|
||||
def delete(): Unit = {
|
||||
toMap.values.foreach(_.delete())
|
||||
cacheDirName.foreach { name =>
|
||||
for (cacheFile <- new File(name).listFiles()) {
|
||||
if (!cacheFile.delete()) {
|
||||
throw new IllegalStateException(s"failed to delete $cacheFile")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = toMap.toString
|
||||
@ -462,7 +477,7 @@ private object Watches {
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
baseMarginsOpt: Option[Array[Float]],
|
||||
cacheFileName: String): Watches = {
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
@ -475,8 +490,8 @@ private object Watches {
|
||||
|
||||
accepted
|
||||
}
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheFileName)
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
r.setSeed(seed)
|
||||
for (baseMargins <- baseMarginsOpt) {
|
||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
@ -489,6 +504,6 @@ private object Watches {
|
||||
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
new Watches(train = trainMatrix, test = testMatrix)
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
|
||||
private def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
|
||||
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user