[jvm-packages] fixed several issues in unit tests (#2173)

* add back train method but mark as deprecated

* fix scalastyle error

* change class to object in examples

* fix compilation error

* fix several issues in tests
This commit is contained in:
Nan Zhu 2017-04-06 06:25:23 -07:00 committed by GitHub
parent 2715baef64
commit 8d8cbcc6db
2 changed files with 26 additions and 7 deletions

View File

@ -222,22 +222,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1))) LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
} }
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = { private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): Seq[LabeledPoint] = {
val file = Source.fromFile(new File(filePath)) val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint] val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) { for (sample <- file.getLines()) {
sampleList += convertCSVPointToLabelPoint(sample.split(",")) sampleList += convertCSVPointToLabelPoint(sample.split(","))
} }
sampleList.toList sampleList
} }
test("multi_class classification test") { test("multi_class classification test") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6") "objective" -> "multi:softmax", "num_class" -> "6")
val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
val trainingDF = buildTrainingDataframe() val spark = SparkSession.builder().getOrCreate()
XGBoost.trainWithDataFrame(trainingDF, paramMap, import spark.implicits._
round = 5, nWorkers = numWorkers) XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
} }
test("test DF use nested groupData") { test("test DF use nested groupData") {

View File

@ -16,16 +16,35 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import org.apache.spark.SparkConf import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._ import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession
import scala.concurrent.duration._ import scala.concurrent.duration._
case class Foobar(TARGET: Int, bar: Double, baz: Double) case class Foobar(TARGET: Int, bar: Double, baz: Double)
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils { class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
override def afterAll(): Unit = {
super.afterAll()
delete(new File("./testxgbPipe"))
delete(new File("./test2xgbPipe"))
}
private def delete(f: File) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
test("test sparks pipeline persistence of dataframe-based model") { test("test sparks pipeline persistence of dataframe-based model") {
// maybe move to shared context, but requires session to import implicits. // maybe move to shared context, but requires session to import implicits.
// what about introducing https://github.com/holdenk/spark-testing-base ? // what about introducing https://github.com/holdenk/spark-testing-base ?