[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)
* fix the nan and non-zero missing value handling * fix nan handling part * add missing value * Update MissingValueHandlingSuite.scala * Update MissingValueHandlingSuite.scala * stylistic fix
This commit is contained in:
parent
2d875ec019
commit
995698b0cb
@ -18,9 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
import java.util.Properties
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
|
||||
@ -32,8 +30,8 @@ import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkException, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
|
||||
@ -75,8 +73,9 @@ object XGBoost extends Serializable {
|
||||
if (missing != 0.0f) {
|
||||
xgbLabelPoints.map(labeledPoint => {
|
||||
if (labeledPoint.indices != null) {
|
||||
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
|
||||
" SparseVector as your feature format")
|
||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||
" format")
|
||||
}
|
||||
labeledPoint
|
||||
})
|
||||
@ -107,7 +106,8 @@ object XGBoost extends Serializable {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||
missing, (v: Float) => v != missing)
|
||||
} else {
|
||||
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||
missing, (v: Float) => !v.isNaN)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,148 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) 0 else Random.nextDouble
|
||||
}
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("handle Float.NaN as missing value correctly") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0),
|
||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
.setHandleInvalid("keep")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
|
||||
test("specify a non-zero missing value but with dense vector does not stop" +
|
||||
" application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, -1.0f, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, -1.0f, 0.0f, 0.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0),
|
||||
(-1.0f, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
|
||||
test("specify a non-zero missing value and meet an empty vector we should" +
|
||||
" stop the application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, -1.0f, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, -1.0f, 0.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 1.0),// empty vector
|
||||
(-1.0f, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||
intercept[XGBoostError] {
|
||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
}
|
||||
}
|
||||
|
||||
test("specify a non-zero missing value and meet a Sparse vector we should" +
|
||||
" stop the application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
||||
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "col4", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
inputDF.show()
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||
intercept[XGBoostError] {
|
||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -287,7 +287,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
test("infrequent features") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2)
|
||||
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
||||
import DataUtils._
|
||||
val sparkSession = SparkSession.builder().getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
@ -308,7 +308,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
test("infrequent features (use_external_memory)") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true)
|
||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||
import DataUtils._
|
||||
val sparkSession = SparkSession.builder().getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
|
||||
@ -225,50 +225,6 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) 0 else Random.nextDouble
|
||||
}
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("handle Float.NaN as missing value correctly") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 0.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0),
|
||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
.setHandleInvalid("keep")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
@ -4,6 +4,9 @@ import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
|
||||
/**
|
||||
@ -13,6 +16,7 @@ import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||
*/
|
||||
class DataBatch {
|
||||
private static final Log logger = LogFactory.getLog(DataBatch.class);
|
||||
/** The offset of each rows in the sparse matrix */
|
||||
final long[] rowOffset;
|
||||
/** weight of each data point, can be null */
|
||||
@ -49,6 +53,7 @@ class DataBatch {
|
||||
|
||||
@Override
|
||||
public DataBatch next() {
|
||||
try {
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
@ -87,6 +92,10 @@ class DataBatch {
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
} catch (RuntimeException runtimeError) {
|
||||
logger.error(runtimeError);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user