Scala 2.13 support. (#9099)
1. Updated the test logic 2. Added smoke tests for Spark examples. 3. Added integration tests for Spark with Scala 2.13
This commit is contained in:
@@ -5,10 +5,11 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<name>xgboost4j</name>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
@@ -28,13 +29,13 @@
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.2.16</version>
|
||||
<version>${scalatest.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@@ -37,7 +37,7 @@ trait EvalTrait extends IEvaluation {
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
|
||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
|
||||
def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
|
||||
require(predicts.length == jdmat.getLabel.length, "predicts size and label size must match " +
|
||||
s" predicts size: ${predicts.length}, label size: ${jdmat.getLabel.length}")
|
||||
eval(predicts, new DMatrix(jdmat))
|
||||
|
||||
@@ -31,7 +31,7 @@ trait ObjectiveTrait extends IObjective {
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
|
||||
|
||||
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||
java.util.List[Array[Float]] = {
|
||||
getGradient(predicts, new DMatrix(dtrain)).asJava
|
||||
}
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, XGBoost => JXGBoost}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
@@ -40,7 +39,12 @@ object XGBoost {
|
||||
earlyStoppingRound: Int = 0,
|
||||
prevBooster: Booster,
|
||||
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
|
||||
// we have to filter null value for customized obj and eval
|
||||
val jParams: java.util.Map[String, AnyRef] =
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava
|
||||
|
||||
val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
|
||||
val jBooster = if (prevBooster == null) {
|
||||
null
|
||||
} else {
|
||||
@@ -51,8 +55,7 @@ object XGBoost {
|
||||
map(cp => {
|
||||
JXGBoost.trainAndSaveCheckpoint(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
jParams,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||
cp.checkpointInterval,
|
||||
cp.checkpointPath,
|
||||
@@ -61,8 +64,7 @@ object XGBoost {
|
||||
getOrElse(
|
||||
JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
jParams,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
)
|
||||
if (prevBooster == null) {
|
||||
|
||||
Reference in New Issue
Block a user