[jvm-packages] fix "key not found: train" issue (#6842)
* [jvm-packages] fix "key not found: train" issue * fix bug
This commit is contained in:
parent
556a83022d
commit
2c684ffd32
@ -377,7 +377,7 @@ object XGBoost extends Serializable {
|
|||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
if (watches.toMap("train").rowNum == 0) {
|
if (!watches.toMap.contains("train")) {
|
||||||
throw new XGBoostError(
|
throw new XGBoostError(
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s" ${TaskContext.getPartitionId()}")
|
s" ${TaskContext.getPartitionId()}")
|
||||||
|
|||||||
@ -16,10 +16,12 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import org.apache.spark.{TaskContext}
|
import org.apache.spark.TaskContext
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
import org.apache.spark.sql.functions.lit
|
import org.apache.spark.sql.functions.lit
|
||||||
@ -367,4 +369,16 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|||||||
df2.collect()
|
df2.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("throw exception for empty partition in trainingset") {
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softmax", "num_class" -> "2", "num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers, "tree_method" -> "auto")
|
||||||
|
// The Dmatrix will be empty
|
||||||
|
val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 1, Array(), Array())))
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
intercept[XGBoostError] {
|
||||||
|
val model = xgb.fit(trainingDF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user