[jvm-packages] Implement new save_raw in jvm-packages. (#7570)

* New `toByteArray` that accepts a parameter for format.
This commit is contained in:
Jiaming Yuan 2022-01-19 16:00:14 +08:00 committed by GitHub
parent b4ec1682c6
commit ac7a36367c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 23 deletions

View File

@ -683,5 +683,3 @@ private object Watches {
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName) new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014,2021 by Contributors Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -636,13 +636,27 @@ public class Booster implements Serializable, KryoSerializable {
} }
/** /**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
* *
* @return the saved byte array. * @return the saved byte array
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public byte[] toByteArray() throws XGBoostError { public byte[] toByteArray() throws XGBoostError {
return this.toByteArray("deprecated");
}
/**
* Save model into raw byte array.
*
* @param format The output format. Available options are "json", "ubj" and "deprecated".
*
* @return the saved byte array
* @throws XGBoostError native error
*/
public byte[] toByteArray(String format) throws XGBoostError {
byte[][] bytes = new byte[1][]; byte[][] bytes = new byte[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModelToBuffer(this.handle, format, bytes));
return bytes[0]; return bytes[0];
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014 by Contributors Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -107,7 +107,7 @@ class XGBoostJNI {
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes); public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes); public final static native int XGBoosterSaveModelToBuffer(long handle, String format, byte[][] out_bytes);
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
String format, String[][] out_strings); String format, String[][] out_strings);

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014 by Contributors Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -301,6 +301,19 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def getVersion: Int = booster.getVersion def getVersion: Int = booster.getVersion
/**
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
*/
@throws(classOf[XGBoostError])
def toByteArray(format: String): Array[Byte] = {
booster.toByteArray(format)
}
/**
* Save model into a raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
*/
@throws(classOf[XGBoostError])
def toByteArray: Array[Byte] = { def toByteArray: Array[Byte] = {
booster.toByteArray booster.toByteArray
} }

View File

@ -662,20 +662,25 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterSaveModelToBuffer
* Signature: (J[[B)I * Signature: (JLjava/lang/String;[[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jstring jformat, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char *format = jenv->GetStringUTFChars(jformat, 0);
bst_ulong len = 0; bst_ulong len = 0;
const char* result; const char *result{nullptr};
int ret = XGBoosterGetModelRaw(handle, &len, &result); xgboost::Json config {xgboost::Object{}};
JVM_CHECK_CALL(ret); config["format"] = std::string{format};
std::string config_str;
xgboost::Json::Dump(config, &config_str);
int ret = XGBoosterSaveModelToBuffer(handle, config_str.c_str(), &len, &result);
JVM_CHECK_CALL(ret);
if (result) { if (result) {
jbyteArray jarray = jenv->NewByteArray(len); jbyteArray jarray = jenv->NewByteArray(len);
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result); jenv->SetByteArrayRegion(jarray, 0, len, (jbyte *)result);
jenv->SetObjectArrayElement(jout, 0, jarray); jenv->SetObjectArrayElement(jout, 0, jarray);
} }
return ret; return ret;

View File

@ -209,11 +209,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterSaveModelToBuffer
* Signature: (J[[B)I * Signature: (JLjava/lang/String;[[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer
(JNIEnv *, jclass, jlong, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jobjectArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2021 by Contributors Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -115,7 +115,9 @@ public class BoosterImplTest {
booster.saveModel(temp.getAbsolutePath()); booster.saveModel(temp.getAbsolutePath());
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray())); assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
float[][] predicts2 = bst2.predict(testMat, true, 0); float[][] predicts2 = bst2.predict(testMat, true, 0);
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014 by Contributors Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -117,6 +117,7 @@ class ScalaBoosterImplSuite extends FunSuite {
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath) val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray)) assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
assert(java.util.Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")))
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0) val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f) TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
} }