support kryo serialization
This commit is contained in:
@@ -20,13 +20,17 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo;
|
||||
import com.esotericsoftware.kryo.KryoSerializable;
|
||||
import com.esotericsoftware.kryo.io.Input;
|
||||
import com.esotericsoftware.kryo.io.Output;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||
*/
|
||||
public class Booster implements Serializable {
|
||||
public class Booster implements Serializable, KryoSerializable {
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
@@ -436,7 +440,8 @@ public class Booster implements Serializable {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -447,7 +452,8 @@ public class Booster implements Serializable {
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,4 +469,33 @@ public class Booster implements Serializable {
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(Kryo kryo, Output output) {
|
||||
try {
|
||||
byte[] serObj = this.toByteArray();
|
||||
int serObjSize = serObj.length;
|
||||
System.out.println("==== serialized obj size " + serObjSize);
|
||||
output.writeInt(serObjSize);
|
||||
output.write(serObj);
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void read(Kryo kryo, Input input) {
|
||||
try {
|
||||
this.init(null);
|
||||
int serObjSize = input.readInt();
|
||||
System.out.println("==== the size of the object: " + serObjSize);
|
||||
byte[] bytes = new byte[serObjSize];
|
||||
input.readBytes(bytes);
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,12 +18,15 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import com.esotericsoftware.kryo.io.{Output, Input}
|
||||
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||
class Booster private[xgboost4j](private var booster: JBooster)
|
||||
extends Serializable with KryoSerializable {
|
||||
|
||||
/**
|
||||
* Set parameter to the Booster.
|
||||
@@ -193,4 +196,12 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
|
||||
override def write(kryo: Kryo, output: Output): Unit = {
|
||||
kryo.writeObject(output, booster)
|
||||
}
|
||||
|
||||
override def read(kryo: Kryo, input: Input): Unit = {
|
||||
booster = kryo.readObject(input, classOf[JBooster])
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user