[jvm-packages] support spark 2.4 and compatibility test with previous xgboost version (#4377)
* bump spark version * keep float.nan * handle brokenly changed name/value * add test * add model files * add model files * update doc
This commit is contained in:
@@ -19,9 +19,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.{File, FileNotFoundException}
|
||||
import java.util.Arrays
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
@@ -162,5 +164,17 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||
}
|
||||
|
||||
test("cross-version model loading (0.82)") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
model.transform(assembler.transform(df)).show()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -261,6 +261,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
.setHandleInvalid("keep")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||
|
||||
Reference in New Issue
Block a user