[jvm-packages] support missing value when constructing dmatrix with iterator (#10628)
This commit is contained in:
@@ -15,15 +15,18 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
@@ -36,6 +39,32 @@ import static org.junit.Assert.assertEquals;
|
||||
*/
|
||||
public class DMatrixTest {
|
||||
|
||||
|
||||
@Test
|
||||
public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError {
|
||||
//create DMatrix from DataIterator
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<>();
|
||||
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 0, 0, 0}));
|
||||
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{Float.NaN, 13, 14, 15}));
|
||||
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 0, 25}));
|
||||
|
||||
// Default missing value: Float.NaN
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
assert dmat.nonMissingNum() == 11;
|
||||
|
||||
// missing value 0
|
||||
dmat = new DMatrix(blist.iterator(), null, 0.0f);
|
||||
assert dmat.nonMissingNum() == 12 - 4 - 1;
|
||||
|
||||
// missing value 21
|
||||
dmat = new DMatrix(blist.iterator(), null, 21.0f);
|
||||
assert dmat.nonMissingNum() == 12 - 1 - 1;
|
||||
|
||||
// missing value 101010101010
|
||||
dmat = new DMatrix(blist.iterator(), null, 101010101010.0f);
|
||||
assert dmat.nonMissingNum() == 12 - 1;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDataIterator() throws XGBoostError {
|
||||
//create DMatrix from DataIterator
|
||||
@@ -45,7 +74,7 @@ public class DMatrixTest {
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
@@ -290,7 +319,7 @@ public class DMatrixTest {
|
||||
} finally {
|
||||
if (dmat0 != null) {
|
||||
dmat0.dispose();
|
||||
} else if (data0 != null){
|
||||
} else if (data0 != null) {
|
||||
data0.dispose();
|
||||
}
|
||||
}
|
||||
@@ -309,9 +338,9 @@ public class DMatrixTest {
|
||||
// (3,1) -> 2
|
||||
// (2,3) -> 3
|
||||
float[][] data = new float[][]{
|
||||
new float[]{4f, 5f},
|
||||
new float[]{3f, 1f},
|
||||
new float[]{2f, 3f}
|
||||
new float[]{4f, 5f},
|
||||
new float[]{3f, 1f},
|
||||
new float[]{2f, 3f}
|
||||
};
|
||||
data0 = new BigDenseMatrix(3, 2);
|
||||
for (int i = 0; i < data0.nrow; i++)
|
||||
@@ -428,4 +457,40 @@ public class DMatrixTest {
|
||||
String[] retFeatureTypes = dmat.getFeatureTypes();
|
||||
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetAndGetQueryId() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
int nrow = 10;
|
||||
int ncol = 5;
|
||||
float[] data0 = new float[nrow * ncol];
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nrow * ncol; i++) {
|
||||
data0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
//create label
|
||||
float[] label0 = new float[nrow];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
//create two groups
|
||||
int[] qid = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
|
||||
int[] qidExpected = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
|
||||
|
||||
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||
dmat0.setLabel(label0);
|
||||
dmat0.setQueryId(qid);
|
||||
//check
|
||||
TestCase.assertTrue(Arrays.equals(qidExpected, dmat0.getGroup()));
|
||||
|
||||
//create two groups
|
||||
int[] qid1 = new int[]{10, 10, 10, 20, 60, 60, 80, 80, 90, 100};
|
||||
int[] qidExpected1 = new int[]{0, 3, 4, 6, 8, 9, 10};
|
||||
dmat0.setQueryId(qid1);
|
||||
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user