diff --git a/NEWS.md b/NEWS.md
index 6fc6a37a5..81afdbb5a 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -12,6 +12,9 @@ This file records the changes in xgboost library in reverse chronological order.
- Enable registry pattern to allow optionally plugin of objective, metric, tree constructor, data loader.
- Future plugin modules can be put into xgboost/plugin and register back to the library.
- Remove most of the raw pointers to smart ptrs, for RAII safety.
+* Add official option to approximate algorithm `tree_method` to parameter.
+ - Change default behavior to switch to prefer faster algorithm.
+ - User will get a message when approximate algorithm is chosen.
* Change library name to libxgboost.so
* Backward compatiblity
- The binary buffer file is not backward compatible with previous version.
diff --git a/R-package/README.md b/R-package/README.md
index b94f50af6..36c2308dd 100644
--- a/R-package/README.md
+++ b/R-package/README.md
@@ -1,5 +1,5 @@
-R package for xgboost
-=====================
+XGBoost R Package for Scalable GBM
+==================================
[](http://cran.r-project.org/web/packages/xgboost)
[](http://cran.rstudio.com/web/packages/xgboost/index.html)
diff --git a/doc/jvm/index.md b/doc/jvm/index.md
index 8b1f22d5e..3bd5b7dfa 100644
--- a/doc/jvm/index.md
+++ b/doc/jvm/index.md
@@ -18,3 +18,8 @@ Contents
--------
* [Java Overview Tutorial](java_intro.md)
* [Code Examples](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example)
+* [Java API Docs](http://dmlc.ml/docs/javadocs/index.html)
+* [Scala API Docs]
+ * [XGBoost4J](http://dmlc.ml/docs/scaladocs/xgboost4j/index.html)
+ * [XGBoost4J-Spark](http://dmlc.ml/docs/scaladocs/xgboost4j-spark/index.html)
+ * [XGBoost4J-Flink](http://dmlc.ml/docs/scaladocs/xgboost4j-flink/index.html)
\ No newline at end of file
diff --git a/doc/parameter.md b/doc/parameter.md
index af3986bbf..32f772fcc 100644
--- a/doc/parameter.md
+++ b/doc/parameter.md
@@ -53,6 +53,24 @@ Parameters for Tree Booster
- L2 regularization term on weights
* alpha [default=0]
- L1 regularization term on weights
+* tree_method, string [default='auto']
+ - The tree constructtion algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
+ - Distributed and external memory version only support approximate algorithm.
+ - Choices: {'auto', 'exact', 'approx'}
+ - 'auto': Use heuristic to choose faster one.
+ - For small to medium dataset, exact greedy will be used.
+ - For very large-dataset, approximate algorithm will be choosed.
+ - Because old behavior is always use exact greedy in single machine,
+ user will get a message when approximate algorithm is choosed to notify this choice.
+ - 'exact': Exact greedy algorithm.
+ - 'approx': Approximate greedy algorithm using sketching and histogram.
+* sketch_eps, [default=0.03]
+ - This is only used for approximate greedy algorithm.
+ - This roughly translated into ```O(1 / sketch_eps)``` number of bins.
+ Compared to directly select number of bins, this comes with theoretical ganrantee with sketch accuracy.
+ - Usuaully user do not have to tune this.
+ but consider set to lower number for more accurate enumeration.
+ - range: (0, 1)
Parameters for Linear Booster
-----------------------------
diff --git a/jvm-packages/README.md b/jvm-packages/README.md
index a390a7288..62aa79268 100644
--- a/jvm-packages/README.md
+++ b/jvm-packages/README.md
@@ -56,7 +56,12 @@ object DistTrainWithSpark {
"usage: program num_of_rounds training_path model_path")
sys.exit(1)
}
- val sc = new SparkContext()
+ // if you do not want to use KryoSerializer in Spark, you can ignore the related configuration
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sparkConf.registerKryoClasses(Array(classOf[Booster]))
+ val sc = new SparkContext(sparkConf)
+ val sc = new SparkContext(sparkConf)
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index dff6e0359..5d0cbd00b 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -48,6 +48,41 @@
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 2.10.3
+
+
+ ml.dmlc.xgboost4j.java.example
+
+
+
+
+ net.alchim31.maven
+ maven-site-plugin
+ 3.0
+
+
+
+ maven-project-info-reports-plugin
+ 2.2
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.1
+
+
+ -Xms64m
+ -Xmx1024m
+
+
+
+ ...
+
+
+
org.apache.maven.plugins
maven-checkstyle-plugin
@@ -135,6 +170,11 @@
+
+ com.esotericsoftware.kryo
+ kryo
+ 2.21
+
org.scala-lang
scala-compiler
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
index 82e6e626b..978e8f0ee 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
@@ -16,9 +16,9 @@
package ml.dmlc.xgboost4j.scala.example.spark
-import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils
object DistTrainWithSpark {
@@ -28,7 +28,10 @@ object DistTrainWithSpark {
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
- val sc = new SparkContext()
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sparkConf.registerKryoClasses(Array(classOf[Booster]))
+ val sc = new SparkContext(sparkConf)
val inputTrainPath = args(2)
val inputTestPath = args(3)
val outputModelPath = args(4)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
index 26032694d..96abe9946 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
@@ -29,8 +29,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
-import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
-import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
+import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfter {
@@ -171,4 +171,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
customSparkContext.stop()
}
+
+ test("kryoSerializer test") {
+ sc.stop()
+ sc = null
+ val eval = new EvalError()
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sparkConf.registerKryoClasses(Array(classOf[Booster]))
+ val customSparkContext = new SparkContext(sparkConf)
+ val trainingRDD = buildTrainingRDD(Some(customSparkContext))
+ val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap
+ val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
+ customSparkContext.stop()
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
index 5778149f2..cce3eb1cd 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
@@ -20,13 +20,17 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
-public class Booster implements Serializable {
+public class Booster implements Serializable, KryoSerializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
@@ -436,7 +440,8 @@ public class Booster implements Serializable {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
- throw new IOException(ex.toString());
+ ex.printStackTrace();
+ logger.error(ex.getMessage());
}
}
@@ -447,7 +452,8 @@ public class Booster implements Serializable {
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
- throw new IOException(ex.toString());
+ ex.printStackTrace();
+ logger.error(ex.getMessage());
}
}
@@ -463,4 +469,33 @@ public class Booster implements Serializable {
handle = 0;
}
}
+
+ @Override
+ public void write(Kryo kryo, Output output) {
+ try {
+ byte[] serObj = this.toByteArray();
+ int serObjSize = serObj.length;
+ System.out.println("==== serialized obj size " + serObjSize);
+ output.writeInt(serObjSize);
+ output.write(serObj);
+ } catch (XGBoostError ex) {
+ ex.printStackTrace();
+ logger.error(ex.getMessage());
+ }
+ }
+
+ @Override
+ public void read(Kryo kryo, Input input) {
+ try {
+ this.init(null);
+ int serObjSize = input.readInt();
+ System.out.println("==== the size of the object: " + serObjSize);
+ byte[] bytes = new byte[serObjSize];
+ input.readBytes(bytes);
+ JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
+ } catch (XGBoostError ex) {
+ ex.printStackTrace();
+ logger.error(ex.getMessage());
+ }
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala
index cb8f16f8c..f0e01062e 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala
@@ -18,12 +18,15 @@ package ml.dmlc.xgboost4j.scala
import java.io.IOException
+import com.esotericsoftware.kryo.io.{Output, Input}
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
-class Booster private[xgboost4j](booster: JBooster) extends Serializable {
+class Booster private[xgboost4j](private var booster: JBooster)
+ extends Serializable with KryoSerializable {
/**
* Set parameter to the Booster.
@@ -193,4 +196,12 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
super.finalize()
dispose
}
+
+ override def write(kryo: Kryo, output: Output): Unit = {
+ kryo.writeObject(output, booster)
+ }
+
+ override def read(kryo: Kryo, input: Input): Unit = {
+ booster = kryo.readObject(input, classOf[JBooster])
+ }
}
diff --git a/src/learner.cc b/src/learner.cc
index 0fd8e7af6..6a95e0bab 100644
--- a/src/learner.cc
+++ b/src/learner.cc
@@ -4,6 +4,7 @@
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
*/
+#include
#include
#include
#include
@@ -69,6 +70,8 @@ struct LearnerTrainParam
bool seed_per_iteration;
// data split mode, can be row, col, or none.
int dsplit;
+ // tree construction method
+ int tree_method;
// internal test flag
std::string test_flag;
// maximum buffered row value
@@ -87,6 +90,11 @@ struct LearnerTrainParam
.add_enum("col", 1)
.add_enum("row", 2)
.describe("Data split mode for distributed trainig. ");
+ DMLC_DECLARE_FIELD(tree_method).set_default(0)
+ .add_enum("auto", 0)
+ .add_enum("approx", 1)
+ .add_enum("exact", 2)
+ .describe("Choice of tree construction method.");
DMLC_DECLARE_FIELD(test_flag).set_default("")
.describe("Internal test flag");
DMLC_DECLARE_FIELD(prob_buffer_row).set_default(1.0f).set_range(0.0f, 1.0f)
@@ -349,21 +357,42 @@ class LearnerImpl : public Learner {
// check if p_train is ready to used by training.
// if not, initialize the column access.
inline void LazyInitDMatrix(DMatrix *p_train) {
- if (p_train->HaveColAccess()) return;
- int ncol = static_cast(p_train->info().num_col);
- std::vector enabled(ncol, true);
- // set max row per batch to limited value
- // in distributed mode, use safe choice otherwise
- size_t max_row_perbatch = tparam.max_row_perbatch;
- if (tparam.test_flag == "block" || tparam.dsplit == 2) {
- max_row_perbatch = std::min(
- static_cast(32UL << 10UL), max_row_perbatch);
+ if (!p_train->HaveColAccess()) {
+ int ncol = static_cast(p_train->info().num_col);
+ std::vector enabled(ncol, true);
+ // set max row per batch to limited value
+ // in distributed mode, use safe choice otherwise
+ size_t max_row_perbatch = tparam.max_row_perbatch;
+ const size_t safe_max_row = static_cast(32UL << 10UL);
+
+ if (tparam.tree_method == 0 &&
+ p_train->info().num_row >= (4UL << 20UL)) {
+ LOG(CONSOLE) << "Tree method is automatically selected to be \'approx\'"
+ << " for faster speed."
+ << " to use old behavior(exact greedy algorithm on single machine),"
+ << " set tree_method to \'exact\'";
+ max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
+ }
+
+ if (tparam.tree_method == 1) {
+ LOG(CONSOLE) << "Tree method is selected to be \'approx\'";
+ max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
+ }
+
+ if (tparam.test_flag == "block" || tparam.dsplit == 2) {
+ max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
+ }
+ // initialize column access
+ p_train->InitColAccess(enabled,
+ tparam.prob_buffer_row,
+ max_row_perbatch);
}
- // initialize column access
- p_train->InitColAccess(enabled,
- tparam.prob_buffer_row,
- max_row_perbatch);
+
if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) {
+ if (tparam.tree_method == 2) {
+ LOG(CONSOLE) << "tree method is set to be 'exact',"
+ << " but currently we are only able to proceed with approximate algorithm";
+ }
cfg_["updater"] = "grow_histmaker,prune";
if (gbm_.get() != nullptr) {
gbm_->Configure(cfg_.begin(), cfg_.end());