[jvm-packages] fix "key not found: train" issue (#6842)

* [jvm-packages] fix "key not found: train" issue

* fix bug
This commit is contained in:
Bobby Wang
2021-04-19 14:28:39 +08:00
committed by GitHub
parent 556a83022d
commit 2c684ffd32
2 changed files with 16 additions and 2 deletions

View File

@@ -16,10 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{TaskContext}
import org.apache.spark.TaskContext
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
@@ -367,4 +369,16 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
df2.collect()
}
test("throw exception for empty partition in trainingset") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "2", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> "auto")
// The Dmatrix will be empty
val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 1, Array(), Array())))
val xgb = new XGBoostClassifier(paramMap)
intercept[XGBoostError] {
val model = xgb.fit(trainingDF)
}
}
}