[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)
* fix the nan and non-zero missing value handling * fix nan handling part * add missing value * Update MissingValueHandlingSuite.scala * Update MissingValueHandlingSuite.scala * stylistic fix
This commit is contained in:
parent
2d875ec019
commit
995698b0cb
@ -18,9 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.util.Properties
|
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
|
||||||
import scala.collection.{AbstractIterator, mutable}
|
import scala.collection.{AbstractIterator, mutable}
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
@ -32,8 +30,8 @@ import org.apache.commons.io.FileUtils
|
|||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkException, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
|
|
||||||
@ -75,8 +73,9 @@ object XGBoost extends Serializable {
|
|||||||
if (missing != 0.0f) {
|
if (missing != 0.0f) {
|
||||||
xgbLabelPoints.map(labeledPoint => {
|
xgbLabelPoints.map(labeledPoint => {
|
||||||
if (labeledPoint.indices != null) {
|
if (labeledPoint.indices != null) {
|
||||||
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
|
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||||
" SparseVector as your feature format")
|
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||||
|
" format")
|
||||||
}
|
}
|
||||||
labeledPoint
|
labeledPoint
|
||||||
})
|
})
|
||||||
@ -107,7 +106,8 @@ object XGBoost extends Serializable {
|
|||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||||
missing, (v: Float) => v != missing)
|
missing, (v: Float) => v != missing)
|
||||||
} else {
|
} else {
|
||||||
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||||
|
missing, (v: Float) => !v.isNaN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,148 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
|
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||||
|
test("dense vectors containing missing value") {
|
||||||
|
def buildDenseDataFrame(): DataFrame = {
|
||||||
|
val numRows = 100
|
||||||
|
val numCols = 5
|
||||||
|
val data = (0 until numRows).map { x =>
|
||||||
|
val label = Random.nextInt(2)
|
||||||
|
val values = Array.tabulate[Double](numCols) { c =>
|
||||||
|
if (c == numCols - 1) 0 else Random.nextDouble
|
||||||
|
}
|
||||||
|
(label, Vectors.dense(values))
|
||||||
|
}
|
||||||
|
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||||
|
}
|
||||||
|
val denseDF = buildDenseDataFrame().repartition(4)
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||||
|
model.transform(denseDF).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("handle Float.NaN as missing value correctly") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val testDF = Seq(
|
||||||
|
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 1.0),
|
||||||
|
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
.setHandleInvalid("keep")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
model.transform(inputDF).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("specify a non-zero missing value but with dense vector does not stop" +
|
||||||
|
" application") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
|
// vector,
|
||||||
|
val testDF = Seq(
|
||||||
|
(1.0f, 0.0f, -1.0f, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, -1.0f, 0.0f, 0.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 1.0),
|
||||||
|
(-1.0f, 0.0f, 0.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
model.transform(inputDF).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("specify a non-zero missing value and meet an empty vector we should" +
|
||||||
|
" stop the application") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val testDF = Seq(
|
||||||
|
(1.0f, 0.0f, -1.0f, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, -1.0f, 0.0f, 0.0),
|
||||||
|
(0.0f, 0.0f, 0.0f, 1.0),// empty vector
|
||||||
|
(-1.0f, 0.0f, 0.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||||
|
intercept[XGBoostError] {
|
||||||
|
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("specify a non-zero missing value and meet a Sparse vector we should" +
|
||||||
|
" stop the application") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
ss.sparkContext.setLogLevel("INFO")
|
||||||
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
|
// vector,
|
||||||
|
val testDF = Seq(
|
||||||
|
(1.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
||||||
|
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "col4", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
inputDF.show()
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
|
||||||
|
intercept[XGBoostError] {
|
||||||
|
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -287,7 +287,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("infrequent features") {
|
test("infrequent features") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> 2)
|
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
val sparkSession = SparkSession.builder().getOrCreate()
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
@ -308,7 +308,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("infrequent features (use_external_memory)") {
|
test("infrequent features (use_external_memory)") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true)
|
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val sparkSession = SparkSession.builder().getOrCreate()
|
val sparkSession = SparkSession.builder().getOrCreate()
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
|
|||||||
@ -225,50 +225,6 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(x < 0.1)
|
assert(x < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("dense vectors containing missing value") {
|
|
||||||
def buildDenseDataFrame(): DataFrame = {
|
|
||||||
val numRows = 100
|
|
||||||
val numCols = 5
|
|
||||||
val data = (0 until numRows).map { x =>
|
|
||||||
val label = Random.nextInt(2)
|
|
||||||
val values = Array.tabulate[Double](numCols) { c =>
|
|
||||||
if (c == numCols - 1) 0 else Random.nextDouble
|
|
||||||
}
|
|
||||||
(label, Vectors.dense(values))
|
|
||||||
}
|
|
||||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
|
||||||
}
|
|
||||||
val denseDF = buildDenseDataFrame().repartition(4)
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
|
||||||
model.transform(denseDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("handle Float.NaN as missing value correctly") {
|
|
||||||
val spark = ss
|
|
||||||
import spark.implicits._
|
|
||||||
val testDF = Seq(
|
|
||||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 0.0),
|
|
||||||
(1.0f, 0.0f, 1.0f, 1.0),
|
|
||||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
|
||||||
(0.0f, 0.0f, 0.0f, 0.0),
|
|
||||||
(0.0f, 1.0f, 0.0f, 1.0),
|
|
||||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
|
||||||
).toDF("col1", "col2", "col3", "label")
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
|
||||||
.setInputCols(Array("col1", "col2", "col3"))
|
|
||||||
.setOutputCol("features")
|
|
||||||
.setHandleInvalid("keep")
|
|
||||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
|
||||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
|
||||||
model.transform(inputDF).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with spark parallelism checks disabled") {
|
test("training with spark parallelism checks disabled") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
|
|||||||
@ -4,6 +4,9 @@ import java.util.ArrayList;
|
|||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -13,6 +16,7 @@ import ml.dmlc.xgboost4j.LabeledPoint;
|
|||||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||||
*/
|
*/
|
||||||
class DataBatch {
|
class DataBatch {
|
||||||
|
private static final Log logger = LogFactory.getLog(DataBatch.class);
|
||||||
/** The offset of each rows in the sparse matrix */
|
/** The offset of each rows in the sparse matrix */
|
||||||
final long[] rowOffset;
|
final long[] rowOffset;
|
||||||
/** weight of each data point, can be null */
|
/** weight of each data point, can be null */
|
||||||
@ -49,44 +53,49 @@ class DataBatch {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBatch next() {
|
public DataBatch next() {
|
||||||
int numRows = 0;
|
try {
|
||||||
int numElem = 0;
|
int numRows = 0;
|
||||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
int numElem = 0;
|
||||||
while (base.hasNext() && batch.size() < batchSize) {
|
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||||
LabeledPoint labeledPoint = base.next();
|
while (base.hasNext() && batch.size() < batchSize) {
|
||||||
batch.add(labeledPoint);
|
LabeledPoint labeledPoint = base.next();
|
||||||
numElem += labeledPoint.values().length;
|
batch.add(labeledPoint);
|
||||||
numRows++;
|
numElem += labeledPoint.values().length;
|
||||||
}
|
numRows++;
|
||||||
|
|
||||||
long[] rowOffset = new long[numRows + 1];
|
|
||||||
float[] label = new float[numRows];
|
|
||||||
int[] featureIndex = new int[numElem];
|
|
||||||
float[] featureValue = new float[numElem];
|
|
||||||
float[] weight = new float[numRows];
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
for (int i = 0; i < batch.size(); i++) {
|
|
||||||
LabeledPoint labeledPoint = batch.get(i);
|
|
||||||
rowOffset[i] = offset;
|
|
||||||
label[i] = labeledPoint.label();
|
|
||||||
weight[i] = labeledPoint.weight();
|
|
||||||
if (labeledPoint.indices() != null) {
|
|
||||||
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
|
||||||
labeledPoint.indices().length);
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < labeledPoint.values().length; j++) {
|
|
||||||
featureIndex[offset + j] = j;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
|
long[] rowOffset = new long[numRows + 1];
|
||||||
labeledPoint.values().length);
|
float[] label = new float[numRows];
|
||||||
offset += labeledPoint.values().length;
|
int[] featureIndex = new int[numElem];
|
||||||
}
|
float[] featureValue = new float[numElem];
|
||||||
|
float[] weight = new float[numRows];
|
||||||
|
|
||||||
rowOffset[batch.size()] = offset;
|
int offset = 0;
|
||||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
for (int i = 0; i < batch.size(); i++) {
|
||||||
|
LabeledPoint labeledPoint = batch.get(i);
|
||||||
|
rowOffset[i] = offset;
|
||||||
|
label[i] = labeledPoint.label();
|
||||||
|
weight[i] = labeledPoint.weight();
|
||||||
|
if (labeledPoint.indices() != null) {
|
||||||
|
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
||||||
|
labeledPoint.indices().length);
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < labeledPoint.values().length; j++) {
|
||||||
|
featureIndex[offset + j] = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
|
||||||
|
labeledPoint.values().length);
|
||||||
|
offset += labeledPoint.values().length;
|
||||||
|
}
|
||||||
|
|
||||||
|
rowOffset[batch.size()] = offset;
|
||||||
|
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||||
|
} catch (RuntimeException runtimeError) {
|
||||||
|
logger.error(runtimeError);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user