[jvm-packages] add format option when saving a model (#7940)
This commit is contained in:
@@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||
*/
|
||||
public class Booster implements Serializable, KryoSerializable {
|
||||
public static final String DEFAULT_FORMAT = "deprecated";
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
@@ -391,7 +392,22 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @param out The output stream
|
||||
*/
|
||||
public void saveModel(OutputStream out) throws XGBoostError, IOException {
|
||||
out.write(this.toByteArray());
|
||||
saveModel(out, DEFAULT_FORMAT);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model to file opened as output stream.
|
||||
* The model format is compatible with other xgboost bindings.
|
||||
* The output stream can only save one xgboost model.
|
||||
* This function will close the OutputStream after the save.
|
||||
*
|
||||
* @param out The output stream
|
||||
* @param format The model format (ubj, json, deprecated)
|
||||
* @throws XGBoostError
|
||||
* @throws IOException
|
||||
*/
|
||||
public void saveModel(OutputStream out, String format) throws XGBoostError, IOException {
|
||||
out.write(this.toByteArray(format));
|
||||
out.close();
|
||||
}
|
||||
|
||||
@@ -643,7 +659,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
return this.toByteArray("deprecated");
|
||||
return this.toByteArray(DEFAULT_FORMAT);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to Output stream
|
||||
*
|
||||
@@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
def saveModel(out: java.io.OutputStream): Unit = {
|
||||
booster.saveModel(out)
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to Output stream
|
||||
* @param out output stream
|
||||
* @param format the supported model format, (json, ubj, deprecated)
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def saveModel(out: java.io.OutputStream, format: String): Unit = {
|
||||
booster.saveModel(out, format)
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model as Array of string
|
||||
*
|
||||
@@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def toByteArray: Array[Byte] = {
|
||||
booster.toByteArray
|
||||
booster.toByteArray()
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user