Scala 2.13 support. (#9099)
1. Updated the test logic 2. Added smoke tests for Spark examples. 3. Added integration tests for Spark with Scala 2.13
This commit is contained in:
@@ -73,12 +73,13 @@ object DistTrainWithFlink {
|
||||
.map(_.f1.f0)
|
||||
.returns(testDataTypeHint)
|
||||
|
||||
val paramMap = mapAsJavaMap(Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
))
|
||||
val paramMap = Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
)
|
||||
.asJava
|
||||
|
||||
// number of iterations
|
||||
val round = 2
|
||||
|
||||
@@ -20,10 +20,9 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.tuning._
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
|
||||
|
||||
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
|
||||
|
||||
@@ -50,6 +49,13 @@ object SparkMLlibPipeline {
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
.getOrCreate()
|
||||
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
|
||||
.show(false)
|
||||
}
|
||||
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
||||
pipelineModelPath: String, treeMethod: String,
|
||||
numWorkers: Int): DataFrame = {
|
||||
|
||||
// Load dataset
|
||||
val schema = new StructType(Array(
|
||||
StructField("sepal length", DoubleType, true),
|
||||
@@ -90,11 +96,11 @@ object SparkMLlibPipeline {
|
||||
val labelConverter = new IndexToString()
|
||||
.setInputCol("prediction")
|
||||
.setOutputCol("realLabel")
|
||||
.setLabels(labelIndexer.labels)
|
||||
.setLabels(labelIndexer.labelsArray(0))
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(assembler, labelIndexer, booster, labelConverter))
|
||||
val model = pipeline.fit(training)
|
||||
val model: PipelineModel = pipeline.fit(training)
|
||||
|
||||
// Batch prediction
|
||||
val prediction = model.transform(test)
|
||||
@@ -136,6 +142,6 @@ object SparkMLlibPipeline {
|
||||
|
||||
// Load a saved model and serving
|
||||
val model2 = PipelineModel.load(pipelineModelPath)
|
||||
model2.transform(test).show(false)
|
||||
model2.transform(test)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,8 @@
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
||||
|
||||
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
|
||||
|
||||
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
|
||||
@@ -38,6 +37,12 @@ object SparkTraining {
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
|
||||
results.show()
|
||||
}
|
||||
|
||||
private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
treeMethod: String, numWorkers: Int): DataFrame = {
|
||||
val schema = new StructType(Array(
|
||||
StructField("sepal length", DoubleType, true),
|
||||
StructField("sepal width", DoubleType, true),
|
||||
@@ -81,7 +86,6 @@ object SparkTraining {
|
||||
setFeaturesCol("features").
|
||||
setLabelCol("classIndex")
|
||||
val xgbClassificationModel = xgbClassifier.fit(train)
|
||||
val results = xgbClassificationModel.transform(test)
|
||||
results.show()
|
||||
xgbClassificationModel.transform(test)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user