[JVM] Make JVM Serializable
This commit is contained in:
parent
e80d3db64b
commit
0df2ed80c8
@ -1,9 +1,10 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public interface Booster {
|
public interface Booster extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameter
|
* set parameter
|
||||||
@ -109,12 +110,25 @@ public interface Booster {
|
|||||||
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* save model to modelPath
|
* save model to modelPath, the model path support depends on the path support
|
||||||
*
|
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
|
||||||
|
* compiled with HDFS support.
|
||||||
|
* See also toByteArray
|
||||||
* @param modelPath model path
|
* @param modelPath model path
|
||||||
*/
|
*/
|
||||||
void saveModel(String modelPath) throws XGBoostError;
|
void saveModel(String modelPath) throws XGBoostError;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as byte array representation.
|
||||||
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
|
*
|
||||||
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||||
|
*
|
||||||
|
* @return the saved byte array.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
byte[] toByteArray() throws XGBoostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dump model into a text file.
|
* Dump model into a text file.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
|
|||||||
setParams(params);
|
setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* load model from modelPath
|
* load model from modelPath
|
||||||
*
|
*
|
||||||
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
|
|||||||
return featureScore;
|
return featureScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as byte array representation.
|
||||||
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
|
*
|
||||||
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||||
|
*
|
||||||
|
* @return the saved byte array.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
|
byte[][] bytes = new byte[1][];
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||||
|
return bytes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the booster model from thread-local rabit checkpoint.
|
* Load the booster model from thread-local rabit checkpoint.
|
||||||
* This is only used in distributed training.
|
* This is only used in distributed training.
|
||||||
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
|
|||||||
return handles;
|
return handles;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// making Booster serializable
|
||||||
|
private void writeObject(java.io.ObjectOutputStream out)
|
||||||
|
throws IOException {
|
||||||
|
try {
|
||||||
|
out.writeObject(this.toByteArray());
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
throw new IOException(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readObject(java.io.ObjectInputStream in)
|
||||||
|
throws IOException, ClassNotFoundException {
|
||||||
|
try {
|
||||||
|
this.init(null);
|
||||||
|
byte[] bytes = (byte[])in.readObject();
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
throw new IOException(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finalize() throws Throwable {
|
protected void finalize() throws Throwable {
|
||||||
super.finalize();
|
super.finalize();
|
||||||
|
|||||||
@ -23,7 +23,7 @@ import scala.collection.mutable
|
|||||||
import ml.dmlc.xgboost4j.XGBoostError
|
import ml.dmlc.xgboost4j.XGBoostError
|
||||||
|
|
||||||
|
|
||||||
trait Booster {
|
trait Booster extends Serializable {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -19,5 +19,5 @@ typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
|
|||||||
GlobalRandomEngine& GlobalRandom() {
|
GlobalRandomEngine& GlobalRandom() {
|
||||||
return RandomThreadLocalStore::Get()->engine;
|
return RandomThreadLocalStore::Get()->engine;
|
||||||
}
|
}
|
||||||
}
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user