[jvm-packages] use ML's para system to build the passed-in params to XGBoost (#2043)

* add back train method but mark as deprecated

* fix scalastyle error

* use ML's para system to build the passed-in params to XGBoost

* clean
This commit is contained in:
Nan Zhu 2017-02-18 11:56:27 -08:00 committed by GitHub
parent acce11d3f4
commit 185fe1d645
3 changed files with 19 additions and 31 deletions

View File

@ -32,7 +32,7 @@ import org.apache.spark.sql.{Dataset, Row}
* XGBoost Estimator to produce a XGBoost model * XGBoost Estimator to produce a XGBoost model
*/ */
class XGBoostEstimator private[spark]( class XGBoostEstimator private[spark](
override val uid: String, private[spark] var xgboostParams: Map[String, Any]) override val uid: String, xgboostParams: Map[String, Any])
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
with LearningTaskParams with GeneralParams with BoosterParams { with LearningTaskParams with GeneralParams with BoosterParams {
@ -41,7 +41,6 @@ class XGBoostEstimator private[spark](
def this(uid: String) = this(uid, Map[String, Any]()) def this(uid: String) = this(uid, Map[String, Any]())
// called in fromXGBParamMapToParams only when eval_metric is not defined // called in fromXGBParamMapToParams only when eval_metric is not defined
private def setupDefaultEvalMetric(): String = { private def setupDefaultEvalMetric(): String = {
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null)) val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
@ -93,16 +92,11 @@ class XGBoostEstimator private[spark](
fromXGBParamMapToParams() fromXGBParamMapToParams()
// only called when XGBParamMap is empty, i.e. in the constructor this(String) private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
// TODO: refactor to be functional
private def fromParamsToXGBParamMap(): Map[String, Any] = {
require(xgboostParams.isEmpty, "fromParamsToXGBParamMap can only be called when" +
" XGBParamMap is empty, i.e. in the constructor this(String)")
val xgbParamMap = new mutable.HashMap[String, Any]() val xgbParamMap = new mutable.HashMap[String, Any]()
for (param <- params) { for (param <- params) {
xgbParamMap += param.name -> $(param) xgbParamMap += param.name -> $(param)
} }
xgboostParams = xgbParamMap.toMap
xgbParamMap.toMap xgbParamMap.toMap
} }
@ -116,8 +110,9 @@ class XGBoostEstimator private[spark](
LabeledPoint(label, feature) LabeledPoint(label, feature)
} }
transformSchema(trainingSet.schema, logging = true) transformSchema(trainingSet.schema, logging = true)
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, $(round), $(nWorkers), val trainedModel = XGBoost.trainWithRDD(instances, fromParamsToXGBParamMap,
$(customObj), $(customEval), $(useExternalMemory), $(missing)).setParent(this) $(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this)
val returnedModel = copyValues(trainedModel) val returnedModel = copyValues(trainedModel)
if (XGBoost.isClassificationTask(xgboostParams)) { if (XGBoost.isClassificationTask(xgboostParams)) {
val numClass = { val numClass = {
@ -133,11 +128,6 @@ class XGBoostEstimator private[spark](
} }
override def copy(extra: ParamMap): XGBoostEstimator = { override def copy(extra: ParamMap): XGBoostEstimator = {
val est = defaultCopy(extra).asInstanceOf[XGBoostEstimator] defaultCopy(extra).asInstanceOf[XGBoostEstimator]
// we need to synchronize the params here instead of in the constructor
// because we cannot guarantee that params (default implementation) is initialized fully
// before the other params
est.fromParamsToXGBParamMap()
est
} }
} }

View File

@ -196,7 +196,7 @@ trait BoosterParams extends Params {
minChildWeight -> 1, maxDeltaStep -> 0, minChildWeight -> 1, maxDeltaStep -> 0,
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1, subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03, lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1, sampleType -> "uniform", normalizeType -> "tree", scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0) rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
/** /**

View File

@ -18,10 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql._ import org.apache.spark.sql._
class XGBoostDFSuite extends SharedSparkContext with Utils { class XGBoostDFSuite extends SharedSparkContext with Utils {
@ -47,23 +50,21 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
val (testItr, auxTestItr) = val (testItr, auxTestItr) =
loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
import DataUtils._ import DataUtils._
val round = 5
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null)) val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
val testDMatrix = new DMatrix(new JDMatrix(testItr, null)) val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5) val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
val predResultFromSeq = xgboostModel.predict(testDMatrix) val predResultFromSeq = xgboostModel.predict(testDMatrix)
val testSetItr = auxTestItr.zipWithIndex.map { val testSetItr = auxTestItr.zipWithIndex.map {
case (instance: LabeledPoint, id: Int) => case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label)
(id, instance.features, instance.label)
} }
val trainingDF = buildTrainingDataframe() val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false) round = round, nWorkers = numWorkers, useExternalMemory = false)
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF( val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
"id", "features", "label") "id", "features", "label")
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF). val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
(row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))
).toMap
assert(testDF.count() === predResultsFromDF.size) assert(testDF.count() === predResultsFromDF.size)
// the vector length in probabilties column is 2 since we have to fit to the evaluator in // the vector length in probabilties column is 2 since we have to fit to the evaluator in
// Spark // Spark
@ -169,8 +170,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic") assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
// from spark to xgboost params // from spark to xgboost params
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty) val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
assert(xgbEstimatorCopy.xgboostParams.get("eta").get.toString.toDouble === 1.0) assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
assert(xgbEstimatorCopy.xgboostParams.get("objective").get.toString === "binary:logistic") assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
} }
test("eval_metric is configured correctly") { test("eval_metric is configured correctly") {
@ -179,11 +180,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error") assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
val sparkParamMap = ParamMap.empty val sparkParamMap = ParamMap.empty
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap) val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
assert(xgbEstimatorCopy.xgboostParams.get("eval_metric") === Some("error")) assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss")) val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
assert(xgbEstimatorCopy1.xgboostParams.get("eval_metric") === Some("logloss")) assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
} }
} }