[jvm-packages] fix the persistence of XGBoostEstimator (#2265)
* add back train method but mark as deprecated * fix scalastyle error * fix the persistence of XGBoostEstimator * test persistence of a complete pipeline * fix compilation issue * do not allow persist custom_eval and custom_obj * fix the failed tesl
This commit is contained in:
parent
6bf968efe6
commit
428453f7d6
5
.gitignore
vendored
5
.gitignore
vendored
@ -88,4 +88,7 @@ build_tests
|
|||||||
/tests/cpp/xgboost_test
|
/tests/cpp/xgboost_test
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
lib/
|
lib/
|
||||||
|
|
||||||
|
# spark
|
||||||
|
metastore_db
|
||||||
@ -18,21 +18,22 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.SparseVector
|
import org.apache.spark.ml.linalg.SparseVector
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS}
|
||||||
import scala.concurrent.duration.{Duration, MILLISECONDS}
|
|
||||||
|
|
||||||
object TrackerConf {
|
object TrackerConf {
|
||||||
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -40,13 +41,14 @@ object TrackerConf {
|
|||||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||||
* Set timeout length to zero to disable timeout.
|
* Set timeout length to zero to disable timeout.
|
||||||
* Use a finite, non-zero timeout value to prevent tracker from
|
* Use a finite, non-zero timeout value to prevent tracker from
|
||||||
* hanging indefinitely (supported by "scala" implementation only.)
|
* hanging indefinitely (in milliseconds)
|
||||||
|
* (supported by "scala" implementation only.)
|
||||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||||
* in Scala without Python components, and with full support of timeouts.
|
* in Scala without Python components, and with full support of timeouts.
|
||||||
* The Scala implementation is currently experimental, use at your own risk.
|
* The Scala implementation is currently experimental, use at your own risk.
|
||||||
*/
|
*/
|
||||||
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
|
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
@ -240,14 +242,7 @@ object XGBoost extends Serializable {
|
|||||||
case _ => new PyRabitTracker(nWorkers)
|
case _ => new PyRabitTracker(nWorkers)
|
||||||
}
|
}
|
||||||
|
|
||||||
val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
|
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||||
trackerConf.workerConnectionTimeout.toMillis
|
|
||||||
} else {
|
|
||||||
// 0 == Duration.Inf
|
|
||||||
0L
|
|
||||||
}
|
|
||||||
|
|
||||||
require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
|
|
||||||
tracker
|
tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,12 +18,14 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams}
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
import org.apache.spark.ml.Predictor
|
import org.apache.spark.ml.Predictor
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
import org.apache.spark.sql.{Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
@ -34,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row}
|
|||||||
class XGBoostEstimator private[spark](
|
class XGBoostEstimator private[spark](
|
||||||
override val uid: String, xgboostParams: Map[String, Any])
|
override val uid: String, xgboostParams: Map[String, Any])
|
||||||
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
|
||||||
with LearningTaskParams with GeneralParams with BoosterParams {
|
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
|
||||||
|
|
||||||
def this(xgboostParams: Map[String, Any]) =
|
def this(xgboostParams: Map[String, Any]) =
|
||||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
||||||
@ -129,4 +131,38 @@ class XGBoostEstimator private[spark](
|
|||||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||||
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
|
||||||
|
|
||||||
|
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
|
||||||
|
|
||||||
|
override def load(path: String): XGBoostEstimator = super.load(path)
|
||||||
|
|
||||||
|
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
|
||||||
|
extends MLWriter {
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
|
||||||
|
instance.fromParamsToXGBParamMap("custom_obj") == null,
|
||||||
|
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
|
||||||
|
" function for now")
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
implicit val sc = super.sparkSession.sparkContext
|
||||||
|
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
|
||||||
|
|
||||||
|
override def load(path: String): XGBoostEstimator = {
|
||||||
|
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||||
|
val cls = Utils.classForName(metadata.className)
|
||||||
|
val instance =
|
||||||
|
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
|
||||||
|
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
|
||||||
|
instance.asInstanceOf[XGBoostEstimator]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -324,14 +324,13 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
|
|||||||
implicit val format = DefaultFormats
|
implicit val format = DefaultFormats
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
implicit val sc = super.sparkSession.sparkContext
|
||||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||||
|
|
||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
instance.saveModelAsHadoopFile(dataPath)
|
instance.saveModelAsHadoopFile(dataPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
||||||
private val className = classOf[XGBoostModel].getName
|
|
||||||
override def load(path: String): XGBoostModel = {
|
override def load(path: String): XGBoostModel = {
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
implicit val sc = super.sparkSession.sparkContext
|
||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
@ -340,5 +339,4 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
|
|||||||
XGBoost.loadModelFromHadoopFile(dataPath)
|
XGBoost.loadModelFromHadoopFile(dataPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,106 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||||
|
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||||
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||||
|
|
||||||
|
class GroupDataParam(
|
||||||
|
parent: Params,
|
||||||
|
name: String,
|
||||||
|
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
|
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)
|
||||||
|
|
||||||
|
override def jsonEncode(value: Seq[Seq[Int]]): String = {
|
||||||
|
import org.json4s.jackson.Serialization
|
||||||
|
implicit val formats = Serialization.formats(NoTypeHints)
|
||||||
|
compact(render(Extraction.decompose(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def jsonDecode(json: String): Seq[Seq[Int]] = {
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
parse(json).extract[Seq[Seq[Int]]]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class CustomEvalParam(
|
||||||
|
parent: Params,
|
||||||
|
name: String,
|
||||||
|
doc: String) extends Param[EvalTrait](parent, name, doc) {
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
|
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)
|
||||||
|
|
||||||
|
override def jsonEncode(value: EvalTrait): String = {
|
||||||
|
import org.json4s.jackson.Serialization
|
||||||
|
implicit val formats = Serialization.formats(NoTypeHints)
|
||||||
|
compact(render(Extraction.decompose(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def jsonDecode(json: String): EvalTrait = {
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
parse(json).extract[EvalTrait]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class CustomObjParam(
|
||||||
|
parent: Params,
|
||||||
|
name: String,
|
||||||
|
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
|
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)
|
||||||
|
|
||||||
|
override def jsonEncode(value: ObjectiveTrait): String = {
|
||||||
|
import org.json4s.jackson.Serialization
|
||||||
|
implicit val formats = Serialization.formats(NoTypeHints)
|
||||||
|
compact(render(Extraction.decompose(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def jsonDecode(json: String): ObjectiveTrait = {
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
parse(json).extract[ObjectiveTrait]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class TrackerConfParam(
|
||||||
|
parent: Params,
|
||||||
|
name: String,
|
||||||
|
doc: String) extends Param[TrackerConf](parent, name, doc) {
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
|
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
|
||||||
|
|
||||||
|
override def jsonEncode(value: TrackerConf): String = {
|
||||||
|
import org.json4s.jackson.Serialization
|
||||||
|
implicit val formats = Serialization.formats(NoTypeHints)
|
||||||
|
compact(render(Extraction.decompose(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def jsonDecode(json: String): TrackerConf = {
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
val parsedValue = parse(json)
|
||||||
|
println(parsedValue.children)
|
||||||
|
parsedValue.extract[TrackerConf]
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.json4s.{DefaultFormats, JValue}
|
||||||
|
import org.json4s.JsonAST.JObject
|
||||||
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.param.Params
|
||||||
|
import org.apache.spark.ml.util.MLReader
|
||||||
|
|
||||||
|
// This originates from apache-spark DefaultPramsReader copy paste
|
||||||
|
private[spark] object DefaultXGBoostParamsReader {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* All info from metadata file.
|
||||||
|
*
|
||||||
|
* @param params paramMap, as a `JValue`
|
||||||
|
* @param metadata All metadata, including the other fields
|
||||||
|
* @param metadataJson Full metadata file String (for debugging)
|
||||||
|
*/
|
||||||
|
case class Metadata(
|
||||||
|
className: String,
|
||||||
|
uid: String,
|
||||||
|
timestamp: Long,
|
||||||
|
sparkVersion: String,
|
||||||
|
params: JValue,
|
||||||
|
metadata: JValue,
|
||||||
|
metadataJson: String) {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
|
||||||
|
* This can be useful for getting a Param value before an instance of `Params`
|
||||||
|
* is available.
|
||||||
|
*/
|
||||||
|
def getParamValue(paramName: String): JValue = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
params match {
|
||||||
|
case JObject(pairs) =>
|
||||||
|
val values = pairs.filter { case (pName, jsonValue) =>
|
||||||
|
pName == paramName
|
||||||
|
}.map(_._2)
|
||||||
|
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
|
||||||
|
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
|
||||||
|
values.head
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
s"Cannot recognize JSON metadata: $metadataJson.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
|
||||||
|
*
|
||||||
|
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||||
|
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||||
|
*/
|
||||||
|
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||||
|
val metadataPath = new Path(path, "metadata").toString
|
||||||
|
val metadataStr = sc.textFile(metadataPath, 1).first()
|
||||||
|
parseMetadata(metadataStr, expectedClassName)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
|
||||||
|
* This is a helper function for [[loadMetadata()]].
|
||||||
|
*
|
||||||
|
* @param metadataStr JSON string of metadata
|
||||||
|
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||||
|
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||||
|
*/
|
||||||
|
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
|
||||||
|
val metadata = parse(metadataStr)
|
||||||
|
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
val className = (metadata \ "class").extract[String]
|
||||||
|
val uid = (metadata \ "uid").extract[String]
|
||||||
|
val timestamp = (metadata \ "timestamp").extract[Long]
|
||||||
|
val sparkVersion = (metadata \ "sparkVersion").extract[String]
|
||||||
|
val params = metadata \ "paramMap"
|
||||||
|
if (expectedClassName.nonEmpty) {
|
||||||
|
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
|
||||||
|
s" $expectedClassName but found class name $className")
|
||||||
|
}
|
||||||
|
|
||||||
|
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract Params from metadata, and set them in the instance.
|
||||||
|
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||||
|
* TODO: Move to [[Metadata]] method
|
||||||
|
*/
|
||||||
|
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
metadata.params match {
|
||||||
|
case JObject(pairs) =>
|
||||||
|
pairs.foreach { case (paramName, jsonValue) =>
|
||||||
|
val param = instance.getParam(paramName)
|
||||||
|
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||||
|
instance.set(param, value)
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a `Params` instance from the given path, and return it.
|
||||||
|
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
|
||||||
|
*/
|
||||||
|
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
|
||||||
|
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||||
|
val cls = Utils.classForName(metadata.className)
|
||||||
|
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -46,6 +46,7 @@ private[spark] object DefaultXGBoostParamsWriter {
|
|||||||
sc: SparkContext,
|
sc: SparkContext,
|
||||||
extraMetadata: Option[JObject] = None,
|
extraMetadata: Option[JObject] = None,
|
||||||
paramMap: Option[JValue] = None): Unit = {
|
paramMap: Option[JValue] = None): Unit = {
|
||||||
|
|
||||||
val metadataPath = new Path(path, "metadata").toString
|
val metadataPath = new Path(path, "metadata").toString
|
||||||
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
||||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||||
@ -65,7 +66,9 @@ private[spark] object DefaultXGBoostParamsWriter {
|
|||||||
val uid = instance.uid
|
val uid = instance.uid
|
||||||
val cls = instance.getClass.getName
|
val cls = instance.getClass.getName
|
||||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||||
val jsonParams = paramMap.getOrElse(render(params.map {
|
val jsonParams = paramMap.getOrElse(render(params.filter{
|
||||||
|
case ParamPair(p, _) => p != null
|
||||||
|
}.map {
|
||||||
case ParamPair(p, v) =>
|
case ParamPair(p, v) =>
|
||||||
p.name -> parse(p.jsonEncode(v))
|
p.name -> parse(p.jsonEncode(v))
|
||||||
}.toList))
|
}.toList))
|
||||||
|
|||||||
@ -20,8 +20,6 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
|||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
import scala.concurrent.duration.{Duration, NANOSECONDS}
|
|
||||||
|
|
||||||
trait GeneralParams extends Params {
|
trait GeneralParams extends Params {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -58,13 +56,13 @@ trait GeneralParams extends Params {
|
|||||||
/**
|
/**
|
||||||
* customized objective function provided by user. default: null
|
* customized objective function provided by user. default: null
|
||||||
*/
|
*/
|
||||||
val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " +
|
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
|
||||||
"provided by user")
|
"provided by user")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* customized evaluation function provided by user. default: null
|
* customized evaluation function provided by user. default: null
|
||||||
*/
|
*/
|
||||||
val customEval = new Param[EvalTrait](this, "custom_eval", "customized evaluation function " +
|
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
|
||||||
"provided by user")
|
"provided by user")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -99,7 +97,7 @@ trait GeneralParams extends Params {
|
|||||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||||
* Ignored if the tracker implementation is "python".
|
* Ignored if the tracker implementation is "python".
|
||||||
*/
|
*/
|
||||||
val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations")
|
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||||
|
|
||||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||||
useExternalMemory -> false, silent -> 0,
|
useExternalMemory -> false, silent -> 0,
|
||||||
|
|||||||
@ -57,7 +57,7 @@ trait LearningTaskParams extends Params {
|
|||||||
* group data specify each group sizes for ranking task. To correspond to partition of
|
* group data specify each group sizes for ranking task. To correspond to partition of
|
||||||
* training data, it is nested.
|
* training data, it is nested.
|
||||||
*/
|
*/
|
||||||
val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" +
|
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
|
||||||
" for ranking task. To correspond to partition of training data, it is nested.")
|
" for ranking task. To correspond to partition of training data, it is nested.")
|
||||||
|
|
||||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)
|
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)
|
||||||
|
|||||||
@ -18,17 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.Pipeline
|
import org.apache.spark.ml.Pipeline
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||||
|
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
|
|
||||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||||
|
|||||||
@ -110,7 +110,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"tracker_conf" -> TrackerConf(1 minute, "scala")).toMap
|
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
nWorkers = numWorkers)
|
nWorkers = numWorkers)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
|||||||
@ -18,67 +18,84 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.{File, FileNotFoundException}
|
import java.io.{File, FileNotFoundException}
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
import org.apache.spark.ml.feature._
|
import org.apache.spark.ml.feature._
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
import scala.concurrent.duration._
|
|
||||||
|
|
||||||
case class Foobar(TARGET: Int, bar: Double, baz: Double)
|
|
||||||
|
|
||||||
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
||||||
|
|
||||||
override def afterAll(): Unit = {
|
override def afterAll(): Unit = {
|
||||||
super.afterAll()
|
super.afterAll()
|
||||||
delete(new File("./testxgbPipe"))
|
delete(new File("./testxgbPipe"))
|
||||||
delete(new File("./test2xgbPipe"))
|
delete(new File("./testxgbEst"))
|
||||||
|
delete(new File("./testxgbModel"))
|
||||||
|
delete(new File("./test2xgbModel"))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def delete(f: File) {
|
private def delete(f: File) {
|
||||||
if (f.isDirectory()) {
|
if (f.exists()) {
|
||||||
for (c <- f.listFiles()) {
|
if (f.isDirectory()) {
|
||||||
delete(c)
|
for (c <- f.listFiles()) {
|
||||||
|
delete(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!f.delete()) {
|
||||||
|
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (!f.delete()) {
|
|
||||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test sparks pipeline persistence of dataframe-based model") {
|
test("test persistence of XGBoostEstimator") {
|
||||||
// maybe move to shared context, but requires session to import implicits.
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
// what about introducing https://github.com/holdenk/spark-testing-base ?
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
val conf: SparkConf = new SparkConf()
|
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||||
.setAppName("foo")
|
xgbEstimator.write.overwrite().save("./testxgbEst")
|
||||||
.setMaster("local[*]")
|
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
|
||||||
|
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
|
||||||
|
paramMap.foreach {
|
||||||
|
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
val spark: SparkSession = SparkSession
|
test("test persistence of a complete pipeline") {
|
||||||
.builder()
|
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||||
.config(conf)
|
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||||
.getOrCreate()
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||||
|
val r = new Random(0)
|
||||||
|
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
|
||||||
|
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||||
|
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
|
||||||
|
pipeline.write.overwrite().save("testxgbPipe")
|
||||||
|
val loadedPipeline = Pipeline.read.load("testxgbPipe")
|
||||||
|
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
|
||||||
|
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
|
||||||
|
paramMap.foreach {
|
||||||
|
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
import spark.implicits._
|
test("test persistence of XGBoostModel") {
|
||||||
|
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||||
|
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||||
|
val r = new Random(0)
|
||||||
// maybe move to shared context, but requires session to import implicits
|
// maybe move to shared context, but requires session to import implicits
|
||||||
|
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
|
||||||
val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
|
toDF("feature", "label")
|
||||||
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
|
|
||||||
.toDS
|
|
||||||
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setInputCols(df.columns
|
.setInputCols(df.columns
|
||||||
.filter(!_.contains("TARGET")))
|
.filter(!_.contains("label")))
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
|
|
||||||
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
|
val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
|
||||||
"tracker_conf" -> TrackerConf(1 minute, "scala")
|
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
|
||||||
))
|
)).setFeaturesCol("features").setLabelCol("label")
|
||||||
.setFeaturesCol("features")
|
|
||||||
.setLabelCol("TARGET")
|
|
||||||
|
|
||||||
// separate
|
// separate
|
||||||
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||||
predModel.write.overwrite.save("test2xgbPipe")
|
predModel.write.overwrite.save("test2xgbModel")
|
||||||
val same2Model = XGBoostModel.load("test2xgbPipe")
|
val same2Model = XGBoostModel.load("test2xgbModel")
|
||||||
|
|
||||||
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||||
val predParamMap = predModel.extractParamMap()
|
val predParamMap = predModel.extractParamMap()
|
||||||
@ -93,8 +110,8 @@ class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
|
|||||||
|
|
||||||
// chained
|
// chained
|
||||||
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||||
predictionModel.write.overwrite.save("testxgbPipe")
|
predictionModel.write.overwrite.save("testxgbModel")
|
||||||
val sameModel = PipelineModel.load("testxgbPipe")
|
val sameModel = PipelineModel.load("testxgbModel")
|
||||||
|
|
||||||
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||||
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user