From d80cec338443e4801d6bfeae90f644ec62611266 Mon Sep 17 00:00:00 2001 From: Ruimin Wang Date: Mon, 21 Nov 2016 21:52:26 +0800 Subject: [PATCH] [jvm-pacakges] the first parameter in getModelDump should be featuremap path not model path (#1788) * fix the model dump in xgboost4j example * Modify the dump model part of scala version * add the forgotten modelInfos --- .../java/example/BasicWalkThrough.java | 19 +++++++++++++++---- .../scala/example/BasicWalkThrough.scala | 15 ++++++++++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java index 7a74852f4..3852c75ef 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.java.example; import java.io.File; import java.io.IOException; +import java.io.PrintWriter; import java.util.Arrays; import java.util.HashMap; @@ -46,6 +47,18 @@ public class BasicWalkThrough { return true; } + public static void saveDumpModel(String modelPath, String[] modelInfos) throws IOException { + try{ + PrintWriter writer = new PrintWriter(modelPath, "UTF-8"); + for(int i = 0; i < modelInfos.length; ++ i) { + writer.print("booster[" + i + "]:\n"); + writer.print(modelInfos[i]); + } + writer.close(); + } catch (Exception e) { + e.printStackTrace(); + } + } public static void main(String[] args) throws IOException, XGBoostError { // load file from text file, also binary buffer generated by xgboost4j @@ -81,11 +94,9 @@ public class BasicWalkThrough { String modelPath = "./model/xgb.model"; booster.saveModel(modelPath); - //dump model - booster.getModelDump("./model/dump.raw.txt", false); - //dump model with feature map - booster.getModelDump("../../demo/data/featmap.txt", false); + String[] modelInfos = booster.getModelDump("../../demo/data/featmap.txt", false); + saveDumpModel("./model/dump.raw.txt", modelInfos); //save dmatrix into binary buffer testMat.saveBinary("./model/dtest.buffer"); diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala index ee8fde8ed..ffc0c6a1d 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.example import java.io.File +import java.io.PrintWriter import scala.collection.mutable @@ -25,6 +26,15 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} object BasicWalkThrough { + def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = { + val writer = new PrintWriter(modelPath, "UTF-8") + for (i <- 0 until modelInfos.length) { + writer.print(s"booster[$i]:\n") + writer.print(modelInfos(i)) + } + writer.close() + } + def main(args: Array[String]): Unit = { val trainMax = new DMatrix("../../demo/data/agaricus.txt.train") val testMax = new DMatrix("../../demo/data/agaricus.txt.test") @@ -50,10 +60,9 @@ object BasicWalkThrough { file.mkdirs() } booster.saveModel(file.getAbsolutePath + "/xgb.model") - // dump model - booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false) // dump model with feature map - booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false) + val modelInfos = booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false) + saveDumpModel(file.getAbsolutePath + "/dump.raw.txt", modelInfos) // save dmatrix into binary buffer testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")