[jvm-packages] refactor xgboost read/write (#7956)
1. Removed the duplicated Default XGBoost read/write which is copied from spark 2.3.x 2. Put some utils into util package
This commit is contained in:
@@ -24,8 +24,9 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
@@ -35,10 +36,8 @@ import org.apache.commons.logging.LogFactory
|
|||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
@@ -272,7 +271,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||||
|
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val cacheInfo = {
|
val cacheInfo = {
|
||||||
if (useExternalMemory) {
|
if (useExternalMemory) {
|
||||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
|
||||||
|
|||||||
@@ -23,15 +23,13 @@ import org.apache.hadoop.fs.Path
|
|||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.ml.classification._
|
import org.apache.spark.ml.classification._
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.ml.param._
|
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.json4s.DefaultFormats
|
|
||||||
import scala.collection.{Iterator, mutable}
|
import scala.collection.{Iterator, mutable}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
class XGBoostClassifier (
|
class XGBoostClassifier (
|
||||||
@@ -274,7 +272,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val dm = new DMatrix(processMissingValues(
|
val dm = new DMatrix(processMissingValues(
|
||||||
Iterator(features.asXGB),
|
Iterator(features.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
@@ -469,10 +467,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
override protected def saveImpl(path: String): Unit = {
|
||||||
// Save metadata and Params
|
// Save metadata and Params
|
||||||
implicit val format = DefaultFormats
|
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
|
||||||
|
|
||||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||||
|
|
||||||
// Save model data
|
// Save model data
|
||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||||
@@ -495,18 +491,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
|
|||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||||
|
val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
|
||||||
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
|
|
||||||
// or the new xgboost compatible.
|
|
||||||
val numClasses = metadata.xgboostVersion.map { _ =>
|
|
||||||
implicit val format = DefaultFormats
|
|
||||||
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
|
|
||||||
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
|
|
||||||
// or else, XGBoost will throw exception.
|
|
||||||
// So it's safe to get numClass from meta data.
|
|
||||||
(metadata.params \ "numClass").extractOpt[Int].getOrElse(2)
|
|
||||||
}.getOrElse(dataInStream.readInt())
|
|
||||||
|
|
||||||
val booster = SXGBoost.loadModel(dataInStream)
|
val booster = SXGBoost.loadModel(dataInStream)
|
||||||
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
||||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.{Iterator, mutable}
|
import scala.collection.{Iterator, mutable}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
@@ -30,9 +29,9 @@ import org.apache.spark.ml._
|
|||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.json4s.DefaultFormats
|
|
||||||
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
class XGBoostRegressor (
|
class XGBoostRegressor (
|
||||||
@@ -260,7 +259,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val dm = new DMatrix(processMissingValues(
|
val dm = new DMatrix(processMissingValues(
|
||||||
Iterator(features.asXGB),
|
Iterator(features.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
@@ -384,8 +383,6 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
|||||||
|
|
||||||
override protected def saveImpl(path: String): Unit = {
|
override protected def saveImpl(path: String): Unit = {
|
||||||
// Save metadata and Params
|
// Save metadata and Params
|
||||||
implicit val format = DefaultFormats
|
|
||||||
implicit val sc = super.sparkSession.sparkContext
|
|
||||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||||
// Save model data
|
// Save model data
|
||||||
val dataPath = new Path(path, "data").toString
|
val dataPath = new Path(path, "data").toString
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
import org.apache.spark.ml.param.{Param, ParamPair, Params}
|
||||||
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
|
||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
||||||
|
|||||||
@@ -1,161 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.hadoop.fs.Path
|
|
||||||
import org.json4s.{DefaultFormats, JValue}
|
|
||||||
import org.json4s.JsonAST.JObject
|
|
||||||
import org.json4s.jackson.JsonMethods.{compact, parse, render}
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
import org.apache.spark.ml.param.Params
|
|
||||||
import org.apache.spark.ml.util.MLReader
|
|
||||||
|
|
||||||
// This originates from apache-spark DefaultPramsReader copy paste
|
|
||||||
private[spark] object DefaultXGBoostParamsReader {
|
|
||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
|
|
||||||
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
|
||||||
|
|
||||||
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
|
||||||
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
|
||||||
|
|
||||||
/**
|
|
||||||
* All info from metadata file.
|
|
||||||
*
|
|
||||||
* @param params paramMap, as a `JValue`
|
|
||||||
* @param metadata All metadata, including the other fields
|
|
||||||
* @param metadataJson Full metadata file String (for debugging)
|
|
||||||
*/
|
|
||||||
case class Metadata(
|
|
||||||
className: String,
|
|
||||||
uid: String,
|
|
||||||
timestamp: Long,
|
|
||||||
sparkVersion: String,
|
|
||||||
params: JValue,
|
|
||||||
metadata: JValue,
|
|
||||||
metadataJson: String,
|
|
||||||
xgboostVersion: Option[String] = None) {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
|
|
||||||
* This can be useful for getting a Param value before an instance of `Params`
|
|
||||||
* is available.
|
|
||||||
*/
|
|
||||||
def getParamValue(paramName: String): JValue = {
|
|
||||||
implicit val format = DefaultFormats
|
|
||||||
params match {
|
|
||||||
case JObject(pairs) =>
|
|
||||||
val values = pairs.filter { case (pName, jsonValue) =>
|
|
||||||
pName == paramName
|
|
||||||
}.map(_._2)
|
|
||||||
assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
|
|
||||||
s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
|
|
||||||
values.head
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
s"Cannot recognize JSON metadata: $metadataJson.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]]
|
|
||||||
*
|
|
||||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
|
||||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
|
||||||
*/
|
|
||||||
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
|
||||||
val metadataPath = new Path(path, "metadata").toString
|
|
||||||
val metadataStr = sc.textFile(metadataPath, 1).first()
|
|
||||||
parseMetadata(metadataStr, expectedClassName)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]].
|
|
||||||
* This is a helper function for [[loadMetadata()]].
|
|
||||||
*
|
|
||||||
* @param metadataStr JSON string of metadata
|
|
||||||
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
|
||||||
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
|
||||||
*/
|
|
||||||
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
|
|
||||||
val metadata = parse(metadataStr)
|
|
||||||
|
|
||||||
implicit val format = DefaultFormats
|
|
||||||
val className = (metadata \ "class").extract[String]
|
|
||||||
val uid = (metadata \ "uid").extract[String]
|
|
||||||
val timestamp = (metadata \ "timestamp").extract[Long]
|
|
||||||
val sparkVersion = (metadata \ "sparkVersion").extract[String]
|
|
||||||
val params = metadata \ "paramMap"
|
|
||||||
if (expectedClassName.nonEmpty) {
|
|
||||||
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
|
|
||||||
s" $expectedClassName but found class name $className")
|
|
||||||
}
|
|
||||||
val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String]
|
|
||||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
|
||||||
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
|
||||||
}
|
|
||||||
|
|
||||||
private def handleBrokenlyChangedName(paramName: String): String = {
|
|
||||||
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract Params from metadata, and set them in the instance.
|
|
||||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
|
||||||
* TODO: Move to [[Metadata]] method
|
|
||||||
*/
|
|
||||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
|
||||||
implicit val format = DefaultFormats
|
|
||||||
metadata.params match {
|
|
||||||
case JObject(pairs) =>
|
|
||||||
pairs.foreach { case (paramName, jsonValue) =>
|
|
||||||
val finalName = handleBrokenlyChangedName(paramName)
|
|
||||||
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
|
||||||
// So we need to check if the parameter exists instead of blindly setting it.
|
|
||||||
if (instance.hasParam(finalName)) {
|
|
||||||
val param = instance.getParam(finalName)
|
|
||||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
|
||||||
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
|
||||||
} else {
|
|
||||||
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load a `Params` instance from the given path, and return it.
|
|
||||||
* This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
|
|
||||||
*/
|
|
||||||
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
|
|
||||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
|
||||||
val cls = Utils.classForName(metadata.className)
|
|
||||||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
import org.apache.spark.ml.param.{ParamPair, Params}
|
|
||||||
import org.json4s.jackson.JsonMethods._
|
|
||||||
import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue}
|
|
||||||
import JsonDSLXGBoost._
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
// This originates from apache-spark DefaultPramsWriter copy paste
|
|
||||||
private[spark] object DefaultXGBoostParamsWriter {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Saves metadata + Params to: path + "/metadata"
|
|
||||||
* - class
|
|
||||||
* - timestamp
|
|
||||||
* - sparkVersion
|
|
||||||
* - uid
|
|
||||||
* - paramMap
|
|
||||||
* - (optionally, extra metadata)
|
|
||||||
*
|
|
||||||
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
|
|
||||||
* @param paramMap If given, this is saved in the "paramMap" field.
|
|
||||||
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
|
|
||||||
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
|
||||||
*/
|
|
||||||
def saveMetadata(
|
|
||||||
instance: Params,
|
|
||||||
path: String,
|
|
||||||
sc: SparkContext,
|
|
||||||
extraMetadata: Option[JObject] = None,
|
|
||||||
paramMap: Option[JValue] = None): Unit = {
|
|
||||||
|
|
||||||
val metadataPath = new Path(path, "metadata").toString
|
|
||||||
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
|
|
||||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper for [[saveMetadata()]] which extracts the JSON to save.
|
|
||||||
* This is useful for ensemble models which need to save metadata for many sub-models.
|
|
||||||
*
|
|
||||||
* @see [[saveMetadata()]] for details on what this includes.
|
|
||||||
*/
|
|
||||||
def getMetadataToSave(
|
|
||||||
instance: Params,
|
|
||||||
sc: SparkContext,
|
|
||||||
extraMetadata: Option[JObject] = None,
|
|
||||||
paramMap: Option[JValue] = None): String = {
|
|
||||||
val uid = instance.uid
|
|
||||||
val cls = instance.getClass.getName
|
|
||||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
|
||||||
val jsonParams = paramMap.getOrElse(render(params.filter{
|
|
||||||
case ParamPair(p, _) => p != null
|
|
||||||
}.map {
|
|
||||||
case ParamPair(p, v) =>
|
|
||||||
p.name -> parse(p.jsonEncode(v))
|
|
||||||
}.toList))
|
|
||||||
val basicMetadata = ("class" -> cls) ~
|
|
||||||
("timestamp" -> System.currentTimeMillis()) ~
|
|
||||||
("sparkVersion" -> sc.version) ~
|
|
||||||
("uid" -> uid) ~
|
|
||||||
("xgboostVersion" -> spark.VERSION) ~
|
|
||||||
("paramMap" -> jsonParams)
|
|
||||||
val metadata = extraMetadata match {
|
|
||||||
case Some(jObject) =>
|
|
||||||
basicMetadata ~ jObject
|
|
||||||
case None =>
|
|
||||||
basicMetadata
|
|
||||||
}
|
|
||||||
val metadataJson: String = compact(render(metadata))
|
|
||||||
metadataJson
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fix json4s bin-incompatible issue.
|
|
||||||
// This originates from org.json4s.JsonDSL of 3.6.6
|
|
||||||
object JsonDSLXGBoost {
|
|
||||||
|
|
||||||
implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray =
|
|
||||||
JArray(s.toList.map(ev))
|
|
||||||
|
|
||||||
implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject =
|
|
||||||
JObject(m.toList.map { case (k, v) => JField(k, ev(v)) })
|
|
||||||
|
|
||||||
implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match {
|
|
||||||
case Some(x) => ev(x)
|
|
||||||
case None => JNothing
|
|
||||||
}
|
|
||||||
|
|
||||||
implicit def short2jvalue(x: Short): JValue = JInt(x)
|
|
||||||
implicit def byte2jvalue(x: Byte): JValue = JInt(x)
|
|
||||||
implicit def char2jvalue(x: Char): JValue = JInt(x)
|
|
||||||
implicit def int2jvalue(x: Int): JValue = JInt(x)
|
|
||||||
implicit def long2jvalue(x: Long): JValue = JInt(x)
|
|
||||||
implicit def bigint2jvalue(x: BigInt): JValue = JInt(x)
|
|
||||||
implicit def double2jvalue(x: Double): JValue = JDouble(x)
|
|
||||||
implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble)
|
|
||||||
implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue)
|
|
||||||
implicit def boolean2jvalue(x: Boolean): JValue = JBool(x)
|
|
||||||
implicit def string2jvalue(x: String): JValue = JString(x)
|
|
||||||
|
|
||||||
implicit def symbol2jvalue(x: Symbol): JString = JString(x.name)
|
|
||||||
implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject =
|
|
||||||
JObject(List(JField(t._1, ev(t._2))))
|
|
||||||
implicit def list2jvalue(l: List[JField]): JObject = JObject(l)
|
|
||||||
implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj)
|
|
||||||
implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] =
|
|
||||||
new JsonAssoc(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
final class JsonAssoc[A](private val left: (String, A)) extends AnyVal {
|
|
||||||
def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = {
|
|
||||||
val l: JValue = ev1(left._2)
|
|
||||||
val r: JValue = ev2(right._2)
|
|
||||||
JObject(JField(left._1, l) :: JField(right._1, r) :: Nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
def ~(right: JObject)(implicit ev: A => JValue): JObject = {
|
|
||||||
val l: JValue = ev(left._2)
|
|
||||||
JObject(JField(left._1, l) :: right.obj)
|
|
||||||
}
|
|
||||||
def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject =
|
|
||||||
this.~(right)
|
|
||||||
def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right)
|
|
||||||
}
|
|
||||||
|
|
||||||
final class JsonListAssoc(private val left: List[JField]) extends AnyVal {
|
|
||||||
def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2)))
|
|
||||||
def ~(right: JObject): JObject = JObject(left ::: right.obj)
|
|
||||||
def ~~(right: (String, JValue)): JObject = this.~(right)
|
|
||||||
def ~~(right: JObject): JObject = this.~(right)
|
|
||||||
}
|
|
||||||
@@ -17,9 +17,9 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
|
||||||
import org.apache.spark.ml.param.{Param, ParamValidators}
|
import org.apache.spark.ml.param.{Param, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
|
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
|
||||||
|
import org.apache.spark.ml.util.XGBoostSchemaUtils
|
||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark.util
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
@@ -24,8 +24,8 @@ import org.apache.spark.HashPartitioner
|
|||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{Column, DataFrame, Row}
|
|
||||||
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
import org.apache.spark.sql.types.{FloatType, IntegerType}
|
||||||
|
import org.apache.spark.sql.{Column, DataFrame, Row}
|
||||||
|
|
||||||
object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
private[spark] implicit class XGBLabeledPointFeatures(
|
private[spark] implicit class XGBLabeledPointFeatures(
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.util
|
||||||
|
|
||||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||||
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.utils
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
|
||||||
|
|
||||||
import org.apache.spark.ml.util.MLWriter
|
|
||||||
|
|
||||||
private[spark] abstract class XGBoostWriter extends MLWriter {
|
|
||||||
|
|
||||||
/** Currently it's using the "deprecated" format as
|
|
||||||
* default, which will be changed into `ubj` in future releases. */
|
|
||||||
def getModelFormat(): String = {
|
|
||||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2022 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.util
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.hadoop.fs.FSDataInputStream
|
||||||
|
import org.json4s.DefaultFormats
|
||||||
|
import org.json4s.JsonAST.JObject
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods.{compact, render}
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.param.Params
|
||||||
|
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||||
|
|
||||||
|
abstract class XGBoostWriter extends MLWriter {
|
||||||
|
|
||||||
|
/** Currently it's using the "deprecated" format as
|
||||||
|
* default, which will be changed into `ubj` in future releases. */
|
||||||
|
def getModelFormat(): String = {
|
||||||
|
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object DefaultXGBoostParamsWriter {
|
||||||
|
|
||||||
|
val XGBOOST_VERSION_TAG = "xgboostVersion"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
|
||||||
|
*/
|
||||||
|
def saveMetadata(
|
||||||
|
instance: Params,
|
||||||
|
path: String,
|
||||||
|
sc: SparkContext): Unit = {
|
||||||
|
// save xgboost version to distinguish the old model.
|
||||||
|
val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
|
||||||
|
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object DefaultXGBoostParamsReader {
|
||||||
|
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
|
||||||
|
*
|
||||||
|
* @param expectedClassName If non empty, this is checked against the loaded metadata.
|
||||||
|
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
|
||||||
|
*/
|
||||||
|
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
|
||||||
|
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract Params from metadata, and set them in the instance.
|
||||||
|
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||||
|
*
|
||||||
|
* And it will auto-skip the parameter not defined.
|
||||||
|
*
|
||||||
|
* This API is mainly copied from DefaultParamsReader
|
||||||
|
*/
|
||||||
|
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
||||||
|
|
||||||
|
// XGBoost didn't set the default parameters since the save/load code is copied
|
||||||
|
// from spark 2.3.x, which means it just used the default values
|
||||||
|
// as the same with XGBoost version instead of them in model.
|
||||||
|
// For the compatibility, here we still don't set the default parameters.
|
||||||
|
// setParams(instance, metadata, isDefault = true)
|
||||||
|
|
||||||
|
setParams(instance, metadata, isDefault = false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** This API is only for XGBoostClassificationModel */
|
||||||
|
def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
|
||||||
|
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
|
||||||
|
// or the new xgboost compatible.
|
||||||
|
val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
|
||||||
|
.extractOpt[String]
|
||||||
|
|
||||||
|
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
|
||||||
|
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
|
||||||
|
// or else, XGBoost will throw exception.
|
||||||
|
// So it's safe to get numClass from meta data.
|
||||||
|
xgbVerOpt
|
||||||
|
.map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
|
||||||
|
.getOrElse(dataInStream.readInt())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private def setParams(
|
||||||
|
instance: Params,
|
||||||
|
metadata: Metadata,
|
||||||
|
isDefault: Boolean): Unit = {
|
||||||
|
val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
|
||||||
|
paramsToSet match {
|
||||||
|
case JObject(pairs) =>
|
||||||
|
pairs.foreach { case (paramName, jsonValue) =>
|
||||||
|
val finalName = handleBrokenlyChangedName(paramName)
|
||||||
|
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
|
||||||
|
// So we need to check if the parameter exists instead of blindly setting it.
|
||||||
|
if (instance.hasParam(finalName)) {
|
||||||
|
val param = instance.getParam(finalName)
|
||||||
|
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||||
|
instance.set(param, handleBrokenlyChangedValue(paramName, value))
|
||||||
|
} else {
|
||||||
|
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
|
||||||
|
|
||||||
|
/** This is really not good to do this transformation, but it is needed since there're
|
||||||
|
* some tests based on 0.82 saved model in which the objective is "reg:linear" */
|
||||||
|
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
|
||||||
|
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
|
||||||
|
|
||||||
|
private def handleBrokenlyChangedName(paramName: String): String = {
|
||||||
|
paramNameCompatibilityMap.getOrElse(paramName, paramName)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
|
||||||
|
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.apache.spark.ml.linalg.xgboost
|
package org.apache.spark.ml.util
|
||||||
|
|
||||||
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
|
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
|
||||||
import org.apache.spark.ml.linalg.VectorUDT
|
import org.apache.spark.ml.linalg.VectorUDT
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||||
|
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
|||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val sparkSession = ss
|
val sparkSession = ss
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
|
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
protected def buildDataFrame(
|
protected def buildDataFrame(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val it = labeledPoints.iterator.zipWithIndex
|
val it = labeledPoints.iterator.zipWithIndex
|
||||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||||
(id, labeledPoint.label, labeledPoint.features)
|
(id, labeledPoint.label, labeledPoint.features)
|
||||||
@@ -98,7 +98,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
|||||||
protected def buildDataFrameWithGroup(
|
protected def buildDataFrameWithGroup(
|
||||||
labeledPoints: Seq[XGBLabeledPoint],
|
labeledPoints: Seq[XGBLabeledPoint],
|
||||||
numPartitions: Int = numWorkers): DataFrame = {
|
numPartitions: Int = numWorkers): DataFrame = {
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val it = labeledPoints.iterator.zipWithIndex
|
val it = labeledPoints.iterator.zipWithIndex
|
||||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||||
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
|||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
val sparkSession = SparkSession.builder().getOrCreate()
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
||||||
@@ -331,7 +331,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
|||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||||
import DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
val sparkSession = SparkSession.builder().getOrCreate()
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
|
||||||
|
|||||||
Reference in New Issue
Block a user