[jvm-packages] support missing value when constructing dmatrix with iterator (#10628)

This commit is contained in:
Bobby Wang
2024-07-23 23:25:07 +08:00
committed by GitHub
parent b3ed81877a
commit 7949a8d5f4
8 changed files with 300 additions and 162 deletions

View File

@@ -15,15 +15,18 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
@@ -36,6 +39,32 @@ import static org.junit.Assert.assertEquals;
*/
public class DMatrixTest {
@Test
public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError {
//create DMatrix from DataIterator
java.util.List<LabeledPoint> blist = new java.util.LinkedList<>();
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 0, 0, 0}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{Float.NaN, 13, 14, 15}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 0, 25}));
// Default missing value: Float.NaN
DMatrix dmat = new DMatrix(blist.iterator(), null);
assert dmat.nonMissingNum() == 11;
// missing value 0
dmat = new DMatrix(blist.iterator(), null, 0.0f);
assert dmat.nonMissingNum() == 12 - 4 - 1;
// missing value 21
dmat = new DMatrix(blist.iterator(), null, 21.0f);
assert dmat.nonMissingNum() == 12 - 1 - 1;
// missing value 101010101010
dmat = new DMatrix(blist.iterator(), null, 101010101010.0f);
assert dmat.nonMissingNum() == 12 - 1;
}
@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
@@ -45,7 +74,7 @@ public class DMatrixTest {
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
LabeledPoint p = new LabeledPoint(
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
@@ -290,7 +319,7 @@ public class DMatrixTest {
} finally {
if (dmat0 != null) {
dmat0.dispose();
} else if (data0 != null){
} else if (data0 != null) {
data0.dispose();
}
}
@@ -309,9 +338,9 @@ public class DMatrixTest {
// (3,1) -> 2
// (2,3) -> 3
float[][] data = new float[][]{
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
};
data0 = new BigDenseMatrix(3, 2);
for (int i = 0; i < data0.nrow; i++)
@@ -428,4 +457,40 @@ public class DMatrixTest {
String[] retFeatureTypes = dmat.getFeatureTypes();
assertArrayEquals(featureTypes, retFeatureTypes);
}
@Test
public void testSetAndGetQueryId() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
//create two groups
int[] qid = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int[] qidExpected = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
dmat0.setQueryId(qid);
//check
TestCase.assertTrue(Arrays.equals(qidExpected, dmat0.getGroup()));
//create two groups
int[] qid1 = new int[]{10, 10, 10, 20, 60, 60, 80, 80, 90, 100};
int[] qidExpected1 = new int[]{0, 3, 4, 6, 8, 9, 10};
dmat0.setQueryId(qid1);
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
}
}