[jvm-packages] remove default parameters (#7938)
This commit is contained in:
parent
47224dd6d3
commit
fbc3d861bb
@ -166,7 +166,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||||
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
val useExternalMemory = overridedParams
|
||||||
|
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -261,6 +261,7 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
|
|
||||||
final val treeLimit = new IntParam(this, name = "treeLimit",
|
final val treeLimit = new IntParam(this, name = "treeLimit",
|
||||||
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
|
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
|
||||||
|
setDefault(treeLimit, 0)
|
||||||
|
|
||||||
final def getTreeLimit: Int = $(treeLimit)
|
final def getTreeLimit: Int = $(treeLimit)
|
||||||
|
|
||||||
@ -280,13 +281,6 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
|
|
||||||
final def getInteractionConstraints: String = $(interactionConstraints)
|
final def getInteractionConstraints: String = $(interactionConstraints)
|
||||||
|
|
||||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
|
||||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
|
||||||
growPolicy -> "depthwise", maxBins -> 256,
|
|
||||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
|
||||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
|
||||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
|
||||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[scala] object BoosterParams {
|
private[scala] object BoosterParams {
|
||||||
|
|||||||
@ -29,6 +29,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
setDefault(numRound, 1)
|
||||||
|
|
||||||
final def getNumRound: Int = $(numRound)
|
final def getNumRound: Int = $(numRound)
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
|
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
setDefault(numWorkers, 1)
|
||||||
|
|
||||||
final def getNumWorkers: Int = $(numWorkers)
|
final def getNumWorkers: Int = $(numWorkers)
|
||||||
|
|
||||||
@ -45,6 +47,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
|
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
setDefault(nthread, 1)
|
||||||
|
|
||||||
final def getNthread: Int = $(nthread)
|
final def getNthread: Int = $(nthread)
|
||||||
|
|
||||||
@ -53,6 +56,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
|
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
|
||||||
"whether to use external memory as cache")
|
"whether to use external memory as cache")
|
||||||
|
setDefault(useExternalMemory, false)
|
||||||
|
|
||||||
final def getUseExternalMemory: Boolean = $(useExternalMemory)
|
final def getUseExternalMemory: Boolean = $(useExternalMemory)
|
||||||
|
|
||||||
@ -94,6 +98,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
* the value treated as missing. default: Float.NaN
|
* the value treated as missing. default: Float.NaN
|
||||||
*/
|
*/
|
||||||
final val missing = new FloatParam(this, "missing", "the value treated as missing")
|
final val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||||
|
setDefault(missing, Float.NaN)
|
||||||
|
|
||||||
final def getMissing: Float = $(missing)
|
final def getMissing: Float = $(missing)
|
||||||
|
|
||||||
@ -109,6 +114,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
"not use Spark's VectorAssembler class to construct the feature vector " +
|
"not use Spark's VectorAssembler class to construct the feature vector " +
|
||||||
"but instead used a method that preserves zeros in your vector."
|
"but instead used a method that preserves zeros in your vector."
|
||||||
)
|
)
|
||||||
|
setDefault(allowNonZeroForMissing, false)
|
||||||
|
|
||||||
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
|
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
|
||||||
|
|
||||||
@ -163,18 +169,14 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
* Ignored if the tracker implementation is "python".
|
* Ignored if the tracker implementation is "python".
|
||||||
*/
|
*/
|
||||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||||
|
setDefault(trackerConf, TrackerConf())
|
||||||
|
|
||||||
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
||||||
final val seed = new LongParam(this, "seed", "random seed")
|
final val seed = new LongParam(this, "seed", "random seed")
|
||||||
|
setDefault(seed, 0L)
|
||||||
|
|
||||||
final def getSeed: Long = $(seed)
|
final def getSeed: Long = $(seed)
|
||||||
|
|
||||||
setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1,
|
|
||||||
useExternalMemory -> false, silent -> 0, verbosity -> 1,
|
|
||||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
|
||||||
trackerConf -> TrackerConf(), seed -> 0,
|
|
||||||
checkpointPath -> "", checkpointInterval -> -1,
|
|
||||||
allowNonZeroForMissing -> false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
trait HasLeafPredictionCol extends Params {
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -26,7 +26,7 @@ private[spark] trait InferenceParams extends Params {
|
|||||||
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
|
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getInferBatchSize: Int = ${inferBatchSize}
|
final def getInferBatchSize: Int = $(inferBatchSize)
|
||||||
|
|
||||||
setDefault(inferBatchSize, 32 << 10)
|
setDefault(inferBatchSize, 32 << 10)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -71,6 +71,7 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||||
"fraction of training points to use for testing",
|
"fraction of training points to use for testing",
|
||||||
ParamValidators.inRange(0, 1))
|
ParamValidators.inRange(0, 1))
|
||||||
|
setDefault(trainTestRatio, 1.0)
|
||||||
|
|
||||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||||
|
|
||||||
@ -105,8 +106,6 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
|
|
||||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||||
|
|
||||||
setDefault(baseScore -> 0.5, trainTestRatio -> 1.0,
|
|
||||||
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -29,12 +29,14 @@ private[spark] trait RabitParams extends Params {
|
|||||||
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
|
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
|
||||||
"threshold count to enable allreduce/broadcast with ring based topology",
|
"threshold count to enable allreduce/broadcast with ring based topology",
|
||||||
ParamValidators.gtEq(1))
|
ParamValidators.gtEq(1))
|
||||||
|
setDefault(rabitRingReduceThreshold, (32 << 10))
|
||||||
|
|
||||||
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
|
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
|
||||||
"timeout threshold after rabit observed failures")
|
"timeout threshold after rabit observed failures")
|
||||||
|
setDefault(rabitTimeout, -1)
|
||||||
|
|
||||||
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
||||||
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
||||||
|
setDefault(rabitConnectRetry, 5)
|
||||||
|
|
||||||
setDefault(rabitRingReduceThreshold -> (32 << 10), rabitConnectRetry -> 5, rabitTimeout -> -1)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,8 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
|
|||||||
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
|
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
|
||||||
|
isDefined(checkpointInterval) && getCheckpointInterval > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -84,4 +84,11 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
|||||||
|
|
||||||
new XGBoostClassifier(paramMap).fit(trainingDF)
|
new XGBoostClassifier(paramMap).fit(trainingDF)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Default parameters") {
|
||||||
|
val classifier = new XGBoostClassifier()
|
||||||
|
intercept[NoSuchElementException] {
|
||||||
|
classifier.getBaseScore
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user