[jvm-packages] Fixed test/train persistence (#2949)
* [jvm-packages] Fixed test/train persistence Prior to this patch both data sets were persisted in the same directory, i.e. the test data replaced the training one which led to * training on less data (since usually test < train) and * test loss being exactly equal to the training loss. Closes #2945. * Cleanup file cache after the training * Addressed review comments
This commit is contained in:
@@ -18,15 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
|
||||
private def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
@@ -252,12 +252,14 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
|
||||
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user