[jvm-packages] Spark pipeline persistence (#1906)
[jvm-packages] Spark pipeline persistence
This commit is contained in:
parent
5b54b9437c
commit
cf6b173bd7
@ -17,10 +17,9 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||||
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap}
|
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
@ -43,7 +42,7 @@ class XGBoostClassificationModel private[spark](
|
|||||||
/**
|
/**
|
||||||
* whether to output raw margin
|
* whether to output raw margin
|
||||||
*/
|
*/
|
||||||
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
|
final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")
|
||||||
|
|
||||||
setDefault(outputMargin, false)
|
setDefault(outputMargin, false)
|
||||||
|
|
||||||
|
|||||||
@ -17,30 +17,32 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
|
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||||
import org.apache.spark.ml.PredictionModel
|
import org.apache.spark.ml.PredictionModel
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||||
import org.apache.spark.ml.param.{Param, Params}
|
import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
|
||||||
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||||
*/
|
*/
|
||||||
abstract class XGBoostModel(protected var _booster: Booster)
|
abstract class XGBoostModel(protected var _booster: Booster)
|
||||||
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
|
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
|
||||||
|
|
||||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||||
|
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
|
|
||||||
final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction")
|
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
|
||||||
|
|
||||||
setDefault(useExternalMemory, false)
|
setDefault(useExternalMemory, false)
|
||||||
|
|
||||||
@ -295,4 +297,38 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
}
|
}
|
||||||
|
|
||||||
def booster: Booster = _booster
|
def booster: Booster = _booster
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
|
||||||
|
|
||||||
|
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||||
|
|
||||||
|
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
|
||||||
|
|
||||||
|
override def load(path: String): XGBoostModel = super.load(path)
|
||||||
|
|
||||||
|
private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
implicit val sc = super.sparkSession.sparkContext
|
||||||
|
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||||
|
|
||||||
|
val dataPath = new Path(path, "data").toString
|
||||||
|
instance.saveModelAsHadoopFile(dataPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
||||||
|
private val className = classOf[XGBoostModel].getName
|
||||||
|
override def load(path: String): XGBoostModel = {
|
||||||
|
implicit val sc = super.sparkSession.sparkContext
|
||||||
|
val dataPath = new Path(path, "data").toString
|
||||||
|
// not used / all data resides in platform independent xgboost model file
|
||||||
|
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||||
|
XGBoost.loadModelFromHadoopFile(dataPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,86 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.param.{ParamPair, Params}
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
import org.json4s.{JObject, _}
|
||||||
|
|
||||||
|
// This originates from apache-spark DefaultPramsWriter copy paste
|
||||||
|
private[spark] object DefaultXGBoostParamsWriter {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Saves metadata + Params to: path + "/metadata"
|
||||||
|
* - class
|
||||||
|
* - timestamp
|
||||||
|
* - sparkVersion
|
||||||
|
* - uid
|
||||||
|
* - paramMap
|
||||||
|
* - (optionally, extra metadata)
|
||||||
|
*
|
||||||
|
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
|
||||||
|
* @param paramMap If given, this is saved in the "paramMap" field.
|
||||||
|
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
|
||||||
|
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
||||||
|
*/
|
||||||
|
def saveMetadata(
|
||||||
|
instance: Params,
|
||||||
|
path: String,
|
||||||
|
sc: SparkContext,
|
||||||
|
extraMetadata: Option[JObject] = None,
|
||||||
|
paramMap: Option[JValue] = None): Unit = {
|
||||||
|
val metadataPath = new Path(path, "metadata").toString
|
||||||
|
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
||||||
|
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper for [[saveMetadata()]] which extracts the JSON to save.
|
||||||
|
* This is useful for ensemble models which need to save metadata for many sub-models.
|
||||||
|
*
|
||||||
|
* @see [[saveMetadata()]] for details on what this includes.
|
||||||
|
*/
|
||||||
|
def getMetadataToSave(
|
||||||
|
instance: Params,
|
||||||
|
sc: SparkContext,
|
||||||
|
extraMetadata: Option[JObject] = None,
|
||||||
|
paramMap: Option[JValue] = None): String = {
|
||||||
|
val uid = instance.uid
|
||||||
|
val cls = instance.getClass.getName
|
||||||
|
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||||
|
val jsonParams = paramMap.getOrElse(render(params.map {
|
||||||
|
case ParamPair(p, v) =>
|
||||||
|
p.name -> parse(p.jsonEncode(v))
|
||||||
|
}.toList))
|
||||||
|
val basicMetadata = ("class" -> cls) ~
|
||||||
|
("timestamp" -> System.currentTimeMillis()) ~
|
||||||
|
("sparkVersion" -> sc.version) ~
|
||||||
|
("uid" -> uid) ~
|
||||||
|
("paramMap" -> jsonParams)
|
||||||
|
val metadata = extraMetadata match {
|
||||||
|
case Some(jObject) =>
|
||||||
|
basicMetadata ~ jObject
|
||||||
|
case None =>
|
||||||
|
basicMetadata
|
||||||
|
}
|
||||||
|
val metadataJson: String = compact(render(metadata))
|
||||||
|
metadataJson
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
// based on org.apache.spark.util copy /paste
|
||||||
|
private[spark] object Utils {
|
||||||
|
|
||||||
|
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
||||||
|
|
||||||
|
def getContextOrSparkClassLoader: ClassLoader =
|
||||||
|
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
|
||||||
|
|
||||||
|
// scalastyle:off classforname
|
||||||
|
/** Preferred alternative to Class.forName(className) */
|
||||||
|
def classForName(className: String): Class[_] = {
|
||||||
|
Class.forName(className, true, getContextOrSparkClassLoader)
|
||||||
|
// scalastyle:on classforname
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.apache.spark.SparkConf
|
||||||
|
import org.apache.spark.ml.feature._
|
||||||
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
|
||||||
|
case class Foobar(TARGET: Int, bar: Double, baz: Double)
|
||||||
|
|
||||||
|
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||||
|
test("test sparks pipeline persistence of dataframe-based model") {
|
||||||
|
// maybe move to shared context, but requires session to import implicits.
|
||||||
|
// what about introducing https://github.com/holdenk/spark-testing-base ?
|
||||||
|
val conf: SparkConf = new SparkConf()
|
||||||
|
.setAppName("foo")
|
||||||
|
.setMaster("local[*]")
|
||||||
|
|
||||||
|
val spark: SparkSession = SparkSession
|
||||||
|
.builder()
|
||||||
|
.config(conf)
|
||||||
|
.getOrCreate()
|
||||||
|
|
||||||
|
import spark.implicits._
|
||||||
|
// maybe move to shared context, but requires session to import implicits
|
||||||
|
|
||||||
|
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
|
||||||
|
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
|
||||||
|
.toDS
|
||||||
|
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(df.columns
|
||||||
|
.filter(!_.contains("TARGET")))
|
||||||
|
.setOutputCol("features")
|
||||||
|
|
||||||
|
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
|
||||||
|
"tracker_conf" -> TrackerConf(1 minute, "scala")
|
||||||
|
))
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setLabelCol("TARGET")
|
||||||
|
|
||||||
|
// separate
|
||||||
|
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||||
|
predModel.write.overwrite.save("test2xgbPipe")
|
||||||
|
val same2Model = XGBoostModel.load("test2xgbPipe")
|
||||||
|
|
||||||
|
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||||
|
val predParamMap = predModel.extractParamMap()
|
||||||
|
val same2ParamMap = same2Model.extractParamMap()
|
||||||
|
assert(predParamMap.get(predModel.useExternalMemory)
|
||||||
|
=== same2ParamMap.get(same2Model.useExternalMemory))
|
||||||
|
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
|
||||||
|
assert(predParamMap.get(predModel.predictionCol)
|
||||||
|
=== same2ParamMap.get(same2Model.predictionCol))
|
||||||
|
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||||
|
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||||
|
|
||||||
|
// chained
|
||||||
|
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||||
|
predictionModel.write.overwrite.save("testxgbPipe")
|
||||||
|
val sameModel = PipelineModel.load("testxgbPipe")
|
||||||
|
|
||||||
|
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||||
|
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||||
|
|
||||||
|
assert(java.util.Arrays.equals(
|
||||||
|
predictionModelXGB.booster.toByteArray,
|
||||||
|
sameModelXGB.booster.toByteArray
|
||||||
|
))
|
||||||
|
val predictionModelXGBParamMap = predictionModel.extractParamMap()
|
||||||
|
val sameModelXGBParamMap = sameModel.extractParamMap()
|
||||||
|
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
|
||||||
|
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
|
||||||
|
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
|
||||||
|
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
|
||||||
|
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
|
||||||
|
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
|
||||||
|
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||||
|
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||||
|
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||||
|
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user