merge with master
This commit is contained in:
commit
16008ebfb8
@ -82,16 +82,16 @@ public class BasicWalkThrough {
|
|||||||
booster.saveModel(modelPath);
|
booster.saveModel(modelPath);
|
||||||
|
|
||||||
//dump model
|
//dump model
|
||||||
booster.dumpModel("./model/dump.raw.txt", false);
|
booster.getModelDump("./model/dump.raw.txt", false);
|
||||||
|
|
||||||
//dump model with feature map
|
//dump model with feature map
|
||||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||||
|
|
||||||
//save dmatrix into binary buffer
|
//save dmatrix into binary buffer
|
||||||
testMat.saveBinary("./model/dtest.buffer");
|
testMat.saveBinary("./model/dtest.buffer");
|
||||||
|
|
||||||
//reload model and data
|
//reload model and data
|
||||||
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
Booster booster2 = XGBoost.loadModel("./model/xgb.model");
|
||||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||||
float[][] predicts2 = booster2.predict(testMat2);
|
float[][] predicts2 = booster2.predict(testMat2);
|
||||||
|
|
||||||
|
|||||||
@ -1,79 +0,0 @@
|
|||||||
package ml.dmlc.xgboost4j.java.demo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
|
||||||
import org.apache.commons.logging.LogFactory;
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Distributed training example, used to quick test distributed training.
|
|
||||||
*
|
|
||||||
* @author tqchen
|
|
||||||
*/
|
|
||||||
public class DistTrain {
|
|
||||||
private static final Log logger = LogFactory.getLog(DistTrain.class);
|
|
||||||
private Map<String, String> envs = null;
|
|
||||||
|
|
||||||
private class Worker implements Runnable {
|
|
||||||
private final int workerId;
|
|
||||||
|
|
||||||
Worker(int workerId) {
|
|
||||||
this.workerId = workerId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void run() {
|
|
||||||
try {
|
|
||||||
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
|
||||||
|
|
||||||
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
|
||||||
// always initialize rabit module before training.
|
|
||||||
Rabit.init(worker_env);
|
|
||||||
|
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
|
||||||
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
params.put("eta", 1.0);
|
|
||||||
params.put("max_depth", 2);
|
|
||||||
params.put("silent", 1);
|
|
||||||
params.put("nthread", 2);
|
|
||||||
params.put("objective", "binary:logistic");
|
|
||||||
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
|
||||||
watches.put("train", trainMat);
|
|
||||||
watches.put("test", testMat);
|
|
||||||
|
|
||||||
//set round
|
|
||||||
int round = 2;
|
|
||||||
|
|
||||||
//train a boost model
|
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
|
||||||
|
|
||||||
// always shutdown rabit module after training.
|
|
||||||
Rabit.shutdown();
|
|
||||||
} catch (Exception ex){
|
|
||||||
logger.error(ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
|
||||||
RabitTracker tracker = new RabitTracker(nWorkers);
|
|
||||||
if (tracker.start()) {
|
|
||||||
envs = tracker.getWorkerEnvs();
|
|
||||||
for (int i = 0; i < nWorkers; ++i) {
|
|
||||||
new Thread(new Worker(i)).start();
|
|
||||||
}
|
|
||||||
tracker.waitFor();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
|
||||||
new DistTrain().start(Integer.parseInt(args[0]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -52,13 +52,13 @@ public class PredictLeafIndices {
|
|||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
//predict using first 2 tree
|
//predict using first 2 tree
|
||||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
float[][] leafindex = booster.predictLeaf(testMat, 2);
|
||||||
for (float[] leafs : leafindex) {
|
for (float[] leafs : leafindex) {
|
||||||
System.out.println(Arrays.toString(leafs));
|
System.out.println(Arrays.toString(leafs));
|
||||||
}
|
}
|
||||||
|
|
||||||
//predict all trees
|
//predict all trees
|
||||||
leafindex = booster.predict(testMat, 0, true);
|
leafindex = booster.predictLeaf(testMat, 0);
|
||||||
for (float[] leafs : leafindex) {
|
for (float[] leafs : leafindex) {
|
||||||
System.out.println(Arrays.toString(leafs));
|
System.out.println(Arrays.toString(leafs));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,6 +37,8 @@ object Test {
|
|||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val round = 2
|
val round = 2
|
||||||
val model = XGBoost.train(paramMap, data, round)
|
val model = XGBoost.train(paramMap, data, round)
|
||||||
|
|
||||||
|
|
||||||
log.info(model)
|
log.info(model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,6 +25,9 @@ import org.apache.flink.api.scala.DataSet
|
|||||||
import org.apache.flink.api.scala._
|
import org.apache.flink.api.scala._
|
||||||
import org.apache.flink.ml.common.LabeledVector
|
import org.apache.flink.ml.common.LabeledVector
|
||||||
import org.apache.flink.util.Collector
|
import org.apache.flink.util.Collector
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
/**
|
/**
|
||||||
@ -60,6 +63,20 @@ object XGBoost {
|
|||||||
|
|
||||||
val logger = LogFactory.getLog(this.getClass)
|
val logger = LogFactory.getLog(this.getClass)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||||
|
*
|
||||||
|
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||||
|
* @return The loaded model
|
||||||
|
*/
|
||||||
|
def loadModel(modelPath: String) : XGBoostModel = {
|
||||||
|
new XGBoostModel(
|
||||||
|
XGBoostScala.loadModel(
|
||||||
|
FileSystem
|
||||||
|
.get(new Configuration)
|
||||||
|
.open(new Path(modelPath))))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a xgboost model with link.
|
* Train a xgboost model with link.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -16,8 +16,45 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.flink
|
package ml.dmlc.xgboost4j.flink
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||||
|
import org.apache.flink.api.scala.DataSet
|
||||||
|
import org.apache.flink.api.scala._
|
||||||
|
import org.apache.flink.ml.math.Vector
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
|
||||||
class XGBoostModel (booster: Booster) extends Serializable {
|
class XGBoostModel (booster: Booster) extends Serializable {
|
||||||
|
/**
|
||||||
|
* Save the model as a Hadoop filesystem file.
|
||||||
|
*
|
||||||
|
* @param modelPath The model path as in Hadoop path.
|
||||||
|
*/
|
||||||
|
def saveModel(modelPath: String): Unit = {
|
||||||
|
booster.saveModel(FileSystem
|
||||||
|
.get(new Configuration)
|
||||||
|
.create(new Path(modelPath)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict given vector dataset.
|
||||||
|
*
|
||||||
|
* @param data The dataset to be predicted.
|
||||||
|
* @return The prediction result.
|
||||||
|
*/
|
||||||
|
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||||
|
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
|
||||||
|
(it: Iterator[Vector]) => {
|
||||||
|
val mapper = (x: Vector) => {
|
||||||
|
val (index, value) = x.toSeq.unzip
|
||||||
|
LabeledPoint.fromSparseVector(0.0f,
|
||||||
|
index.toArray, value.map(z => z.toFloat).toArray)
|
||||||
|
}
|
||||||
|
val dataIter = for (x <- it) yield mapper(x)
|
||||||
|
val dmat = new DMatrix(dataIter, null)
|
||||||
|
this.booster.predict(dmat)
|
||||||
|
}
|
||||||
|
data.mapPartition(predictMap)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,20 +24,17 @@ import org.apache.commons.logging.Log;
|
|||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Booster for xgboost, similar to the python wrapper xgboost.py
|
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||||
* but custom obj function and eval function not supported at present.
|
|
||||||
*
|
|
||||||
* @author hzx
|
|
||||||
*/
|
*/
|
||||||
public class Booster implements Serializable {
|
public class Booster implements Serializable {
|
||||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||||
|
// handle to the booster.
|
||||||
long handle = 0;
|
private long handle = 0;
|
||||||
|
|
||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.initXgBoost();
|
NativeLibLoader.initXGBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
@ -45,60 +42,70 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* init Booster from dMatrixs
|
* Create a new Booster with empty stage.
|
||||||
*
|
*
|
||||||
* @param params parameters
|
* @param params Model parameters
|
||||||
* @param dMatrixs DMatrix array
|
* @param cacheMats Cached DMatrix entries,
|
||||||
|
* the prediction of these DMatrices will become faster than not-cached data.
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
Booster(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
|
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
|
||||||
init(dMatrixs);
|
init(cacheMats);
|
||||||
setParam("seed", "0");
|
setParam("seed", "0");
|
||||||
setParams(params);
|
setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* load model from modelPath
|
* Load a new Booster model from modelPath
|
||||||
*
|
* @param modelPath The path to the model.
|
||||||
* @param params parameters
|
* @return The created Booster.
|
||||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
* @throws XGBoostError
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
*/
|
||||||
Booster(Map<String, Object> params, String modelPath) throws XGBoostError {
|
static Booster loadModel(String modelPath) throws XGBoostError {
|
||||||
init(null);
|
|
||||||
if (modelPath == null) {
|
if (modelPath == null) {
|
||||||
throw new NullPointerException("modelPath : null");
|
throw new NullPointerException("modelPath : null");
|
||||||
}
|
}
|
||||||
loadModel(modelPath);
|
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
||||||
setParam("seed", "0");
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
|
||||||
setParams(params);
|
return ret;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
|
||||||
long[] handles = null;
|
|
||||||
if (dMatrixs != null) {
|
|
||||||
handles = dmatrixsToHandles(dMatrixs);
|
|
||||||
}
|
|
||||||
long[] out = new long[1];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
|
||||||
|
|
||||||
handle = out[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameter
|
* Load a new Booster model from a file opened as input stream.
|
||||||
|
* The assumption is the input stream only contains one XGBoost Model.
|
||||||
|
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||||
|
*
|
||||||
|
* @param in The input stream of the file.
|
||||||
|
* @return The create boosted
|
||||||
|
* @throws XGBoostError
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
|
||||||
|
int size;
|
||||||
|
byte[] buf = new byte[1<<20];
|
||||||
|
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||||
|
while ((size = in.read(buf)) != -1) {
|
||||||
|
os.write(buf, 0, size);
|
||||||
|
}
|
||||||
|
in.close();
|
||||||
|
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set parameter to the Booster.
|
||||||
*
|
*
|
||||||
* @param key param name
|
* @param key param name
|
||||||
* @param value param value
|
* @param value param value
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public final void setParam(String key, String value) throws XGBoostError {
|
public final void setParam(String key, Object value) throws XGBoostError {
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameters
|
* Set parameters to the Booster.
|
||||||
*
|
*
|
||||||
* @param params parameters key-value map
|
* @param params parameters key-value map
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
@ -111,9 +118,8 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update (one iteration)
|
* Update the booster for one iteration.
|
||||||
*
|
*
|
||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param iter current iteration number
|
* @param iter current iteration number
|
||||||
@ -124,14 +130,14 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* update with customize obj func
|
* Update with customize obj func
|
||||||
*
|
*
|
||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param obj customized objective class
|
* @param obj customized objective class
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||||
float[][] predicts = predict(dtrain, true);
|
float[][] predicts = this.predict(dtrain, true, 0, false);
|
||||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||||
}
|
}
|
||||||
@ -149,8 +155,8 @@ public class Booster implements Serializable {
|
|||||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||||
hess.length));
|
hess.length));
|
||||||
}
|
}
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
|
||||||
hess));
|
dtrain.getHandle(), grad, hess));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -193,18 +199,20 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* base function for Predict
|
* Advanced predict function with all the options.
|
||||||
*
|
*
|
||||||
* @param data data
|
* @param data data
|
||||||
* @param outPutMargin output margin
|
* @param outputMargin output margin
|
||||||
* @param treeLimit limit number of trees
|
* @param treeLimit limit number of trees, 0 means all trees.
|
||||||
* @param predLeaf prediction minimum to keep leafs
|
* @param predLeaf prediction minimum to keep leafs
|
||||||
* @return predict results
|
* @return predict results
|
||||||
*/
|
*/
|
||||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
|
private synchronized float[][] predict(DMatrix data,
|
||||||
|
boolean outputMargin,
|
||||||
|
int treeLimit,
|
||||||
boolean predLeaf) throws XGBoostError {
|
boolean predLeaf) throws XGBoostError {
|
||||||
int optionMask = 0;
|
int optionMask = 0;
|
||||||
if (outPutMargin) {
|
if (outputMargin) {
|
||||||
optionMask = 1;
|
optionMask = 1;
|
||||||
}
|
}
|
||||||
if (predLeaf) {
|
if (predLeaf) {
|
||||||
@ -225,6 +233,18 @@ public class Booster implements Serializable {
|
|||||||
return predicts;
|
return predicts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict leaf indices given the data
|
||||||
|
*
|
||||||
|
* @param data The input data.
|
||||||
|
* @param treeLimit Number of trees to include, 0 means all trees.
|
||||||
|
* @return The leaf indices of the instance.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
|
||||||
|
return this.predict(data, false, treeLimit, true);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
*
|
*
|
||||||
@ -233,53 +253,34 @@ public class Booster implements Serializable {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||||
return pred(data, false, 0, false);
|
return this.predict(data, false, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
*
|
*
|
||||||
* @param data dmatrix storing the input
|
* @param data data
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outputMargin output margin
|
||||||
* @return predict result
|
* @return predict results
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
|
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
|
||||||
return pred(data, outPutMargin, 0, false);
|
return this.predict(data, outputMargin, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Advanced predict function with all the options.
|
||||||
*
|
*
|
||||||
* @param data dmatrix storing the input
|
* @param data data
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outputMargin output margin
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit limit number of trees, 0 means all trees.
|
||||||
* @return predict result
|
* @return predict results
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
|
||||||
return pred(data, outPutMargin, treeLimit, false);
|
return this.predict(data, outputMargin, treeLimit, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Save model to modelPath
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
|
||||||
* nsample = data.numRow with each record indicating the predicted leaf index
|
|
||||||
* of each sample in each tree.
|
|
||||||
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
|
||||||
* in both tree 1 and tree 0.
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
|
|
||||||
return pred(data, false, treeLimit, predLeaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* save model to modelPath
|
|
||||||
*
|
*
|
||||||
* @param modelPath model path
|
* @param modelPath model path
|
||||||
*/
|
*/
|
||||||
@ -287,8 +288,65 @@ public class Booster implements Serializable {
|
|||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void loadModel(String modelPath) {
|
/**
|
||||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
* Save the model to file opened as output stream.
|
||||||
|
* The model format is compatible with other xgboost bindings.
|
||||||
|
* The output stream can only save one xgboost model.
|
||||||
|
* This function will close the OutputStream after the save.
|
||||||
|
*
|
||||||
|
* @param out The output stream
|
||||||
|
*/
|
||||||
|
public void saveModel(OutputStream out) throws XGBoostError, IOException {
|
||||||
|
out.write(this.toByteArray());
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the dump of the model as a string array
|
||||||
|
*
|
||||||
|
* @param withStats Controls whether the split statistics are output.
|
||||||
|
* @return dumped model information
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
|
||||||
|
int statsFlag = 0;
|
||||||
|
if (featureMap == null) {
|
||||||
|
featureMap = "";
|
||||||
|
}
|
||||||
|
if (withStats) {
|
||||||
|
statsFlag = 1;
|
||||||
|
}
|
||||||
|
String[][] modelInfos = new String[1][];
|
||||||
|
JNIErrorHandle.checkCall(
|
||||||
|
XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get importance of each feature
|
||||||
|
*
|
||||||
|
* @return featureMap key: feature index, value: feature importance score, can be nill
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||||
|
String[] modelInfos = getModelDump(featureMap, false);
|
||||||
|
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||||
|
for (String tree : modelInfos) {
|
||||||
|
for (String node : tree.split("\n")) {
|
||||||
|
String[] array = node.split("\\[");
|
||||||
|
if (array.length == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String fid = array[1].split("\\]")[0];
|
||||||
|
fid = fid.split("<")[0];
|
||||||
|
if (featureScore.containsKey(fid)) {
|
||||||
|
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||||
|
} else {
|
||||||
|
featureScore.put(fid, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return featureScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -309,152 +367,17 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the dump of the model as a string array
|
* get the dump of the model as a byte array
|
||||||
*
|
*
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats Controls whether the split statistics are output.
|
|
||||||
* @return dumped model information
|
* @return dumped model information
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
|
|
||||||
int statsFlag = 0;
|
|
||||||
if (withStats) {
|
|
||||||
statsFlag = 1;
|
|
||||||
}
|
|
||||||
String[][] modelInfos = new String[1][];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
|
||||||
modelInfos));
|
|
||||||
return modelInfos[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
* @throws FileNotFoundException file not found
|
|
||||||
* @throws UnsupportedEncodingException unsupported feature
|
|
||||||
* @throws IOException error with model writing
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
|
|
||||||
File tf = new File(modelPath);
|
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
|
||||||
String[] modelInfos = getDumpInfo(withStats);
|
|
||||||
|
|
||||||
for (int i = 0; i < modelInfos.length; i++) {
|
|
||||||
writer.write("booster [" + i + "]:\n");
|
|
||||||
writer.write(modelInfos[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.close();
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
* @throws FileNotFoundException exception
|
|
||||||
* @throws UnsupportedEncodingException exception
|
|
||||||
* @throws IOException exception
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
|
|
||||||
IOException, XGBoostError {
|
|
||||||
File tf = new File(modelPath);
|
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
|
||||||
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
|
||||||
|
|
||||||
for (int i = 0; i < modelInfos.length; i++) {
|
|
||||||
writer.write("booster [" + i + "]:\n");
|
|
||||||
writer.write(modelInfos[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.close();
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public Map<String, Integer> getFeatureScore() throws XGBoostError {
|
|
||||||
String[] modelInfos = getDumpInfo(false);
|
|
||||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
|
||||||
for (String tree : modelInfos) {
|
|
||||||
for (String node : tree.split("\n")) {
|
|
||||||
String[] array = node.split("\\[");
|
|
||||||
if (array.length == 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String fid = array[1].split("\\]")[0];
|
|
||||||
fid = fid.split("<")[0];
|
|
||||||
if (featureScore.containsKey(fid)) {
|
|
||||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
|
||||||
} else {
|
|
||||||
featureScore.put(fid, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return featureScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @param featureMap file to save dumped model info
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
|
||||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
|
||||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
|
||||||
for (String tree : modelInfos) {
|
|
||||||
for (String node : tree.split("\n")) {
|
|
||||||
String[] array = node.split("\\[");
|
|
||||||
if (array.length == 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String fid = array[1].split("\\]")[0];
|
|
||||||
fid = fid.split("<")[0];
|
|
||||||
if (featureScore.containsKey(fid)) {
|
|
||||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
|
||||||
} else {
|
|
||||||
featureScore.put(fid, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return featureScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the model as byte array representation.
|
|
||||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
|
||||||
*
|
|
||||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
|
||||||
*
|
|
||||||
* @return the saved byte array.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
byte[][] bytes = new byte[1][];
|
byte[][] bytes = new byte[1][];
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||||
return bytes[0];
|
return bytes[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the booster model from thread-local rabit checkpoint.
|
* Load the booster model from thread-local rabit checkpoint.
|
||||||
* This is only used in distributed training.
|
* This is only used in distributed training.
|
||||||
@ -476,6 +399,22 @@ public class Booster implements Serializable {
|
|||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal initialization function.
|
||||||
|
* @param cacheMats The cached DMatrix.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
private void init(DMatrix[] cacheMats) throws XGBoostError {
|
||||||
|
long[] handles = null;
|
||||||
|
if (cacheMats != null) {
|
||||||
|
handles = dmatrixsToHandles(cacheMats);
|
||||||
|
}
|
||||||
|
long[] out = new long[1];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||||
|
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* transfer DMatrix array to handle array (used for native functions)
|
* transfer DMatrix array to handle array (used for native functions)
|
||||||
*
|
*
|
||||||
@ -499,7 +438,8 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
|
private void readObject(java.io.ObjectInputStream in)
|
||||||
|
throws IOException, ClassNotFoundException {
|
||||||
try {
|
try {
|
||||||
this.init(null);
|
this.init(null);
|
||||||
byte[] bytes = (byte[])in.readObject();
|
byte[] bytes = (byte[])in.readObject();
|
||||||
|
|||||||
@ -35,7 +35,7 @@ public class DMatrix {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.initXgBoost();
|
NativeLibLoader.initXGBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class JNIErrorHandle {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.initXgBoost();
|
NativeLibLoader.initXGBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class NativeLibLoader {
|
|||||||
private static final String nativeResourcePath = "/lib/";
|
private static final String nativeResourcePath = "/lib/";
|
||||||
private static final String[] libNames = new String[]{"xgboost4j"};
|
private static final String[] libNames = new String[]{"xgboost4j"};
|
||||||
|
|
||||||
public static synchronized void initXgBoost() throws IOException {
|
public static synchronized void initXGBoost() throws IOException {
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
for (String libName : libNames) {
|
for (String libName : libNames) {
|
||||||
smartLoad(libName);
|
smartLoad(libName);
|
||||||
|
|||||||
@ -15,7 +15,7 @@ public class Rabit implements Serializable {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.initXgBoost();
|
NativeLibLoader.initXGBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
@ -28,6 +30,33 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
public class XGBoost {
|
public class XGBoost {
|
||||||
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* load model from modelPath
|
||||||
|
*
|
||||||
|
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public static Booster loadModel(String modelPath)
|
||||||
|
throws XGBoostError {
|
||||||
|
return Booster.loadModel(modelPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a new Booster model from a file opened as input stream.
|
||||||
|
* The assumption is the input stream only contains one XGBoost Model.
|
||||||
|
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||||
|
*
|
||||||
|
* @param in The input stream of the file,
|
||||||
|
* will be closed after this function call.
|
||||||
|
* @return The create boosted
|
||||||
|
* @throws XGBoostError
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
public static Booster loadModel(InputStream in)
|
||||||
|
throws XGBoostError, IOException {
|
||||||
|
return Booster.loadModel(in);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a booster with given parameters.
|
* Train a booster with given parameters.
|
||||||
*
|
*
|
||||||
@ -41,8 +70,10 @@ public class XGBoost {
|
|||||||
* @return trained booster
|
* @return trained booster
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public static Booster train(Map<String, Object> params, DMatrix dtrain, int round,
|
public static Booster train(Map<String, Object> params,
|
||||||
Map<String, DMatrix> watches, IObjective obj,
|
DMatrix dtrain, int round,
|
||||||
|
Map<String, DMatrix> watches,
|
||||||
|
IObjective obj,
|
||||||
IEvaluation eval) throws XGBoostError {
|
IEvaluation eval) throws XGBoostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
@ -106,32 +137,7 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* init Booster from dMatrixs
|
* Cross-validation with given parameters.
|
||||||
*
|
|
||||||
* @param params parameters
|
|
||||||
* @param dMatrixs DMatrix array
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public static Booster initBoostingModel(
|
|
||||||
Map<String, Object> params,
|
|
||||||
DMatrix[] dMatrixs) throws XGBoostError {
|
|
||||||
return new Booster(params, dMatrixs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* load model from modelPath
|
|
||||||
*
|
|
||||||
* @param params parameters
|
|
||||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
|
|
||||||
throws XGBoostError {
|
|
||||||
return new Booster(params, modelPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cross-validation with given paramaters.
|
|
||||||
*
|
*
|
||||||
* @param params Booster params.
|
* @param params Booster params.
|
||||||
* @param data Data to be trained.
|
* @param data Data to be trained.
|
||||||
@ -294,7 +300,7 @@ public class XGBoost {
|
|||||||
public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params)
|
public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params)
|
||||||
throws XGBoostError {
|
throws XGBoostError {
|
||||||
dmats = new DMatrix[]{dtrain, dtest};
|
dmats = new DMatrix[]{dtrain, dtest};
|
||||||
booster = XGBoost.initBoostingModel(params, dmats);
|
booster = new Booster(params, dmats);
|
||||||
names = new String[]{"train", "test"};
|
names = new String[]{"train", "test"};
|
||||||
this.dtrain = dtrain;
|
this.dtrain = dtrain;
|
||||||
this.dtest = dtest;
|
this.dtest = dtest;
|
||||||
|
|||||||
@ -16,86 +16,177 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java
|
import java.io.IOException
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
class Booster private[xgboost4j](booster: java.Booster) extends Serializable {
|
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||||
|
|
||||||
def setParam(key: String, value: String): Unit = {
|
/**
|
||||||
|
* Set parameter to the Booster.
|
||||||
|
*
|
||||||
|
* @param key param name
|
||||||
|
* @param value param value
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setParam(key: String, value: AnyRef): Unit = {
|
||||||
booster.setParam(key, value)
|
booster.setParam(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
def update(dtrain: DMatrix, iter: Int): Unit = {
|
/**
|
||||||
booster.update(dtrain.jDMatrix, iter)
|
* set parameters
|
||||||
}
|
*
|
||||||
|
* @param params parameters key-value map
|
||||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
*/
|
||||||
booster.update(dtrain.jDMatrix, obj)
|
@throws(classOf[XGBoostError])
|
||||||
}
|
|
||||||
|
|
||||||
def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
|
||||||
booster.dumpModel(modelPath, withStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
|
||||||
booster.dumpModel(modelPath, featureMap, withStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
def setParams(params: Map[String, AnyRef]): Unit = {
|
def setParams(params: Map[String, AnyRef]): Unit = {
|
||||||
booster.setParams(params.asJava)
|
booster.setParams(params.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
/**
|
||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
* Update (one iteration)
|
||||||
|
*
|
||||||
|
* @param dtrain training data
|
||||||
|
* @param iter current iteration number
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||||
|
booster.update(dtrain.jDMatrix, iter)
|
||||||
}
|
}
|
||||||
|
|
||||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
/**
|
||||||
String = {
|
* update with customize obj func
|
||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
*
|
||||||
}
|
* @param dtrain training data
|
||||||
|
* @param obj customized objective class
|
||||||
def dispose: Unit = {
|
*/
|
||||||
booster.dispose()
|
@throws(classOf[XGBoostError])
|
||||||
}
|
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||||
|
booster.update(dtrain.jDMatrix, obj)
|
||||||
def predict(data: DMatrix): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix)
|
|
||||||
}
|
|
||||||
|
|
||||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, outPutMargin)
|
|
||||||
}
|
|
||||||
|
|
||||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
|
||||||
Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
|
||||||
}
|
|
||||||
|
|
||||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* update with give grad and hess
|
||||||
|
*
|
||||||
|
* @param dtrain training data
|
||||||
|
* @param grad first order of gradient
|
||||||
|
* @param hess seconde order of gradient
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||||
}
|
}
|
||||||
|
|
||||||
def getFeatureScore: mutable.Map[String, Integer] = {
|
/**
|
||||||
booster.getFeatureScore.asScala
|
* evaluate with given dmatrixs.
|
||||||
|
*
|
||||||
|
* @param evalMatrixs dmatrixs for evaluation
|
||||||
|
* @param evalNames name for eval dmatrixs, used for check results
|
||||||
|
* @param iter current eval iteration
|
||||||
|
* @return eval information
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int)
|
||||||
|
: String = {
|
||||||
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||||
}
|
}
|
||||||
|
|
||||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
/**
|
||||||
|
* evaluate with given customized Evaluation class
|
||||||
|
*
|
||||||
|
* @param evalMatrixs evaluation matrix
|
||||||
|
* @param evalNames evaluation names
|
||||||
|
* @param eval custom evaluator
|
||||||
|
* @return eval information
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait)
|
||||||
|
: String = {
|
||||||
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict with data
|
||||||
|
*
|
||||||
|
* @param data dmatrix storing the input
|
||||||
|
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||||
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
|
* @return predict result
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
|
||||||
|
: Array[Array[Float]] = {
|
||||||
|
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict the leaf indices
|
||||||
|
*
|
||||||
|
* @param data dmatrix storing the input
|
||||||
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
|
* @return predict result
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
|
||||||
|
: Array[Array[Float]] = {
|
||||||
|
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* save model to modelPath
|
||||||
|
*
|
||||||
|
* @param modelPath model path
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def saveModel(modelPath: String): Unit = {
|
||||||
|
booster.saveModel(modelPath)
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* save model to Output stream
|
||||||
|
*
|
||||||
|
* @param out Output stream
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def saveModel(out: java.io.OutputStream): Unit = {
|
||||||
|
booster.saveModel(out)
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Dump model as Array of string
|
||||||
|
*
|
||||||
|
* @param featureMap featureMap file
|
||||||
|
* @param withStats bool
|
||||||
|
* Controls whether the split statistics are output.
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getModelDump(featureMap: String = null, withStats: Boolean = false)
|
||||||
|
: Array[String] = {
|
||||||
|
booster.getModelDump(featureMap, withStats)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get importance of each feature
|
||||||
|
*
|
||||||
|
* @return featureMap key: feature index, value: feature importance score
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
|
||||||
booster.getFeatureScore(featureMap).asScala
|
booster.getFeatureScore(featureMap).asScala
|
||||||
}
|
}
|
||||||
|
|
||||||
def saveModel(modelPath: String): Unit = {
|
/**
|
||||||
booster.saveModel(modelPath)
|
* Dispose the booster when it is no longer needed
|
||||||
|
*/
|
||||||
|
def dispose: Unit = {
|
||||||
|
booster.dispose()
|
||||||
}
|
}
|
||||||
|
|
||||||
override def finalize(): Unit = {
|
override def finalize(): Unit = {
|
||||||
super.finalize()
|
super.finalize()
|
||||||
dispose
|
dispose
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,12 +16,28 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import java.io.InputStream
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
/**
|
||||||
|
* XGBoost Scala Training function.
|
||||||
|
*/
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
def train(
|
def train(
|
||||||
params: Map[String, AnyRef],
|
params: Map[String, AnyRef],
|
||||||
dtrain: DMatrix,
|
dtrain: DMatrix,
|
||||||
@ -35,6 +51,19 @@ object XGBoost {
|
|||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cross-validation with given parameters.
|
||||||
|
*
|
||||||
|
* @param params Booster params.
|
||||||
|
* @param data Data to be trained.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param nfold Number of folds in CV.
|
||||||
|
* @param metrics Evaluation metrics to be watched in CV.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @return evaluation history
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
def crossValidation(
|
def crossValidation(
|
||||||
params: Map[String, AnyRef],
|
params: Map[String, AnyRef],
|
||||||
data: DMatrix,
|
data: DMatrix,
|
||||||
@ -46,13 +75,28 @@ object XGBoost {
|
|||||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||||
}
|
}
|
||||||
|
|
||||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
/**
|
||||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
* load model from modelPath
|
||||||
|
*
|
||||||
|
* @param modelPath booster modelPath
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def loadModel(modelPath: String): Booster = {
|
||||||
|
val xgboostInJava = JXGBoost.loadModel(modelPath)
|
||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
/**
|
||||||
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
* Load a new Booster model from a file opened as input stream.
|
||||||
|
* The assumption is the input stream only contains one XGBoost Model.
|
||||||
|
* This can be used to load existing booster models saved by other XGBoost bindings.
|
||||||
|
*
|
||||||
|
* @param in The input stream of the file.
|
||||||
|
* @return The create booster
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def loadModel(in: InputStream): Booster = {
|
||||||
|
val xgboostInJava = JXGBoost.loadModel(in)
|
||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,6 +15,10 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@ -67,7 +71,7 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBoosterBasic() throws XGBoostError {
|
public void testBoosterBasic() throws XGBoostError, IOException {
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
@ -94,15 +98,20 @@ public class BoosterImplTest {
|
|||||||
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
//predict raw output
|
//predict raw output
|
||||||
float[][] predicts = booster.predict(testMat, true);
|
float[][] predicts = booster.predict(testMat, true, 0);
|
||||||
|
|
||||||
//eval
|
//eval
|
||||||
IEvaluation eval = new EvalError();
|
IEvaluation eval = new EvalError();
|
||||||
//error must be less than 0.1
|
//error must be less than 0.1
|
||||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||||
|
|
||||||
//test dump model
|
// save and load
|
||||||
|
File temp = File.createTempFile("temp", "model");
|
||||||
|
temp.deleteOnExit();
|
||||||
|
booster.saveModel(temp.getAbsolutePath());
|
||||||
|
|
||||||
|
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath()));
|
||||||
|
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user