[jvm-packages] add format option when saving a model (#7940)

This commit is contained in:
Bobby Wang
2022-05-30 15:49:59 +08:00
committed by GitHub
parent cc6d57aa0d
commit 6275cdc486
8 changed files with 153 additions and 30 deletions

View File

@@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable, KryoSerializable {
public static final String DEFAULT_FORMAT = "deprecated";
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
@@ -391,7 +392,22 @@ public class Booster implements Serializable, KryoSerializable {
* @param out The output stream
*/
public void saveModel(OutputStream out) throws XGBoostError, IOException {
out.write(this.toByteArray());
saveModel(out, DEFAULT_FORMAT);
}
/**
* Save the model to file opened as output stream.
* The model format is compatible with other xgboost bindings.
* The output stream can only save one xgboost model.
* This function will close the OutputStream after the save.
*
* @param out The output stream
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError
* @throws IOException
*/
public void saveModel(OutputStream out, String format) throws XGBoostError, IOException {
out.write(this.toByteArray(format));
out.close();
}
@@ -643,7 +659,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {
return this.toByteArray("deprecated");
return this.toByteArray(DEFAULT_FORMAT);
}
/**

View File

@@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
/**
* save model to Output stream
*
@@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(out: java.io.OutputStream): Unit = {
booster.saveModel(out)
}
/**
* save model to Output stream
* @param out output stream
* @param format the supported model format, (json, ubj, deprecated)
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def saveModel(out: java.io.OutputStream, format: String): Unit = {
booster.saveModel(out, format)
}
/**
* Dump model as Array of string
*
@@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
*/
@throws(classOf[XGBoostError])
def toByteArray: Array[Byte] = {
booster.toByteArray
booster.toByteArray()
}
/**