diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala index 787cd753b..c7330c578 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params import scala.collection.mutable +import com.google.common.base.CaseFormat import org.apache.spark.ml.param._ private[spark] trait ParamMapConversion extends NonXGBoostParams { @@ -28,20 +29,23 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams { * @param xgboostParams XGBoost style parameters */ def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = { - for ((name, paramValue) <- xgboostParams) { - params.find(_.name == name).foreach { - case _: DoubleParam => - set(name, paramValue.toString.toDouble) - case _: BooleanParam => - set(name, paramValue.toString.toBoolean) - case _: IntParam => - set(name, paramValue.toString.toInt) - case _: FloatParam => - set(name, paramValue.toString.toFloat) - case _: LongParam => - set(name, paramValue.toString.toLong) - case _: Param[_] => - set(name, paramValue) + for ((paramName, paramValue) <- xgboostParams) { + val lowerCamelName = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) + val lowerName = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, paramName) + val qualifiedNames = mutable.Set(paramName, lowerName, lowerCamelName) + params.find(p => qualifiedNames.contains(p.name)) foreach { + case p: DoubleParam => + set(p.name, paramValue.toString.toDouble) + case p: BooleanParam => + set(p.name, paramValue.toString.toBoolean) + case p: IntParam => + set(p.name, paramValue.toString.toInt) + case p: FloatParam => + set(p.name, paramValue.toString.toFloat) + case p: LongParam => + set(p.name, paramValue.toString.toLong) + case p: Param[_] => + set(p.name, paramValue) } } } @@ -49,7 +53,7 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams { /** * Convert the user-supplied parameters to the XGBoost parameters. * - * Note that this also contains jvm-specific parameters. + * Note that this doesn't contain jvm-specific parameters. */ def getXGBoostParams: Map[String, Any] = { val xgboostParams = new mutable.HashMap[String, Any]() diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index de0b8e3dd..25b4b9e02 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -69,6 +69,48 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu assert(model.getContribPredictionCol === "contrib") } + test("camel case parameters") { + val xgbParams: Map[String, Any] = Map( + "max_depth" -> 5, + "featuresCol" -> "abc", + "num_workers" -> 2, + "numRound" -> 11 + ) + val estimator = new XGBoostClassifier(xgbParams) + assert(estimator.getFeaturesCol === "abc") + assert(estimator.getNumWorkers === 2) + assert(estimator.getNumRound === 11) + assert(estimator.getMaxDepth === 5) + + val xgbParams1: Map[String, Any] = Map( + "maxDepth" -> 5, + "features_col" -> "abc", + "numWorkers" -> 2, + "num_round" -> 11 + ) + val estimator1 = new XGBoostClassifier(xgbParams1) + assert(estimator1.getFeaturesCol === "abc") + assert(estimator1.getNumWorkers === 2) + assert(estimator1.getNumRound === 11) + assert(estimator1.getMaxDepth === 5) + } + + test("get xgboost parameters") { + val params: Map[String, Any] = Map( + "max_depth" -> 5, + "featuresCol" -> "abc", + "label" -> "class", + "num_workers" -> 2, + "tree_method" -> "hist", + "numRound" -> 11, + "not_exist_parameters" -> "hello" + ) + val estimator = new XGBoostClassifier(params) + val xgbParams = estimator.getXGBoostParams + assert(xgbParams.size === 2) + assert(xgbParams.contains("max_depth") && xgbParams.contains("tree_method")) + } + test("nthread") { val classifier = new XGBoostClassifier().setNthread(100)