[jvm-packages] fixed several issues in unit tests (#2173)

* add back train method but mark as deprecated

* fix scalastyle error

* change class to object in examples

* fix compilation error

* fix several issues in tests
This commit is contained in:
Nan Zhu 2017-04-06 06:25:23 -07:00 committed by GitHub
parent 2715baef64
commit 8d8cbcc6db
2 changed files with 26 additions and 7 deletions

View File

@ -222,22 +222,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
}
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = {
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): Seq[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += convertCSVPointToLabelPoint(sample.split(","))
}
sampleList.toList
sampleList
}
test("multi_class classification test") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator
val trainingDF = buildTrainingDataframe()
XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers)
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
}
test("test DF use nested groupData") {

View File

@ -16,16 +16,35 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import scala.concurrent.duration._
case class Foobar(TARGET: Int, bar: Double, baz: Double)
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
override def afterAll(): Unit = {
super.afterAll()
delete(new File("./testxgbPipe"))
delete(new File("./test2xgbPipe"))
}
private def delete(f: File) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
test("test sparks pipeline persistence of dataframe-based model") {
// maybe move to shared context, but requires session to import implicits.
// what about introducing https://github.com/holdenk/spark-testing-base ?