[jvm-packages] fixed several issues in unit tests (#2173)
* add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * fix several issues in tests
This commit is contained in:
parent
2715baef64
commit
8d8cbcc6db
@ -222,22 +222,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
|
||||
}
|
||||
|
||||
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = {
|
||||
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): Seq[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += convertCSVPointToLabelPoint(sample.split(","))
|
||||
}
|
||||
sampleList.toList
|
||||
sampleList
|
||||
}
|
||||
|
||||
test("multi_class classification test") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers)
|
||||
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
import spark.implicits._
|
||||
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||
}
|
||||
|
||||
test("test DF use nested groupData") {
|
||||
|
||||
@ -16,16 +16,35 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
case class Foobar(TARGET: Int, bar: Double, baz: Double)
|
||||
|
||||
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
delete(new File("./testxgbPipe"))
|
||||
delete(new File("./test2xgbPipe"))
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
|
||||
test("test sparks pipeline persistence of dataframe-based model") {
|
||||
// maybe move to shared context, but requires session to import implicits.
|
||||
// what about introducing https://github.com/holdenk/spark-testing-base ?
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user