[Spark] Refactor train, predict, add save

This commit is contained in:
tqchen
2016-03-06 21:51:08 -08:00
parent 3402953633
commit 435a0425b9
5 changed files with 112 additions and 63 deletions

View File

@@ -1,6 +1,6 @@
# XGBoost4J: Distributed XGBoost for Scala/Java
[![Build Status](https://travis-ci.org/dmlc/xgboost.svg?branch=master)](https://travis-ci.org/dmlc/xgboost)
[![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
[![Build Status](https://travis-ci.org/dmlc/xgboost.svg?branch=master)](https://travis-ci.org/dmlc/xgboost)
[![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](../LICENSE)
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
@@ -72,3 +72,32 @@ object DistTrainWithFlink {
```
### XGBoost Spark
```scala
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println(
"usage: program num_of_rounds training_path model_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations
val numRound = args(0).toInt
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
// training parameters
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val model = XGBoost.train(trainRDD, paramMap, numRound)
// save model to HDFS path
model.saveModelToHadoop(outputModelPath)
}
}
```