[jvm-packages] Spark pipeline persistence (#1906)

[jvm-packages] Spark pipeline persistence
This commit is contained in:
geoHeil 2017-03-06 03:35:37 +01:00 committed by Nan Zhu
parent 5b54b9437c
commit cf6b173bd7
5 changed files with 263 additions and 8 deletions

View File

@ -17,10 +17,9 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.Booster import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap} import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
@ -43,7 +42,7 @@ class XGBoostClassificationModel private[spark](
/** /**
* whether to output raw margin * whether to output raw margin
*/ */
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ") final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")
setDefault(outputMargin, false) setDefault(outputMargin, false)

View File

@ -17,30 +17,32 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit} import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.types.{ArrayType, FloatType} import org.apache.spark.sql.types.{ArrayType, FloatType}
import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.{SparkContext, TaskContext}
import org.json4s.DefaultFormats
/** /**
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]] * the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
*/ */
abstract class XGBoostModel(protected var _booster: Booster) abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params { extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {
def setLabelCol(name: String): XGBoostModel = set(labelCol, name) def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off // scalastyle:off
final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction") final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")
setDefault(useExternalMemory, false) setDefault(useExternalMemory, false)
@ -295,4 +297,38 @@ abstract class XGBoostModel(protected var _booster: Booster)
} }
def booster: Booster = _booster def booster: Booster = _booster
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
}
object XGBoostModel extends MLReadable[XGBoostModel] {
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
override def load(path: String): XGBoostModel = super.load(path)
private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.saveModelAsHadoopFile(dataPath)
}
}
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
private val className = classOf[XGBoostModel].getName
override def load(path: String): XGBoostModel = {
implicit val sc = super.sparkSession.sparkContext
val dataPath = new Path(path, "data").toString
// not used / all data resides in platform independent xgboost model file
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
XGBoost.loadModelFromHadoopFile(dataPath)
}
}
} }

View File

@ -0,0 +1,86 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{ParamPair, Params}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{JObject, _}
// This originates from apache-spark DefaultPramsWriter copy paste
private[spark] object DefaultXGBoostParamsWriter {
/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
*
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val metadataPath = new Path(path, "metadata").toString
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
}
/**
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
*
* @see [[saveMetadata()]] for details on what this includes.
*/
def getMetadataToSave(
instance: Params,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map {
case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
}
val metadataJson: String = compact(render(metadata))
metadataJson
}
}

View File

@ -0,0 +1,33 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
// based on org.apache.spark.util copy /paste
private[spark] object Utils {
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
def getContextOrSparkClassLoader: ClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
// scalastyle:off classforname
/** Preferred alternative to Class.forName(className) */
def classForName(className: String): Class[_] = {
Class.forName(className, true, getContextOrSparkClassLoader)
// scalastyle:on classforname
}
}

View File

@ -0,0 +1,101 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import scala.concurrent.duration._
case class Foobar(TARGET: Int, bar: Double, baz: Double)
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
test("test sparks pipeline persistence of dataframe-based model") {
// maybe move to shared context, but requires session to import implicits.
// what about introducing https://github.com/holdenk/spark-testing-base ?
val conf: SparkConf = new SparkConf()
.setAppName("foo")
.setMaster("local[*]")
val spark: SparkSession = SparkSession
.builder()
.config(conf)
.getOrCreate()
import spark.implicits._
// maybe move to shared context, but requires session to import implicits
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
.toDS
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("TARGET")))
.setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
"tracker_conf" -> TrackerConf(1 minute, "scala")
))
.setFeaturesCol("features")
.setLabelCol("TARGET")
// separate
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
predModel.write.overwrite.save("test2xgbPipe")
val same2Model = XGBoostModel.load("test2xgbPipe")
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
val predParamMap = predModel.extractParamMap()
val same2ParamMap = same2Model.extractParamMap()
assert(predParamMap.get(predModel.useExternalMemory)
=== same2ParamMap.get(same2Model.useExternalMemory))
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
assert(predParamMap.get(predModel.predictionCol)
=== same2ParamMap.get(same2Model.predictionCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
// chained
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
predictionModel.write.overwrite.save("testxgbPipe")
val sameModel = PipelineModel.load("testxgbPipe")
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
assert(java.util.Arrays.equals(
predictionModelXGB.booster.toByteArray,
sameModelXGB.booster.toByteArray
))
val predictionModelXGBParamMap = predictionModel.extractParamMap()
val sameModelXGBParamMap = sameModel.extractParamMap()
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
}
}