[jvm-packages] add format option when saving a model (#7940)

This commit is contained in:
Bobby Wang
2022-05-30 15:49:59 +08:00
committed by GitHub
parent cc6d57aa0d
commit 6275cdc486
8 changed files with 153 additions and 30 deletions

View File

@@ -16,16 +16,18 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.io.{File, FileInputStream}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
@@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}
/** Executes the provided code block and then closes the resource */
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.overwrite().save(modelPath)
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeModelPath)
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeModelPath))
}
nativeJsonModelPath))
private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}
// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
/** Executes the provided code block and then closes the resource */
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
}
}

View File

@@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
@@ -25,7 +27,7 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostRegressorSuite extends FunSuite with PerTest {
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
@@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val df1 = model.transform(vectorizedInput)
df1.show()
}
test("XGBoostRegressionModel should be compatible") {
val trainingDF = buildDataFrame(Regression.train)
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))
// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
}
}