[JVM] Make JVM Serializable

This commit is contained in:
tqchen 2016-03-03 21:04:02 -08:00
parent e80d3db64b
commit 0df2ed80c8
4 changed files with 56 additions and 6 deletions

View File

@ -1,9 +1,10 @@
package ml.dmlc.xgboost4j;
import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
public interface Booster {
public interface Booster extends Serializable {
/**
* set parameter
@ -109,12 +110,25 @@ public interface Booster {
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
/**
* save model to modelPath
*
* save model to modelPath, the model path support depends on the path support
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
* compiled with HDFS support.
* See also toByteArray
* @param modelPath model path
*/
void saveModel(String modelPath) throws XGBoostError;
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
byte[] toByteArray() throws XGBoostError;
/**
* Dump model into a text file.
*

View File

@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
setParams(params);
}
/**
* load model from modelPath
*
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
return featureScore;
}
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
return handles;
}
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out)
throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();

View File

@ -23,7 +23,7 @@ import scala.collection.mutable
import ml.dmlc.xgboost4j.XGBoostError
trait Booster {
trait Booster extends Serializable {
/**

View File

@ -19,5 +19,5 @@ typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
GlobalRandomEngine& GlobalRandom() {
return RandomThreadLocalStore::Get()->engine;
}
}
} // namespace common
} // namespace xgboost