diff --git a/.gitignore b/.gitignore
index 151a23603..4be5f8c1e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -88,3 +88,4 @@ build_tests
/tests/cpp/xgboost_test
.DS_Store
+lib/
\ No newline at end of file
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 04cffa32c..f018b316a 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -199,7 +199,7 @@
maven-surefire-plugin
2.19.1
- true
+ false
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala
index 6273a0763..6bb35491b 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala
@@ -126,9 +126,22 @@ trait BoosterParams extends Params {
* [default='auto']
*/
val treeMethod = new Param[String](this, "tree_method",
- "The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
+ "The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
+ /**
+ * growth policy for fast histogram algorithm
+ */
+ val growthPolicty = new Param[String](this, "grow_policy",
+ "growth policy for fast histogram algorithm",
+ (value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
+
+ /**
+ * maximum number of bins in histogram
+ */
+ val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
+ (value: Int) => value > 0)
+
/**
* This is only used for approximate greedy algorithm.
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
@@ -194,6 +207,7 @@ trait BoosterParams extends Params {
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
+ growthPolicty -> "depthwise", maxBins -> 16,
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
@@ -227,7 +241,9 @@ private[spark] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
- val supportedTreeMethods = HashSet("auto", "exact", "approx")
+ val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
+
+ val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
val supportedSampleType = HashSet("uniform", "weighted")
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala
index 1734123e3..aff16f146 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala
@@ -190,6 +190,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
}
+ test("fast histogram algorithm parameters are exposed correctly") {
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
+ "eval_metric" -> "error")
+ val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ val trainingDF = buildTrainingDataframe()
+ val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
+ round = 10, nWorkers = math.min(2, numWorkers))
+ val error = new EvalError
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testItr, null))
+ assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix) < 0.1)
+ }
+
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
val intValueArray = new Array[Double](valueArray.length)
intValueArray(valueArray.length - 2) = {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 1874a3b6d..20c5263ef 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -111,11 +111,94 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
- nWorkers = numWorkers, useExternalMemory = true)
+ nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
+ test("test with fast histo depthwise") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "eval_metric" -> "error")
+ // TODO: histogram algorithm seems to be very very sensitive to worker number
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
+ nWorkers = math.min(numWorkers, 2))
+ assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix) < 0.1)
+ }
+
+ test("test with fast histo lossguide") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
+ nWorkers = math.min(numWorkers, 2))
+ val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix)
+ assert(x < 0.1)
+ }
+
+ test("test with fast histo lossguide with max bin") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
+ "eval_metric" -> "error")
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
+ nWorkers = math.min(numWorkers, 2))
+ val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix)
+ assert(x < 0.1)
+ }
+
+ test("test with fast histo depthwidth with max depth") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
+ "eval_metric" -> "error")
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
+ nWorkers = math.min(numWorkers, 2))
+ val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix)
+ assert(x < 0.1)
+ }
+
+ test("test with fast histo depthwidth with max depth and max bin") {
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(sc)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
+ "eval_metric" -> "error")
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
+ nWorkers = math.min(numWorkers, 2))
+ val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix)
+ assert(x < 0.1)
+ }
+
test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[LabeledPoint] = {
val nrow = 100
@@ -142,6 +225,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
}
sc.parallelize(points)
}
+
val trainingRDD = buildDenseRDD().repartition(4)
val testRDD = buildDenseRDD().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
@@ -189,6 +273,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val sampleList = new ListBuffer[SparkVector]
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
}
+
val trainingRDD = buildTrainingRDD(sc)
val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
index cce3eb1cd..7c01de4a0 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java
@@ -180,6 +180,26 @@ public class Booster implements Serializable, KryoSerializable {
return evalInfo[0];
}
+ /**
+ * evaluate with given dmatrixs.
+ *
+ * @param evalMatrixs dmatrixs for evaluation
+ * @param evalNames name for eval dmatrixs, used for check results
+ * @param iter current eval iteration
+ * @param metricsOut output array containing the evaluation metrics for each evalMatrix
+ * @return eval information
+ * @throws XGBoostError native error
+ */
+ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, float[] metricsOut)
+ throws XGBoostError {
+ String stringFormat = evalSet(evalMatrixs, evalNames, iter);
+ String[] metricPairs = stringFormat.split("\t");
+ for (int i = 1; i < metricPairs.length; i++) {
+ metricsOut[i - 1] = Float.valueOf(metricPairs[i].split(":")[1]);
+ }
+ return stringFormat;
+ }
+
/**
* evaluate with given customized Evaluation class
*
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
index f2ce989f6..b8601b19e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
@@ -57,26 +57,24 @@ public class XGBoost {
return Booster.loadModel(in);
}
- /**
- * Train a booster with given parameters.
- *
- * @param dtrain Data to be trained.
- * @param params Booster params.
- * @param round Number of boosting iterations.
- * @param watches a group of items to be evaluated during training, this allows user to watch
- * performance on the validation set.
- * @param obj customized objective (set to null if not used)
- * @param eval customized evaluation (set to null if not used)
- * @return trained booster
- * @throws XGBoostError native error
- */
public static Booster train(
- DMatrix dtrain,
- Map params,
- int round,
- Map watches,
- IObjective obj,
- IEvaluation eval) throws XGBoostError {
+ DMatrix dtrain,
+ Map params,
+ int round,
+ Map watches,
+ IObjective obj,
+ IEvaluation eval) throws XGBoostError {
+ return train(dtrain, params, round, watches, null, obj, eval);
+ }
+
+ public static Booster train(
+ DMatrix dtrain,
+ Map params,
+ int round,
+ Map watches,
+ float[][] metrics,
+ IObjective obj,
+ IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@@ -94,7 +92,7 @@ public class XGBoost {
//collect all data matrixs
DMatrix[] allMats;
- if (evalMats != null && evalMats.length > 0) {
+ if (evalMats.length > 0) {
allMats = new DMatrix[evalMats.length + 1];
allMats[0] = dtrain;
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
@@ -121,12 +119,20 @@ public class XGBoost {
}
//evaluation
- if (evalMats != null && evalMats.length > 0) {
+ if (evalMats.length > 0) {
String evalInfo;
if (eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, eval);
} else {
- evalInfo = booster.evalSet(evalMats, evalNames, iter);
+ if (metrics == null) {
+ evalInfo = booster.evalSet(evalMats, evalNames, iter);
+ } else {
+ float[] m = new float[evalMats.length];
+ evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
+ for (int i = 0; i < m.length; i++) {
+ metrics[i][iter] = m[i];
+ }
+ }
}
if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
index cb842af72..48d57af17 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
@@ -25,6 +25,41 @@ import scala.collection.JavaConverters._
* XGBoost Scala Training function.
*/
object XGBoost {
+
+ /**
+ * Train a booster given parameters.
+ *
+ * @param dtrain Data to be trained.
+ * @param params Parameters.
+ * @param round Number of boosting iterations.
+ * @param watches a group of items to be evaluated during training, this allows user to watch
+ * performance on the validation set.
+ * @param metrics array containing the evaluation metrics for each matrix in watches for each
+ * iteration
+ * @param obj customized objective
+ * @param eval customized evaluation
+ * @return The trained booster.
+ */
+ @throws(classOf[XGBoostError])
+ def train(
+ dtrain: DMatrix,
+ params: Map[String, Any],
+ round: Int,
+ watches: Map[String, DMatrix],
+ metrics: Array[Array[Float]],
+ obj: ObjectiveTrait,
+ eval: EvalTrait): Booster = {
+ val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
+ val xgboostInJava = JXGBoost.train(
+ dtrain.jDMatrix,
+ // we have to filter null value for customized obj and eval
+ params.filter(_._2 != null).map{
+ case (key: String, value) => (key, value.toString)
+ }.toMap[String, AnyRef].asJava,
+ round, jWatches.asJava, metrics, obj, eval)
+ new Booster(xgboostInJava)
+ }
+
/**
* Train a booster given parameters.
*
@@ -45,16 +80,7 @@ object XGBoost {
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Booster = {
- val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
- val xgboostInJava = JXGBoost.train(
- dtrain.jDMatrix,
- // we have to filter null value for customized obj and eval
- params.filter(_._2 != null).map{
- case (key: String, value) => (key, value.toString)
- }.toMap[String, AnyRef].asJava,
- round, jWatches.asJava,
- obj, eval)
- new Booster(xgboostInJava)
+ train(dtrain, params, round, watches, null, obj, eval)
}
/**
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
index d8cb1d505..c60b2406a 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
@@ -26,7 +26,6 @@ import java.util.HashMap;
import java.util.Map;
import junit.framework.TestCase;
-import ml.dmlc.xgboost4j.java.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
@@ -151,6 +150,130 @@ public class BoosterImplTest {
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
}
+ private void testWithFastHisto(DMatrix trainingSet, Map watches, int round,
+ Map paramMap, float threshold) throws XGBoostError {
+ float[][] metrics = new float[watches.size()][round];
+ Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
+ metrics, null, null);
+ for (int i = 0; i < metrics.length; i++)
+ for (int j = 1; j < metrics[i].length; j++) {
+ TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
+ }
+ for (int i = 0; i < metrics.length; i++)
+ for (int j = 0; j < metrics[i].length; j++) {
+ TestCase.assertTrue(metrics[i][j] >= threshold);
+ }
+ booster.dispose();
+ }
+
+ @Test
+ public void testFastHistoDepthWise() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ // testBoosterWithFastHistogram(trainMat, testMat);
+ Map paramMap = new HashMap() {
+ {
+ put("max_depth", 3);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ put("tree_method", "hist");
+ put("grow_policy", "depthwise");
+ put("eval_metric", "auc");
+ }
+ };
+ Map watches = new HashMap<>();
+ watches.put("training", trainMat);
+ watches.put("test", testMat);
+ testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
+ }
+
+ @Test
+ public void testFastHistoLossGuide() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ // testBoosterWithFastHistogram(trainMat, testMat);
+ Map paramMap = new HashMap() {
+ {
+ put("max_depth", 0);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ put("tree_method", "hist");
+ put("grow_policy", "lossguide");
+ put("max_leaves", 8);
+ put("eval_metric", "auc");
+ }
+ };
+ Map watches = new HashMap<>();
+ watches.put("training", trainMat);
+ watches.put("test", testMat);
+ testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
+ }
+
+ @Test
+ public void testFastHistoLossGuideMaxBin() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ // testBoosterWithFastHistogram(trainMat, testMat);
+ Map paramMap = new HashMap() {
+ {
+ put("max_depth", 0);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ put("tree_method", "hist");
+ put("grow_policy", "lossguide");
+ put("max_leaves", 8);
+ put("max_bins", 16);
+ put("eval_metric", "auc");
+ }
+ };
+ Map watches = new HashMap<>();
+ watches.put("training", trainMat);
+ testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
+ }
+
+ @Test
+ public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ // testBoosterWithFastHistogram(trainMat, testMat);
+ Map paramMap = new HashMap() {
+ {
+ put("max_depth", 3);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ put("tree_method", "hist");
+ put("max_depth", 2);
+ put("grow_policy", "depthwise");
+ put("eval_metric", "auc");
+ }
+ };
+ Map watches = new HashMap<>();
+ watches.put("training", trainMat);
+ testWithFastHisto(trainMat, watches, 10, paramMap, 0.85f);
+ }
+
+ @Test
+ public void testFastHistoDepthwiseMaxDepthMaxBin() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ // testBoosterWithFastHistogram(trainMat, testMat);
+ Map paramMap = new HashMap() {
+ {
+ put("max_depth", 3);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ put("tree_method", "hist");
+ put("max_depth", 2);
+ put("max_bin", 2);
+ put("grow_policy", "depthwise");
+ put("eval_metric", "auc");
+ }
+ };
+ Map watches = new HashMap<>();
+ watches.put("training", trainMat);
+ testWithFastHisto(trainMat, watches, 10, paramMap, 0.85f);
+ }
+
/**
* test cross valiation
*
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
index 147a486c5..fc4badc7b 100644
--- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
@@ -77,6 +77,23 @@ class ScalaBoosterImplSuite extends FunSuite {
XGBoost.train(trainMat, paramMap, round, watches, null, null)
}
+ private def trainBoosterWithFastHisto(
+ trainMat: DMatrix,
+ watches: Map[String, DMatrix],
+ round: Int,
+ paramMap: Map[String, String],
+ threshold: Float): Booster = {
+ val metrics = Array.fill(watches.size, round)(0.0f)
+ val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null)
+ for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
+ assert(metrics(i)(j) >= metrics(i)(j - 1))
+ }
+ for (metricsArray <- metrics; m <- metricsArray) {
+ assert(m >= threshold)
+ }
+ booster
+ }
+
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
@@ -128,4 +145,57 @@ class ScalaBoosterImplSuite extends FunSuite {
val nfold = 5
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
}
+
+ test("test with fast histo depthwise") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val paramMap = List("max_depth" -> "3", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
+ trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat, "test" -> testMat),
+ round = 10, paramMap, 0.0f)
+ }
+
+ test("test with fast histo lossguide") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val paramMap = List("max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
+ trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat, "test" -> testMat),
+ round = 10, paramMap, 0.0f)
+ }
+
+ test("test with fast histo lossguide with max bin") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val paramMap = List("max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
+ "eval_metric" -> "auc").toMap
+ trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
+ round = 10, paramMap, 0.0f)
+ }
+
+ test("test with fast histo depthwidth with max depth") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val paramMap = List("max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
+ "eval_metric" -> "auc").toMap
+ trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
+ round = 10, paramMap, 0.85f)
+ }
+
+ test("test with fast histo depthwidth with max depth and max bin") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val paramMap = List("max_depth" -> "0", "silent" -> "0",
+ "objective" -> "binary:logistic", "tree_method" -> "hist",
+ "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
+ "eval_metric" -> "auc").toMap
+ trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
+ round = 10, paramMap, 0.85f)
+ }
}