Set feature_names and feature_types in jvm-packages (#9364)
* 1. Add parameters to set feature names and feature types 2. Save feature names and feature types to native json model * Change serialization and deserialization format to ubj.
This commit is contained in:
@@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.json4s.{DefaultFormats, Formats}
|
||||
import org.json4s.jackson.parseJson
|
||||
|
||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
@@ -453,4 +455,26 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
test("native json model file should store feature_name and feature_type") {
|
||||
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
|
||||
val featureTypes = (1 to 33).map(idx => "q").toArray
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod
|
||||
)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
.setFeatureNames(featureNames)
|
||||
.setFeatureTypes(featureTypes)
|
||||
val model = xgb.fit(trainingDF)
|
||||
val modelStr = new String(model._booster.toByteArray("json"))
|
||||
System.out.println(modelStr)
|
||||
val jsonModel = parseJson(modelStr)
|
||||
implicit val formats: Formats = DefaultFormats
|
||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
|
||||
assert(featureNamesInModel.length == 33)
|
||||
assert(featureTypesInModel.length == 33)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user