[JVM] Refactor, add filesys API
This commit is contained in:
@@ -82,16 +82,16 @@ public class BasicWalkThrough {
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.dumpModel("./model/dump.raw.txt", false);
|
||||
booster.getModelDump("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
||||
booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
//reload model and data
|
||||
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
||||
Booster booster2 = XGBoost.loadModel("./model/xgb.model");
|
||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||
float[][] predicts2 = booster2.predict(testMat2);
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.*;
|
||||
|
||||
/**
|
||||
* Distributed training example, used to quick test distributed training.
|
||||
*
|
||||
* @author tqchen
|
||||
*/
|
||||
public class DistTrain {
|
||||
private static final Log logger = LogFactory.getLog(DistTrain.class);
|
||||
private Map<String, String> envs = null;
|
||||
|
||||
private class Worker implements Runnable {
|
||||
private final int workerId;
|
||||
|
||||
Worker(int workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public void run() {
|
||||
try {
|
||||
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
||||
|
||||
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
||||
// always initialize rabit module before training.
|
||||
Rabit.init(worker_env);
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("nthread", 2);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
// always shutdown rabit module after training.
|
||||
Rabit.shutdown();
|
||||
} catch (Exception ex){
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
||||
RabitTracker tracker = new RabitTracker(nWorkers);
|
||||
if (tracker.start()) {
|
||||
envs = tracker.getWorkerEnvs();
|
||||
for (int i = 0; i < nWorkers; ++i) {
|
||||
new Thread(new Worker(i)).start();
|
||||
}
|
||||
tracker.waitFor();
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
||||
new DistTrain().start(Integer.parseInt(args[0]));
|
||||
}
|
||||
}
|
||||
@@ -52,13 +52,13 @@ public class PredictLeafIndices {
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||
float[][] leafindex = booster.predictLeaf(testMat, 2);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
//predict all trees
|
||||
leafindex = booster.predict(testMat, 0, true);
|
||||
leafindex = booster.predictLeaf(testMat, 0);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user