[jvm] Add ability to load booster direct from byte array (#6655)
* Add ability to load booster direct from byte array * fix compiler error * move InputStream to byte-buffer conversion - move it from Booster to XGBoost facade class
This commit is contained in:
parent
872e559b91
commit
17913713b5
@ -62,31 +62,23 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
if (modelPath == null) {
|
||||
throw new NullPointerException("modelPath : null");
|
||||
}
|
||||
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
||||
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a new Booster model from a file opened as input stream.
|
||||
* The assumption is the input stream only contains one XGBoost Model.
|
||||
* Load a new Booster model from a byte array buffer.
|
||||
* The assumption is the array only contains one XGBoost Model.
|
||||
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||
*
|
||||
* @param in The input stream of the file.
|
||||
* @return The create boosted
|
||||
* @param buffer The byte contents of the booster.
|
||||
* @return The created boosted
|
||||
* @throws XGBoostError
|
||||
* @throws IOException
|
||||
*/
|
||||
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
|
||||
int size;
|
||||
byte[] buf = new byte[1<<20];
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
while ((size = in.read(buf)) != -1) {
|
||||
os.write(buf, 0, size);
|
||||
}
|
||||
in.close();
|
||||
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
|
||||
static Booster loadModel(byte[] buffer) throws XGBoostError {
|
||||
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, buffer));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@ -15,10 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
@ -56,9 +53,28 @@ public class XGBoost {
|
||||
* @throws XGBoostError
|
||||
* @throws IOException
|
||||
*/
|
||||
public static Booster loadModel(InputStream in)
|
||||
throws XGBoostError, IOException {
|
||||
return Booster.loadModel(in);
|
||||
public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
|
||||
int size;
|
||||
byte[] buf = new byte[1<<20];
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
while ((size = in.read(buf)) != -1) {
|
||||
os.write(buf, 0, size);
|
||||
}
|
||||
in.close();
|
||||
return Booster.loadModel(buf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a new Booster model from a byte array buffer.
|
||||
* The assumption is the array only contains one XGBoost Model.
|
||||
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||
*
|
||||
* @param buffer The byte contents of the booster.
|
||||
* @return The create boosted
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException {
|
||||
return Booster.loadModel(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user