[jvm-packages] fix compatibility problem of spark version (#4411)

* fix compatibility problem of spark version on MissingValueHandlingSuite.scala

* call setHandleInvalid by runtime reflection
This commit is contained in:
Xu Xiao 2019-05-01 00:13:05 +08:00 committed by Nan Zhu
parent 253fdd8a42
commit 797ba8e72d

View File

@ -16,14 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.util.Random
import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.scalatest.FunSuite
import scala.util.Random
class MissingValueHandlingSuite extends FunSuite with PerTest { class MissingValueHandlingSuite extends FunSuite with PerTest {
test("dense vectors containing missing value") { test("dense vectors containing missing value") {
@ -61,7 +60,13 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
val vectorAssembler = new VectorAssembler() val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3")) .setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features") .setOutputCol("features")
.setHandleInvalid("keep") org.apache.spark.SPARK_VERSION match {
case version if version.startsWith("2.4") =>
val m = vectorAssembler.getClass.getDeclaredMethods
.filter(_.getName.contains("setHandleInvalid")).head
m.invoke(vectorAssembler, "keep")
case _ =>
}
val inputDF = vectorAssembler.transform(testDF).select("features", "label") val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2", val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap "objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap