[Spark] Refactor train, predict, add save
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# XGBoost4J: Distributed XGBoost for Scala/Java
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||
[](../LICENSE)
|
||||
|
||||
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||
@@ -72,3 +72,32 @@ object DistTrainWithFlink {
|
||||
```
|
||||
|
||||
### XGBoost Spark
|
||||
```scala
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
object DistTrainWithSpark {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 3) {
|
||||
println(
|
||||
"usage: program num_of_rounds training_path model_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(1)
|
||||
val outputModelPath = args(2)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
||||
// training parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val model = XGBoost.train(trainRDD, paramMap, numRound)
|
||||
// save model to HDFS path
|
||||
model.saveModelToHadoop(outputModelPath)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user