[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)
* add back train method but mark as deprecated * fix scalastyle error * make XGBoostModel hold BoosterParams as well
This commit is contained in:
parent
e38bea3cdf
commit
392aa6d1d3
@ -119,7 +119,7 @@ class XGBoostEstimator private[spark](
|
|||||||
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
|
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
|
||||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||||
$(missing)).setParent(this)
|
$(missing)).setParent(this)
|
||||||
val returnedModel = copyValues(trainedModel)
|
val returnedModel = copyValues(trainedModel, extractParamMap())
|
||||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,10 +17,12 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
|
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||||
|
|
||||||
import org.apache.spark.ml.PredictionModel
|
import org.apache.spark.ml.PredictionModel
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||||
@ -36,13 +38,15 @@ import org.json4s.DefaultFormats
|
|||||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||||
*/
|
*/
|
||||||
abstract class XGBoostModel(protected var _booster: Booster)
|
abstract class XGBoostModel(protected var _booster: Booster)
|
||||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
|
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||||
|
with Params with MLWritable {
|
||||||
|
|
||||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||||
|
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
|
|
||||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
|
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
|
||||||
|
"whether to use external memory for prediction")
|
||||||
|
|
||||||
setDefault(useExternalMemory, false)
|
setDefault(useExternalMemory, false)
|
||||||
|
|
||||||
|
|||||||
@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import scala.collection.immutable.HashSet
|
import scala.collection.immutable.HashSet
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
|
|
||||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||||
|
|
||||||
trait BoosterParams extends Params {
|
trait BoosterParams extends Params {
|
||||||
this: XGBoostEstimator =>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.io.{File, FileNotFoundException}
|
import java.io.File
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
@ -25,11 +25,10 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
|||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
import org.apache.spark.ml.Pipeline
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
|
|
||||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||||
@ -272,4 +271,16 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||||
assert(testDF.count() === predResultsFromDF.size)
|
assert(testDF.count() === predResultsFromDF.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params of estimator and produced model are coordinated correctly") {
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
|
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
||||||
|
val spark = SparkSession.builder().getOrCreate()
|
||||||
|
import spark.implicits._
|
||||||
|
val model =
|
||||||
|
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||||
|
assert(model.get[Double](model.eta).get == 0.1)
|
||||||
|
assert(model.get[Int](model.maxDepth).get == 6)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user