[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)

* add back train method but mark as deprecated

* fix scalastyle error

* make XGBoostModel hold BoosterParams as well
This commit is contained in:
Nan Zhu 2017-04-21 08:12:50 -07:00 committed by GitHub
parent e38bea3cdf
commit 392aa6d1d3
4 changed files with 22 additions and 9 deletions

View File

@ -119,7 +119,7 @@ class XGBoostEstimator private[spark](
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap, val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory), $(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this) $(missing)).setParent(this)
val returnedModel = copyValues(trainedModel) val returnedModel = copyValues(trainedModel, extractParamMap())
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) { if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses) returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
} }

View File

@ -17,10 +17,12 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
@ -36,13 +38,15 @@ import org.json4s.DefaultFormats
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]] * the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
*/ */
abstract class XGBoostModel(protected var _booster: Booster) abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable { extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
with Params with MLWritable {
def setLabelCol(name: String): XGBoostModel = set(labelCol, name) def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off // scalastyle:off
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction") final val useExternalMemory = new BooleanParam(this, "use_external_memory",
"whether to use external memory for prediction")
setDefault(useExternalMemory, false) setDefault(useExternalMemory, false)

View File

@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.immutable.HashSet import scala.collection.immutable.HashSet
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
trait BoosterParams extends Params { trait BoosterParams extends Params {
this: XGBoostEstimator =>
/** /**
* Booster to use, options: {'gbtree', 'gblinear', 'dart'} * Booster to use, options: {'gbtree', 'gblinear', 'dart'}

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException} import java.io.File
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import scala.io.Source import scala.io.Source
@ -25,11 +25,10 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql._ import org.apache.spark.sql._
class XGBoostDFSuite extends SharedSparkContext with Utils { class XGBoostDFSuite extends SharedSparkContext with Utils {
@ -272,4 +271,16 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
assert(testDF.count() === predResultsFromDF.size) assert(testDF.count() === predResultsFromDF.size)
} }
test("params of estimator and produced model are coordinated correctly") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val model =
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
assert(model.get[Double](model.eta).get == 0.1)
assert(model.get[Int](model.maxDepth).get == 6)
}
} }