[jvm-packages] add format option when saving a model (#7940)
This commit is contained in:
@@ -16,16 +16,18 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
import java.io.{File, FileInputStream}
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
import scala.math.min
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
|
||||
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||
@@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features", "group")
|
||||
}
|
||||
|
||||
|
||||
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
||||
withResource(new FileInputStream(lhs)) { lfis =>
|
||||
withResource(new FileInputStream(rhs)) { rfis =>
|
||||
IOUtils.contentEquals(lfis, rfis)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Executes the provided code block and then closes the resource */
|
||||
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||
try {
|
||||
block(r)
|
||||
} finally {
|
||||
r.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
model.write.overwrite().save(modelPath)
|
||||
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeModelPath)
|
||||
|
||||
model.write.option("format", "json").save(modelPath)
|
||||
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||
model.nativeBooster.saveModel(nativeJsonModelPath)
|
||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeModelPath))
|
||||
}
|
||||
nativeJsonModelPath))
|
||||
|
||||
private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
||||
withResource(new FileInputStream(lhs)) { lfis =>
|
||||
withResource(new FileInputStream(rhs)) { rfis =>
|
||||
IOUtils.contentEquals(lfis, rfis)
|
||||
}
|
||||
}
|
||||
}
|
||||
// test default "deprecated"
|
||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||
model.write.save(modelUbjPath)
|
||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeDeprecatedModelPath))
|
||||
|
||||
/** Executes the provided code block and then closes the resource */
|
||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||
try {
|
||||
block(r)
|
||||
} finally {
|
||||
r.close()
|
||||
}
|
||||
// json file should be indifferent with ubj file
|
||||
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||
model.write.option("format", "json").save(modelJsonPath)
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
@@ -25,7 +27,7 @@ import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||
@@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
val df1 = model.transform(vectorizedInput)
|
||||
df1.show()
|
||||
}
|
||||
|
||||
test("XGBoostRegressionModel should be compatible") {
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:squarederror",
|
||||
"num_round" -> 5,
|
||||
"tree_method" -> treeMethod,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||
|
||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
model.write.option("format", "json").save(modelPath)
|
||||
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||
model.nativeBooster.saveModel(nativeJsonModelPath)
|
||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeJsonModelPath))
|
||||
|
||||
// test default "deprecated"
|
||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||
model.write.save(modelUbjPath)
|
||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeDeprecatedModelPath))
|
||||
|
||||
// json file should be indifferent with ubj file
|
||||
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||
model.write.option("format", "json").save(modelJsonPath)
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user