diff --git a/jvm-packages/README.md b/jvm-packages/README.md index 8f0e89916..28dbba685 100644 --- a/jvm-packages/README.md +++ b/jvm-packages/README.md @@ -34,7 +34,7 @@ object XGBoostScalaExample { // number of iterations val round = 2 // train the model - val model = XGBoost.train(paramMap, trainData, round) + val model = XGBoost.train(trainData, paramMap, round) // run prediction val predTrain = model.predict(trainData) // save model to the file. @@ -43,34 +43,6 @@ object XGBoostScalaExample { } ``` -### XGBoost Flink -```scala -import ml.dmlc.xgboost4j.scala.flink.XGBoost -import org.apache.flink.api.scala._ -import org.apache.flink.api.scala.ExecutionEnvironment -import org.apache.flink.ml.MLUtils - -object DistTrainWithFlink { - def main(args: Array[String]) { - val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment - // read trainining data - val trainData = - MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train") - // define parameters - val paramMap = List( - "eta" -> 0.1, - "max_depth" -> 2, - "objective" -> "binary:logistic").toMap - // number of iterations - val round = 2 - // train the model - val model = XGBoost.train(paramMap, trainData, round) - val predTrain = model.predict(trainData.map{x => x.vector}) - model.saveModelToHadoop("file:///path/to/xgboost.model") - } -} -``` - ### XGBoost Spark ```scala import org.apache.spark.SparkContext @@ -101,3 +73,33 @@ object DistTrainWithSpark { } } ``` + +### XGBoost Flink +```scala +import ml.dmlc.xgboost4j.scala.flink.XGBoost +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.ml.MLUtils + +object DistTrainWithFlink { + def main(args: Array[String]) { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + // read trainining data + val trainData = + MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train") + // define parameters + val paramMap = List( + "eta" -> 0.1, + "max_depth" -> 2, + "objective" -> "binary:logistic").toMap + // number of iterations + val round = 2 + // train the model + val model = XGBoost.train(trainData, paramMap, round) + val predTrain = model.predict(trainData.map{x => x.vector}) + model.saveModelToHadoop("file:///path/to/xgboost.model") + } +} +``` + +