[jvm-packages] Fixed test/train persistence (#2949)
* [jvm-packages] Fixed test/train persistence Prior to this patch both data sets were persisted in the same directory, i.e. the test data replaced the training one which led to * training on less data (since usually test < train) and * test loss being exactly equal to the training loss. Closes #2945. * Cleanup file cache after the training * Addressed review comments
This commit is contained in:
parent
7759ab99ee
commit
7c6673cb9e
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
@ -112,7 +114,6 @@ object XGBoost extends Serializable {
|
|||||||
data
|
data
|
||||||
}
|
}
|
||||||
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
||||||
val appName = partitionedData.context.appName
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
@ -122,17 +123,21 @@ object XGBoost extends Serializable {
|
|||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s" ${TaskContext.getPartitionId()}")
|
s" ${TaskContext.getPartitionId()}")
|
||||||
}
|
}
|
||||||
val cacheFileName = if (useExternalMemory) {
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
s"$appName-${TaskContext.get().stageId()}-" +
|
val cacheDirName = if (useExternalMemory) {
|
||||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||||
|
if (!(dir.exists() || dir.mkdirs())) {
|
||||||
|
throw new XGBoostError(s"failed to create cache directory: $dir")
|
||||||
|
}
|
||||||
|
Some(dir.toString)
|
||||||
} else {
|
} else {
|
||||||
null
|
None
|
||||||
}
|
}
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val watches = Watches(params,
|
val watches = Watches(params,
|
||||||
removeMissingValues(labeledPoints, missing),
|
removeMissingValues(labeledPoints, missing),
|
||||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||||
@ -442,7 +447,10 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
private class Watches private(
|
||||||
|
val train: DMatrix,
|
||||||
|
val test: DMatrix,
|
||||||
|
private val cacheDirName: Option[String]) {
|
||||||
|
|
||||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||||
@ -451,6 +459,13 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
|
|||||||
|
|
||||||
def delete(): Unit = {
|
def delete(): Unit = {
|
||||||
toMap.values.foreach(_.delete())
|
toMap.values.foreach(_.delete())
|
||||||
|
cacheDirName.foreach { name =>
|
||||||
|
for (cacheFile <- new File(name).listFiles()) {
|
||||||
|
if (!cacheFile.delete()) {
|
||||||
|
throw new IllegalStateException(s"failed to delete $cacheFile")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = toMap.toString
|
override def toString: String = toMap.toString
|
||||||
@ -462,7 +477,7 @@ private object Watches {
|
|||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
baseMarginsOpt: Option[Array[Float]],
|
baseMarginsOpt: Option[Array[Float]],
|
||||||
cacheFileName: String): Watches = {
|
cacheDirName: Option[String]): Watches = {
|
||||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||||
val r = new Random(seed)
|
val r = new Random(seed)
|
||||||
@ -475,8 +490,8 @@ private object Watches {
|
|||||||
|
|
||||||
accepted
|
accepted
|
||||||
}
|
}
|
||||||
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||||
val testMatrix = new DMatrix(testPoints.iterator, cacheFileName)
|
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||||
r.setSeed(seed)
|
r.setSeed(seed)
|
||||||
for (baseMargins <- baseMarginsOpt) {
|
for (baseMargins <- baseMarginsOpt) {
|
||||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||||
@ -489,6 +504,6 @@ private object Watches {
|
|||||||
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||||
TaskContext.getPartitionId()).toArray)
|
TaskContext.getPartitionId()).toArray)
|
||||||
}
|
}
|
||||||
new Watches(train = trainMatrix, test = testMatrix)
|
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.DataTypes
|
import org.apache.spark.sql.types.DataTypes
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||||
|
|
||||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
|
||||||
private def buildDataFrame(
|
private def buildDataFrame(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
|||||||
test("train/test split") {
|
test("train/test split") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||||
|
|
||||||
val trainingDf = buildDataFrame(Classification.train)
|
val trainingDf = buildDataFrame(Classification.train)
|
||||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
|
||||||
nWorkers = numWorkers)
|
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
|
||||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||||
assert(testObjectiveHistory.length === 5)
|
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
|
||||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||||
|
assert(testObjectiveHistory.length === 5)
|
||||||
|
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user