Set feature_names and feature_types in jvm-packages (#9364)

* 1. Add parameters to set feature names and feature types
2. Save feature names and feature types to native json model

* Change serialization and deserialization format to ubj.
This commit is contained in:
jinmfeng001
2023-07-12 15:18:46 +08:00
committed by GitHub
parent 3632242e0b
commit a1367ea1f8
12 changed files with 295 additions and 8 deletions

View File

@@ -74,7 +74,9 @@ private[scala] case class XGBoostExecutionParams(
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
isLocal: Boolean,
featureNames: Option[Array[String]],
featureTypes: Option[Array[String]]) {
private var rawParamMap: Map[String, Any] = _
@@ -213,6 +215,13 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]
val featureNames = if (overridedParams.contains("feature_names")) {
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
} else None
val featureTypes = if (overridedParams.contains("feature_types")){
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
} else None
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
checkpointParam,
@@ -220,7 +229,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
isLocal)
isLocal,
featureNames,
featureTypes
)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
@@ -531,6 +543,16 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}
if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
@@ -643,6 +665,15 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}
if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
}

View File

@@ -139,6 +139,12 @@ class XGBoostClassifier (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)
def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

View File

@@ -141,6 +141,12 @@ class XGBoostRegressor (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)
def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

View File

@@ -177,6 +177,21 @@ private[spark] trait GeneralParams extends Params {
final def getSeed: Long = $(seed)
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
* In native code, the parameter name is feature_name.
* */
final val featureNames = new StringArrayParam(this, "feature_names",
"an array of feature names")
final def getFeatureNames: Array[String] = $(featureNames)
/** Feature types, q is numeric and c is categorical.
* In native code, the parameter name is feature_type
* */
final val featureTypes = new StringArrayParam(this, "feature_types",
"an array of feature types")
final def getFeatureTypes: Array[String] = $(featureTypes)
}
trait HasLeafPredictionCol extends Params {

View File

@@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.json4s.{DefaultFormats, Formats}
import org.json4s.jackson.parseJson
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
@@ -453,4 +455,26 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
}
test("native json model file should store feature_name and feature_type") {
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
val featureTypes = (1 to 33).map(idx => "q").toArray
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod
)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
.setFeatureNames(featureNames)
.setFeatureTypes(featureTypes)
val model = xgb.fit(trainingDF)
val modelStr = new String(model._booster.toByteArray("json"))
System.out.println(modelStr)
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
assert(featureNamesInModel.length == 33)
assert(featureTypesInModel.length == 33)
}
}