[jvm-packages] update spark dependency to 3.0.0 (#5836)
This commit is contained in:
parent
23e2c6ec91
commit
9f85e92602
@ -34,7 +34,7 @@
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.7.2</flink.version>
|
||||
<spark.version>2.4.3</spark.version>
|
||||
<spark.version>3.0.0</spark.version>
|
||||
<scala.version>2.12.8</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>2.7.3</hadoop.version>
|
||||
|
||||
@ -275,7 +275,7 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override protected def predictRaw(features: Vector): Vector = {
|
||||
override def predictRaw(features: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||
}
|
||||
|
||||
|
||||
@ -60,13 +60,8 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
org.apache.spark.SPARK_VERSION match {
|
||||
case version if version.startsWith("2.4") =>
|
||||
val m = vectorAssembler.getClass.getDeclaredMethods
|
||||
.filter(_.getName.contains("setHandleInvalid")).head
|
||||
m.invoke(vectorAssembler, "keep")
|
||||
case _ =>
|
||||
}
|
||||
.setHandleInvalid("keep")
|
||||
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
||||
|
||||
@ -127,7 +127,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user