[jvm-packages] fix the persistence of XGBoostEstimator (#2265)
* add back train method but mark as deprecated * fix scalastyle error * fix the persistence of XGBoostEstimator * test persistence of a complete pipeline * fix compilation issue * do not allow persist custom_eval and custom_obj * fix the failed tesl
This commit is contained in:
parent
6bf968efe6
commit
428453f7d6
3
.gitignore
vendored
3
.gitignore
vendored
@ -89,3 +89,6 @@ build_tests
|
||||
|
||||
.DS_Store
|
||||
lib/
|
||||
|
||||
# spark
|
||||
metastore_db
|
||||
@ -18,21 +18,22 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.SparseVector
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
import scala.concurrent.duration.{Duration, MILLISECONDS}
|
||||
import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS}
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
}
|
||||
|
||||
/**
|
||||
@ -40,13 +41,14 @@ object TrackerConf {
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (supported by "scala" implementation only.)
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
@ -240,14 +242,7 @@ object XGBoost extends Serializable {
|
||||
case _ => new PyRabitTracker(nWorkers)
|
||||
}
|
||||
|
||||
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
|
||||
trackerConf.workerConnectionTimeout.toMillis
|
||||
} else {
|
||||
// 0 == Duration.Inf
|
||||
0L
|
||||
}
|
||||
|
||||
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
|
||||
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
|
||||
@ -18,12 +18,14 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import org.apache.spark.ml.Predictor
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
@ -34,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row}
|
||||
class XGBoostEstimator private[spark](
|
||||
override val uid: String, xgboostParams: Map[String, Any])
|
||||
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
||||
with LearningTaskParams with GeneralParams with BoosterParams {
|
||||
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) =
|
||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
||||
@ -129,4 +131,38 @@ class XGBoostEstimator private[spark](
|
||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
|
||||
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
|
||||
|
||||
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
|
||||
|
||||
override def load(path: String): XGBoostEstimator = super.load(path)
|
||||
|
||||
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
|
||||
extends MLWriter {
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
|
||||
instance.fromParamsToXGBParamMap("custom_obj") == null,
|
||||
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
|
||||
" function for now")
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
|
||||
|
||||
override def load(path: String): XGBoostEstimator = {
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
val instance =
|
||||
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
|
||||
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
|
||||
instance.asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -324,14 +324,13 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
instance.saveModelAsHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
||||
private val className = classOf[XGBoostModel].getName
|
||||
|
||||
override def load(path: String): XGBoostModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
val dataPath = new Path(path, "data").toString
|
||||
@ -340,5 +339,4 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||
XGBoost.loadModelFromHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,106 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||
|
||||
class GroupDataParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: Seq[Seq[Int]]): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): Seq[Seq[Int]] = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[Seq[Seq[Int]]]
|
||||
}
|
||||
}
|
||||
|
||||
class CustomEvalParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[EvalTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: EvalTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): EvalTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[EvalTrait]
|
||||
}
|
||||
}
|
||||
|
||||
class CustomObjParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: ObjectiveTrait): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): ObjectiveTrait = {
|
||||
implicit val formats = DefaultFormats
|
||||
parse(json).extract[ObjectiveTrait]
|
||||
}
|
||||
}
|
||||
|
||||
class TrackerConfParam(
|
||||
parent: Params,
|
||||
name: String,
|
||||
doc: String) extends Param[TrackerConf](parent, name, doc) {
|
||||
|
||||
/** Creates a param pair with the given value (for Java). */
|
||||
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
|
||||
|
||||
override def jsonEncode(value: TrackerConf): String = {
|
||||
import org.json4s.jackson.Serialization
|
||||
implicit val formats = Serialization.formats(NoTypeHints)
|
||||
compact(render(Extraction.decompose(value)))
|
||||
}
|
||||
|
||||
override def jsonDecode(json: String): TrackerConf = {
|
||||
implicit val formats = DefaultFormats
|
||||
val parsedValue = parse(json)
|
||||
println(parsedValue.children)
|
||||
parsedValue.extract[TrackerConf]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s.{DefaultFormats, JValue}
|
||||
import org.json4s.JsonAST.JObject
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.util.MLReader
|
||||
|
||||
// This originates from apache-spark DefaultPramsReader copy paste
|
||||
private[spark] object DefaultXGBoostParamsReader {
|
||||
|
||||
/**
|
||||
* All info from metadata file.
|
||||
*
|
||||
* @param params paramMap, as a `JValue`
|
||||
* @param metadata All metadata, including the other fields
|
||||
* @param metadataJson Full metadata file String (for debugging)
|
||||
*/
|
||||
case class Metadata(
|
||||
className: String,
|
||||
uid: String,
|
||||
timestamp: Long,
|
||||
sparkVersion: String,
|
||||
params: JValue,
|
||||
metadata: JValue,
|
||||
metadataJson: String) {
|
||||
|
||||
/**
|
||||
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
|
||||
* This can be useful for getting a Param value before an instance of `Params`
|
||||
* is available.
|
||||
*/
|
||||
def getParamValue(paramName: String): JValue = {
|
||||
implicit val format = DefaultFormats
|
||||
params match {
|
||||
case JObject(pairs) =>
|
||||
val values = pairs.filter { case (pName, jsonValue) =>
|
||||
pName == paramName
|
||||
}.map(_._2)
|
||||
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
|
||||
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
|
||||
values.head
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: $metadataJson.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
|
||||
*
|
||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||
*/
|
||||
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataStr = sc.textFile(metadataPath, 1).first()
|
||||
parseMetadata(metadataStr, expectedClassName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
|
||||
* This is a helper function for [[loadMetadata()]].
|
||||
*
|
||||
* @param metadataStr JSON string of metadata
|
||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||
*/
|
||||
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
|
||||
val metadata = parse(metadataStr)
|
||||
|
||||
implicit val format = DefaultFormats
|
||||
val className = (metadata \ "class").extract[String]
|
||||
val uid = (metadata \ "uid").extract[String]
|
||||
val timestamp = (metadata \ "timestamp").extract[Long]
|
||||
val sparkVersion = (metadata \ "sparkVersion").extract[String]
|
||||
val params = metadata \ "paramMap"
|
||||
if (expectedClassName.nonEmpty) {
|
||||
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
|
||||
s" $expectedClassName but found class name $className")
|
||||
}
|
||||
|
||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Params from metadata, and set them in the instance.
|
||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||
* TODO: Move to [[Metadata]] method
|
||||
*/
|
||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
||||
implicit val format = DefaultFormats
|
||||
metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
pairs.foreach { case (paramName, jsonValue) =>
|
||||
val param = instance.getParam(paramName)
|
||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||
instance.set(param, value)
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a `Params` instance from the given path, and return it.
|
||||
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
|
||||
*/
|
||||
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ private[spark] object DefaultXGBoostParamsWriter {
|
||||
sc: SparkContext,
|
||||
extraMetadata: Option[JObject] = None,
|
||||
paramMap: Option[JValue] = None): Unit = {
|
||||
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||
@ -65,7 +66,9 @@ private[spark] object DefaultXGBoostParamsWriter {
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||
val jsonParams = paramMap.getOrElse(render(params.map {
|
||||
val jsonParams = paramMap.getOrElse(render(params.filter{
|
||||
case ParamPair(p, _) => p != null
|
||||
}.map {
|
||||
case ParamPair(p, v) =>
|
||||
p.name -> parse(p.jsonEncode(v))
|
||||
}.toList))
|
||||
|
||||
@ -20,8 +20,6 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
import scala.concurrent.duration.{Duration, NANOSECONDS}
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
@ -58,13 +56,13 @@ trait GeneralParams extends Params {
|
||||
/**
|
||||
* customized objective function provided by user. default: null
|
||||
*/
|
||||
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
|
||||
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
|
||||
"provided by user")
|
||||
|
||||
/**
|
||||
* customized evaluation function provided by user. default: null
|
||||
*/
|
||||
val customEval = new Param[EvalTrait](this, "custom_eval", "customized evaluation function " +
|
||||
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
|
||||
"provided by user")
|
||||
|
||||
/**
|
||||
@ -99,7 +97,7 @@ trait GeneralParams extends Params {
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*/
|
||||
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
|
||||
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
|
||||
@ -57,7 +57,7 @@ trait LearningTaskParams extends Params {
|
||||
* group data specify each group sizes for ranking task. To correspond to partition of
|
||||
* training data, it is nested.
|
||||
*/
|
||||
val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" +
|
||||
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
|
||||
" for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)
|
||||
|
||||
@ -18,17 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
|
||||
@ -110,7 +110,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
|
||||
@ -18,67 +18,84 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import scala.concurrent.duration._
|
||||
|
||||
case class Foobar(TARGET: Int, bar: Double, baz: Double)
|
||||
|
||||
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
delete(new File("./testxgbPipe"))
|
||||
delete(new File("./test2xgbPipe"))
|
||||
delete(new File("./testxgbEst"))
|
||||
delete(new File("./testxgbModel"))
|
||||
delete(new File("./test2xgbModel"))
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
if (f.exists()) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
|
||||
test("test sparks pipeline persistence of dataframe-based model") {
|
||||
// maybe move to shared context, but requires session to import implicits.
|
||||
// what about introducing https://github.com/holdenk/spark-testing-base ?
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setAppName("foo")
|
||||
.setMaster("local[*]")
|
||||
test("test persistence of XGBoostEstimator") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
xgbEstimator.write.overwrite().save("./testxgbEst")
|
||||
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
|
||||
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
val spark: SparkSession = SparkSession
|
||||
.builder()
|
||||
.config(conf)
|
||||
.getOrCreate()
|
||||
test("test persistence of a complete pipeline") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val r = new Random(0)
|
||||
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
|
||||
pipeline.write.overwrite().save("testxgbPipe")
|
||||
val loadedPipeline = Pipeline.read.load("testxgbPipe")
|
||||
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
|
||||
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
import spark.implicits._
|
||||
test("test persistence of XGBoostModel") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
|
||||
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
|
||||
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
|
||||
.toDS
|
||||
|
||||
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(df.columns
|
||||
.filter(!_.contains("TARGET")))
|
||||
.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
|
||||
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
|
||||
"tracker_conf" -> TrackerConf(1 minute, "scala")
|
||||
))
|
||||
.setFeaturesCol("features")
|
||||
.setLabelCol("TARGET")
|
||||
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
|
||||
)).setFeaturesCol("features").setLabelCol("label")
|
||||
// separate
|
||||
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||
predModel.write.overwrite.save("test2xgbPipe")
|
||||
val same2Model = XGBoostModel.load("test2xgbPipe")
|
||||
predModel.write.overwrite.save("test2xgbModel")
|
||||
val same2Model = XGBoostModel.load("test2xgbModel")
|
||||
|
||||
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||
val predParamMap = predModel.extractParamMap()
|
||||
@ -93,8 +110,8 @@ class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||
|
||||
// chained
|
||||
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||
predictionModel.write.overwrite.save("testxgbPipe")
|
||||
val sameModel = PipelineModel.load("testxgbPipe")
|
||||
predictionModel.write.overwrite.save("testxgbModel")
|
||||
val sameModel = PipelineModel.load("testxgbModel")
|
||||
|
||||
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user