diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index f8098c0de..4afa4cddc 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -222,22 +222,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1))) } - private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = { + private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): Seq[LabeledPoint] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[LabeledPoint] for (sample <- file.getLines()) { sampleList += convertCSVPointToLabelPoint(sample.split(",")) } - sampleList.toList + sampleList } test("multi_class classification test") { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", "objective" -> "multi:softmax", "num_class" -> "6") - val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator - val trainingDF = buildTrainingDataframe() - XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers) + val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile) + val spark = SparkSession.builder().getOrCreate() + import spark.implicits._ + XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers) } test("test DF use nested groupData") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala index 9f339f135..07a289528 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala @@ -16,16 +16,35 @@ package ml.dmlc.xgboost4j.scala.spark +import java.io.{File, FileNotFoundException} + import org.apache.spark.SparkConf import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.SparkSession - import scala.concurrent.duration._ case class Foobar(TARGET: Int, bar: Double, baz: Double) class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils { + + override def afterAll(): Unit = { + super.afterAll() + delete(new File("./testxgbPipe")) + delete(new File("./test2xgbPipe")) + } + + private def delete(f: File) { + if (f.isDirectory()) { + for (c <- f.listFiles()) { + delete(c) + } + } + if (!f.delete()) { + throw new FileNotFoundException("Failed to delete file: " + f) + } + } + test("test sparks pipeline persistence of dataframe-based model") { // maybe move to shared context, but requires session to import implicits. // what about introducing https://github.com/holdenk/spark-testing-base ?