[jvm] Add ability to load booster direct from byte array (#6655)

* Add ability to load booster direct from byte array

* fix compiler error

* move InputStream to byte-buffer conversion

- move it from Booster to XGBoost facade class
This commit is contained in:
Honza Sterba 2021-02-23 20:28:27 +01:00 committed by GitHub
parent 872e559b91
commit 17913713b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 23 deletions

View File

@ -62,31 +62,23 @@ public class Booster implements Serializable, KryoSerializable {
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
return ret;
}
/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* Load a new Booster model from a byte array buffer.
* The assumption is the array only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param in The input stream of the file.
* @return The create boosted
* @param buffer The byte contents of the booster.
* @return The created boosted
* @throws XGBoostError
* @throws IOException
*/
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
int size;
byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
static Booster loadModel(byte[] buffer) throws XGBoostError {
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, buffer));
return ret;
}

View File

@ -15,10 +15,7 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.*;
import java.util.*;
import org.apache.commons.logging.Log;
@ -56,9 +53,28 @@ public class XGBoost {
* @throws XGBoostError
* @throws IOException
*/
public static Booster loadModel(InputStream in)
throws XGBoostError, IOException {
return Booster.loadModel(in);
public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
int size;
byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
return Booster.loadModel(buf);
}
/**
* Load a new Booster model from a byte array buffer.
* The assumption is the array only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param buffer The byte contents of the booster.
* @return The create boosted
* @throws XGBoostError
*/
public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException {
return Booster.loadModel(buffer);
}
/**