[jvm-pacakges] the first parameter in getModelDump should be featuremap path not model path (#1788)

* fix the model dump in xgboost4j example

* Modify the dump model part of scala version

* add the forgotten modelInfos
This commit is contained in:
Ruimin Wang 2016-11-21 21:52:26 +08:00 committed by Nan Zhu
parent 97371ff7e5
commit d80cec3384
2 changed files with 27 additions and 7 deletions

View File

@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.java.example;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
@ -46,6 +47,18 @@ public class BasicWalkThrough {
return true; return true;
} }
public static void saveDumpModel(String modelPath, String[] modelInfos) throws IOException {
try{
PrintWriter writer = new PrintWriter(modelPath, "UTF-8");
for(int i = 0; i < modelInfos.length; ++ i) {
writer.print("booster[" + i + "]:\n");
writer.print(modelInfos[i]);
}
writer.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws IOException, XGBoostError { public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
@ -81,11 +94,9 @@ public class BasicWalkThrough {
String modelPath = "./model/xgb.model"; String modelPath = "./model/xgb.model";
booster.saveModel(modelPath); booster.saveModel(modelPath);
//dump model
booster.getModelDump("./model/dump.raw.txt", false);
//dump model with feature map //dump model with feature map
booster.getModelDump("../../demo/data/featmap.txt", false); String[] modelInfos = booster.getModelDump("../../demo/data/featmap.txt", false);
saveDumpModel("./model/dump.raw.txt", modelInfos);
//save dmatrix into binary buffer //save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer"); testMat.saveBinary("./model/dtest.buffer");

View File

@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.example package ml.dmlc.xgboost4j.scala.example
import java.io.File import java.io.File
import java.io.PrintWriter
import scala.collection.mutable import scala.collection.mutable
@ -25,6 +26,15 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object BasicWalkThrough { object BasicWalkThrough {
def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = {
val writer = new PrintWriter(modelPath, "UTF-8")
for (i <- 0 until modelInfos.length) {
writer.print(s"booster[$i]:\n")
writer.print(modelInfos(i))
}
writer.close()
}
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train") val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test") val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
@ -50,10 +60,9 @@ object BasicWalkThrough {
file.mkdirs() file.mkdirs()
} }
booster.saveModel(file.getAbsolutePath + "/xgb.model") booster.saveModel(file.getAbsolutePath + "/xgb.model")
// dump model
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
// dump model with feature map // dump model with feature map
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false) val modelInfos = booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
saveDumpModel(file.getAbsolutePath + "/dump.raw.txt", modelInfos)
// save dmatrix into binary buffer // save dmatrix into binary buffer
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer") testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")