[BLOCKING] Handle empty rows in data iterators correctly (#5929)

* [jvm-packages] Handle empty rows in data iterators correctly

* Fix clang-tidy error

* last empty row

* Add comments [skip ci]

Co-authored-by: Nan Zhu <nanzhu@uber.com>
This commit is contained in:
Philip Hyunsu Cho
2020-07-25 13:46:19 -07:00
committed by GitHub
parent a4de2f68e4
commit 487ab0ce73
5 changed files with 79 additions and 19 deletions

View File

@@ -171,4 +171,31 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
// https://github.com/dmlc/xgboost/pull/5929
test("handle the empty last row correctly with a missing value as 0") {
val spark = ss
import spark.implicits._
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
(0.0f, 0.0f, 0.0f, 0.0f, 0.0)
).toDF("col1", "col2", "col3", "col4", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3", "col4"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> 0.0f,
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
}