[jvm-packages] Allow for bypassing spark missing value check (#4805)
* Allow for bypassing spark missing value check * Update documentation for dealing with missing values in spark xgboost
This commit is contained in:
@@ -150,4 +150,32 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
}
|
||||
}
|
||||
|
||||
test("specify a non-zero missing value but set allow_non_zero_missing_value " +
|
||||
"does not stop application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
||||
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "col4", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
inputDF.show()
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -1.0f,
|
||||
"num_workers" -> 1, "allow_non_zero_for_missing_value" -> "true").toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user