Merge pull request #926 from tqchen/master

[JVM] Refactor, add filesys API
This commit is contained in:
Tianqi Chen 2016-03-06 11:49:01 -08:00
commit cf2a7851eb
17 changed files with 597 additions and 896 deletions

View File

@ -82,16 +82,16 @@ public class BasicWalkThrough {
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
booster.getModelDump("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
booster.getModelDump("../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
Booster booster2 = XGBoost.loadModel("./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);

View File

@ -1,79 +0,0 @@
package ml.dmlc.xgboost4j.java.demo;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.java.*;
/**
* Distributed training example, used to quick test distributed training.
*
* @author tqchen
*/
public class DistTrain {
private static final Log logger = LogFactory.getLog(DistTrain.class);
private Map<String, String> envs = null;
private class Worker implements Runnable {
private final int workerId;
Worker(int workerId) {
this.workerId = workerId;
}
public void run() {
try {
Map<String, String> worker_env = new HashMap<String, String>(envs);
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
// always initialize rabit module before training.
Rabit.init(worker_env);
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("nthread", 2);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
// always shutdown rabit module after training.
Rabit.shutdown();
} catch (Exception ex){
logger.error(ex);
}
}
}
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nWorkers);
if (tracker.start()) {
envs = tracker.getWorkerEnvs();
for (int i = 0; i < nWorkers; ++i) {
new Thread(new Worker(i)).start();
}
tracker.waitFor();
}
}
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
new DistTrain().start(Integer.parseInt(args[0]));
}
}

View File

@ -52,13 +52,13 @@ public class PredictLeafIndices {
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
float[][] leafindex = booster.predictLeaf(testMat, 2);
for (float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
leafindex = booster.predictLeaf(testMat, 0);
for (float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}

View File

@ -37,6 +37,8 @@ object Test {
"objective" -> "binary:logistic").toMap
val round = 2
val model = XGBoost.train(paramMap, data, round)
log.info(model)
}
}

View File

@ -25,6 +25,9 @@ import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala._
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
object XGBoost {
/**
@ -60,6 +63,20 @@ object XGBoost {
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModel(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(
FileSystem
.get(new Configuration)
.open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*

View File

@ -16,8 +16,45 @@
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala._
import org.apache.flink.ml.math.Vector
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
class XGBoostModel (booster: Booster) extends Serializable {
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModel(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint.fromSparseVector(0.0f,
index.toArray, value.map(z => z.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)
this.booster.predict(dmat)
}
data.mapPartition(predictMap)
}
}

View File

@ -1,41 +1,146 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public interface Booster extends Serializable {
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBOost Model
*/
public class Booster implements Serializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* set parameter
* Create a new Booster with empty stage.
*
* @param params Model parameters
* @param cacheMats Cached DMatrix entries,
* the prediction of these DMatrices will become faster than not-cached data.
* @throws XGBoostError native error
*/
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
init(cacheMats);
setParam("seed", "0");
setParams(params);
}
/**
* Load a new Booster model from modelPath
* @param modelPath The path to the model.
* @return The created Booster.
* @throws XGBoostError
*/
static Booster loadModel(String modelPath) throws XGBoostError {
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
return ret;
}
/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param in The input stream of the file.
* @return The create boosted
* @throws XGBoostError
* @throws IOException
*/
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
int size;
byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
return ret;
}
/**
* Set parameter to the Booster.
*
* @param key param name
* @param value param value
* @throws XGBoostError native error
*/
void setParam(String key, String value) throws XGBoostError;
public final void setParam(String key, Object value) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString()));
}
/**
* set parameters
* Set parameters to the Booster.
*
* @param params parameters key-value map
* @throws XGBoostError native error
*/
void setParams(Map<String, Object> params) throws XGBoostError;
public void setParams(Map<String, Object> params) throws XGBoostError {
if (params != null) {
for (Map.Entry<String, Object> entry : params.entrySet()) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
* Update the booster for one iteration.
*
* @param dtrain training data
* @param iter current iteration number
* @throws XGBoostError native error
*/
void update(DMatrix dtrain, int iter) throws XGBoostError;
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
* Update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
* @throws XGBoostError native error
*/
void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
@ -43,8 +148,16 @@ public interface Booster extends Serializable {
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
dtrain.getHandle(), grad, hess));
}
/**
* evaluate with given dmatrixs.
@ -53,8 +166,15 @@ public interface Booster extends Serializable {
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws XGBoostError native error
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
@ -63,60 +183,171 @@ public interface Booster extends Serializable {
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
* @throws XGBoostError native error
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
}
return evalInfo;
}
/**
* Advanced predict function with all the options.
*
* @param data data
* @param outputMargin output margin
* @param treeLimit limit number of trees, 0 means all trees.
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] predict(DMatrix data,
boolean outputMargin,
int treeLimit,
boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if (outputMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict leaf indices given the data
*
* @param data The input data.
* @param treeLimit Number of trees to include, 0 means all trees.
* @return The leaf indices of the instance.
* @throws XGBoostError
*/
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
return this.predict(data, false, treeLimit, true);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
*/
float[][] predict(DMatrix data) throws XGBoostError;
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
*/
float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
/**
* Predict with data
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
public float[][] predict(DMatrix data) throws XGBoostError {
return this.predict(data, false, 0, false);
}
/**
* save model to modelPath, the model path support depends on the path support
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
* compiled with HDFS support.
* See also toByteArray
* Predict with data
*
* @param data data
* @param outputMargin output margin
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
return this.predict(data, outputMargin, 0, false);
}
/**
* Advanced predict function with all the options.
*
* @param data data
* @param outputMargin output margin
* @param treeLimit limit number of trees, 0 means all trees.
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
return this.predict(data, outputMargin, treeLimit, false);
}
/**
* Save model to modelPath
*
* @param modelPath model path
*/
void saveModel(String modelPath) throws XGBoostError;
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
/**
* Save the model to file opened as output stream.
* The model format is compatible with other xgboost bindings.
* The output stream can only save one xgboost model.
* This function will close the OutputStream after the save.
*
* @param out The output stream
*/
public void saveModel(OutputStream out) throws XGBoostError, IOException {
out.write(this.toByteArray());
out.close();
}
/**
* Get the dump of the model as a string array
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (featureMap == null) {
featureMap = "";
}
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(
XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
return modelInfos[0];
}
/**
* Get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score, can be nill
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getModelDump(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* Save the model as byte array representation.
@ -127,41 +358,93 @@ public interface Booster extends Serializable {
* @return the saved byte array.
* @throws XGBoostError
*/
byte[] toByteArray() throws XGBoostError;
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void dumpModel(String modelPath, String featureMap, boolean withStats)
throws IOException, XGBoostError;
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* Internal initialization function.
* @param cacheMats The cached DMatrix.
* @throws XGBoostError
*/
Map<String, Integer> getFeatureScore() throws XGBoostError ;
private void init(DMatrix[] cacheMats) throws XGBoostError {
long[] handles = null;
if (cacheMats != null) {
handles = dmatrixsToHandles(cacheMats);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* get importance of each feature
* transfer DMatrix array to handle array (used for native functions)
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @param dmatrixs
* @return handle array for input dmatrixs
*/
Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError;
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
void dispose();
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
public synchronized void dispose() {
if (handle != 0L) {
XGBoostJNI.XGBoosterFree(handle);
handle = 0;
}
}
}

View File

@ -35,7 +35,7 @@ public class DMatrix {
//load native library
static {
try {
NativeLibLoader.initXgBoost();
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);

View File

@ -30,7 +30,7 @@ class JNIErrorHandle {
//load native library
static {
try {
NativeLibLoader.initXgBoost();
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);

View File

@ -1,525 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
*
* @author hzx
*/
class JavaBoosterImpl implements Booster {
private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
init(dMatrixs);
setParam("seed", "0");
setParams(params);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, String modelPath) throws XGBoostError {
init(null);
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed", "0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if (dMatrixs != null) {
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
*
* @param key param name
* @param value param value
* @throws XGBoostError native error
*/
public final void setParam(String key, String value) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
*
* @param params parameters key-value map
* @throws XGBoostError native error
*/
public void setParams(Map<String, Object> params) throws XGBoostError {
if (params != null) {
for (Map.Entry<String, Object> entry : params.entrySet()) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
}
return evalInfo;
}
/**
* base function for Predict
*
* @param data data
* @param outPutMargin output margin
* @param treeLimit limit number of trees
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if (outPutMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return pred(data, false, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index
* of each sample in each tree.
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
* in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath
*
* @param modelPath model path
*/
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
private void loadModel(String modelPath) {
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* get the dump of the model as a string array
*
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException file not found
* @throws UnsupportedEncodingException unsupported feature
* @throws IOException error with model writing
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException exception
* @throws UnsupportedEncodingException exception
* @throws IOException exception
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore() throws XGBoostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
* transfer DMatrix array to handle array (used for native functions)
*
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
public synchronized void dispose() {
if (handle != 0L) {
XGBoostJNI.XGBoosterFree(handle);
handle = 0;
}
}
}

View File

@ -35,7 +35,7 @@ class NativeLibLoader {
private static final String nativeResourcePath = "/lib/";
private static final String[] libNames = new String[]{"xgboost4j"};
public static synchronized void initXgBoost() throws IOException {
public static synchronized void initXGBoost() throws IOException {
if (!initialized) {
for (String libName : libNames) {
smartLoad(libName);

View File

@ -14,7 +14,7 @@ public class Rabit {
//load native library
static {
try {
NativeLibLoader.initXgBoost();
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);

View File

@ -15,6 +15,8 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import org.apache.commons.logging.Log;
@ -28,6 +30,33 @@ import org.apache.commons.logging.LogFactory;
public class XGBoost {
private static final Log logger = LogFactory.getLog(XGBoost.class);
/**
* load model from modelPath
*
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
public static Booster loadModel(String modelPath)
throws XGBoostError {
return Booster.loadModel(modelPath);
}
/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param in The input stream of the file,
* will be closed after this function call.
* @return The create boosted
* @throws XGBoostError
* @throws IOException
*/
public static Booster loadModel(InputStream in)
throws XGBoostError, IOException {
return Booster.loadModel(in);
}
/**
* Train a booster with given parameters.
*
@ -41,9 +70,11 @@ public class XGBoost {
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster train(Map<String, Object> params, DMatrix dtrain, int round,
Map<String, DMatrix> watches, IObjective obj,
IEvaluation eval) throws XGBoostError {
public static Booster train(Map<String, Object> params,
DMatrix dtrain, int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@ -71,7 +102,7 @@ public class XGBoost {
}
//initialize booster
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
Booster booster = new Booster(params, allMats);
int version = booster.loadRabitCheckpoint();
@ -106,32 +137,7 @@ public class XGBoost {
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
public static Booster initBoostingModel(
Map<String, Object> params,
DMatrix[] dMatrixs) throws XGBoostError {
return new JavaBoosterImpl(params, dMatrixs);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
throws XGBoostError {
return new JavaBoosterImpl(params, modelPath);
}
/**
* Cross-validation with given paramaters.
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
@ -294,7 +300,7 @@ public class XGBoost {
public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params)
throws XGBoostError {
dmats = new DMatrix[]{dtrain, dtest};
booster = XGBoost.initBoostingModel(params, dmats);
booster = new Booster(params, dmats);
names = new String[]{"train", "test"};
this.dtrain = dtrain;
this.dtest = dtest;

View File

@ -18,20 +18,23 @@ package ml.dmlc.xgboost4j.scala
import java.io.IOException
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
trait Booster extends Serializable {
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
/**
* set parameter
*
* @param key param name
* @param value param value
*/
* Set parameter to the Booster.
*
* @param key param name
* @param value param value
*/
@throws(classOf[XGBoostError])
def setParam(key: String, value: String)
def setParam(key: String, value: AnyRef): Unit = {
booster.setParam(key, value)
}
/**
* set parameters
@ -39,7 +42,9 @@ trait Booster extends Serializable {
* @param params parameters key-value map
*/
@throws(classOf[XGBoostError])
def setParams(params: Map[String, AnyRef])
def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
/**
* Update (one iteration)
@ -48,7 +53,9 @@ trait Booster extends Serializable {
* @param iter current iteration number
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, iter: Int)
def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
/**
* update with customize obj func
@ -57,7 +64,9 @@ trait Booster extends Serializable {
* @param obj customized objective class
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait)
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
/**
* update with give grad and hess
@ -67,7 +76,9 @@ trait Booster extends Serializable {
* @param hess seconde order of gradient
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
/**
* evaluate with given dmatrixs.
@ -78,7 +89,10 @@ trait Booster extends Serializable {
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int)
: String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
/**
* evaluate with given customized Evaluation class
@ -89,26 +103,11 @@ trait Booster extends Serializable {
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait)
: String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
/**
* Predict with data
@ -119,22 +118,24 @@ trait Booster extends Serializable {
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
: Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
/**
* Predict with data
* Predict the leaf indices
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
: Array[Array[Float]] = {
booster.predictLeaf(data.jDMatrix, treeLimit)
}
/**
* save model to modelPath
@ -142,46 +143,50 @@ trait Booster extends Serializable {
* @param modelPath model path
*/
@throws(classOf[XGBoostError])
def saveModel(modelPath: String)
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
* save model to Output stream
*
* @param out Output stream
*/
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, withStats: Boolean)
def saveModel(out: java.io.OutputStream): Unit = {
booster.saveModel(out)
}
/**
* Dump model into a text file.
* Dump model as Array of string
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
def getModelDump(featureMap: String = null, withStats: Boolean = false)
: Array[String] = {
booster.getModelDump(featureMap, withStats)
}
/**
* get importance of each feature
* Get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore: mutable.Map[String, Integer]
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
* Dispose the booster when it is no longer needed
*/
def dispose: Unit = {
booster.dispose()
}
def dispose
override def finalize(): Unit = {
super.finalize()
dispose
}
}

View File

@ -1,99 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java
import scala.collection.JavaConverters._
import scala.collection.mutable
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
override def setParam(key: String, value: String): Unit = {
booster.setParam(key, value)
}
override def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, withStats)
}
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, featureMap, withStats)
}
override def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
override def dispose: Unit = {
booster.dispose()
}
override def predict(data: DMatrix): Array[Array[Float]] = {
booster.predict(data.jDMatrix)
}
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin)
}
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, treeLimit, predLeaf)
}
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
override def getFeatureScore: mutable.Map[String, Integer] = {
booster.getFeatureScore.asScala
}
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
override def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
override def finalize(): Unit = {
super.finalize()
dispose
}
}

View File

@ -16,11 +16,28 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
import scala.collection.JavaConverters._
/**
* XGBoost Scala Training function.
*/
object XGBoost {
/**
* Train a booster given parameters.
*
* @param params Parameters.
* @param dtrain Data to be trained.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
def train(
params: Map[String, AnyRef],
dtrain: DMatrix,
@ -31,9 +48,22 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
obj, eval)
new ScalaBoosterImpl(xgboostInJava)
new Booster(xgboostInJava)
}
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
* @param obj customized objective
* @param eval customized evaluation
* @return evaluation history
*/
@throws(classOf[XGBoostError])
def crossValidation(
params: Map[String, AnyRef],
data: DMatrix,
@ -45,13 +75,28 @@ object XGBoost {
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
new ScalaBoosterImpl(xgboostInJava)
/**
* load model from modelPath
*
* @param modelPath booster modelPath
*/
@throws(classOf[XGBoostError])
def loadModel(modelPath: String): Booster = {
val xgboostInJava = JXGBoost.loadModel(modelPath)
new Booster(xgboostInJava)
}
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
new ScalaBoosterImpl(xgboostInJava)
/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* This can be used to load existing booster models saved by other XGBoost bindings.
*
* @param in The input stream of the file.
* @return The create booster
*/
@throws(classOf[XGBoostError])
def loadModel(in: InputStream): Booster = {
val xgboostInJava = JXGBoost.loadModel(in)
new Booster(xgboostInJava)
}
}

View File

@ -15,6 +15,10 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@ -67,7 +71,7 @@ public class BoosterImplTest {
}
@Test
public void testBoosterBasic() throws XGBoostError {
public void testBoosterBasic() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
@ -94,15 +98,20 @@ public class BoosterImplTest {
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
//predict raw output
float[][] predicts = booster.predict(testMat, true);
float[][] predicts = booster.predict(testMat, true, 0);
//eval
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
//test dump model
// save and load
File temp = File.createTempFile("temp", "model");
temp.deleteOnExit();
booster.saveModel(temp.getAbsolutePath());
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath()));
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
}
/**