[Breaking][jvm-packages] make classification model be xgboost-compatible (#7896)
This commit is contained in:
@@ -16,16 +16,19 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileInputStream}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
@@ -391,4 +394,37 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
df1.show()
|
||||
}
|
||||
|
||||
test("XGBoostClassificationModel should be compatible") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
model.write.overwrite().save(modelPath)
|
||||
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeModelPath)
|
||||
|
||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeModelPath))
|
||||
}
|
||||
|
||||
private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
|
||||
withResource(new FileInputStream(lhs)) { lfis =>
|
||||
withResource(new FileInputStream(rhs)) { rfis =>
|
||||
IOUtils.contentEquals(lfis, rfis)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Executes the provided code block and then closes the resource */
|
||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||
try {
|
||||
block(r)
|
||||
} finally {
|
||||
r.close()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user