[jvm-packages] Allow for bypassing spark missing value check (#4805)

* Allow for bypassing spark missing value check

* Update documentation for dealing with missing values in spark xgboost
This commit is contained in:
cpfarrell
2019-12-18 10:48:20 -08:00
committed by Nan Zhu
parent 27b3646d29
commit bc9d88259f
6 changed files with 134 additions and 40 deletions

View File

@@ -150,4 +150,32 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
new XGBoostClassifier(paramMap).fit(inputDF)
}
}
test("specify a non-zero missing value but set allow_non_zero_missing_value " +
"does not stop application") {
val spark = ss
import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
).toDF("col1", "col2", "col3", "col4", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3", "col4"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f,
"num_workers" -> 1, "allow_non_zero_for_missing_value" -> "true").toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
}