[JVM] Add Iterator loading API
This commit is contained in:
@@ -28,6 +28,41 @@ import org.junit.Test;
|
||||
*/
|
||||
public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testCreateFromDataIterator() throws XGBoostError {
|
||||
//create DMatrix from DataIterator
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
DataBatch batch = new DataBatch();
|
||||
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||
batch.rowOffset = new long[]{0, 3, 7, 11};
|
||||
batch.label = new float[] {0.1f, 0.2f, 0.3f};
|
||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||
int nrep = 3;
|
||||
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
|
||||
blist.add(batch.shallowCopy());
|
||||
for (float f : batch.label) {
|
||||
labelall.add(f);
|
||||
}
|
||||
}
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
// get label
|
||||
float[] labels = dmat.getLabel();
|
||||
// get label
|
||||
TestCase.assertTrue(batch.label.length * nrep == labels.length);
|
||||
|
||||
for (int i = 0; i < labels.length; ++i) {
|
||||
TestCase.assertTrue(labelall.get(i) == labels[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
|
||||
Reference in New Issue
Block a user