Scala 2.13 support. (#9099)

1. Updated the test logic
2. Added smoke tests for Spark examples.
3. Added integration tests for Spark with Scala 2.13
This commit is contained in:
Boris
2023-05-27 13:34:02 +02:00
committed by GitHub
parent 8c174ef2d3
commit a01df102c9
24 changed files with 325 additions and 160 deletions

View File

@@ -5,10 +5,11 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>xgboost4j_2.12</artifactId>
<name>xgboost4j</name>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<packaging>jar</packaging>
@@ -28,13 +29,13 @@
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.2.16</version>
<version>${scalatest.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

View File

@@ -37,7 +37,7 @@ trait EvalTrait extends IEvaluation {
*/
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
require(predicts.length == jdmat.getLabel.length, "predicts size and label size must match " +
s" predicts size: ${predicts.length}, label size: ${jdmat.getLabel.length}")
eval(predicts, new DMatrix(jdmat))

View File

@@ -31,7 +31,7 @@ trait ObjectiveTrait extends IObjective {
*/
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
java.util.List[Array[Float]] = {
getGradient(predicts, new DMatrix(dtrain)).asJava
}

View File

@@ -17,12 +17,11 @@
package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoostError, XGBoost => JXGBoost}
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.Path
/**
* XGBoost Scala Training function.
@@ -40,7 +39,12 @@ object XGBoost {
earlyStoppingRound: Int = 0,
prevBooster: Booster,
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
// we have to filter null value for customized obj and eval
val jParams: java.util.Map[String, AnyRef] =
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava
val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
val jBooster = if (prevBooster == null) {
null
} else {
@@ -51,8 +55,7 @@ object XGBoost {
map(cp => {
JXGBoost.trainAndSaveCheckpoint(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
jParams,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
cp.checkpointInterval,
cp.checkpointPath,
@@ -61,8 +64,7 @@ object XGBoost {
getOrElse(
JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
jParams,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
)
if (prevBooster == null) {