[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)

* add back train method but mark as deprecated

* fix scalastyle error

* make XGBoostModel hold BoosterParams as well
This commit is contained in:
Nan Zhu
2017-04-21 08:12:50 -07:00
committed by GitHub
parent e38bea3cdf
commit 392aa6d1d3
4 changed files with 22 additions and 9 deletions

View File

@@ -119,7 +119,7 @@ class XGBoostEstimator private[spark](
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this)
val returnedModel = copyValues(trainedModel)
val returnedModel = copyValues(trainedModel, extractParamMap())
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
}

View File

@@ -17,10 +17,12 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
@@ -36,13 +38,15 @@ import org.json4s.DefaultFormats
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
*/
abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
with Params with MLWritable {
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
"whether to use external memory for prediction")
setDefault(useExternalMemory, false)

View File

@@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.immutable.HashSet
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
trait BoosterParams extends Params {
this: XGBoostEstimator =>
/**
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}