[jvm-pacakges] the first parameter in getModelDump should be featuremap path not model path (#1788)
* fix the model dump in xgboost4j example * Modify the dump model part of scala version * add the forgotten modelInfos
This commit is contained in:
parent
97371ff7e5
commit
d80cec3384
@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
@ -46,6 +47,18 @@ public class BasicWalkThrough {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static void saveDumpModel(String modelPath, String[] modelInfos) throws IOException {
|
||||
try{
|
||||
PrintWriter writer = new PrintWriter(modelPath, "UTF-8");
|
||||
for(int i = 0; i < modelInfos.length; ++ i) {
|
||||
writer.print("booster[" + i + "]:\n");
|
||||
writer.print(modelInfos[i]);
|
||||
}
|
||||
writer.close();
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
@ -81,11 +94,9 @@ public class BasicWalkThrough {
|
||||
String modelPath = "./model/xgb.model";
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.getModelDump("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||
String[] modelInfos = booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||
saveDumpModel("./model/dump.raw.txt", modelInfos);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import java.io.File
|
||||
import java.io.PrintWriter
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
@ -25,6 +26,15 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
object BasicWalkThrough {
|
||||
def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = {
|
||||
val writer = new PrintWriter(modelPath, "UTF-8")
|
||||
for (i <- 0 until modelInfos.length) {
|
||||
writer.print(s"booster[$i]:\n")
|
||||
writer.print(modelInfos(i))
|
||||
}
|
||||
writer.close()
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
@ -50,10 +60,9 @@ object BasicWalkThrough {
|
||||
file.mkdirs()
|
||||
}
|
||||
booster.saveModel(file.getAbsolutePath + "/xgb.model")
|
||||
// dump model
|
||||
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
|
||||
// dump model with feature map
|
||||
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
||||
val modelInfos = booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
||||
saveDumpModel(file.getAbsolutePath + "/dump.raw.txt", modelInfos)
|
||||
// save dmatrix into binary buffer
|
||||
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user