support kryo serialization

This commit is contained in:
CodingCat
2016-03-13 11:55:14 -04:00
parent 9011acf52b
commit f2ef958ebb
6 changed files with 88 additions and 10 deletions

View File

@@ -20,13 +20,17 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable {
public class Booster implements Serializable, KryoSerializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
@@ -436,7 +440,8 @@ public class Booster implements Serializable {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@@ -447,7 +452,8 @@ public class Booster implements Serializable {
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@@ -463,4 +469,33 @@ public class Booster implements Serializable {
handle = 0;
}
}
@Override
public void write(Kryo kryo, Output output) {
try {
byte[] serObj = this.toByteArray();
int serObjSize = serObj.length;
System.out.println("==== serialized obj size " + serObjSize);
output.writeInt(serObjSize);
output.write(serObj);
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@Override
public void read(Kryo kryo, Input input) {
try {
this.init(null);
int serObjSize = input.readInt();
System.out.println("==== the size of the object: " + serObjSize);
byte[] bytes = new byte[serObjSize];
input.readBytes(bytes);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
}

View File

@@ -18,12 +18,15 @@ package ml.dmlc.xgboost4j.scala
import java.io.IOException
import com.esotericsoftware.kryo.io.{Output, Input}
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
class Booster private[xgboost4j](private var booster: JBooster)
extends Serializable with KryoSerializable {
/**
* Set parameter to the Booster.
@@ -193,4 +196,12 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
super.finalize()
dispose
}
override def write(kryo: Kryo, output: Output): Unit = {
kryo.writeObject(output, booster)
}
override def read(kryo: Kryo, input: Input): Unit = {
booster = kryo.readObject(input, classOf[JBooster])
}
}