[jvm-packages] Implement new save_raw in jvm-packages. (#7570)
* New `toByteArray` that accepts a parameter for format.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -636,13 +636,27 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
*
|
||||
* @return the saved byte array.
|
||||
* @return the saved byte array
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
return this.toByteArray("deprecated");
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array.
|
||||
*
|
||||
* @param format The output format. Available options are "json", "ubj" and "deprecated".
|
||||
*
|
||||
* @return the saved byte array
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public byte[] toByteArray(String format) throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModelToBuffer(this.handle, format, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -107,7 +107,7 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
|
||||
|
||||
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
|
||||
public final static native int XGBoosterSaveModelToBuffer(long handle, String format, byte[][] out_bytes);
|
||||
|
||||
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
|
||||
String format, String[][] out_strings);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -301,6 +301,19 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def toByteArray(format: String): Array[Byte] = {
|
||||
booster.toByteArray(format)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def toByteArray: Array[Byte] = {
|
||||
booster.toByteArray
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user