[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)
* add back train method but mark as deprecated * fix scalastyle error * make XGBoostModel hold BoosterParams as well
This commit is contained in:
parent
e38bea3cdf
commit
392aa6d1d3
@ -119,7 +119,7 @@ class XGBoostEstimator private[spark](
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
|
||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel)
|
||||
val returnedModel = copyValues(trainedModel, extractParamMap())
|
||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||
}
|
||||
|
||||
@ -17,10 +17,12 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||
|
||||
import org.apache.spark.ml.PredictionModel
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
@ -36,13 +38,15 @@ import org.json4s.DefaultFormats
|
||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||
*/
|
||||
abstract class XGBoostModel(protected var _booster: Booster)
|
||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
|
||||
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||
with Params with MLWritable {
|
||||
|
||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
|
||||
"whether to use external memory for prediction")
|
||||
|
||||
setDefault(useExternalMemory, false)
|
||||
|
||||
|
||||
@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
trait BoosterParams extends Params {
|
||||
this: XGBoostEstimator =>
|
||||
|
||||
/**
|
||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
@ -25,11 +25,10 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
@ -272,4 +271,16 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
}
|
||||
|
||||
test("params of estimator and produced model are coordinated correctly") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
import spark.implicits._
|
||||
val model =
|
||||
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||
assert(model.get[Double](model.eta).get == 0.1)
|
||||
assert(model.get[Int](model.maxDepth).get == 6)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user