diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala index 59b0fb1cd..e09f10a4c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala @@ -16,14 +16,13 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.util.Random - import ml.dmlc.xgboost4j.java.XGBoostError -import org.scalatest.FunSuite - import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.DataFrame +import org.scalatest.FunSuite + +import scala.util.Random class MissingValueHandlingSuite extends FunSuite with PerTest { test("dense vectors containing missing value") { @@ -61,7 +60,13 @@ class MissingValueHandlingSuite extends FunSuite with PerTest { val vectorAssembler = new VectorAssembler() .setInputCols(Array("col1", "col2", "col3")) .setOutputCol("features") - .setHandleInvalid("keep") + org.apache.spark.SPARK_VERSION match { + case version if version.startsWith("2.4") => + val m = vectorAssembler.getClass.getDeclaredMethods + .filter(_.getName.contains("setHandleInvalid")).head + m.invoke(vectorAssembler, "keep") + case _ => + } val inputDF = vectorAssembler.transform(testDF).select("features", "label") val paramMap = List("eta" -> "1", "max_depth" -> "2", "objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap