[jvm-packages] jvm test should clean up after themselfs (#4706)

This commit is contained in:
Oleksandr Pryimak 2019-08-04 14:09:11 -07:00 committed by Nan Zhu
parent 4fe0d8203e
commit b68de018b8
5 changed files with 62 additions and 53 deletions

View File

@ -17,12 +17,11 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.FunSuite
import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.{FileSystem, Path}
class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll { class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private lazy val (model4, model8) = { private lazy val (model4, model8) = {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
@ -33,7 +32,7 @@ class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAl
} }
test("test update/load models") { test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model4._booster) manager.updateCheckpoint(model4._booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
@ -49,7 +48,7 @@ class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAl
} }
test("test cleanUpHigherVersions") { test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8) manager.cleanUpHigherVersions(round = 8)
@ -60,7 +59,7 @@ class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAl
} }
test("test checkpoint rounds") { test("test checkpoint rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException} import java.io.File
import java.util.Arrays import java.util.Arrays
import scala.io.Source import scala.io.Source
@ -26,40 +26,9 @@ import scala.util.Random
import org.apache.spark.ml.feature._ import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.network.util.JavaUtils import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll { class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName)
if (tempDir.exists) {
tempDir.delete
}
tempDir.mkdirs
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir)
super.afterAll()
}
private def delete(f: File) {
if (f.exists) {
if (f.isDirectory) {
for (c <- f.listFiles) {
delete(c)
}
}
if (!f.delete) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") { test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError() val eval = new EvalError()
@ -69,7 +38,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers) "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbc = new XGBoostClassifier(paramMap) val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir, "xgbc").getPath val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath) xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath) val xgbc2 = XGBoostClassifier.load(xgbcPath)
val paramMap2 = xgbc2.MLlib2XGBoostParams val paramMap2 = xgbc2.MLlib2XGBoostParams
@ -80,7 +49,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
val model = xgbc.fit(trainingDF) val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1) assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir, "xgbcModel").getPath val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath) model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath) val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
@ -100,7 +69,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers) "objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbr = new XGBoostRegressor(paramMap) val xgbr = new XGBoostRegressor(paramMap)
val xgbrPath = new File(tempDir, "xgbr").getPath val xgbrPath = new File(tempDir.toFile, "xgbr").getPath
xgbr.write.overwrite().save(xgbrPath) xgbr.write.overwrite().save(xgbrPath)
val xgbr2 = XGBoostRegressor.load(xgbrPath) val xgbr2 = XGBoostRegressor.load(xgbrPath)
val paramMap2 = xgbr2.MLlib2XGBoostParams val paramMap2 = xgbr2.MLlib2XGBoostParams
@ -111,7 +80,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
val model = xgbr.fit(trainingDF) val model = xgbr.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1) assert(evalResults < 0.1)
val xgbrModelPath = new File(tempDir, "xgbrModel").getPath val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath
model.write.overwrite.save(xgbrModelPath) model.write.overwrite.save(xgbrModelPath)
val model2 = XGBoostRegressionModel.load(xgbrModelPath) val model2 = XGBoostRegressionModel.load(xgbrModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
@ -140,7 +109,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
// Construct MLlib pipeline, save and load // Construct MLlib pipeline, save and load
val pipeline = new Pipeline().setStages(Array(assembler, xgb)) val pipeline = new Pipeline().setStages(Array(assembler, xgb))
val pipePath = new File(tempDir, "pipeline").getPath val pipePath = new File(tempDir.toFile, "pipeline").getPath
pipeline.write.overwrite().save(pipePath) pipeline.write.overwrite().save(pipePath)
val pipeline2 = Pipeline.read.load(pipePath) val pipeline2 = Pipeline.read.load(pipePath)
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier] val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
@ -151,7 +120,7 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
// Model training, save and load // Model training, save and load
val pipeModel = pipeline.fit(df) val pipeModel = pipeline.fit(df)
val pipeModelPath = new File(tempDir, "pipelineModel").getPath val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath
pipeModel.write.overwrite.save(pipeModelPath) pipeModel.write.overwrite.save(pipeModelPath)
val pipeModel2 = PipelineModel.load(pipeModelPath) val pipeModel2 = PipelineModel.load(pipeModelPath)

View File

@ -0,0 +1,42 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.{Files, Path}
import org.apache.spark.network.util.JavaUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory(getClass.getName)
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir.toFile)
super.afterAll()
}
protected def createTmpFolder(prefix: String): Path = {
Files.createTempDirectory(tempDir, prefix)
}
}

View File

@ -30,7 +30,7 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.VectorAssembler
class XGBoostGeneralSuite extends FunSuite with PerTest { class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") { test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train) val trainingRDD = sc.parallelize(Classification.train)
@ -184,7 +184,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator) val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2, val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) "checkpoint_interval" -> 2, "num_workers" -> numWorkers)
@ -211,7 +211,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator) val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2, val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true) "checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)

View File

@ -126,11 +126,10 @@ public class BoosterImplTest {
Booster booster = trainBooster(trainMat, testMat); Booster booster = trainBooster(trainMat, testMat);
Path tempDir = Files.createTempDirectory("boosterTest-"); ByteArrayOutputStream output = new ByteArrayOutputStream();
File tempFile = Files.createTempFile("", "").toFile(); booster.saveModel(output);
booster.saveModel(new FileOutputStream(tempFile));
IEvaluation eval = new EvalError(); IEvaluation eval = new EvalError();
Booster loadedBooster = XGBoost.loadModel(new FileInputStream(tempFile)); Booster loadedBooster = XGBoost.loadModel(new ByteArrayInputStream(output.toByteArray()));
float originalPredictError = eval.eval(booster.predict(testMat, true), testMat); float originalPredictError = eval.eval(booster.predict(testMat, true), testMat);
TestCase.assertTrue("originalPredictErr:" + originalPredictError, TestCase.assertTrue("originalPredictErr:" + originalPredictError,
originalPredictError < 0.1f); originalPredictError < 0.1f);