diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala index 862d65f37..200ca3ba1 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala @@ -25,6 +25,7 @@ object DistTrainWithFlink { // read trainining data val trainData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train") + val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test") // define parameters val paramMap = List( "eta" -> 0.1, @@ -34,7 +35,7 @@ object DistTrainWithFlink { val round = 2 // train the model val model = XGBoost.train(paramMap, trainData, round) - val predTrain = model.predict(trainData.map{x => x.vector}) - model.saveModelToHadoop("file:///path/to/xgboost.model") + val predTest = model.predict(testData.map{x => x.vector}) + model.saveModelAsHadoopFile("file:///path/to/xgboost.model") } } diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index acbd6e656..94b36be91 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -70,7 +70,7 @@ object XGBoost { * @param modelPath The path that is accessible by hadoop filesystem API. * @return The loaded model */ - def loadModelFromHadoop(modelPath: String) : XGBoostModel = { + def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = { new XGBoostModel( XGBoostScala.loadModel( FileSystem diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala index 6391e2a39..54bcdb27b 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala @@ -31,7 +31,7 @@ class XGBoostModel (booster: Booster) extends Serializable { * * @param modelPath The model path as in Hadoop path. */ - def saveModelToHadoop(modelPath: String): Unit = { + def saveModelAsHadoopFile(modelPath: String): Unit = { booster.saveModel(FileSystem .get(new Configuration) .create(new Path(modelPath))) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index f577320b6..6dba097a6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -49,7 +49,7 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri * * @param modelPath The model path as in Hadoop path. */ - def saveModelToHadoopFile(modelPath: String): Unit = { + def saveModelAsHadoopFile(modelPath: String): Unit = { val path = new Path(modelPath) val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path) booster.saveModel(outputStream)