[jvm-packages] fix compatibility problem of spark version (#4411)
* fix compatibility problem of spark version on MissingValueHandlingSuite.scala * call setHandleInvalid by runtime reflection
This commit is contained in:
parent
253fdd8a42
commit
797ba8e72d
@ -16,14 +16,13 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import org.scalatest.FunSuite
|
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||||
test("dense vectors containing missing value") {
|
test("dense vectors containing missing value") {
|
||||||
@ -61,7 +60,13 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
.setHandleInvalid("keep")
|
org.apache.spark.SPARK_VERSION match {
|
||||||
|
case version if version.startsWith("2.4") =>
|
||||||
|
val m = vectorAssembler.getClass.getDeclaredMethods
|
||||||
|
.filter(_.getName.contains("setHandleInvalid")).head
|
||||||
|
m.invoke(vectorAssembler, "keep")
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user