[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)

* fix the nan and non-zero missing value handling

* fix nan handling part

* add missing value

* Update MissingValueHandlingSuite.scala

* Update MissingValueHandlingSuite.scala

* stylistic fix
This commit is contained in:
Nan Zhu
2019-04-26 11:10:33 -07:00
committed by GitHub
parent 2d875ec019
commit 995698b0cb
5 changed files with 201 additions and 88 deletions

View File

@@ -4,6 +4,9 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.LabeledPoint;
/**
@@ -13,6 +16,7 @@ import ml.dmlc.xgboost4j.LabeledPoint;
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
class DataBatch {
private static final Log logger = LogFactory.getLog(DataBatch.class);
/** The offset of each rows in the sparse matrix */
final long[] rowOffset;
/** weight of each data point, can be null */
@@ -49,44 +53,49 @@ class DataBatch {
@Override
public DataBatch next() {
int numRows = 0;
int numElem = 0;
List<LabeledPoint> batch = new ArrayList<>(batchSize);
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
batch.add(labeledPoint);
numElem += labeledPoint.values().length;
numRows++;
}
long[] rowOffset = new long[numRows + 1];
float[] label = new float[numRows];
int[] featureIndex = new int[numElem];
float[] featureValue = new float[numElem];
float[] weight = new float[numRows];
int offset = 0;
for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset;
label[i] = labeledPoint.label();
weight[i] = labeledPoint.weight();
if (labeledPoint.indices() != null) {
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
labeledPoint.indices().length);
} else {
for (int j = 0; j < labeledPoint.values().length; j++) {
featureIndex[offset + j] = j;
}
try {
int numRows = 0;
int numElem = 0;
List<LabeledPoint> batch = new ArrayList<>(batchSize);
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
batch.add(labeledPoint);
numElem += labeledPoint.values().length;
numRows++;
}
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
labeledPoint.values().length);
offset += labeledPoint.values().length;
}
long[] rowOffset = new long[numRows + 1];
float[] label = new float[numRows];
int[] featureIndex = new int[numElem];
float[] featureValue = new float[numElem];
float[] weight = new float[numRows];
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
int offset = 0;
for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset;
label[i] = labeledPoint.label();
weight[i] = labeledPoint.weight();
if (labeledPoint.indices() != null) {
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
labeledPoint.indices().length);
} else {
for (int j = 0; j < labeledPoint.values().length; j++) {
featureIndex[offset + j] = j;
}
}
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
labeledPoint.values().length);
offset += labeledPoint.values().length;
}
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
} catch (RuntimeException runtimeError) {
logger.error(runtimeError);
return null;
}
}
@Override