[jvm-packages] Implement new save_raw in jvm-packages. (#7570)
* New `toByteArray` that accepts a parameter for format.
This commit is contained in:
parent
b4ec1682c6
commit
ac7a36367c
@ -683,5 +683,3 @@ private object Watches {
|
|||||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -636,13 +636,27 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||||
|
* default, which will be changed into `ubj` in future releases.
|
||||||
*
|
*
|
||||||
* @return the saved byte array.
|
* @return the saved byte array
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
|
return this.toByteArray("deprecated");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save model into raw byte array.
|
||||||
|
*
|
||||||
|
* @param format The output format. Available options are "json", "ubj" and "deprecated".
|
||||||
|
*
|
||||||
|
* @return the saved byte array
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public byte[] toByteArray(String format) throws XGBoostError {
|
||||||
byte[][] bytes = new byte[1][];
|
byte[][] bytes = new byte[1][];
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModelToBuffer(this.handle, format, bytes));
|
||||||
return bytes[0];
|
return bytes[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -107,7 +107,7 @@ class XGBoostJNI {
|
|||||||
|
|
||||||
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
|
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
|
||||||
|
|
||||||
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
|
public final static native int XGBoosterSaveModelToBuffer(long handle, String format, byte[][] out_bytes);
|
||||||
|
|
||||||
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
|
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
|
||||||
String format, String[][] out_strings);
|
String format, String[][] out_strings);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -301,6 +301,19 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
|
|
||||||
def getVersion: Int = booster.getVersion
|
def getVersion: Int = booster.getVersion
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def toByteArray(format: String): Array[Byte] = {
|
||||||
|
booster.toByteArray(format)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save model into a raw byte array. Currently it's using the deprecated format as
|
||||||
|
* default, which will be changed into `ubj` in future releases.
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
def toByteArray: Array[Byte] = {
|
def toByteArray: Array[Byte] = {
|
||||||
booster.toByteArray
|
booster.toByteArray
|
||||||
}
|
}
|
||||||
|
|||||||
@ -662,17 +662,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterSaveModelToBuffer
|
||||||
* Signature: (J[[B)I
|
* Signature: (JLjava/lang/String;[[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jstring jformat, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
|
const char *format = jenv->GetStringUTFChars(jformat, 0);
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
const char* result;
|
const char *result{nullptr};
|
||||||
int ret = XGBoosterGetModelRaw(handle, &len, &result);
|
xgboost::Json config {xgboost::Object{}};
|
||||||
JVM_CHECK_CALL(ret);
|
config["format"] = std::string{format};
|
||||||
|
std::string config_str;
|
||||||
|
xgboost::Json::Dump(config, &config_str);
|
||||||
|
|
||||||
|
int ret = XGBoosterSaveModelToBuffer(handle, config_str.c_str(), &len, &result);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
if (result) {
|
if (result) {
|
||||||
jbyteArray jarray = jenv->NewByteArray(len);
|
jbyteArray jarray = jenv->NewByteArray(len);
|
||||||
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte *)result);
|
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte *)result);
|
||||||
|
|||||||
@ -209,11 +209,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterSaveModelToBuffer
|
||||||
* Signature: (J[[B)I
|
* Signature: (JLjava/lang/String;[[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer
|
||||||
(JNIEnv *, jclass, jlong, jobjectArray);
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -115,7 +115,9 @@ public class BoosterImplTest {
|
|||||||
booster.saveModel(temp.getAbsolutePath());
|
booster.saveModel(temp.getAbsolutePath());
|
||||||
|
|
||||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||||
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
|
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||||
|
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
|
||||||
|
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
|
||||||
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -117,6 +117,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
|
|
||||||
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
|
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
|
||||||
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
|
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
|
||||||
|
assert(java.util.Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")))
|
||||||
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
|
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
|
||||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
|
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user