Merge pull request #921 from tqchen/master
[JVM] Add LabeledPoint read support
This commit is contained in:
commit
74bda4bfc5
@ -21,7 +21,6 @@
|
|||||||
<module>xgboost4j</module>
|
<module>xgboost4j</module>
|
||||||
<module>xgboost4j-demo</module>
|
<module>xgboost4j-demo</module>
|
||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
<module>xgboost4j-spark</module>
|
|
||||||
</modules>
|
</modules>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|||||||
@ -6,13 +6,13 @@ package ml.dmlc.xgboost4j;
|
|||||||
*/
|
*/
|
||||||
public class LabeledPoint {
|
public class LabeledPoint {
|
||||||
/** Label of the point */
|
/** Label of the point */
|
||||||
float label;
|
public float label;
|
||||||
/** Weight of this data point */
|
/** Weight of this data point */
|
||||||
float weight = 1.0f;
|
public float weight = 1.0f;
|
||||||
/** Feature indices, used for sparse input */
|
/** Feature indices, used for sparse input */
|
||||||
int[] indices = null;
|
public int[] indices = null;
|
||||||
/** Feature values */
|
/** Feature values */
|
||||||
float[] values;
|
public float[] values;
|
||||||
|
|
||||||
private LabeledPoint() {}
|
private LabeledPoint() {}
|
||||||
|
|
||||||
@ -27,6 +27,7 @@ public class LabeledPoint {
|
|||||||
ret.label = label;
|
ret.label = label;
|
||||||
ret.indices = indices;
|
ret.indices = indices;
|
||||||
ret.values = values;
|
ret.values = values;
|
||||||
|
assert indices.length == values.length;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,8 @@ import java.util.Iterator;
|
|||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DMatrix for xgboost.
|
* DMatrix for xgboost.
|
||||||
*
|
*
|
||||||
@ -52,20 +54,18 @@ public class DMatrix {
|
|||||||
* Create DMatrix from iterator.
|
* Create DMatrix from iterator.
|
||||||
*
|
*
|
||||||
* @param iter The data iterator of mini batch to provide the data.
|
* @param iter The data iterator of mini batch to provide the data.
|
||||||
* @param cache_info Cache path information, used for external memory setting, can be null.
|
* @param cacheInfo Cache path information, used for external memory setting, can be null.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
|
public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
|
||||||
if (iter == null) {
|
if (iter == null) {
|
||||||
throw new NullPointerException("iter: null");
|
throw new NullPointerException("iter: null");
|
||||||
}
|
}
|
||||||
try {
|
// 32k as batch size
|
||||||
logger.info(iter.getClass().getMethod("next").toString());
|
int batchSize = 32 << 10;
|
||||||
} catch(NoSuchMethodException e) {
|
Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
|
||||||
logger.info(e.toString());
|
|
||||||
}
|
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
|
||||||
handle = out[0];
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,16 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A mini-batch of data that can be converted to DMatrix.
|
* A mini-batch of data that can be converted to DMatrix.
|
||||||
* The data is in sparse matrix CSR format.
|
* The data is in sparse matrix CSR format.
|
||||||
*
|
*
|
||||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||||
*/
|
*/
|
||||||
public class DataBatch {
|
class DataBatch {
|
||||||
/** The offset of each rows in the sparse matrix */
|
/** The offset of each rows in the sparse matrix */
|
||||||
long[] rowOffset = null;
|
long[] rowOffset = null;
|
||||||
/** weight of each data point, can be null */
|
/** weight of each data point, can be null */
|
||||||
@ -51,4 +55,58 @@ public class DataBatch {
|
|||||||
b.featureValue = this.featureValue;
|
b.featureValue = this.featureValue;
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static class BatchIterator implements Iterator<DataBatch> {
|
||||||
|
private Iterator<LabeledPoint> base;
|
||||||
|
private int batchSize;
|
||||||
|
|
||||||
|
BatchIterator(java.util.Iterator<LabeledPoint> base, int batchSize) {
|
||||||
|
this.base = base;
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return base.hasNext();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public DataBatch next() {
|
||||||
|
int num_rows = 0, num_elem = 0;
|
||||||
|
java.util.List<LabeledPoint> batch = new java.util.ArrayList<LabeledPoint>();
|
||||||
|
for (int i = 0; i < this.batchSize; ++i) {
|
||||||
|
if (!base.hasNext()) break;
|
||||||
|
LabeledPoint inst = base.next();
|
||||||
|
batch.add(inst);
|
||||||
|
num_elem += inst.values.length;
|
||||||
|
++num_rows;
|
||||||
|
}
|
||||||
|
DataBatch ret = new DataBatch();
|
||||||
|
// label
|
||||||
|
ret.rowOffset = new long[num_rows + 1];
|
||||||
|
ret.label = new float[num_rows];
|
||||||
|
ret.featureIndex = new int[num_elem];
|
||||||
|
ret.featureValue = new float[num_elem];
|
||||||
|
// current offset
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = 0; i < batch.size(); ++i) {
|
||||||
|
LabeledPoint inst = batch.get(i);
|
||||||
|
ret.rowOffset[i] = offset;
|
||||||
|
ret.label[i] = inst.label;
|
||||||
|
if (inst.indices != null) {
|
||||||
|
System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length);
|
||||||
|
} else{
|
||||||
|
for (int j = 0; j < inst.values.length; ++j) {
|
||||||
|
ret.featureIndex[offset + j] = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length);
|
||||||
|
offset += inst.values.length;
|
||||||
|
}
|
||||||
|
ret.rowOffset[batch.size()] = offset;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public void remove() {
|
||||||
|
throw new Error("not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import _root_.scala.collection.JavaConverters._
|
import _root_.scala.collection.JavaConverters._
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
||||||
|
|
||||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||||
@ -31,6 +31,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
this(new JDMatrix(dataPath))
|
this(new JDMatrix(dataPath))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* init DMatrix from Iterator of LabeledPoint
|
||||||
|
*
|
||||||
|
* @param dataIter An iterator of LabeledPoint
|
||||||
|
* @param cacheInfo Cache path information, used for external memory setting, can be null.
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
def this(dataIter: Iterator[LabeledPoint], cacheInfo: String) {
|
||||||
|
this(new JDMatrix(dataIter.asJava, cacheInfo))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create DMatrix from sparse matrix
|
* create DMatrix from sparse matrix
|
||||||
*
|
*
|
||||||
@ -44,10 +55,6 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
this(new JDMatrix(headers, indices, data, st))
|
this(new JDMatrix(headers, indices, data, st))
|
||||||
}
|
}
|
||||||
|
|
||||||
private[xgboost4j] def this(dataBatches: Iterator[DataBatch]) {
|
|
||||||
this(new JDMatrix(dataBatches.asJava, null))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create DMatrix from dense matrix
|
* create DMatrix from dense matrix
|
||||||
*
|
*
|
||||||
|
|||||||
@ -15,10 +15,12 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
import ml.dmlc.xgboost4j.java.DataBatch;
|
import ml.dmlc.xgboost4j.java.DataBatch;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
@ -34,33 +36,19 @@ public class DMatrixTest {
|
|||||||
@Test
|
@Test
|
||||||
public void testCreateFromDataIterator() throws XGBoostError {
|
public void testCreateFromDataIterator() throws XGBoostError {
|
||||||
//create DMatrix from DataIterator
|
//create DMatrix from DataIterator
|
||||||
/**
|
|
||||||
* sparse matrix
|
|
||||||
* 1 0 2 3 0
|
|
||||||
* 4 0 2 3 5
|
|
||||||
* 3 1 2 5 0
|
|
||||||
*/
|
|
||||||
DataBatch batch = new DataBatch();
|
|
||||||
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
|
||||||
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
|
||||||
batch.rowOffset = new long[]{0, 3, 7, 11};
|
|
||||||
batch.label = new float[] {0.1f, 0.2f, 0.3f};
|
|
||||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||||
int nrep = 3;
|
int nrep = 3000;
|
||||||
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
|
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||||
for (int i = 0; i < nrep; ++i) {
|
for (int i = 0; i < nrep; ++i) {
|
||||||
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
|
LabeledPoint p = LabeledPoint.fromSparseVector(
|
||||||
blist.add(batch.shallowCopy());
|
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||||
for (float f : batch.label) {
|
blist.add(p);
|
||||||
labelall.add(f);
|
labelall.add(p.label);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||||
// get label
|
// get label
|
||||||
float[] labels = dmat.getLabel();
|
float[] labels = dmat.getLabel();
|
||||||
// get label
|
|
||||||
TestCase.assertTrue(batch.label.length * nrep == labels.length);
|
|
||||||
|
|
||||||
for (int i = 0; i < labels.length; ++i) {
|
for (int i = 0; i < labels.length; ++i) {
|
||||||
TestCase.assertTrue(labelall.get(i) == labels[i]);
|
TestCase.assertTrue(labelall.get(i) == labels[i]);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user