[jvm-packages] make XGBoostModel hold BoosterParams as well (#2214)
* add back train method but mark as deprecated * fix scalastyle error * make XGBoostModel hold BoosterParams as well
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
@@ -25,11 +25,10 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
@@ -272,4 +271,16 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
}
|
||||
|
||||
test("params of estimator and produced model are coordinated correctly") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingSet = loadCSVPoints(getClass.getResource("/dermatology.data").getFile)
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
import spark.implicits._
|
||||
val model =
|
||||
XGBoost.trainWithDataFrame(trainingSet.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||
assert(model.get[Double](model.eta).get == 0.1)
|
||||
assert(model.get[Int](model.maxDepth).get == 6)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user