[jvm-pacakges] the first parameter in getModelDump should be featuremap path not model path (#1788)
* fix the model dump in xgboost4j example * Modify the dump model part of scala version * add the forgotten modelInfos
This commit is contained in:
parent
97371ff7e5
commit
d80cec3384
@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.java.example;
|
|||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.PrintWriter;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
@ -46,6 +47,18 @@ public class BasicWalkThrough {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void saveDumpModel(String modelPath, String[] modelInfos) throws IOException {
|
||||||
|
try{
|
||||||
|
PrintWriter writer = new PrintWriter(modelPath, "UTF-8");
|
||||||
|
for(int i = 0; i < modelInfos.length; ++ i) {
|
||||||
|
writer.print("booster[" + i + "]:\n");
|
||||||
|
writer.print(modelInfos[i]);
|
||||||
|
}
|
||||||
|
writer.close();
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, XGBoostError {
|
public static void main(String[] args) throws IOException, XGBoostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
@ -81,11 +94,9 @@ public class BasicWalkThrough {
|
|||||||
String modelPath = "./model/xgb.model";
|
String modelPath = "./model/xgb.model";
|
||||||
booster.saveModel(modelPath);
|
booster.saveModel(modelPath);
|
||||||
|
|
||||||
//dump model
|
|
||||||
booster.getModelDump("./model/dump.raw.txt", false);
|
|
||||||
|
|
||||||
//dump model with feature map
|
//dump model with feature map
|
||||||
booster.getModelDump("../../demo/data/featmap.txt", false);
|
String[] modelInfos = booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||||
|
saveDumpModel("./model/dump.raw.txt", modelInfos);
|
||||||
|
|
||||||
//save dmatrix into binary buffer
|
//save dmatrix into binary buffer
|
||||||
testMat.saveBinary("./model/dtest.buffer");
|
testMat.saveBinary("./model/dtest.buffer");
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.example
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.io.PrintWriter
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
@ -25,6 +26,15 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader
|
|||||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
object BasicWalkThrough {
|
object BasicWalkThrough {
|
||||||
|
def saveDumpModel(modelPath: String, modelInfos: Array[String]): Unit = {
|
||||||
|
val writer = new PrintWriter(modelPath, "UTF-8")
|
||||||
|
for (i <- 0 until modelInfos.length) {
|
||||||
|
writer.print(s"booster[$i]:\n")
|
||||||
|
writer.print(modelInfos(i))
|
||||||
|
}
|
||||||
|
writer.close()
|
||||||
|
}
|
||||||
|
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
@ -50,10 +60,9 @@ object BasicWalkThrough {
|
|||||||
file.mkdirs()
|
file.mkdirs()
|
||||||
}
|
}
|
||||||
booster.saveModel(file.getAbsolutePath + "/xgb.model")
|
booster.saveModel(file.getAbsolutePath + "/xgb.model")
|
||||||
// dump model
|
|
||||||
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
|
|
||||||
// dump model with feature map
|
// dump model with feature map
|
||||||
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
val modelInfos = booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
||||||
|
saveDumpModel(file.getAbsolutePath + "/dump.raw.txt", modelInfos)
|
||||||
// save dmatrix into binary buffer
|
// save dmatrix into binary buffer
|
||||||
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
|
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user