[JVM] Add LabeledPoint read support

fix
This commit is contained in:
tqchen
2016-03-05 13:00:18 -08:00
parent ac8e950227
commit 514df14baf
6 changed files with 93 additions and 40 deletions

View File

@@ -15,10 +15,12 @@
*/
package ml.dmlc.xgboost4j.java;
import java.awt.*;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
@@ -34,33 +36,19 @@ public class DMatrixTest {
@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
DataBatch batch = new DataBatch();
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
batch.rowOffset = new long[]{0, 3, 7, 11};
batch.label = new float[] {0.1f, 0.2f, 0.3f};
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
int nrep = 3;
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
int nrep = 3000;
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
blist.add(batch.shallowCopy());
for (float f : batch.label) {
labelall.add(f);
}
LabeledPoint p = LabeledPoint.fromSparseVector(
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label);
}
DMatrix dmat = new DMatrix(blist.iterator(), null);
// get label
float[] labels = dmat.getLabel();
// get label
TestCase.assertTrue(batch.label.length * nrep == labels.length);
for (int i = 0; i < labels.length; ++i) {
TestCase.assertTrue(labelall.get(i) == labels[i]);
}