diff --git a/jvm-packages/README.md b/jvm-packages/README.md index a390a7288..62aa79268 100644 --- a/jvm-packages/README.md +++ b/jvm-packages/README.md @@ -56,7 +56,12 @@ object DistTrainWithSpark { "usage: program num_of_rounds training_path model_path") sys.exit(1) } - val sc = new SparkContext() + // if you do not want to use KryoSerializer in Spark, you can ignore the related configuration + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sparkConf.registerKryoClasses(Array(classOf[Booster])) + val sc = new SparkContext(sparkConf) + val sc = new SparkContext(sparkConf) val inputTrainPath = args(1) val outputModelPath = args(2) // number of iterations diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 1409877be..5d0cbd00b 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -170,6 +170,11 @@ + + com.esotericsoftware.kryo + kryo + 2.21 + org.scala-lang scala-compiler diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala index 82e6e626b..978e8f0ee 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala @@ -16,9 +16,9 @@ package ml.dmlc.xgboost4j.scala.example.spark -import ml.dmlc.xgboost4j.scala.DMatrix +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.util.MLUtils object DistTrainWithSpark { @@ -28,7 +28,10 @@ object DistTrainWithSpark { "usage: program num_of_rounds num_workers training_path test_path model_path") sys.exit(1) } - val sc = new SparkContext() + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sparkConf.registerKryoClasses(Array(classOf[Booster])) + val sc = new SparkContext(sparkConf) val inputTrainPath = args(2) val inputTestPath = args(3) val outputModelPath = args(4) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 26032694d..96abe9946 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{BeforeAndAfter, FunSuite} -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError} -import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} class XGBoostSuite extends FunSuite with BeforeAndAfter { @@ -171,4 +171,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { } customSparkContext.stop() } + + test("kryoSerializer test") { + sc.stop() + sc = null + val eval = new EvalError() + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sparkConf.registerKryoClasses(Array(classOf[Booster])) + val customSparkContext = new SparkContext(sparkConf) + val trainingRDD = buildTrainingRDD(Some(customSparkContext)) + val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator + import DataUtils._ + val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap + val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers) + assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1) + customSparkContext.stop() + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 5778149f2..cce3eb1cd 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -20,13 +20,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model */ -public class Booster implements Serializable { +public class Booster implements Serializable, KryoSerializable { private static final Log logger = LogFactory.getLog(Booster.class); // handle to the booster. private long handle = 0; @@ -436,7 +440,8 @@ public class Booster implements Serializable { try { out.writeObject(this.toByteArray()); } catch (XGBoostError ex) { - throw new IOException(ex.toString()); + ex.printStackTrace(); + logger.error(ex.getMessage()); } } @@ -447,7 +452,8 @@ public class Booster implements Serializable { byte[] bytes = (byte[])in.readObject(); JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); } catch (XGBoostError ex) { - throw new IOException(ex.toString()); + ex.printStackTrace(); + logger.error(ex.getMessage()); } } @@ -463,4 +469,33 @@ public class Booster implements Serializable { handle = 0; } } + + @Override + public void write(Kryo kryo, Output output) { + try { + byte[] serObj = this.toByteArray(); + int serObjSize = serObj.length; + System.out.println("==== serialized obj size " + serObjSize); + output.writeInt(serObjSize); + output.write(serObj); + } catch (XGBoostError ex) { + ex.printStackTrace(); + logger.error(ex.getMessage()); + } + } + + @Override + public void read(Kryo kryo, Input input) { + try { + this.init(null); + int serObjSize = input.readInt(); + System.out.println("==== the size of the object: " + serObjSize); + byte[] bytes = new byte[serObjSize]; + input.readBytes(bytes); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + } catch (XGBoostError ex) { + ex.printStackTrace(); + logger.error(ex.getMessage()); + } + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index cb8f16f8c..f0e01062e 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -18,12 +18,15 @@ package ml.dmlc.xgboost4j.scala import java.io.IOException +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import ml.dmlc.xgboost4j.java.{Booster => JBooster} import ml.dmlc.xgboost4j.java.XGBoostError import scala.collection.JavaConverters._ import scala.collection.mutable -class Booster private[xgboost4j](booster: JBooster) extends Serializable { +class Booster private[xgboost4j](private var booster: JBooster) + extends Serializable with KryoSerializable { /** * Set parameter to the Booster. @@ -193,4 +196,12 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable { super.finalize() dispose } + + override def write(kryo: Kryo, output: Output): Unit = { + kryo.writeObject(output, booster) + } + + override def read(kryo: Kryo, input: Input): Unit = { + booster = kryo.readObject(input, classOf[JBooster]) + } }