[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)
* fix the nan and non-zero missing value handling * fix nan handling part * add missing value * Update MissingValueHandlingSuite.scala * Update MissingValueHandlingSuite.scala * stylistic fix
This commit is contained in:
@@ -4,6 +4,9 @@ import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
|
||||
/**
|
||||
@@ -13,6 +16,7 @@ import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||
*/
|
||||
class DataBatch {
|
||||
private static final Log logger = LogFactory.getLog(DataBatch.class);
|
||||
/** The offset of each rows in the sparse matrix */
|
||||
final long[] rowOffset;
|
||||
/** weight of each data point, can be null */
|
||||
@@ -49,44 +53,49 @@ class DataBatch {
|
||||
|
||||
@Override
|
||||
public DataBatch next() {
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
while (base.hasNext() && batch.size() < batchSize) {
|
||||
LabeledPoint labeledPoint = base.next();
|
||||
batch.add(labeledPoint);
|
||||
numElem += labeledPoint.values().length;
|
||||
numRows++;
|
||||
}
|
||||
|
||||
long[] rowOffset = new long[numRows + 1];
|
||||
float[] label = new float[numRows];
|
||||
int[] featureIndex = new int[numElem];
|
||||
float[] featureValue = new float[numElem];
|
||||
float[] weight = new float[numRows];
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; i < batch.size(); i++) {
|
||||
LabeledPoint labeledPoint = batch.get(i);
|
||||
rowOffset[i] = offset;
|
||||
label[i] = labeledPoint.label();
|
||||
weight[i] = labeledPoint.weight();
|
||||
if (labeledPoint.indices() != null) {
|
||||
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
||||
labeledPoint.indices().length);
|
||||
} else {
|
||||
for (int j = 0; j < labeledPoint.values().length; j++) {
|
||||
featureIndex[offset + j] = j;
|
||||
}
|
||||
try {
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
while (base.hasNext() && batch.size() < batchSize) {
|
||||
LabeledPoint labeledPoint = base.next();
|
||||
batch.add(labeledPoint);
|
||||
numElem += labeledPoint.values().length;
|
||||
numRows++;
|
||||
}
|
||||
|
||||
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
|
||||
labeledPoint.values().length);
|
||||
offset += labeledPoint.values().length;
|
||||
}
|
||||
long[] rowOffset = new long[numRows + 1];
|
||||
float[] label = new float[numRows];
|
||||
int[] featureIndex = new int[numElem];
|
||||
float[] featureValue = new float[numElem];
|
||||
float[] weight = new float[numRows];
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
int offset = 0;
|
||||
for (int i = 0; i < batch.size(); i++) {
|
||||
LabeledPoint labeledPoint = batch.get(i);
|
||||
rowOffset[i] = offset;
|
||||
label[i] = labeledPoint.label();
|
||||
weight[i] = labeledPoint.weight();
|
||||
if (labeledPoint.indices() != null) {
|
||||
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
||||
labeledPoint.indices().length);
|
||||
} else {
|
||||
for (int j = 0; j < labeledPoint.values().length; j++) {
|
||||
featureIndex[offset + j] = j;
|
||||
}
|
||||
}
|
||||
|
||||
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
|
||||
labeledPoint.values().length);
|
||||
offset += labeledPoint.values().length;
|
||||
}
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
} catch (RuntimeException runtimeError) {
|
||||
logger.error(runtimeError);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user