[jvm-packages] support spark 2.4 and compatibility test with previous xgboost version (#4377)

* bump spark version

* keep float.nan

* handle brokenly changed name/value

* add test

* add model files

* add model files

* update doc
This commit is contained in:
Nan Zhu
2019-04-17 11:33:13 -07:00
committed by GitHub
parent 711397d645
commit 65db8d0626
8 changed files with 36 additions and 7 deletions

View File

@@ -19,9 +19,11 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import java.util.Arrays
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.io.Source
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.network.util.JavaUtils
@@ -162,5 +164,17 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
}
test("cross-version model loading (0.82)") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
model.transform(assembler.transform(df)).show()
}
}

View File

@@ -261,6 +261,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
.setHandleInvalid("keep")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "num_workers" -> 1).toMap