[JVM] Make JVM Serializable
This commit is contained in:
parent
e80d3db64b
commit
0df2ed80c8
@ -1,9 +1,10 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
public interface Booster {
|
||||
public interface Booster extends Serializable {
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
@ -109,12 +110,25 @@ public interface Booster {
|
||||
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
*
|
||||
* save model to modelPath, the model path support depends on the path support
|
||||
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
|
||||
* compiled with HDFS support.
|
||||
* See also toByteArray
|
||||
* @param modelPath model path
|
||||
*/
|
||||
void saveModel(String modelPath) throws XGBoostError;
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||
*
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
byte[] toByteArray() throws XGBoostError;
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
|
||||
@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||
*
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
|
||||
return handles;
|
||||
}
|
||||
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out)
|
||||
throws IOException {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
private void readObject(java.io.ObjectInputStream in)
|
||||
throws IOException, ClassNotFoundException {
|
||||
try {
|
||||
this.init(null);
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
super.finalize();
|
||||
|
||||
@ -23,7 +23,7 @@ import scala.collection.mutable
|
||||
import ml.dmlc.xgboost4j.XGBoostError
|
||||
|
||||
|
||||
trait Booster {
|
||||
trait Booster extends Serializable {
|
||||
|
||||
|
||||
/**
|
||||
|
||||
@ -19,5 +19,5 @@ typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
|
||||
GlobalRandomEngine& GlobalRandom() {
|
||||
return RandomThreadLocalStore::Get()->engine;
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user