diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java index e234fef60..0707cff2d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java @@ -1,9 +1,10 @@ package ml.dmlc.xgboost4j; import java.io.IOException; +import java.io.Serializable; import java.util.Map; -public interface Booster { +public interface Booster extends Serializable { /** * set parameter @@ -109,12 +110,25 @@ public interface Booster { float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; /** - * save model to modelPath - * + * save model to modelPath, the model path support depends on the path support + * in libxgboost. For example, if we want to save to hdfs, libxgboost need to be + * compiled with HDFS support. + * See also toByteArray * @param modelPath model path */ void saveModel(String modelPath) throws XGBoostError; + /** + * Save the model as byte array representation. + * Write these bytes to a file will give compatible format with other xgboost bindings. + * + * If java natively support HDFS file API, use toByteArray and write the ByteArray, + * + * @return the saved byte array. + * @throws XGBoostError + */ + byte[] toByteArray() throws XGBoostError; + /** * Dump model into a text file. * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java index 76a9195a0..ae265a36d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java @@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster { setParams(params); } - /** * load model from modelPath * @@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster { return featureScore; } + /** + * Save the model as byte array representation. + * Write these bytes to a file will give compatible format with other xgboost bindings. + * + * If java natively support HDFS file API, use toByteArray and write the ByteArray, + * + * @return the saved byte array. + * @throws XGBoostError + */ + public byte[] toByteArray() throws XGBoostError { + byte[][] bytes = new byte[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes)); + return bytes[0]; + } + + /** * Load the booster model from thread-local rabit checkpoint. * This is only used in distributed training. @@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster { return handles; } + // making Booster serializable + private void writeObject(java.io.ObjectOutputStream out) + throws IOException { + try { + out.writeObject(this.toByteArray()); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } + + private void readObject(java.io.ObjectInputStream in) + throws IOException, ClassNotFoundException { + try { + this.init(null); + byte[] bytes = (byte[])in.readObject(); + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } + @Override protected void finalize() throws Throwable { super.finalize(); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 5d5cd5619..524a5aa92 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import ml.dmlc.xgboost4j.XGBoostError -trait Booster { +trait Booster extends Serializable { /** diff --git a/src/common/common.cc b/src/common/common.cc index 2010e9ee4..43a23853e 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -19,5 +19,5 @@ typedef ThreadLocalStore RandomThreadLocalStore; GlobalRandomEngine& GlobalRandom() { return RandomThreadLocalStore::Get()->engine; } -} +} // namespace common } // namespace xgboost