diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index c00d16ad0..a5bbdb60c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -119,7 +119,7 @@ class XGBoostEstimator private[spark]( val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap, $(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory), $(missing)).setParent(this) - val returnedModel = copyValues(trainedModel) + val returnedModel = copyValues(trainedModel, extractParamMap()) if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) { returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 1ce2e33b2..c1b615993 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -17,10 +17,12 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.JavaConverters._ + import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix} -import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter +import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} import org.apache.hadoop.fs.{FSDataOutputStream, Path} + import org.apache.spark.ml.PredictionModel import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} @@ -36,13 +38,15 @@ import org.json4s.DefaultFormats * the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]] */ abstract class XGBoostModel(protected var _booster: Booster) - extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable { + extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable + with Params with MLWritable { def setLabelCol(name: String): XGBoostModel = set(labelCol, name) // scalastyle:off - final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction") + final val useExternalMemory = new BooleanParam(this, "use_external_memory", + "whether to use external memory for prediction") setDefault(useExternalMemory, false) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 6bb35491b..c01b31140 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params import scala.collection.immutable.HashSet -import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} trait BoosterParams extends Params { - this: XGBoostEstimator => /** * Booster to use, options: {'gbtree', 'gblinear', 'dart'} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index c58f9d372..01eaca737 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -16,7 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark -import java.io.{File, FileNotFoundException} +import java.io.File import scala.collection.mutable.ListBuffer import scala.io.Source @@ -25,11 +25,10 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.SparkContext -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql._ class XGBoostDFSuite extends SharedSparkContext with Utils { @@ -272,4 +271,16 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap assert(testDF.count() === predResultsFromDF.size) } + + test("params of estimator and produced model are coordinated correctly") { + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6") + val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile) + val spark = SparkSession.builder().getOrCreate() + import spark.implicits._ + val model = + XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers) + assert(model.get[Double](model.eta).get == 0.1) + assert(model.get[Int](model.maxDepth).get == 6) + } }