[jvm-packages] (xgboost-spark) preserving num_class across save & load (#2742)
* [bugfix] (xgboost-spark) preserving num_class across save & load * add testcase for save & load of multiclass model
This commit is contained in:
@@ -192,6 +192,7 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
||||
assert(model.get[Double](model.eta).get == 0.1)
|
||||
assert(model.get[Int](model.maxDepth).get == 6)
|
||||
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
}
|
||||
|
||||
test("test use base margin") {
|
||||
|
||||
@@ -286,7 +286,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
import DataUtils._
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
// validate regression model
|
||||
@@ -318,6 +318,26 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
// (multiclass) classification model
|
||||
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||
"raw_col")
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
}
|
||||
|
||||
test("test use groupData") {
|
||||
|
||||
Reference in New Issue
Block a user