[jvm] Add ability to load booster direct from byte array (#6655)

* Add ability to load booster direct from byte array

* fix compiler error

* move InputStream to byte-buffer conversion

- move it from Booster to XGBoost facade class
This commit is contained in:
Honza Sterba 2021-02-23 20:28:27 +01:00 committed by GitHub
parent 872e559b91
commit 17913713b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 23 deletions

View File

@ -62,31 +62,23 @@ public class Booster implements Serializable, KryoSerializable {
if (modelPath == null) { if (modelPath == null) {
throw new NullPointerException("modelPath : null"); throw new NullPointerException("modelPath : null");
} }
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]); Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath)); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
return ret; return ret;
} }
/** /**
* Load a new Booster model from a file opened as input stream. * Load a new Booster model from a byte array buffer.
* The assumption is the input stream only contains one XGBoost Model. * The assumption is the array only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings. * This can be used to load existing booster models saved by other xgboost bindings.
* *
* @param in The input stream of the file. * @param buffer The byte contents of the booster.
* @return The create boosted * @return The created boosted
* @throws XGBoostError * @throws XGBoostError
* @throws IOException
*/ */
static Booster loadModel(InputStream in) throws XGBoostError, IOException { static Booster loadModel(byte[] buffer) throws XGBoostError {
int size; Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
byte[] buf = new byte[1<<20]; XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, buffer));
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
return ret; return ret;
} }

View File

@ -15,10 +15,7 @@
*/ */
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.File; import java.io.*;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*; import java.util.*;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -56,9 +53,28 @@ public class XGBoost {
* @throws XGBoostError * @throws XGBoostError
* @throws IOException * @throws IOException
*/ */
public static Booster loadModel(InputStream in) public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
throws XGBoostError, IOException { int size;
return Booster.loadModel(in); byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
return Booster.loadModel(buf);
}
/**
* Load a new Booster model from a byte array buffer.
* The assumption is the array only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param buffer The byte contents of the booster.
* @return The create boosted
* @throws XGBoostError
*/
public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException {
return Booster.loadModel(buffer);
} }
/** /**