diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index dc11e307d..7bfb35a7f 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -21,7 +21,6 @@
xgboost4j
xgboost4j-demo
xgboost4j-flink
- xgboost4j-spark
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
index 19d4b7da9..fc14e361e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java
@@ -6,13 +6,13 @@ package ml.dmlc.xgboost4j;
*/
public class LabeledPoint {
/** Label of the point */
- float label;
+ public float label;
/** Weight of this data point */
- float weight = 1.0f;
+ public float weight = 1.0f;
/** Feature indices, used for sparse input */
- int[] indices = null;
+ public int[] indices = null;
/** Feature values */
- float[] values;
+ public float[] values;
private LabeledPoint() {}
@@ -27,6 +27,7 @@ public class LabeledPoint {
ret.label = label;
ret.indices = indices;
ret.values = values;
+ assert indices.length == values.length;
return ret;
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java
index a5a9b4972..2a7461377 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java
@@ -21,6 +21,8 @@ import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import ml.dmlc.xgboost4j.LabeledPoint;
+
/**
* DMatrix for xgboost.
*
@@ -52,20 +54,18 @@ public class DMatrix {
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
- * @param cache_info Cache path information, used for external memory setting, can be null.
+ * @param cacheInfo Cache path information, used for external memory setting, can be null.
* @throws XGBoostError
*/
- public DMatrix(Iterator iter, String cache_info) throws XGBoostError {
+ public DMatrix(Iterator iter, String cacheInfo) throws XGBoostError {
if (iter == null) {
throw new NullPointerException("iter: null");
}
- try {
- logger.info(iter.getClass().getMethod("next").toString());
- } catch(NoSuchMethodException e) {
- logger.info(e.toString());
- }
+ // 32k as batch size
+ int batchSize = 32 << 10;
+ Iterator batchIter = new DataBatch.BatchIterator(iter, batchSize);
long[] out = new long[1];
- JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
+ JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
handle = out[0];
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
index b4bbbe690..2a52d0b9b 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java
@@ -1,12 +1,16 @@
package ml.dmlc.xgboost4j.java;
+import java.util.Iterator;
+
+import ml.dmlc.xgboost4j.LabeledPoint;
+
/**
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
-public class DataBatch {
+class DataBatch {
/** The offset of each rows in the sparse matrix */
long[] rowOffset = null;
/** weight of each data point, can be null */
@@ -51,4 +55,58 @@ public class DataBatch {
b.featureValue = this.featureValue;
return b;
}
+
+ static class BatchIterator implements Iterator {
+ private Iterator base;
+ private int batchSize;
+
+ BatchIterator(java.util.Iterator base, int batchSize) {
+ this.base = base;
+ this.batchSize = batchSize;
+ }
+ @Override
+ public boolean hasNext() {
+ return base.hasNext();
+ }
+ @Override
+ public DataBatch next() {
+ int num_rows = 0, num_elem = 0;
+ java.util.List batch = new java.util.ArrayList();
+ for (int i = 0; i < this.batchSize; ++i) {
+ if (!base.hasNext()) break;
+ LabeledPoint inst = base.next();
+ batch.add(inst);
+ num_elem += inst.values.length;
+ ++num_rows;
+ }
+ DataBatch ret = new DataBatch();
+ // label
+ ret.rowOffset = new long[num_rows + 1];
+ ret.label = new float[num_rows];
+ ret.featureIndex = new int[num_elem];
+ ret.featureValue = new float[num_elem];
+ // current offset
+ int offset = 0;
+ for (int i = 0; i < batch.size(); ++i) {
+ LabeledPoint inst = batch.get(i);
+ ret.rowOffset[i] = offset;
+ ret.label[i] = inst.label;
+ if (inst.indices != null) {
+ System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length);
+ } else{
+ for (int j = 0; j < inst.values.length; ++j) {
+ ret.featureIndex[offset + j] = j;
+ }
+ }
+ System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length);
+ offset += inst.values.length;
+ }
+ ret.rowOffset[batch.size()] = offset;
+ return ret;
+ }
+ @Override
+ public void remove() {
+ throw new Error("not implemented");
+ }
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
index c62784c08..cdf3a9844 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
-
+import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@@ -31,6 +31,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(dataPath))
}
+ /**
+ * init DMatrix from Iterator of LabeledPoint
+ *
+ * @param dataIter An iterator of LabeledPoint
+ * @param cacheInfo Cache path information, used for external memory setting, can be null.
+ * @throws XGBoostError native error
+ */
+ def this(dataIter: Iterator[LabeledPoint], cacheInfo: String) {
+ this(new JDMatrix(dataIter.asJava, cacheInfo))
+ }
+
/**
* create DMatrix from sparse matrix
*
@@ -44,10 +55,6 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(headers, indices, data, st))
}
- private[xgboost4j] def this(dataBatches: Iterator[DataBatch]) {
- this(new JDMatrix(dataBatches.asJava, null))
- }
-
/**
* create DMatrix from dense matrix
*
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
index 056291ddf..d9004b418 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
@@ -15,10 +15,12 @@
*/
package ml.dmlc.xgboost4j.java;
+import java.awt.*;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
+import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
@@ -34,33 +36,19 @@ public class DMatrixTest {
@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
- /**
- * sparse matrix
- * 1 0 2 3 0
- * 4 0 2 3 5
- * 3 1 2 5 0
- */
- DataBatch batch = new DataBatch();
- batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
- batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
- batch.rowOffset = new long[]{0, 3, 7, 11};
- batch.label = new float[] {0.1f, 0.2f, 0.3f};
+
java.util.ArrayList labelall = new java.util.ArrayList();
- int nrep = 3;
- java.util.List blist = new java.util.LinkedList();
+ int nrep = 3000;
+ java.util.List blist = new java.util.LinkedList();
for (int i = 0; i < nrep; ++i) {
- batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
- blist.add(batch.shallowCopy());
- for (float f : batch.label) {
- labelall.add(f);
- }
+ LabeledPoint p = LabeledPoint.fromSparseVector(
+ 0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
+ blist.add(p);
+ labelall.add(p.label);
}
DMatrix dmat = new DMatrix(blist.iterator(), null);
// get label
float[] labels = dmat.getLabel();
- // get label
- TestCase.assertTrue(batch.label.length * nrep == labels.length);
-
for (int i = 0; i < labels.length; ++i) {
TestCase.assertTrue(labelall.get(i) == labels[i]);
}