Distributed Fast Histogram Algorithm (#4011)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * init * allow hist algo * more changes * temp * update * remove hist sync * udpate rabit * change hist size * change the histogram * update kfactor * sync per node stats * temp * update * final * code clean * update rabit * more cleanup * fix errors * fix failed tests * enforce c++11 * fix lint issue * broadcast subsampled feature correctly * revert some changes * fix lint issue * enable monotone and interaction constraints * don't specify default for monotone and interactions * update docs
This commit is contained in:
@@ -263,8 +263,10 @@ object XGBoost extends Serializable {
|
||||
validateSparkSslConf(sparkContext)
|
||||
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
require(params("tree_method") == "hist" ||
|
||||
params("tree_method") == "approx" ||
|
||||
params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," +
|
||||
" 'approx' and 'auto'")
|
||||
}
|
||||
if (params.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
|
||||
@@ -50,10 +50,21 @@ private[spark] trait BoosterParams extends Params {
|
||||
* overfitting. [default=6] range: [1, Int.MaxValue]
|
||||
*/
|
||||
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
|
||||
|
||||
final def getMaxDepth: Int = $(maxDepth)
|
||||
|
||||
|
||||
/**
|
||||
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
|
||||
*/
|
||||
final val maxLeaves = new IntParam(this, "maxLeaves",
|
||||
"Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
|
||||
(value: Int) => value >= 0)
|
||||
|
||||
final def getMaxLeaves: Int = $(maxDepth)
|
||||
|
||||
|
||||
/**
|
||||
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
||||
* in a leaf node with the sum of instance weight less than min_child_weight, then the building
|
||||
@@ -147,7 +158,9 @@ private[spark] trait BoosterParams extends Params {
|
||||
* growth policy for fast histogram algorithm
|
||||
*/
|
||||
final val growPolicy = new Param[String](this, "growPolicy",
|
||||
"growth policy for fast histogram algorithm",
|
||||
"Controls a way new nodes are added to the tree. Currently supported only if" +
|
||||
" tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
|
||||
" closest to the root. lossguide: split at nodes with highest loss change.",
|
||||
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
||||
|
||||
final def getGrowPolicy: String = $(growPolicy)
|
||||
@@ -242,6 +255,22 @@ private[spark] trait BoosterParams extends Params {
|
||||
|
||||
final def getTreeLimit: Int = $(treeLimit)
|
||||
|
||||
final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
|
||||
doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
|
||||
"decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
|
||||
"padded ")
|
||||
|
||||
final def getMonotoneConstraints: String = $(monotoneConstraints)
|
||||
|
||||
final val interactionConstraints = new Param[String](this,
|
||||
name = "interactionConstraints",
|
||||
doc = "Constraints for interaction representing permitted interactions. The constraints" +
|
||||
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
|
||||
" where each inner list is a group of indices of features that are allowed to interact" +
|
||||
" with each other. See tutorial for more information")
|
||||
|
||||
final def getInteractionConstraints: String = $(interactionConstraints)
|
||||
|
||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
|
||||
@@ -231,10 +231,11 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune")) {
|
||||
(paramName == "updater" && (paramValue != "grow_histmaker,prune" ||
|
||||
paramValue != "hist"))) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type" +
|
||||
" and grow_histmaker,prune as the updater type")
|
||||
" and grow_histmaker,prune or hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
|
||||
Reference in New Issue
Block a user