diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d62badc61..69e606d9c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -683,5 +683,3 @@ private object Watches { new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName) } } - - diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 85b5c2602..f08435f3a 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -636,13 +636,27 @@ public class Booster implements Serializable, KryoSerializable { } /** + * Save model into raw byte array. Currently it's using the deprecated format as + * default, which will be changed into `ubj` in future releases. * - * @return the saved byte array. + * @return the saved byte array * @throws XGBoostError native error */ public byte[] toByteArray() throws XGBoostError { + return this.toByteArray("deprecated"); + } + + /** + * Save model into raw byte array. + * + * @param format The output format. Available options are "json", "ubj" and "deprecated". + * + * @return the saved byte array + * @throws XGBoostError native error + */ + public byte[] toByteArray(String format) throws XGBoostError { byte[][] bytes = new byte[1][]; - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModelToBuffer(this.handle, format, bytes)); return bytes[0]; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 6f3fd1768..22b3155fe 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -107,7 +107,7 @@ class XGBoostJNI { public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes); - public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes); + public final static native int XGBoosterSaveModelToBuffer(long handle, String format, byte[][] out_bytes); public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats, String format, String[][] out_strings); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index e442c4f75..88f5607d3 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -301,6 +301,19 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) def getVersion: Int = booster.getVersion + /** + * Save model into a raw byte array. Available options are "json", "ubj" and "deprecated". + */ + @throws(classOf[XGBoostError]) + def toByteArray(format: String): Array[Byte] = { + booster.toByteArray(format) + } + + /** + * Save model into a raw byte array. Currently it's using the deprecated format as + * default, which will be changed into `ubj` in future releases. + */ + @throws(classOf[XGBoostError]) def toByteArray: Array[Byte] = { booster.toByteArray } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index fa97fc67d..4fd6131ac 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -662,20 +662,25 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterGetModelRaw - * Signature: (J[[B)I + * Method: XGBoosterSaveModelToBuffer + * Signature: (JLjava/lang/String;[[B)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw - (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer + (JNIEnv * jenv, jclass jcls, jlong jhandle, jstring jformat, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; + const char *format = jenv->GetStringUTFChars(jformat, 0); bst_ulong len = 0; - const char* result; - int ret = XGBoosterGetModelRaw(handle, &len, &result); - JVM_CHECK_CALL(ret); + const char *result{nullptr}; + xgboost::Json config {xgboost::Object{}}; + config["format"] = std::string{format}; + std::string config_str; + xgboost::Json::Dump(config, &config_str); + int ret = XGBoosterSaveModelToBuffer(handle, config_str.c_str(), &len, &result); + JVM_CHECK_CALL(ret); if (result) { jbyteArray jarray = jenv->NewByteArray(len); - jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result); + jenv->SetByteArrayRegion(jarray, 0, len, (jbyte *)result); jenv->SetObjectArrayElement(jout, 0, jarray); } return ret; diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 28fa0a938..16ef166fc 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -209,11 +209,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterGetModelRaw - * Signature: (J[[B)I + * Method: XGBoosterSaveModelToBuffer + * Signature: (JLjava/lang/String;[[B)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw - (JNIEnv *, jclass, jlong, jobjectArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer + (JNIEnv *, jclass, jlong, jstring, jobjectArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index e4070ca79..cce1254d0 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -115,7 +115,9 @@ public class BoosterImplTest { booster.saveModel(temp.getAbsolutePath()); Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); - assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray())); + assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))); + assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json"))); + assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated"))); float[][] predicts2 = bst2.predict(testMat, true, 0); TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index ab4ff870b..157971f82 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -117,6 +117,7 @@ class ScalaBoosterImplSuite extends FunSuite { val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath) assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray)) + assert(java.util.Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))) val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0) TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f) }