From 56f7a414d1f0bac549e7b853a8f67db9bcf9c518 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 6 Mar 2016 11:33:48 -0800 Subject: [PATCH] [JVM] Refactor, add filesys API --- .../xgboost4j/java/demo/BasicWalkThrough.java | 6 +- .../dmlc/xgboost4j/java/demo/DistTrain.java | 79 --- .../java/demo/PredictLeafIndices.java | 4 +- .../scala/ml/dmlc/xgboost4j/flink/Test.scala | 2 + .../ml/dmlc/xgboost4j/flink/XGBoost.scala | 17 + .../dmlc/xgboost4j/flink/XGBoostModel.scala | 39 +- .../java/ml/dmlc/xgboost4j/java/Booster.java | 437 ++++++++++++--- .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 2 +- .../dmlc/xgboost4j/java/JNIErrorHandle.java | 2 +- .../dmlc/xgboost4j/java/JavaBoosterImpl.java | 525 ------------------ .../dmlc/xgboost4j/java/NativeLibLoader.java | 2 +- .../java/ml/dmlc/xgboost4j/java/Rabit.java | 2 +- .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 68 +-- .../ml/dmlc/xgboost4j/scala/Booster.scala | 131 ++--- .../xgboost4j/scala/ScalaBoosterImpl.scala | 99 ---- .../ml/dmlc/xgboost4j/scala/XGBoost.scala | 63 ++- .../dmlc/xgboost4j/java/BoosterImplTest.java | 15 +- 17 files changed, 597 insertions(+), 896 deletions(-) delete mode 100644 jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/DistTrain.java delete mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/BasicWalkThrough.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/BasicWalkThrough.java index b0de7d0e0..d13bfbcd5 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/BasicWalkThrough.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/BasicWalkThrough.java @@ -82,16 +82,16 @@ public class BasicWalkThrough { booster.saveModel(modelPath); //dump model - booster.dumpModel("./model/dump.raw.txt", false); + booster.getModelDump("./model/dump.raw.txt", false); //dump model with feature map - booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); + booster.getModelDump("../../demo/data/featmap.txt", false); //save dmatrix into binary buffer testMat.saveBinary("./model/dtest.buffer"); //reload model and data - Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model"); + Booster booster2 = XGBoost.loadModel("./model/xgb.model"); DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); float[][] predicts2 = booster2.predict(testMat2); diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/DistTrain.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/DistTrain.java deleted file mode 100644 index 89c9bcdb0..000000000 --- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/DistTrain.java +++ /dev/null @@ -1,79 +0,0 @@ -package ml.dmlc.xgboost4j.java.demo; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import ml.dmlc.xgboost4j.java.*; - -/** - * Distributed training example, used to quick test distributed training. - * - * @author tqchen - */ -public class DistTrain { - private static final Log logger = LogFactory.getLog(DistTrain.class); - private Map envs = null; - - private class Worker implements Runnable { - private final int workerId; - - Worker(int workerId) { - this.workerId = workerId; - } - - public void run() { - try { - Map worker_env = new HashMap(envs); - - worker_env.put("DMLC_TASK_ID", String.valueOf(workerId)); - // always initialize rabit module before training. - Rabit.init(worker_env); - - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("nthread", 2); - params.put("objective", "binary:logistic"); - - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //set round - int round = 2; - - //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - - // always shutdown rabit module after training. - Rabit.shutdown(); - } catch (Exception ex){ - logger.error(ex); - } - } - } - - void start(int nWorkers) throws IOException, XGBoostError, InterruptedException { - RabitTracker tracker = new RabitTracker(nWorkers); - if (tracker.start()) { - envs = tracker.getWorkerEnvs(); - for (int i = 0; i < nWorkers; ++i) { - new Thread(new Worker(i)).start(); - } - tracker.waitFor(); - } - } - - public static void main(String[] args) throws IOException, XGBoostError, InterruptedException { - new DistTrain().start(Integer.parseInt(args[0])); - } -} \ No newline at end of file diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/PredictLeafIndices.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/PredictLeafIndices.java index 420f38111..9c3f20b90 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/PredictLeafIndices.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/java/demo/PredictLeafIndices.java @@ -52,13 +52,13 @@ public class PredictLeafIndices { Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); //predict using first 2 tree - float[][] leafindex = booster.predict(testMat, 2, true); + float[][] leafindex = booster.predictLeaf(testMat, 2); for (float[] leafs : leafindex) { System.out.println(Arrays.toString(leafs)); } //predict all trees - leafindex = booster.predict(testMat, 0, true); + leafindex = booster.predictLeaf(testMat, 0); for (float[] leafs : leafindex) { System.out.println(Arrays.toString(leafs)); } diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala index e55e61702..3637badf0 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/Test.scala @@ -37,6 +37,8 @@ object Test { "objective" -> "binary:logistic").toMap val round = 2 val model = XGBoost.train(paramMap, data, round) + + log.info(model) } } diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala index 8f1e8260a..8f71cdd09 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoost.scala @@ -25,6 +25,9 @@ import org.apache.flink.api.scala.DataSet import org.apache.flink.api.scala._ import org.apache.flink.ml.common.LabeledVector import org.apache.flink.util.Collector +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration object XGBoost { /** @@ -60,6 +63,20 @@ object XGBoost { val logger = LogFactory.getLog(this.getClass) + /** + * Load XGBoost model from path, using Hadoop Filesystem API. + * + * @param modelPath The path that is accessible by hadoop filesystem API. + * @return The loaded model + */ + def loadModel(modelPath: String) : XGBoostModel = { + new XGBoostModel( + XGBoostScala.loadModel( + FileSystem + .get(new Configuration) + .open(new Path(modelPath)))) + } + /** * Train a xgboost model with link. * diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala index 4197bd724..ce072fd10 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/flink/XGBoostModel.scala @@ -16,8 +16,45 @@ package ml.dmlc.xgboost4j.flink -import ml.dmlc.xgboost4j.scala.Booster +import ml.dmlc.xgboost4j.LabeledPoint +import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} +import org.apache.flink.api.scala.DataSet +import org.apache.flink.api.scala._ +import org.apache.flink.ml.math.Vector +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration class XGBoostModel (booster: Booster) extends Serializable { + /** + * Save the model as a Hadoop filesystem file. + * + * @param modelPath The model path as in Hadoop path. + */ + def saveModel(modelPath: String): Unit = { + booster.saveModel(FileSystem + .get(new Configuration) + .create(new Path(modelPath))) + } + /** + * Predict given vector dataset. + * + * @param data The dataset to be predicted. + * @return The prediction result. + */ + def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = { + val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] = + (it: Iterator[Vector]) => { + val mapper = (x: Vector) => { + val (index, value) = x.toSeq.unzip + LabeledPoint.fromSparseVector(0.0f, + index.toArray, value.map(z => z.toFloat).toArray) + } + val dataIter = for (x <- it) yield mapper(x) + val dmat = new DMatrix(dataIter, null) + this.booster.predict(dmat) + } + data.mapPartition(predictMap) + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index e6e427900..08abc1afc 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -1,41 +1,146 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ package ml.dmlc.xgboost4j.java; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; +import java.util.HashMap; +import java.util.List; import java.util.Map; -public interface Booster extends Serializable { +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * Booster for xgboost, this is a model API that support interactive build of a XGBOost Model + */ +public class Booster implements Serializable { + private static final Log logger = LogFactory.getLog(Booster.class); + // handle to the booster. + private long handle = 0; + + //load native library + static { + try { + NativeLibLoader.initXGBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } /** - * set parameter + * Create a new Booster with empty stage. + * + * @param params Model parameters + * @param cacheMats Cached DMatrix entries, + * the prediction of these DMatrices will become faster than not-cached data. + * @throws XGBoostError native error + */ + Booster(Map params, DMatrix[] cacheMats) throws XGBoostError { + init(cacheMats); + setParam("seed", "0"); + setParams(params); + } + + /** + * Load a new Booster model from modelPath + * @param modelPath The path to the model. + * @return The created Booster. + * @throws XGBoostError + */ + static Booster loadModel(String modelPath) throws XGBoostError { + if (modelPath == null) { + throw new NullPointerException("modelPath : null"); + } + Booster ret = new Booster(new HashMap(), new DMatrix[0]); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath)); + return ret; + } + + /** + * Load a new Booster model from a file opened as input stream. + * The assumption is the input stream only contains one XGBoost Model. + * This can be used to load existing booster models saved by other xgboost bindings. + * + * @param in The input stream of the file. + * @return The create boosted + * @throws XGBoostError + * @throws IOException + */ + static Booster loadModel(InputStream in) throws XGBoostError, IOException { + int size; + byte[] buf = new byte[1<<20]; + ByteArrayOutputStream os = new ByteArrayOutputStream(); + while ((size = in.read(buf)) != -1) { + os.write(buf, 0, size); + } + in.close(); + Booster ret = new Booster(new HashMap(), new DMatrix[0]); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray())); + return ret; + } + + /** + * Set parameter to the Booster. * * @param key param name * @param value param value + * @throws XGBoostError native error */ - void setParam(String key, String value) throws XGBoostError; + public final void setParam(String key, Object value) throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString())); + } /** - * set parameters + * Set parameters to the Booster. * * @param params parameters key-value map + * @throws XGBoostError native error */ - void setParams(Map params) throws XGBoostError; + public void setParams(Map params) throws XGBoostError { + if (params != null) { + for (Map.Entry entry : params.entrySet()) { + setParam(entry.getKey(), entry.getValue().toString()); + } + } + } /** - * Update (one iteration) + * Update the booster for one iteration. * * @param dtrain training data * @param iter current iteration number + * @throws XGBoostError native error */ - void update(DMatrix dtrain, int iter) throws XGBoostError; + public void update(DMatrix dtrain, int iter) throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); + } /** - * update with customize obj func + * Update with customize obj func * * @param dtrain training data * @param obj customized objective class + * @throws XGBoostError native error */ - void update(DMatrix dtrain, IObjective obj) throws XGBoostError; + public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { + float[][] predicts = this.predict(dtrain, true, 0, false); + List gradients = obj.getGradient(predicts, dtrain); + boost(dtrain, gradients.get(0), gradients.get(1)); + } /** * update with give grad and hess @@ -43,8 +148,16 @@ public interface Booster extends Serializable { * @param dtrain training data * @param grad first order of gradient * @param hess seconde order of gradient + * @throws XGBoostError native error */ - void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError; + public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { + if (grad.length != hess.length) { + throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, + hess.length)); + } + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, + dtrain.getHandle(), grad, hess)); + } /** * evaluate with given dmatrixs. @@ -53,8 +166,15 @@ public interface Booster extends Serializable { * @param evalNames name for eval dmatrixs, used for check results * @param iter current eval iteration * @return eval information + * @throws XGBoostError native error */ - String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError; + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { + long[] handles = dmatrixsToHandles(evalMatrixs); + String[] evalInfo = new String[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, + evalInfo)); + return evalInfo[0]; + } /** * evaluate with given customized Evaluation class @@ -63,60 +183,171 @@ public interface Booster extends Serializable { * @param evalNames evaluation names * @param eval custom evaluator * @return eval information + * @throws XGBoostError native error */ - String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError; + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) + throws XGBoostError { + String evalInfo = ""; + for (int i = 0; i < evalNames.length; i++) { + String evalName = evalNames[i]; + DMatrix evalMat = evalMatrixs[i]; + float evalResult = eval.eval(predict(evalMat), evalMat); + String evalMetric = eval.getMetric(); + evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); + } + return evalInfo; + } + + /** + * Advanced predict function with all the options. + * + * @param data data + * @param outputMargin output margin + * @param treeLimit limit number of trees, 0 means all trees. + * @param predLeaf prediction minimum to keep leafs + * @return predict results + */ + private synchronized float[][] predict(DMatrix data, + boolean outputMargin, + int treeLimit, + boolean predLeaf) throws XGBoostError { + int optionMask = 0; + if (outputMargin) { + optionMask = 1; + } + if (predLeaf) { + optionMask = 2; + } + float[][] rawPredicts = new float[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, + treeLimit, rawPredicts)); + int row = (int) data.rowNum(); + int col = rawPredicts[0].length / row; + float[][] predicts = new float[row][col]; + int r, c; + for (int i = 0; i < rawPredicts[0].length; i++) { + r = i / col; + c = i % col; + predicts[r][c] = rawPredicts[0][i]; + } + return predicts; + } + + /** + * Predict leaf indices given the data + * + * @param data The input data. + * @param treeLimit Number of trees to include, 0 means all trees. + * @return The leaf indices of the instance. + * @throws XGBoostError + */ + public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError { + return this.predict(data, false, treeLimit, true); + } /** * Predict with data * * @param data dmatrix storing the input * @return predict result - */ - float[][] predict(DMatrix data) throws XGBoostError; - - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @return predict result - */ - float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError; - - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @return predict result - */ - float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError; - - - /** - * Predict with data - * @param data dmatrix storing the input - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index of - * each sample in each tree. Note that the leaf index of a tree is unique per - * tree, so you may find leaf 1 in both tree 1 and tree 0. - * @return predict result * @throws XGBoostError native error */ - float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; + public float[][] predict(DMatrix data) throws XGBoostError { + return this.predict(data, false, 0, false); + } /** - * save model to modelPath, the model path support depends on the path support - * in libxgboost. For example, if we want to save to hdfs, libxgboost need to be - * compiled with HDFS support. - * See also toByteArray + * Predict with data + * + * @param data data + * @param outputMargin output margin + * @return predict results + */ + public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError { + return this.predict(data, outputMargin, 0, false); + } + + /** + * Advanced predict function with all the options. + * + * @param data data + * @param outputMargin output margin + * @param treeLimit limit number of trees, 0 means all trees. + * @return predict results + */ + public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError { + return this.predict(data, outputMargin, treeLimit, false); + } + + /** + * Save model to modelPath + * * @param modelPath model path */ - void saveModel(String modelPath) throws XGBoostError; + public void saveModel(String modelPath) throws XGBoostError{ + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); + } + + /** + * Save the model to file opened as output stream. + * The model format is compatible with other xgboost bindings. + * The output stream can only save one xgboost model. + * This function will close the OutputStream after the save. + * + * @param out The output stream + */ + public void saveModel(OutputStream out) throws XGBoostError, IOException { + out.write(this.toByteArray()); + out.close(); + } + + /** + * Get the dump of the model as a string array + * + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (featureMap == null) { + featureMap = ""; + } + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall( + XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos)); + return modelInfos[0]; + } + + /** + * Get importance of each feature + * + * @return featureMap key: feature index, value: feature importance score, can be nill + * @throws XGBoostError native error + */ + public Map getFeatureScore(String featureMap) throws XGBoostError { + String[] modelInfos = getModelDump(featureMap, false); + Map featureScore = new HashMap(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } /** * Save the model as byte array representation. @@ -127,41 +358,93 @@ public interface Booster extends Serializable { * @return the saved byte array. * @throws XGBoostError */ - byte[] toByteArray() throws XGBoostError; + public byte[] toByteArray() throws XGBoostError { + byte[][] bytes = new byte[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); + return bytes[0]; + } /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool Controls whether the split statistics are output. + * Load the booster model from thread-local rabit checkpoint. + * This is only used in distributed training. + * @return the stored version number of the checkpoint. + * @throws XGBoostError */ - void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError; + int loadRabitCheckpoint() throws XGBoostError { + int[] out = new int[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); + return out[0]; + } /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param featureMap featureMap file - * @param withStats bool - * Controls whether the split statistics are output. + * Save the booster model into thread-local rabit checkpoint. + * This is only used in distributed training. + * @throws XGBoostError */ - void dumpModel(String modelPath, String featureMap, boolean withStats) - throws IOException, XGBoostError; + void saveRabitCheckpoint() throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); + } /** - * get importance of each feature - * - * @return featureMap key: feature index, value: feature importance score + * Internal initialization function. + * @param cacheMats The cached DMatrix. + * @throws XGBoostError */ - Map getFeatureScore() throws XGBoostError ; + private void init(DMatrix[] cacheMats) throws XGBoostError { + long[] handles = null; + if (cacheMats != null) { + handles = dmatrixsToHandles(cacheMats); + } + long[] out = new long[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); + + handle = out[0]; + } /** - * get importance of each feature + * transfer DMatrix array to handle array (used for native functions) * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score + * @param dmatrixs + * @return handle array for input dmatrixs */ - Map getFeatureScore(String featureMap) throws XGBoostError; + private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { + long[] handles = new long[dmatrixs.length]; + for (int i = 0; i < dmatrixs.length; i++) { + handles[i] = dmatrixs[i].getHandle(); + } + return handles; + } - void dispose(); + // making Booster serializable + private void writeObject(java.io.ObjectOutputStream out) throws IOException { + try { + out.writeObject(this.toByteArray()); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } + + private void readObject(java.io.ObjectInputStream in) + throws IOException, ClassNotFoundException { + try { + this.init(null); + byte[] bytes = (byte[])in.readObject(); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } + + @Override + protected void finalize() throws Throwable { + super.finalize(); + dispose(); + } + + public synchronized void dispose() { + if (handle != 0L) { + XGBoostJNI.XGBoosterFree(handle); + handle = 0; + } + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 2a7461377..29ee6cf72 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -35,7 +35,7 @@ public class DMatrix { //load native library static { try { - NativeLibLoader.initXgBoost(); + NativeLibLoader.initXGBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java index bb64ee56f..c8888154b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java @@ -30,7 +30,7 @@ class JNIErrorHandle { //load native library static { try { - NativeLibLoader.initXgBoost(); + NativeLibLoader.initXGBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java deleted file mode 100644 index 37860dc6f..000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java +++ /dev/null @@ -1,525 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package ml.dmlc.xgboost4j.java; - -import java.io.*; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -/** - * Booster for xgboost, similar to the python wrapper xgboost.py - * but custom obj function and eval function not supported at present. - * - * @author hzx - */ -class JavaBoosterImpl implements Booster { - private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class); - - long handle = 0; - - //load native library - static { - try { - NativeLibLoader.initXgBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - - /** - * init Booster from dMatrixs - * - * @param params parameters - * @param dMatrixs DMatrix array - * @throws XGBoostError native error - */ - JavaBoosterImpl(Map params, DMatrix[] dMatrixs) throws XGBoostError { - init(dMatrixs); - setParam("seed", "0"); - setParams(params); - } - - /** - * load model from modelPath - * - * @param params parameters - * @param modelPath booster modelPath (model generated by booster.saveModel) - * @throws XGBoostError native error - */ - JavaBoosterImpl(Map params, String modelPath) throws XGBoostError { - init(null); - if (modelPath == null) { - throw new NullPointerException("modelPath : null"); - } - loadModel(modelPath); - setParam("seed", "0"); - setParams(params); - } - - - private void init(DMatrix[] dMatrixs) throws XGBoostError { - long[] handles = null; - if (dMatrixs != null) { - handles = dmatrixsToHandles(dMatrixs); - } - long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); - - handle = out[0]; - } - - /** - * set parameter - * - * @param key param name - * @param value param value - * @throws XGBoostError native error - */ - public final void setParam(String key, String value) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value)); - } - - /** - * set parameters - * - * @param params parameters key-value map - * @throws XGBoostError native error - */ - public void setParams(Map params) throws XGBoostError { - if (params != null) { - for (Map.Entry entry : params.entrySet()) { - setParam(entry.getKey(), entry.getValue().toString()); - } - } - } - - - /** - * Update (one iteration) - * - * @param dtrain training data - * @param iter current iteration number - * @throws XGBoostError native error - */ - public void update(DMatrix dtrain, int iter) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); - } - - /** - * update with customize obj func - * - * @param dtrain training data - * @param obj customized objective class - * @throws XGBoostError native error - */ - public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { - float[][] predicts = predict(dtrain, true); - List gradients = obj.getGradient(predicts, dtrain); - boost(dtrain, gradients.get(0), gradients.get(1)); - } - - /** - * update with give grad and hess - * - * @param dtrain training data - * @param grad first order of gradient - * @param hess seconde order of gradient - * @throws XGBoostError native error - */ - public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { - if (grad.length != hess.length) { - throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, - hess.length)); - } - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, - hess)); - } - - /** - * evaluate with given dmatrixs. - * - * @param evalMatrixs dmatrixs for evaluation - * @param evalNames name for eval dmatrixs, used for check results - * @param iter current eval iteration - * @return eval information - * @throws XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { - long[] handles = dmatrixsToHandles(evalMatrixs); - String[] evalInfo = new String[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, - evalInfo)); - return evalInfo[0]; - } - - /** - * evaluate with given customized Evaluation class - * - * @param evalMatrixs evaluation matrix - * @param evalNames evaluation names - * @param eval custom evaluator - * @return eval information - * @throws XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) - throws XGBoostError { - String evalInfo = ""; - for (int i = 0; i < evalNames.length; i++) { - String evalName = evalNames[i]; - DMatrix evalMat = evalMatrixs[i]; - float evalResult = eval.eval(predict(evalMat), evalMat); - String evalMetric = eval.getMetric(); - evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); - } - return evalInfo; - } - - /** - * base function for Predict - * - * @param data data - * @param outPutMargin output margin - * @param treeLimit limit number of trees - * @param predLeaf prediction minimum to keep leafs - * @return predict results - */ - private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, - boolean predLeaf) throws XGBoostError { - int optionMask = 0; - if (outPutMargin) { - optionMask = 1; - } - if (predLeaf) { - optionMask = 2; - } - float[][] rawPredicts = new float[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, - treeLimit, rawPredicts)); - int row = (int) data.rowNum(); - int col = rawPredicts[0].length / row; - float[][] predicts = new float[row][col]; - int r, c; - for (int i = 0; i < rawPredicts[0].length; i++) { - r = i / col; - c = i % col; - predicts[r][c] = rawPredicts[0][i]; - } - return predicts; - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data) throws XGBoostError { - return pred(data, false, 0, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError { - return pred(data, outPutMargin, 0, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError { - return pred(data, outPutMargin, treeLimit, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index - * of each sample in each tree. - * Note that the leaf index of a tree is unique per tree, so you may find leaf 1 - * in both tree 1 and tree 0. - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError { - return pred(data, false, treeLimit, predLeaf); - } - - /** - * save model to modelPath - * - * @param modelPath model path - */ - public void saveModel(String modelPath) throws XGBoostError{ - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); - } - - private void loadModel(String modelPath) { - XGBoostJNI.XGBoosterLoadModel(handle, modelPath); - } - - /** - * get the dump of the model as a string array - * - * @param withStats Controls whether the split statistics are output. - * @return dumped model information - * @throws XGBoostError native error - */ - private String[] getDumpInfo(boolean withStats) throws XGBoostError { - int statsFlag = 0; - if (withStats) { - statsFlag = 1; - } - String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); - return modelInfos[0]; - } - - /** - * get the dump of the model as a string array - * - * @param featureMap featureMap file - * @param withStats Controls whether the split statistics are output. - * @return dumped model information - * @throws XGBoostError native error - */ - private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError { - int statsFlag = 0; - if (withStats) { - statsFlag = 1; - } - String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, - modelInfos)); - return modelInfos[0]; - } - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool - * Controls whether the split statistics are output. - * @throws FileNotFoundException file not found - * @throws UnsupportedEncodingException unsupported feature - * @throws IOException error with model writing - * @throws XGBoostError native error - */ - public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError { - File tf = new File(modelPath); - FileOutputStream out = new FileOutputStream(tf); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); - String[] modelInfos = getDumpInfo(withStats); - - for (int i = 0; i < modelInfos.length; i++) { - writer.write("booster [" + i + "]:\n"); - writer.write(modelInfos[i]); - } - - writer.close(); - out.close(); - } - - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param featureMap featureMap file - * @param withStats bool - * Controls whether the split statistics are output. - * @throws FileNotFoundException exception - * @throws UnsupportedEncodingException exception - * @throws IOException exception - * @throws XGBoostError native error - */ - public void dumpModel(String modelPath, String featureMap, boolean withStats) throws - IOException, XGBoostError { - File tf = new File(modelPath); - FileOutputStream out = new FileOutputStream(tf); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); - String[] modelInfos = getDumpInfo(featureMap, withStats); - - for (int i = 0; i < modelInfos.length; i++) { - writer.write("booster [" + i + "]:\n"); - writer.write(modelInfos[i]); - } - - writer.close(); - out.close(); - } - - - /** - * get importance of each feature - * - * @return featureMap key: feature index, value: feature importance score - * @throws XGBoostError native error - */ - public Map getFeatureScore() throws XGBoostError { - String[] modelInfos = getDumpInfo(false); - Map featureScore = new HashMap(); - for (String tree : modelInfos) { - for (String node : tree.split("\n")) { - String[] array = node.split("\\["); - if (array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if (featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - - /** - * get importance of each feature - * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - * @throws XGBoostError native error - */ - public Map getFeatureScore(String featureMap) throws XGBoostError { - String[] modelInfos = getDumpInfo(featureMap, false); - Map featureScore = new HashMap(); - for (String tree : modelInfos) { - for (String node : tree.split("\n")) { - String[] array = node.split("\\["); - if (array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if (featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - /** - * Save the model as byte array representation. - * Write these bytes to a file will give compatible format with other xgboost bindings. - * - * If java natively support HDFS file API, use toByteArray and write the ByteArray, - * - * @return the saved byte array. - * @throws XGBoostError - */ - public byte[] toByteArray() throws XGBoostError { - byte[][] bytes = new byte[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); - return bytes[0]; - } - - - /** - * Load the booster model from thread-local rabit checkpoint. - * This is only used in distributed training. - * @return the stored version number of the checkpoint. - * @throws XGBoostError - */ - int loadRabitCheckpoint() throws XGBoostError { - int[] out = new int[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); - return out[0]; - } - - /** - * Save the booster model into thread-local rabit checkpoint. - * This is only used in distributed training. - * @throws XGBoostError - */ - void saveRabitCheckpoint() throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); - } - - /** - * transfer DMatrix array to handle array (used for native functions) - * - * @param dmatrixs - * @return handle array for input dmatrixs - */ - private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { - long[] handles = new long[dmatrixs.length]; - for (int i = 0; i < dmatrixs.length; i++) { - handles[i] = dmatrixs[i].getHandle(); - } - return handles; - } - - // making Booster serializable - private void writeObject(java.io.ObjectOutputStream out) throws IOException { - try { - out.writeObject(this.toByteArray()); - } catch (XGBoostError ex) { - throw new IOException(ex.toString()); - } - } - - private void readObject(java.io.ObjectInputStream in) - throws IOException, ClassNotFoundException { - try { - this.init(null); - byte[] bytes = (byte[])in.readObject(); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); - } catch (XGBoostError ex) { - throw new IOException(ex.toString()); - } - } - - @Override - protected void finalize() throws Throwable { - super.finalize(); - dispose(); - } - - public synchronized void dispose() { - if (handle != 0L) { - XGBoostJNI.XGBoosterFree(handle); - handle = 0; - } - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java index 199b9ae42..0ab2ca077 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java @@ -35,7 +35,7 @@ class NativeLibLoader { private static final String nativeResourcePath = "/lib/"; private static final String[] libNames = new String[]{"xgboost4j"}; - public static synchronized void initXgBoost() throws IOException { + public static synchronized void initXGBoost() throws IOException { if (!initialized) { for (String libName : libNames) { smartLoad(libName); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index d8408d26c..0bc069048 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -14,7 +14,7 @@ public class Rabit { //load native library static { try { - NativeLibLoader.initXgBoost(); + NativeLibLoader.initXGBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 44b7425a2..09159e53f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -15,6 +15,8 @@ */ package ml.dmlc.xgboost4j.java; +import java.io.IOException; +import java.io.InputStream; import java.util.*; import org.apache.commons.logging.Log; @@ -28,6 +30,33 @@ import org.apache.commons.logging.LogFactory; public class XGBoost { private static final Log logger = LogFactory.getLog(XGBoost.class); + /** + * load model from modelPath + * + * @param modelPath booster modelPath (model generated by booster.saveModel) + * @throws XGBoostError native error + */ + public static Booster loadModel(String modelPath) + throws XGBoostError { + return Booster.loadModel(modelPath); + } + + /** + * Load a new Booster model from a file opened as input stream. + * The assumption is the input stream only contains one XGBoost Model. + * This can be used to load existing booster models saved by other xgboost bindings. + * + * @param in The input stream of the file, + * will be closed after this function call. + * @return The create boosted + * @throws XGBoostError + * @throws IOException + */ + public static Booster loadModel(InputStream in) + throws XGBoostError, IOException { + return Booster.loadModel(in); + } + /** * Train a booster with given parameters. * @@ -41,9 +70,11 @@ public class XGBoost { * @return trained booster * @throws XGBoostError native error */ - public static Booster train(Map params, DMatrix dtrain, int round, - Map watches, IObjective obj, - IEvaluation eval) throws XGBoostError { + public static Booster train(Map params, + DMatrix dtrain, int round, + Map watches, + IObjective obj, + IEvaluation eval) throws XGBoostError { //collect eval matrixs String[] evalNames; @@ -71,7 +102,7 @@ public class XGBoost { } //initialize booster - JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats); + Booster booster = new Booster(params, allMats); int version = booster.loadRabitCheckpoint(); @@ -106,32 +137,7 @@ public class XGBoost { } /** - * init Booster from dMatrixs - * - * @param params parameters - * @param dMatrixs DMatrix array - * @throws XGBoostError native error - */ - public static Booster initBoostingModel( - Map params, - DMatrix[] dMatrixs) throws XGBoostError { - return new JavaBoosterImpl(params, dMatrixs); - } - - /** - * load model from modelPath - * - * @param params parameters - * @param modelPath booster modelPath (model generated by booster.saveModel) - * @throws XGBoostError native error - */ - public static Booster loadBoostModel(Map params, String modelPath) - throws XGBoostError { - return new JavaBoosterImpl(params, modelPath); - } - - /** - * Cross-validation with given paramaters. + * Cross-validation with given parameters. * * @param params Booster params. * @param data Data to be trained. @@ -294,7 +300,7 @@ public class XGBoost { public CVPack(DMatrix dtrain, DMatrix dtest, Map params) throws XGBoostError { dmats = new DMatrix[]{dtrain, dtest}; - booster = XGBoost.initBoostingModel(params, dmats); + booster = new Booster(params, dmats); names = new String[]{"train", "test"}; this.dtrain = dtrain; this.dtest = dtest; diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 88032a61a..bf1bedf93 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -18,20 +18,23 @@ package ml.dmlc.xgboost4j.scala import java.io.IOException +import ml.dmlc.xgboost4j.java.{Booster => JBooster} import ml.dmlc.xgboost4j.java.XGBoostError +import scala.collection.JavaConverters._ import scala.collection.mutable -trait Booster extends Serializable { - +class Booster private[xgboost4j](booster: JBooster) extends Serializable { /** - * set parameter - * - * @param key param name - * @param value param value - */ + * Set parameter to the Booster. + * + * @param key param name + * @param value param value + */ @throws(classOf[XGBoostError]) - def setParam(key: String, value: String) + def setParam(key: String, value: AnyRef): Unit = { + booster.setParam(key, value) + } /** * set parameters @@ -39,7 +42,9 @@ trait Booster extends Serializable { * @param params parameters key-value map */ @throws(classOf[XGBoostError]) - def setParams(params: Map[String, AnyRef]) + def setParams(params: Map[String, AnyRef]): Unit = { + booster.setParams(params.asJava) + } /** * Update (one iteration) @@ -48,7 +53,9 @@ trait Booster extends Serializable { * @param iter current iteration number */ @throws(classOf[XGBoostError]) - def update(dtrain: DMatrix, iter: Int) + def update(dtrain: DMatrix, iter: Int): Unit = { + booster.update(dtrain.jDMatrix, iter) + } /** * update with customize obj func @@ -57,7 +64,9 @@ trait Booster extends Serializable { * @param obj customized objective class */ @throws(classOf[XGBoostError]) - def update(dtrain: DMatrix, obj: ObjectiveTrait) + def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { + booster.update(dtrain.jDMatrix, obj) + } /** * update with give grad and hess @@ -67,7 +76,9 @@ trait Booster extends Serializable { * @param hess seconde order of gradient */ @throws(classOf[XGBoostError]) - def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]) + def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { + booster.boost(dtrain.jDMatrix, grad, hess) + } /** * evaluate with given dmatrixs. @@ -78,7 +89,10 @@ trait Booster extends Serializable { * @return eval information */ @throws(classOf[XGBoostError]) - def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int) + : String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) + } /** * evaluate with given customized Evaluation class @@ -89,26 +103,11 @@ trait Booster extends Serializable { * @return eval information */ @throws(classOf[XGBoostError]) - def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait) + : String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) + } - /** - * Predict with data - * - * @param data dmatrix storing the input - * @return predict result - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix): Array[Array[Float]] - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @return predict result - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] /** * Predict with data @@ -119,22 +118,24 @@ trait Booster extends Serializable { * @return predict result */ @throws(classOf[XGBoostError]) - def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] + def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0) + : Array[Array[Float]] = { + booster.predict(data.jDMatrix, outPutMargin, treeLimit) + } /** - * Predict with data + * Predict the leaf indices * * @param data dmatrix storing the input * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index of - * each sample in each tree. Note that the leaf index of a tree is unique per - * tree, so you may find leaf 1 in both tree 1 and tree 0. * @return predict result * @throws XGBoostError native error */ @throws(classOf[XGBoostError]) - def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] + def predictLeaf(data: DMatrix, treeLimit: Int = 0) + : Array[Array[Float]] = { + booster.predictLeaf(data.jDMatrix, treeLimit) + } /** * save model to modelPath @@ -142,46 +143,50 @@ trait Booster extends Serializable { * @param modelPath model path */ @throws(classOf[XGBoostError]) - def saveModel(modelPath: String) - + def saveModel(modelPath: String): Unit = { + booster.saveModel(modelPath) + } /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool Controls whether the split statistics are output. - */ - @throws(classOf[IOException]) + * save model to Output stream + * + * @param out Output stream + */ @throws(classOf[XGBoostError]) - def dumpModel(modelPath: String, withStats: Boolean) - + def saveModel(out: java.io.OutputStream): Unit = { + booster.saveModel(out) + } /** - * Dump model into a text file. + * Dump model as Array of string * - * @param modelPath file to save dumped model info * @param featureMap featureMap file * @param withStats bool * Controls whether the split statistics are output. */ - @throws(classOf[IOException]) @throws(classOf[XGBoostError]) - def dumpModel(modelPath: String, featureMap: String, withStats: Boolean) + def getModelDump(featureMap: String = null, withStats: Boolean = false) + : Array[String] = { + booster.getModelDump(featureMap, withStats) + } /** - * get importance of each feature + * Get importance of each feature * * @return featureMap key: feature index, value: feature importance score */ @throws(classOf[XGBoostError]) - def getFeatureScore: mutable.Map[String, Integer] + def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = { + booster.getFeatureScore(featureMap).asScala + } /** - * get importance of each feature - * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - */ - @throws(classOf[XGBoostError]) - def getFeatureScore(featureMap: String): mutable.Map[String, Integer] + * Dispose the booster when it is no longer needed + */ + def dispose: Unit = { + booster.dispose() + } - def dispose + override def finalize(): Unit = { + super.finalize() + dispose + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala deleted file mode 100644 index bdb2fb34c..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala - -import ml.dmlc.xgboost4j.java -import scala.collection.JavaConverters._ -import scala.collection.mutable - -private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster { - - override def setParam(key: String, value: String): Unit = { - booster.setParam(key, value) - } - - override def update(dtrain: DMatrix, iter: Int): Unit = { - booster.update(dtrain.jDMatrix, iter) - } - - override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { - booster.update(dtrain.jDMatrix, obj) - } - - override def dumpModel(modelPath: String, withStats: Boolean): Unit = { - booster.dumpModel(modelPath, withStats) - } - - override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = { - booster.dumpModel(modelPath, featureMap, withStats) - } - - override def setParams(params: Map[String, AnyRef]): Unit = { - booster.setParams(params.asJava) - } - - override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = { - booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) - } - - override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): - String = { - booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) - } - - override def dispose: Unit = { - booster.dispose() - } - - override def predict(data: DMatrix): Array[Array[Float]] = { - booster.predict(data.jDMatrix) - } - - override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = { - booster.predict(data.jDMatrix, outPutMargin) - } - - override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): - Array[Array[Float]] = { - booster.predict(data.jDMatrix, outPutMargin, treeLimit) - } - - override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = { - booster.predict(data.jDMatrix, treeLimit, predLeaf) - } - - override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { - booster.boost(dtrain.jDMatrix, grad, hess) - } - - override def getFeatureScore: mutable.Map[String, Integer] = { - booster.getFeatureScore.asScala - } - - override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = { - booster.getFeatureScore(featureMap).asScala - } - - override def saveModel(modelPath: String): Unit = { - booster.saveModel(modelPath) - } - - override def finalize(): Unit = { - super.finalize() - dispose - } -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 58ed51527..63627ed8d 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -16,11 +16,28 @@ package ml.dmlc.xgboost4j.scala -import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost} +import java.io.InputStream + +import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError} import scala.collection.JavaConverters._ +/** + * XGBoost Scala Training function. + */ object XGBoost { - + /** + * Train a booster given parameters. + * + * @param params Parameters. + * @param dtrain Data to be trained. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param obj customized objective + * @param eval customized evaluation + * @return The trained booster. + */ + @throws(classOf[XGBoostError]) def train( params: Map[String, AnyRef], dtrain: DMatrix, @@ -31,9 +48,22 @@ object XGBoost { val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, obj, eval) - new ScalaBoosterImpl(xgboostInJava) + new Booster(xgboostInJava) } + /** + * Cross-validation with given parameters. + * + * @param params Booster params. + * @param data Data to be trained. + * @param round Number of boosting iterations. + * @param nfold Number of folds in CV. + * @param metrics Evaluation metrics to be watched in CV. + * @param obj customized objective + * @param eval customized evaluation + * @return evaluation history + */ + @throws(classOf[XGBoostError]) def crossValidation( params: Map[String, AnyRef], data: DMatrix, @@ -45,13 +75,28 @@ object XGBoost { JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval) } - def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { - val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix)) - new ScalaBoosterImpl(xgboostInJava) + /** + * load model from modelPath + * + * @param modelPath booster modelPath + */ + @throws(classOf[XGBoostError]) + def loadModel(modelPath: String): Booster = { + val xgboostInJava = JXGBoost.loadModel(modelPath) + new Booster(xgboostInJava) } - def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = { - val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath) - new ScalaBoosterImpl(xgboostInJava) + /** + * Load a new Booster model from a file opened as input stream. + * The assumption is the input stream only contains one XGBoost Model. + * This can be used to load existing booster models saved by other XGBoost bindings. + * + * @param in The input stream of the file. + * @return The create booster + */ + @throws(classOf[XGBoostError]) + def loadModel(in: InputStream): Booster = { + val xgboostInJava = JXGBoost.loadModel(in) + new Booster(xgboostInJava) } } diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 1d81b5d6b..239fb91b8 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -15,6 +15,10 @@ */ package ml.dmlc.xgboost4j.java; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -67,7 +71,7 @@ public class BoosterImplTest { } @Test - public void testBoosterBasic() throws XGBoostError { + public void testBoosterBasic() throws XGBoostError, IOException { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); @@ -94,15 +98,20 @@ public class BoosterImplTest { Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null); //predict raw output - float[][] predicts = booster.predict(testMat, true); + float[][] predicts = booster.predict(testMat, true, 0); //eval IEvaluation eval = new EvalError(); //error must be less than 0.1 TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); - //test dump model + // save and load + File temp = File.createTempFile("temp", "model"); + temp.deleteOnExit(); + booster.saveModel(temp.getAbsolutePath()); + Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath())); + assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray())); } /**