[jvm-packages] Scala/Java interface for Fast Histogram Algorithm (#1966)
* add back train method but mark as deprecated * fix scalastyle error * first commit in scala binding for fast histo * java test * add missed scala tests * spark training * add back train method but mark as deprecated * fix scalastyle error * local change * first commit in scala binding for fast histo * local change * fix df frame test
This commit is contained in:
parent
ac30a0aff5
commit
ab13fd72bd
1
.gitignore
vendored
1
.gitignore
vendored
@ -88,3 +88,4 @@ build_tests
|
|||||||
/tests/cpp/xgboost_test
|
/tests/cpp/xgboost_test
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
lib/
|
||||||
@ -199,7 +199,7 @@
|
|||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>2.19.1</version>
|
<version>2.19.1</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<skipTests>true</skipTests>
|
<skipTests>false</skipTests>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
|
|||||||
@ -126,9 +126,22 @@ trait BoosterParams extends Params {
|
|||||||
* [default='auto']
|
* [default='auto']
|
||||||
*/
|
*/
|
||||||
val treeMethod = new Param[String](this, "tree_method",
|
val treeMethod = new Param[String](this, "tree_method",
|
||||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
|
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
|
||||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* growth policy for fast histogram algorithm
|
||||||
|
*/
|
||||||
|
val growthPolicty = new Param[String](this, "grow_policy",
|
||||||
|
"growth policy for fast histogram algorithm",
|
||||||
|
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* maximum number of bins in histogram
|
||||||
|
*/
|
||||||
|
val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
|
||||||
|
(value: Int) => value > 0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is only used for approximate greedy algorithm.
|
* This is only used for approximate greedy algorithm.
|
||||||
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
|
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
|
||||||
@ -194,6 +207,7 @@ trait BoosterParams extends Params {
|
|||||||
|
|
||||||
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||||
|
growthPolicty -> "depthwise", maxBins -> 16,
|
||||||
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
||||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||||
@ -227,7 +241,9 @@ private[spark] object BoosterParams {
|
|||||||
|
|
||||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
||||||
|
|
||||||
val supportedTreeMethods = HashSet("auto", "exact", "approx")
|
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
|
||||||
|
|
||||||
|
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
|
||||||
|
|
||||||
val supportedSampleType = HashSet("uniform", "weighted")
|
val supportedSampleType = HashSet("uniform", "weighted")
|
||||||
|
|
||||||
|
|||||||
@ -190,6 +190,22 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("fast histogram algorithm parameters are exposed correctly") {
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||||
|
"eval_metric" -> "error")
|
||||||
|
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 10, nWorkers = math.min(2, numWorkers))
|
||||||
|
val error = new EvalError
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
||||||
|
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
|
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
|
||||||
val intValueArray = new Array[Double](valueArray.length)
|
val intValueArray = new Array[Double](valueArray.length)
|
||||||
intValueArray(valueArray.length - 2) = {
|
intValueArray(valueArray.length - 2) = {
|
||||||
|
|||||||
@ -111,11 +111,94 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
nWorkers = numWorkers, useExternalMemory = true)
|
nWorkers = numWorkers)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwise") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "eval_metric" -> "error")
|
||||||
|
// TODO: histogram algorithm seems to be very very sensitive to worker number
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = math.min(numWorkers, 2))
|
||||||
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo lossguide") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = math.min(numWorkers, 2))
|
||||||
|
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix)
|
||||||
|
assert(x < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo lossguide with max bin") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||||
|
"eval_metric" -> "error")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = math.min(numWorkers, 2))
|
||||||
|
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix)
|
||||||
|
assert(x < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwidth with max depth") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||||
|
"eval_metric" -> "error")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||||
|
nWorkers = math.min(numWorkers, 2))
|
||||||
|
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix)
|
||||||
|
assert(x < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwidth with max depth and max bin") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||||
|
"eval_metric" -> "error")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||||
|
nWorkers = math.min(numWorkers, 2))
|
||||||
|
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix)
|
||||||
|
assert(x < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
test("test with dense vectors containing missing value") {
|
test("test with dense vectors containing missing value") {
|
||||||
def buildDenseRDD(): RDD[LabeledPoint] = {
|
def buildDenseRDD(): RDD[LabeledPoint] = {
|
||||||
val nrow = 100
|
val nrow = 100
|
||||||
@ -142,6 +225,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
}
|
}
|
||||||
sc.parallelize(points)
|
sc.parallelize(points)
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainingRDD = buildDenseRDD().repartition(4)
|
val trainingRDD = buildDenseRDD().repartition(4)
|
||||||
val testRDD = buildDenseRDD().repartition(4)
|
val testRDD = buildDenseRDD().repartition(4)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
@ -189,6 +273,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val sampleList = new ListBuffer[SparkVector]
|
val sampleList = new ListBuffer[SparkVector]
|
||||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainingRDD = buildTrainingRDD(sc)
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testRDD = buildEmptyRDD()
|
val testRDD = buildEmptyRDD()
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
|
|||||||
@ -180,6 +180,26 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return evalInfo[0];
|
return evalInfo[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with given dmatrixs.
|
||||||
|
*
|
||||||
|
* @param evalMatrixs dmatrixs for evaluation
|
||||||
|
* @param evalNames name for eval dmatrixs, used for check results
|
||||||
|
* @param iter current eval iteration
|
||||||
|
* @param metricsOut output array containing the evaluation metrics for each evalMatrix
|
||||||
|
* @return eval information
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, float[] metricsOut)
|
||||||
|
throws XGBoostError {
|
||||||
|
String stringFormat = evalSet(evalMatrixs, evalNames, iter);
|
||||||
|
String[] metricPairs = stringFormat.split("\t");
|
||||||
|
for (int i = 1; i < metricPairs.length; i++) {
|
||||||
|
metricsOut[i - 1] = Float.valueOf(metricPairs[i].split(":")[1]);
|
||||||
|
}
|
||||||
|
return stringFormat;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate with given customized Evaluation class
|
* evaluate with given customized Evaluation class
|
||||||
*
|
*
|
||||||
|
|||||||
@ -57,19 +57,6 @@ public class XGBoost {
|
|||||||
return Booster.loadModel(in);
|
return Booster.loadModel(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Train a booster with given parameters.
|
|
||||||
*
|
|
||||||
* @param dtrain Data to be trained.
|
|
||||||
* @param params Booster params.
|
|
||||||
* @param round Number of boosting iterations.
|
|
||||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
|
||||||
* performance on the validation set.
|
|
||||||
* @param obj customized objective (set to null if not used)
|
|
||||||
* @param eval customized evaluation (set to null if not used)
|
|
||||||
* @return trained booster
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public static Booster train(
|
public static Booster train(
|
||||||
DMatrix dtrain,
|
DMatrix dtrain,
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
@ -77,6 +64,17 @@ public class XGBoost {
|
|||||||
Map<String, DMatrix> watches,
|
Map<String, DMatrix> watches,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval) throws XGBoostError {
|
IEvaluation eval) throws XGBoostError {
|
||||||
|
return train(dtrain, params, round, watches, null, obj, eval);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Booster train(
|
||||||
|
DMatrix dtrain,
|
||||||
|
Map<String, Object> params,
|
||||||
|
int round,
|
||||||
|
Map<String, DMatrix> watches,
|
||||||
|
float[][] metrics,
|
||||||
|
IObjective obj,
|
||||||
|
IEvaluation eval) throws XGBoostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
@ -94,7 +92,7 @@ public class XGBoost {
|
|||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
if (evalMats != null && evalMats.length > 0) {
|
if (evalMats.length > 0) {
|
||||||
allMats = new DMatrix[evalMats.length + 1];
|
allMats = new DMatrix[evalMats.length + 1];
|
||||||
allMats[0] = dtrain;
|
allMats[0] = dtrain;
|
||||||
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
|
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
|
||||||
@ -121,12 +119,20 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//evaluation
|
//evaluation
|
||||||
if (evalMats != null && evalMats.length > 0) {
|
if (evalMats.length > 0) {
|
||||||
String evalInfo;
|
String evalInfo;
|
||||||
if (eval != null) {
|
if (eval != null) {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, eval);
|
evalInfo = booster.evalSet(evalMats, evalNames, eval);
|
||||||
} else {
|
} else {
|
||||||
|
if (metrics == null) {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||||
|
} else {
|
||||||
|
float[] m = new float[evalMats.length];
|
||||||
|
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
|
||||||
|
for (int i = 0; i < m.length; i++) {
|
||||||
|
metrics[i][iter] = m[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (Rabit.getRank() == 0) {
|
if (Rabit.getRank() == 0) {
|
||||||
Rabit.trackerPrint(evalInfo + '\n');
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
|
|||||||
@ -25,6 +25,41 @@ import scala.collection.JavaConverters._
|
|||||||
* XGBoost Scala Training function.
|
* XGBoost Scala Training function.
|
||||||
*/
|
*/
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
|
* iteration
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def train(
|
||||||
|
dtrain: DMatrix,
|
||||||
|
params: Map[String, Any],
|
||||||
|
round: Int,
|
||||||
|
watches: Map[String, DMatrix],
|
||||||
|
metrics: Array[Array[Float]],
|
||||||
|
obj: ObjectiveTrait,
|
||||||
|
eval: EvalTrait): Booster = {
|
||||||
|
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||||
|
val xgboostInJava = JXGBoost.train(
|
||||||
|
dtrain.jDMatrix,
|
||||||
|
// we have to filter null value for customized obj and eval
|
||||||
|
params.filter(_._2 != null).map{
|
||||||
|
case (key: String, value) => (key, value.toString)
|
||||||
|
}.toMap[String, AnyRef].asJava,
|
||||||
|
round, jWatches.asJava, metrics, obj, eval)
|
||||||
|
new Booster(xgboostInJava)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a booster given parameters.
|
* Train a booster given parameters.
|
||||||
*
|
*
|
||||||
@ -45,16 +80,7 @@ object XGBoost {
|
|||||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null): Booster = {
|
eval: EvalTrait = null): Booster = {
|
||||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
train(dtrain, params, round, watches, null, obj, eval)
|
||||||
val xgboostInJava = JXGBoost.train(
|
|
||||||
dtrain.jDMatrix,
|
|
||||||
// we have to filter null value for customized obj and eval
|
|
||||||
params.filter(_._2 != null).map{
|
|
||||||
case (key: String, value) => (key, value.toString)
|
|
||||||
}.toMap[String, AnyRef].asJava,
|
|
||||||
round, jWatches.asJava,
|
|
||||||
obj, eval)
|
|
||||||
new Booster(xgboostInJava)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -26,7 +26,6 @@ import java.util.HashMap;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
import ml.dmlc.xgboost4j.java.*;
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
@ -151,6 +150,130 @@ public class BoosterImplTest {
|
|||||||
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void testWithFastHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||||
|
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||||
|
float[][] metrics = new float[watches.size()][round];
|
||||||
|
Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
|
||||||
|
metrics, null, null);
|
||||||
|
for (int i = 0; i < metrics.length; i++)
|
||||||
|
for (int j = 1; j < metrics[i].length; j++) {
|
||||||
|
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metrics.length; i++)
|
||||||
|
for (int j = 0; j < metrics[i].length; j++) {
|
||||||
|
TestCase.assertTrue(metrics[i][j] >= threshold);
|
||||||
|
}
|
||||||
|
booster.dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFastHistoDepthWise() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("tree_method", "hist");
|
||||||
|
put("grow_policy", "depthwise");
|
||||||
|
put("eval_metric", "auc");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFastHistoLossGuide() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 0);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("tree_method", "hist");
|
||||||
|
put("grow_policy", "lossguide");
|
||||||
|
put("max_leaves", 8);
|
||||||
|
put("eval_metric", "auc");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFastHistoLossGuideMaxBin() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 0);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("tree_method", "hist");
|
||||||
|
put("grow_policy", "lossguide");
|
||||||
|
put("max_leaves", 8);
|
||||||
|
put("max_bins", 16);
|
||||||
|
put("eval_metric", "auc");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("tree_method", "hist");
|
||||||
|
put("max_depth", 2);
|
||||||
|
put("grow_policy", "depthwise");
|
||||||
|
put("eval_metric", "auc");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.85f);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFastHistoDepthwiseMaxDepthMaxBin() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("tree_method", "hist");
|
||||||
|
put("max_depth", 2);
|
||||||
|
put("max_bin", 2);
|
||||||
|
put("grow_policy", "depthwise");
|
||||||
|
put("eval_metric", "auc");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.85f);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* test cross valiation
|
* test cross valiation
|
||||||
*
|
*
|
||||||
|
|||||||
@ -77,6 +77,23 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
XGBoost.train(trainMat, paramMap, round, watches, null, null)
|
XGBoost.train(trainMat, paramMap, round, watches, null, null)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def trainBoosterWithFastHisto(
|
||||||
|
trainMat: DMatrix,
|
||||||
|
watches: Map[String, DMatrix],
|
||||||
|
round: Int,
|
||||||
|
paramMap: Map[String, String],
|
||||||
|
threshold: Float): Booster = {
|
||||||
|
val metrics = Array.fill(watches.size, round)(0.0f)
|
||||||
|
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null)
|
||||||
|
for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
|
||||||
|
assert(metrics(i)(j) >= metrics(i)(j - 1))
|
||||||
|
}
|
||||||
|
for (metricsArray <- metrics; m <- metricsArray) {
|
||||||
|
assert(m >= threshold)
|
||||||
|
}
|
||||||
|
booster
|
||||||
|
}
|
||||||
|
|
||||||
test("basic operation of booster") {
|
test("basic operation of booster") {
|
||||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
@ -128,4 +145,57 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
val nfold = 5
|
val nfold = 5
|
||||||
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
|
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwise") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
|
||||||
|
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat, "test" -> testMat),
|
||||||
|
round = 10, paramMap, 0.0f)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo lossguide") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
|
||||||
|
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat, "test" -> testMat),
|
||||||
|
round = 10, paramMap, 0.0f)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo lossguide with max bin") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||||
|
"eval_metric" -> "auc").toMap
|
||||||
|
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
|
||||||
|
round = 10, paramMap, 0.0f)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwidth with max depth") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||||
|
"eval_metric" -> "auc").toMap
|
||||||
|
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
|
||||||
|
round = 10, paramMap, 0.85f)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test with fast histo depthwidth with max depth and max bin") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||||
|
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||||
|
"eval_metric" -> "auc").toMap
|
||||||
|
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
|
||||||
|
round = 10, paramMap, 0.85f)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user