[JVM] Make JVM Serializable

This commit is contained in:
tqchen 2016-03-03 21:04:02 -08:00
parent e80d3db64b
commit 0df2ed80c8
4 changed files with 56 additions and 6 deletions

View File

@ -1,9 +1,10 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable;
import java.util.Map; import java.util.Map;
public interface Booster { public interface Booster extends Serializable {
/** /**
* set parameter * set parameter
@ -109,12 +110,25 @@ public interface Booster {
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
/** /**
* save model to modelPath * save model to modelPath, the model path support depends on the path support
* * in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
* compiled with HDFS support.
* See also toByteArray
* @param modelPath model path * @param modelPath model path
*/ */
void saveModel(String modelPath) throws XGBoostError; void saveModel(String modelPath) throws XGBoostError;
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
byte[] toByteArray() throws XGBoostError;
/** /**
* Dump model into a text file. * Dump model into a text file.
* *

View File

@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
setParams(params); setParams(params);
} }
/** /**
* load model from modelPath * load model from modelPath
* *
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/** /**
* Load the booster model from thread-local rabit checkpoint. * Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training. * This is only used in distributed training.
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
return handles; return handles;
} }
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out)
throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override @Override
protected void finalize() throws Throwable { protected void finalize() throws Throwable {
super.finalize(); super.finalize();

View File

@ -23,7 +23,7 @@ import scala.collection.mutable
import ml.dmlc.xgboost4j.XGBoostError import ml.dmlc.xgboost4j.XGBoostError
trait Booster { trait Booster extends Serializable {
/** /**

View File

@ -19,5 +19,5 @@ typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
GlobalRandomEngine& GlobalRandom() { GlobalRandomEngine& GlobalRandom() {
return RandomThreadLocalStore::Get()->engine; return RandomThreadLocalStore::Get()->engine;
} }
} } // namespace common
} // namespace xgboost } // namespace xgboost