[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)
* add back train method but mark as deprecated * fix scalastyle error * make XGBoostModel hold BoosterParams as well
This commit is contained in:
@@ -119,7 +119,7 @@ class XGBoostEstimator private[spark](
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
|
||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel)
|
||||
val returnedModel = copyValues(trainedModel, extractParamMap())
|
||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||
}
|
||||
|
||||
@@ -17,10 +17,12 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||
|
||||
import org.apache.spark.ml.PredictionModel
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
@@ -36,13 +38,15 @@ import org.json4s.DefaultFormats
|
||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||
*/
|
||||
abstract class XGBoostModel(protected var _booster: Booster)
|
||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
|
||||
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||
with Params with MLWritable {
|
||||
|
||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
|
||||
"whether to use external memory for prediction")
|
||||
|
||||
setDefault(useExternalMemory, false)
|
||||
|
||||
|
||||
@@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
trait BoosterParams extends Params {
|
||||
this: XGBoostEstimator =>
|
||||
|
||||
/**
|
||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||
|
||||
Reference in New Issue
Block a user