diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 20fb6f75c..7c69d7786 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -62,31 +62,23 @@ public class Booster implements Serializable, KryoSerializable { if (modelPath == null) { throw new NullPointerException("modelPath : null"); } - Booster ret = new Booster(new HashMap(), new DMatrix[0]); + Booster ret = new Booster(new HashMap<>(), new DMatrix[0]); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath)); return ret; } /** - * Load a new Booster model from a file opened as input stream. - * The assumption is the input stream only contains one XGBoost Model. + * Load a new Booster model from a byte array buffer. + * The assumption is the array only contains one XGBoost Model. * This can be used to load existing booster models saved by other xgboost bindings. * - * @param in The input stream of the file. - * @return The create boosted + * @param buffer The byte contents of the booster. + * @return The created boosted * @throws XGBoostError - * @throws IOException */ - static Booster loadModel(InputStream in) throws XGBoostError, IOException { - int size; - byte[] buf = new byte[1<<20]; - ByteArrayOutputStream os = new ByteArrayOutputStream(); - while ((size = in.read(buf)) != -1) { - os.write(buf, 0, size); - } - in.close(); - Booster ret = new Booster(new HashMap(), new DMatrix[0]); - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray())); + static Booster loadModel(byte[] buffer) throws XGBoostError { + Booster ret = new Booster(new HashMap<>(), new DMatrix[0]); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, buffer)); return ret; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 8adf4c0ae..4a84f29af 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -15,10 +15,7 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import java.io.*; import java.util.*; import org.apache.commons.logging.Log; @@ -56,9 +53,28 @@ public class XGBoost { * @throws XGBoostError * @throws IOException */ - public static Booster loadModel(InputStream in) - throws XGBoostError, IOException { - return Booster.loadModel(in); + public static Booster loadModel(InputStream in) throws XGBoostError, IOException { + int size; + byte[] buf = new byte[1<<20]; + ByteArrayOutputStream os = new ByteArrayOutputStream(); + while ((size = in.read(buf)) != -1) { + os.write(buf, 0, size); + } + in.close(); + return Booster.loadModel(buf); + } + + /** + * Load a new Booster model from a byte array buffer. + * The assumption is the array only contains one XGBoost Model. + * This can be used to load existing booster models saved by other xgboost bindings. + * + * @param buffer The byte contents of the booster. + * @return The create boosted + * @throws XGBoostError + */ + public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException { + return Booster.loadModel(buffer); } /**