[jvm] Add ability to load booster direct from byte array (#6655)
* Add ability to load booster direct from byte array * fix compiler error * move InputStream to byte-buffer conversion - move it from Booster to XGBoost facade class
This commit is contained in:
parent
872e559b91
commit
17913713b5
@ -62,31 +62,23 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
if (modelPath == null) {
|
if (modelPath == null) {
|
||||||
throw new NullPointerException("modelPath : null");
|
throw new NullPointerException("modelPath : null");
|
||||||
}
|
}
|
||||||
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a new Booster model from a file opened as input stream.
|
* Load a new Booster model from a byte array buffer.
|
||||||
* The assumption is the input stream only contains one XGBoost Model.
|
* The assumption is the array only contains one XGBoost Model.
|
||||||
* This can be used to load existing booster models saved by other xgboost bindings.
|
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||||
*
|
*
|
||||||
* @param in The input stream of the file.
|
* @param buffer The byte contents of the booster.
|
||||||
* @return The create boosted
|
* @return The created boosted
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
|
static Booster loadModel(byte[] buffer) throws XGBoostError {
|
||||||
int size;
|
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
|
||||||
byte[] buf = new byte[1<<20];
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, buffer));
|
||||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
|
||||||
while ((size = in.read(buf)) != -1) {
|
|
||||||
os.write(buf, 0, size);
|
|
||||||
}
|
|
||||||
in.close();
|
|
||||||
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
|
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
@ -56,9 +53,28 @@ public class XGBoost {
|
|||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public static Booster loadModel(InputStream in)
|
public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
|
||||||
throws XGBoostError, IOException {
|
int size;
|
||||||
return Booster.loadModel(in);
|
byte[] buf = new byte[1<<20];
|
||||||
|
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||||
|
while ((size = in.read(buf)) != -1) {
|
||||||
|
os.write(buf, 0, size);
|
||||||
|
}
|
||||||
|
in.close();
|
||||||
|
return Booster.loadModel(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a new Booster model from a byte array buffer.
|
||||||
|
* The assumption is the array only contains one XGBoost Model.
|
||||||
|
* This can be used to load existing booster models saved by other xgboost bindings.
|
||||||
|
*
|
||||||
|
* @param buffer The byte contents of the booster.
|
||||||
|
* @return The create boosted
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException {
|
||||||
|
return Booster.loadModel(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user