[jvm-packages] use ML's para system to build the passed-in params to XGBoost (#2043)
* add back train method but mark as deprecated * fix scalastyle error * use ML's para system to build the passed-in params to XGBoost * clean
This commit is contained in:
@@ -18,10 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
@@ -47,23 +50,21 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
val (testItr, auxTestItr) =
|
||||
loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
|
||||
import DataUtils._
|
||||
val round = 5
|
||||
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
|
||||
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
|
||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||
val testSetItr = auxTestItr.zipWithIndex.map {
|
||||
case (instance: LabeledPoint, id: Int) =>
|
||||
(id, instance.features, instance.label)
|
||||
case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label)
|
||||
}
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||
round = round, nWorkers = numWorkers, useExternalMemory = false)
|
||||
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
||||
"id", "features", "label")
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row =>
|
||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))
|
||||
).toMap
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||
// Spark
|
||||
@@ -169,8 +170,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eta").get.toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("objective").get.toString === "binary:logistic")
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
|
||||
}
|
||||
|
||||
test("eval_metric is configured correctly") {
|
||||
@@ -179,11 +180,8 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||
val sparkParamMap = ParamMap.empty
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eval_metric") === Some("error"))
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.xgboostParams.get("eval_metric") === Some("logloss"))
|
||||
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user