[jvm-packages] clean up example (#10618)

This commit is contained in:
Bobby Wang 2024-07-23 12:15:51 +08:00 committed by GitHub
parent 485d90218c
commit 003b418312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 58 additions and 52 deletions

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@ import scala.collection.mutable
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.example.util.DataLoader
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object BasicWalkThrough {
def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = {

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object BoostFromPrediction {
def main(args: Array[String]): Unit = {
@ -48,6 +49,6 @@ object BoostFromPrediction {
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
}
}

View File

@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object CrossValidation {
def main(args: Array[String]): Unit = {
@ -40,7 +40,6 @@ object CrossValidation {
// set additional eval_metrics
val metrics: Array[String] = null
val evalHist: Array[String] =
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,9 +18,10 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import org.apache.commons.logging.{Log, LogFactory}
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait}
import org.apache.commons.logging.{LogFactory, Log}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait, ObjectiveTrait, XGBoost}
/**
* an example user define objective and eval
@ -150,7 +151,7 @@ object CustomObjective {
val round = 2
// train a model
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap,
obj = new LogRegObj, eval = new EvalError)
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object ExternalMemory {
def main(args: Array[String]): Unit = {
@ -54,6 +54,6 @@ object ExternalMemory {
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
@ -51,7 +51,6 @@ object GeneralizedLinearModel {
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 4
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
val predicts = booster.predict(testMat)
val eval = new CustomEval

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,8 +17,8 @@ package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictFirstNTree {

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,11 +16,9 @@
package ml.dmlc.xgboost4j.scala.example
import java.util
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object PredictLeafIndices {
@ -49,7 +47,7 @@ object PredictLeafIndices {
// predict all trees
val leafIndex2 = booster.predictLeaf(testMat, 0)
for (leafs <- leafIndex) {
for (leafs <- leafIndex2) {
println(java.util.Arrays.toString(leafs))
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 - 2023 by Contributors
Copyright (c) 2014 - 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,12 +17,14 @@ package ml.dmlc.xgboost4j.scala.example.flink
import java.lang.{Double => JDouble, Long => JLong}
import java.nio.file.{Path, Paths}
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.ml.linalg.{Vector, Vectors}
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
import org.apache.flink.api.java.utils.DataSetUtils
import org.apache.flink.ml.linalg.{Vector, Vectors}
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
object DistTrainWithFlink {

View File

@ -22,6 +22,7 @@ import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types._
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
@ -87,11 +88,9 @@ object SparkMLlibPipeline {
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device
)
)
).setNumRound(10).setNumWorkers(numWorkers)
booster.setFeaturesCol("features")
booster.setLabelCol("classIndex")
val labelConverter = new IndexToString()

View File

@ -16,11 +16,13 @@
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
object SparkTraining {
@ -78,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2),
"device" -> device)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
.setNumWorkers(numWorkers)
.setNumRound(10)
val xgbClassificationModel = xgbClassifier.fit(train)
xgbClassificationModel.transform(test)
}

View File

@ -15,9 +15,10 @@
*/
package ml.dmlc.xgboost4j.scala.example.util
import org.apache.commons.logging.{Log, LogFactory}
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.{Log, LogFactory}
class CustomEval extends EvalTrait {
private val logger: Log = LogFactory.getLog(classOf[CustomEval])

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,12 +15,13 @@
*/
package ml.dmlc.xgboost4j.java.example.flink
import java.nio.file.Paths
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")

View File

@ -15,14 +15,15 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
import java.nio.file.Paths
import scala.jdk.CollectionConverters._
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
import scala.jdk.CollectionConverters._
class DistTrainWithFlinkSuite extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,16 +15,18 @@
*/
package ml.dmlc.xgboost4j.scala.example.spark
import java.io.File
import java.nio.file.{Files, StandardOpenOption}
import scala.jdk.CollectionConverters._
import scala.util.{Random, Try}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.slf4j.LoggerFactory
import java.io.File
import java.nio.file.{Files, StandardOpenOption}
import scala.jdk.CollectionConverters._
import scala.util.{Random, Try}
class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
private val logger = LoggerFactory.getLogger(classOf[SparkExamplesTest])
private val random = new Random(42)