Merge pull request #979 from CodingCat/kryo

[jvm-packages] support kryo serialization
This commit is contained in:
Tianqi Chen 2016-03-13 11:25:01 -07:00
commit 5fb09dc0ab
6 changed files with 88 additions and 10 deletions

View File

@ -56,7 +56,12 @@ object DistTrainWithSpark {
"usage: program num_of_rounds training_path model_path")
sys.exit(1)
}
val sc = new SparkContext()
// if you do not want to use KryoSerializer in Spark, you can ignore the related configuration
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
val sc = new SparkContext(sparkConf)
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations

View File

@ -170,6 +170,11 @@
</plugins>
</build>
<dependencies>
<dependency>
<groupId>com.esotericsoftware.kryo</groupId>
<artifactId>kryo</artifactId>
<version>2.21</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>

View File

@ -16,9 +16,9 @@
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
import org.apache.spark.SparkContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils
object DistTrainWithSpark {
@ -28,7 +28,10 @@ object DistTrainWithSpark {
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
val sc = new SparkContext()
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
val inputTrainPath = args(2)
val inputTestPath = args(3)
val outputModelPath = args(4)

View File

@ -29,8 +29,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfter {
@ -171,4 +171,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
customSparkContext.stop()
}
test("kryoSerializer test") {
sc.stop()
sc = null
val eval = new EvalError()
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
customSparkContext.stop()
}
}

View File

@ -20,13 +20,17 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable {
public class Booster implements Serializable, KryoSerializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
@ -436,7 +440,8 @@ public class Booster implements Serializable {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@ -447,7 +452,8 @@ public class Booster implements Serializable {
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@ -463,4 +469,33 @@ public class Booster implements Serializable {
handle = 0;
}
}
@Override
public void write(Kryo kryo, Output output) {
try {
byte[] serObj = this.toByteArray();
int serObjSize = serObj.length;
System.out.println("==== serialized obj size " + serObjSize);
output.writeInt(serObjSize);
output.write(serObj);
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
@Override
public void read(Kryo kryo, Input input) {
try {
this.init(null);
int serObjSize = input.readInt();
System.out.println("==== the size of the object: " + serObjSize);
byte[] bytes = new byte[serObjSize];
input.readBytes(bytes);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
}
}
}

View File

@ -18,12 +18,15 @@ package ml.dmlc.xgboost4j.scala
import java.io.IOException
import com.esotericsoftware.kryo.io.{Output, Input}
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
class Booster private[xgboost4j](private var booster: JBooster)
extends Serializable with KryoSerializable {
/**
* Set parameter to the Booster.
@ -193,4 +196,12 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
super.finalize()
dispose
}
override def write(kryo: Kryo, output: Output): Unit = {
kryo.writeObject(output, booster)
}
override def read(kryo: Kryo, input: Input): Unit = {
booster = kryo.readObject(input, classOf[JBooster])
}
}