[jvm-packages] update spark dependency to 3.0.0 (#5836)
This commit is contained in:
parent
23e2c6ec91
commit
9f85e92602
@ -34,7 +34,7 @@
|
|||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<flink.version>1.7.2</flink.version>
|
<flink.version>1.7.2</flink.version>
|
||||||
<spark.version>2.4.3</spark.version>
|
<spark.version>3.0.0</spark.version>
|
||||||
<scala.version>2.12.8</scala.version>
|
<scala.version>2.12.8</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
<hadoop.version>2.7.3</hadoop.version>
|
<hadoop.version>2.7.3</hadoop.version>
|
||||||
|
|||||||
@ -275,7 +275,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Actually we don't use this function at all, to make it pass compiler check.
|
// Actually we don't use this function at all, to make it pass compiler check.
|
||||||
override protected def predictRaw(features: Vector): Vector = {
|
override def predictRaw(features: Vector): Vector = {
|
||||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -60,13 +60,8 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
org.apache.spark.SPARK_VERSION match {
|
.setHandleInvalid("keep")
|
||||||
case version if version.startsWith("2.4") =>
|
|
||||||
val m = vectorAssembler.getClass.getDeclaredMethods
|
|
||||||
.filter(_.getName.contains("setHandleInvalid")).head
|
|
||||||
m.invoke(vectorAssembler, "keep")
|
|
||||||
case _ =>
|
|
||||||
}
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
||||||
|
|||||||
@ -127,7 +127,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
|
||||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
.withColumn("weight", getWeightFromId(col("id")))
|
.withColumn("weight", getWeightFromId(col("id")))
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user