[jvm-packages] clean up example (#10618)
This commit is contained in:
parent
485d90218c
commit
003b418312
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -23,7 +23,7 @@ import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.example.util.DataLoader
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
|
||||
object BasicWalkThrough {
|
||||
def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
|
||||
|
||||
object BoostFromPrediction {
|
||||
def main(args: Array[String]): Unit = {
|
||||
@ -48,6 +49,6 @@ object BoostFromPrediction {
|
||||
testMat.setBaseMargin(testPred)
|
||||
|
||||
System.out.println("result of running from initial prediction")
|
||||
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
|
||||
XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
|
||||
object CrossValidation {
|
||||
def main(args: Array[String]): Unit = {
|
||||
@ -40,7 +40,6 @@ object CrossValidation {
|
||||
// set additional eval_metrics
|
||||
val metrics: Array[String] = null
|
||||
|
||||
val evalHist: Array[String] =
|
||||
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
|
||||
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,9 +18,10 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import org.apache.commons.logging.{Log, LogFactory}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait}
|
||||
import org.apache.commons.logging.{LogFactory, Log}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait, ObjectiveTrait, XGBoost}
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
@ -150,7 +151,7 @@ object CustomObjective {
|
||||
|
||||
val round = 2
|
||||
// train a model
|
||||
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
|
||||
XGBoost.train(trainMat, params.toMap, round, watches.toMap)
|
||||
XGBoost.train(trainMat, params.toMap, round, watches.toMap,
|
||||
obj = new LogRegObj, eval = new EvalError)
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
|
||||
object ExternalMemory {
|
||||
def main(args: Array[String]): Unit = {
|
||||
@ -54,6 +54,6 @@ object ExternalMemory {
|
||||
testMat.setBaseMargin(testPred)
|
||||
|
||||
System.out.println("result of running from initial prediction")
|
||||
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
|
||||
XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
|
||||
|
||||
@ -51,7 +51,6 @@ object GeneralizedLinearModel {
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 4
|
||||
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
|
||||
val predicts = booster.predict(testMat)
|
||||
val eval = new CustomEval
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,8 +17,8 @@ package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
object PredictFirstNTree {
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,11 +16,9 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import java.util
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
|
||||
|
||||
object PredictLeafIndices {
|
||||
|
||||
@ -49,7 +47,7 @@ object PredictLeafIndices {
|
||||
|
||||
// predict all trees
|
||||
val leafIndex2 = booster.predictLeaf(testMat, 0)
|
||||
for (leafs <- leafIndex) {
|
||||
for (leafs <- leafIndex2) {
|
||||
println(java.util.Arrays.toString(leafs))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 - 2023 by Contributors
|
||||
Copyright (c) 2014 - 2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,12 +17,14 @@ package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import java.lang.{Double => JDouble, Long => JLong}
|
||||
import java.nio.file.{Path, Paths}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
|
||||
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.utils.DataSetUtils
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
|
||||
|
||||
object DistTrainWithFlink {
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.tuning._
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
|
||||
|
||||
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
|
||||
@ -87,11 +88,9 @@ object SparkMLlibPipeline {
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"device" -> device
|
||||
)
|
||||
)
|
||||
).setNumRound(10).setNumWorkers(numWorkers)
|
||||
booster.setFeaturesCol("features")
|
||||
booster.setLabelCol("classIndex")
|
||||
val labelConverter = new IndexToString()
|
||||
|
||||
@ -16,11 +16,13 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
||||
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
||||
|
||||
|
||||
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
|
||||
object SparkTraining {
|
||||
|
||||
@ -78,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"device" -> device,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2),
|
||||
"device" -> device)
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
setLabelCol("classIndex")
|
||||
.setNumWorkers(numWorkers)
|
||||
.setNumRound(10)
|
||||
val xgbClassificationModel = xgbClassifier.fit(train)
|
||||
xgbClassificationModel.transform(test)
|
||||
}
|
||||
|
||||
@ -15,9 +15,10 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.util
|
||||
|
||||
import org.apache.commons.logging.{Log, LogFactory}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import org.apache.commons.logging.{Log, LogFactory}
|
||||
|
||||
class CustomEval extends EvalTrait {
|
||||
private val logger: Log = LogFactory.getLog(classOf[CustomEval])
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,12 +15,13 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
|
||||
@ -15,14 +15,15 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,16 +15,18 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.{Files, StandardOpenOption}
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
import scala.util.{Random, Try}
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.{Files, StandardOpenOption}
|
||||
import scala.jdk.CollectionConverters._
|
||||
import scala.util.{Random, Try}
|
||||
|
||||
class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
private val logger = LoggerFactory.getLogger(classOf[SparkExamplesTest])
|
||||
private val random = new Random(42)
|
||||
@ -53,7 +55,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
}
|
||||
|
||||
if (spark == null) {
|
||||
spark = SparkSession
|
||||
spark = SparkSession
|
||||
.builder()
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
.master(s"local[${numWorkers}]")
|
||||
@ -92,7 +94,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
e
|
||||
)
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def cleanExternalCache(prefix: String): Unit = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user