[JVM] Add Iterator loading API

This commit is contained in:
tqchen
2016-03-04 17:22:08 -08:00
parent 770b3451ca
commit 86871d4be9
10 changed files with 451 additions and 5 deletions

View File

@@ -28,6 +28,41 @@ import org.junit.Test;
*/
public class DMatrixTest {
@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
DataBatch batch = new DataBatch();
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
batch.rowOffset = new long[]{0, 3, 7, 11};
batch.label = new float[] {0.1f, 0.2f, 0.3f};
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
int nrep = 3;
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
for (int i = 0; i < nrep; ++i) {
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
blist.add(batch.shallowCopy());
for (float f : batch.label) {
labelall.add(f);
}
}
DMatrix dmat = new DMatrix(blist.iterator(), null);
// get label
float[] labels = dmat.getLabel();
// get label
TestCase.assertTrue(batch.label.length * nrep == labels.length);
for (int i = 0; i < labels.length; ++i) {
TestCase.assertTrue(labelall.get(i) == labels[i]);
}
}
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file