[jvm-packages] Scala/Java interface for Fast Histogram Algorithm (#1966)

* add back train method but mark as deprecated

* fix scalastyle error

* first commit in scala binding for fast histo

* java test

* add missed scala tests

* spark training

* add back train method but mark as deprecated

* fix scalastyle error

* local change

* first commit in scala binding for fast histo

* local change

* fix df frame test
This commit is contained in:
Nan Zhu
2017-03-04 15:37:24 -08:00
committed by GitHub
parent ac30a0aff5
commit ab13fd72bd
10 changed files with 400 additions and 37 deletions

View File

@@ -126,9 +126,22 @@ trait BoosterParams extends Params {
* [default='auto']
*/
val treeMethod = new Param[String](this, "tree_method",
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx'}",
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
/**
* growth policy for fast histogram algorithm
*/
val growthPolicty = new Param[String](this, "grow_policy",
"growth policy for fast histogram algorithm",
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
/**
* maximum number of bins in histogram
*/
val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
(value: Int) => value > 0)
/**
* This is only used for approximate greedy algorithm.
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
@@ -194,6 +207,7 @@ trait BoosterParams extends Params {
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
growthPolicty -> "depthwise", maxBins -> 16,
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
@@ -227,7 +241,9 @@ private[spark] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
val supportedTreeMethods = HashSet("auto", "exact", "approx")
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
val supportedSampleType = HashSet("uniform", "weighted")