[jvm-packages] use ML's para system to build the passed-in params to XGBoost (#2043)
* add back train method but mark as deprecated * fix scalastyle error * use ML's para system to build the passed-in params to XGBoost * clean
This commit is contained in:
parent
acce11d3f4
commit
185fe1d645
@ -32,7 +32,7 @@ import org.apache.spark.sql.{Dataset, Row}
|
|||||||
* XGBoost Estimator to produce a XGBoost model
|
* XGBoost Estimator to produce a XGBoost model
|
||||||
*/
|
*/
|
||||||
class XGBoostEstimator private[spark](
|
class XGBoostEstimator private[spark](
|
||||||
override val uid: String, private[spark] var xgboostParams: Map[String, Any])
|
override val uid: String, xgboostParams: Map[String, Any])
|
||||||
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
||||||
with LearningTaskParams with GeneralParams with BoosterParams {
|
with LearningTaskParams with GeneralParams with BoosterParams {
|
||||||
|
|
||||||
@ -41,7 +41,6 @@ class XGBoostEstimator private[spark](
|
|||||||
|
|
||||||
def this(uid: String) = this(uid, Map[String, Any]())
|
def this(uid: String) = this(uid, Map[String, Any]())
|
||||||
|
|
||||||
|
|
||||||
// called in fromXGBParamMapToParams only when eval_metric is not defined
|
// called in fromXGBParamMapToParams only when eval_metric is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
|
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
|
||||||
@ -93,16 +92,11 @@ class XGBoostEstimator private[spark](
|
|||||||
|
|
||||||
fromXGBParamMapToParams()
|
fromXGBParamMapToParams()
|
||||||
|
|
||||||
// only called when XGBParamMap is empty, i.e. in the constructor this(String)
|
private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
|
||||||
// TODO: refactor to be functional
|
|
||||||
private def fromParamsToXGBParamMap(): Map[String, Any] = {
|
|
||||||
require(xgboostParams.isEmpty, "fromParamsToXGBParamMap can only be called when" +
|
|
||||||
" XGBParamMap is empty, i.e. in the constructor this(String)")
|
|
||||||
val xgbParamMap = new mutable.HashMap[String, Any]()
|
val xgbParamMap = new mutable.HashMap[String, Any]()
|
||||||
for (param <- params) {
|
for (param <- params) {
|
||||||
xgbParamMap += param.name -> $(param)
|
xgbParamMap += param.name -> $(param)
|
||||||
}
|
}
|
||||||
xgboostParams = xgbParamMap.toMap
|
|
||||||
xgbParamMap.toMap
|
xgbParamMap.toMap
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,8 +110,9 @@ class XGBoostEstimator private[spark](
|
|||||||
LabeledPoint(label, feature)
|
LabeledPoint(label, feature)
|
||||||
}
|
}
|
||||||
transformSchema(trainingSet.schema, logging = true)
|
transformSchema(trainingSet.schema, logging = true)
|
||||||
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, $(round), $(nWorkers),
|
val trainedModel = XGBoost.trainWithRDD(instances, fromParamsToXGBParamMap,
|
||||||
$(customObj), $(customEval), $(useExternalMemory), $(missing)).setParent(this)
|
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||||
|
$(missing)).setParent(this)
|
||||||
val returnedModel = copyValues(trainedModel)
|
val returnedModel = copyValues(trainedModel)
|
||||||
if (XGBoost.isClassificationTask(xgboostParams)) {
|
if (XGBoost.isClassificationTask(xgboostParams)) {
|
||||||
val numClass = {
|
val numClass = {
|
||||||
@ -133,11 +128,6 @@ class XGBoostEstimator private[spark](
|
|||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||||
val est = defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||||
// we need to synchronize the params here instead of in the constructor
|
|
||||||
// because we cannot guarantee that params (default implementation) is initialized fully
|
|
||||||
// before the other params
|
|
||||||
est.fromParamsToXGBParamMap()
|
|
||||||
est
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -196,7 +196,7 @@ trait BoosterParams extends Params {
|
|||||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||||
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
||||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||||
scalePosWeight -> 1, sampleType -> "uniform", normalizeType -> "tree",
|
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -18,10 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
|
|
||||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||||
@ -47,23 +50,21 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
val (testItr, auxTestItr) =
|
val (testItr, auxTestItr) =
|
||||||
loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
|
loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
|
val round = 5
|
||||||
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
|
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
|
||||||
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
||||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
|
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
||||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||||
val testSetItr = auxTestItr.zipWithIndex.map {
|
val testSetItr = auxTestItr.zipWithIndex.map {
|
||||||
case (instance: LabeledPoint, id: Int) =>
|
case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label)
|
||||||
(id, instance.features, instance.label)
|
|
||||||
}
|
}
|
||||||
val trainingDF = buildTrainingDataframe()
|
val trainingDF = buildTrainingDataframe()
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
round = round, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
||||||
"id", "features", "label")
|
"id", "features", "label")
|
||||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||||
collect().map(row =>
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
||||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))
|
|
||||||
).toMap
|
|
||||||
assert(testDF.count() === predResultsFromDF.size)
|
assert(testDF.count() === predResultsFromDF.size)
|
||||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||||
// Spark
|
// Spark
|
||||||
@ -169,8 +170,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||||
// from spark to xgboost params
|
// from spark to xgboost params
|
||||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||||
assert(xgbEstimatorCopy.xgboostParams.get("eta").get.toString.toDouble === 1.0)
|
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
|
||||||
assert(xgbEstimatorCopy.xgboostParams.get("objective").get.toString === "binary:logistic")
|
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
|
||||||
}
|
}
|
||||||
|
|
||||||
test("eval_metric is configured correctly") {
|
test("eval_metric is configured correctly") {
|
||||||
@ -179,11 +180,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||||
val sparkParamMap = ParamMap.empty
|
val sparkParamMap = ParamMap.empty
|
||||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||||
assert(xgbEstimatorCopy.xgboostParams.get("eval_metric") === Some("error"))
|
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
|
||||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||||
assert(xgbEstimatorCopy1.xgboostParams.get("eval_metric") === Some("logloss"))
|
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user