[jvm-packages] Parameter tuning tool for XGBoost (#1664)
This commit is contained in:
@@ -16,16 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
@@ -66,13 +62,15 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
"id", "features", "label")
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row =>
|
||||
(row.getAs[Int]("id"), row.getAs[mutable.WrappedArray[Float]]("probabilities"))
|
||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))
|
||||
).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||
// Spark
|
||||
for (i <- predResultFromSeq.indices) {
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
|
||||
for (j <- predResultFromSeq(i).indices) {
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
|
||||
}
|
||||
}
|
||||
cleanExternalCache("XGBoostDFSuite")
|
||||
@@ -160,4 +158,29 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
cleanExternalCache("XGBoostDFSuite")
|
||||
}
|
||||
|
||||
test("xgboost and spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
|
||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eta").get.toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("objective").get.toString === "binary:logistic")
|
||||
}
|
||||
|
||||
test("eval_metric is configured correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||
val sparkParamMap = ParamMap.empty
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||
assert(xgbEstimatorCopy.xgboostParams.get("eval_metric") === Some("error"))
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.xgboostParams.get("eval_metric") === Some("logloss"))
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user