[jvm-packages] bring back camel case variants of parameters (#10845)

This commit is contained in:
Bobby Wang 2024-09-25 14:14:42 +08:00 committed by GitHub
parent 2179baa50c
commit a049490cdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 15 deletions

View File

@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.mutable import scala.collection.mutable
import com.google.common.base.CaseFormat
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
private[spark] trait ParamMapConversion extends NonXGBoostParams { private[spark] trait ParamMapConversion extends NonXGBoostParams {
@ -28,20 +29,23 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
* @param xgboostParams XGBoost style parameters * @param xgboostParams XGBoost style parameters
*/ */
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = { def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
for ((name, paramValue) <- xgboostParams) { for ((paramName, paramValue) <- xgboostParams) {
params.find(_.name == name).foreach { val lowerCamelName = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
case _: DoubleParam => val lowerName = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, paramName)
set(name, paramValue.toString.toDouble) val qualifiedNames = mutable.Set(paramName, lowerName, lowerCamelName)
case _: BooleanParam => params.find(p => qualifiedNames.contains(p.name)) foreach {
set(name, paramValue.toString.toBoolean) case p: DoubleParam =>
case _: IntParam => set(p.name, paramValue.toString.toDouble)
set(name, paramValue.toString.toInt) case p: BooleanParam =>
case _: FloatParam => set(p.name, paramValue.toString.toBoolean)
set(name, paramValue.toString.toFloat) case p: IntParam =>
case _: LongParam => set(p.name, paramValue.toString.toInt)
set(name, paramValue.toString.toLong) case p: FloatParam =>
case _: Param[_] => set(p.name, paramValue.toString.toFloat)
set(name, paramValue) case p: LongParam =>
set(p.name, paramValue.toString.toLong)
case p: Param[_] =>
set(p.name, paramValue)
} }
} }
} }
@ -49,7 +53,7 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
/** /**
* Convert the user-supplied parameters to the XGBoost parameters. * Convert the user-supplied parameters to the XGBoost parameters.
* *
* Note that this also contains jvm-specific parameters. * Note that this doesn't contain jvm-specific parameters.
*/ */
def getXGBoostParams: Map[String, Any] = { def getXGBoostParams: Map[String, Any] = {
val xgboostParams = new mutable.HashMap[String, Any]() val xgboostParams = new mutable.HashMap[String, Any]()

View File

@ -69,6 +69,48 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(model.getContribPredictionCol === "contrib") assert(model.getContribPredictionCol === "contrib")
} }
test("camel case parameters") {
val xgbParams: Map[String, Any] = Map(
"max_depth" -> 5,
"featuresCol" -> "abc",
"num_workers" -> 2,
"numRound" -> 11
)
val estimator = new XGBoostClassifier(xgbParams)
assert(estimator.getFeaturesCol === "abc")
assert(estimator.getNumWorkers === 2)
assert(estimator.getNumRound === 11)
assert(estimator.getMaxDepth === 5)
val xgbParams1: Map[String, Any] = Map(
"maxDepth" -> 5,
"features_col" -> "abc",
"numWorkers" -> 2,
"num_round" -> 11
)
val estimator1 = new XGBoostClassifier(xgbParams1)
assert(estimator1.getFeaturesCol === "abc")
assert(estimator1.getNumWorkers === 2)
assert(estimator1.getNumRound === 11)
assert(estimator1.getMaxDepth === 5)
}
test("get xgboost parameters") {
val params: Map[String, Any] = Map(
"max_depth" -> 5,
"featuresCol" -> "abc",
"label" -> "class",
"num_workers" -> 2,
"tree_method" -> "hist",
"numRound" -> 11,
"not_exist_parameters" -> "hello"
)
val estimator = new XGBoostClassifier(params)
val xgbParams = estimator.getXGBoostParams
assert(xgbParams.size === 2)
assert(xgbParams.contains("max_depth") && xgbParams.contains("tree_method"))
}
test("nthread") { test("nthread") {
val classifier = new XGBoostClassifier().setNthread(100) val classifier = new XGBoostClassifier().setNthread(100)