[jvm-packages] Fixed test/train persistence (#2949)

* [jvm-packages] Fixed test/train persistence

Prior to this patch both data sets were persisted in the same directory,
i.e. the test data replaced the training one which led to

* training on less data (since usually test < train) and
* test loss being exactly equal to the training loss.

Closes #2945.

* Cleanup file cache after the training

* Addressed review comments
This commit is contained in:
Sergei Lebedev
2017-12-19 16:11:48 +01:00
committed by Nan Zhu
parent 7759ab99ee
commit 7c6673cb9e
2 changed files with 37 additions and 20 deletions

View File

@@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite
import org.scalatest.prop.TableDrivenPropertyChecks
class XGBoostDFSuite extends FunSuite with PerTest {
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
private def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
@@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest {
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}
}