diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index dc11e307d..7bfb35a7f 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -21,7 +21,6 @@ xgboost4j xgboost4j-demo xgboost4j-flink - xgboost4j-spark diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java index 19d4b7da9..fc14e361e 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java @@ -6,13 +6,13 @@ package ml.dmlc.xgboost4j; */ public class LabeledPoint { /** Label of the point */ - float label; + public float label; /** Weight of this data point */ - float weight = 1.0f; + public float weight = 1.0f; /** Feature indices, used for sparse input */ - int[] indices = null; + public int[] indices = null; /** Feature values */ - float[] values; + public float[] values; private LabeledPoint() {} @@ -27,6 +27,7 @@ public class LabeledPoint { ret.label = label; ret.indices = indices; ret.values = values; + assert indices.length == values.length; return ret; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index a5a9b4972..2a7461377 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -21,6 +21,8 @@ import java.util.Iterator; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import ml.dmlc.xgboost4j.LabeledPoint; + /** * DMatrix for xgboost. * @@ -52,20 +54,18 @@ public class DMatrix { * Create DMatrix from iterator. * * @param iter The data iterator of mini batch to provide the data. - * @param cache_info Cache path information, used for external memory setting, can be null. + * @param cacheInfo Cache path information, used for external memory setting, can be null. * @throws XGBoostError */ - public DMatrix(Iterator iter, String cache_info) throws XGBoostError { + public DMatrix(Iterator iter, String cacheInfo) throws XGBoostError { if (iter == null) { throw new NullPointerException("iter: null"); } - try { - logger.info(iter.getClass().getMethod("next").toString()); - } catch(NoSuchMethodException e) { - logger.info(e.toString()); - } + // 32k as batch size + int batchSize = 32 << 10; + Iterator batchIter = new DataBatch.BatchIterator(iter, batchSize); long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out)); + JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); handle = out[0]; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index b4bbbe690..2a52d0b9b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -1,12 +1,16 @@ package ml.dmlc.xgboost4j.java; +import java.util.Iterator; + +import ml.dmlc.xgboost4j.LabeledPoint; + /** * A mini-batch of data that can be converted to DMatrix. * The data is in sparse matrix CSR format. * * This class is used to support advanced creation of DMatrix from Iterator of DataBatch, */ -public class DataBatch { +class DataBatch { /** The offset of each rows in the sparse matrix */ long[] rowOffset = null; /** weight of each data point, can be null */ @@ -51,4 +55,58 @@ public class DataBatch { b.featureValue = this.featureValue; return b; } + + static class BatchIterator implements Iterator { + private Iterator base; + private int batchSize; + + BatchIterator(java.util.Iterator base, int batchSize) { + this.base = base; + this.batchSize = batchSize; + } + @Override + public boolean hasNext() { + return base.hasNext(); + } + @Override + public DataBatch next() { + int num_rows = 0, num_elem = 0; + java.util.List batch = new java.util.ArrayList(); + for (int i = 0; i < this.batchSize; ++i) { + if (!base.hasNext()) break; + LabeledPoint inst = base.next(); + batch.add(inst); + num_elem += inst.values.length; + ++num_rows; + } + DataBatch ret = new DataBatch(); + // label + ret.rowOffset = new long[num_rows + 1]; + ret.label = new float[num_rows]; + ret.featureIndex = new int[num_elem]; + ret.featureValue = new float[num_elem]; + // current offset + int offset = 0; + for (int i = 0; i < batch.size(); ++i) { + LabeledPoint inst = batch.get(i); + ret.rowOffset[i] = offset; + ret.label[i] = inst.label; + if (inst.indices != null) { + System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length); + } else{ + for (int j = 0; j < inst.values.length; ++j) { + ret.featureIndex[offset + j] = j; + } + } + System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length); + offset += inst.values.length; + } + ret.rowOffset[batch.size()] = offset; + return ret; + } + @Override + public void remove() { + throw new Error("not implemented"); + } + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index c62784c08..cdf3a9844 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala import _root_.scala.collection.JavaConverters._ - +import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError} class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { @@ -31,6 +31,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { this(new JDMatrix(dataPath)) } + /** + * init DMatrix from Iterator of LabeledPoint + * + * @param dataIter An iterator of LabeledPoint + * @param cacheInfo Cache path information, used for external memory setting, can be null. + * @throws XGBoostError native error + */ + def this(dataIter: Iterator[LabeledPoint], cacheInfo: String) { + this(new JDMatrix(dataIter.asJava, cacheInfo)) + } + /** * create DMatrix from sparse matrix * @@ -44,10 +55,6 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { this(new JDMatrix(headers, indices, data, st)) } - private[xgboost4j] def this(dataBatches: Iterator[DataBatch]) { - this(new JDMatrix(dataBatches.asJava, null)) - } - /** * create DMatrix from dense matrix * diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 056291ddf..d9004b418 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -15,10 +15,12 @@ */ package ml.dmlc.xgboost4j.java; +import java.awt.*; import java.util.Arrays; import java.util.Random; import junit.framework.TestCase; +import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.DataBatch; import ml.dmlc.xgboost4j.java.XGBoostError; @@ -34,33 +36,19 @@ public class DMatrixTest { @Test public void testCreateFromDataIterator() throws XGBoostError { //create DMatrix from DataIterator - /** - * sparse matrix - * 1 0 2 3 0 - * 4 0 2 3 5 - * 3 1 2 5 0 - */ - DataBatch batch = new DataBatch(); - batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; - batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; - batch.rowOffset = new long[]{0, 3, 7, 11}; - batch.label = new float[] {0.1f, 0.2f, 0.3f}; + java.util.ArrayList labelall = new java.util.ArrayList(); - int nrep = 3; - java.util.List blist = new java.util.LinkedList(); + int nrep = 3000; + java.util.List blist = new java.util.LinkedList(); for (int i = 0; i < nrep; ++i) { - batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i}; - blist.add(batch.shallowCopy()); - for (float f : batch.label) { - labelall.add(f); - } + LabeledPoint p = LabeledPoint.fromSparseVector( + 0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5}); + blist.add(p); + labelall.add(p.label); } DMatrix dmat = new DMatrix(blist.iterator(), null); // get label float[] labels = dmat.getLabel(); - // get label - TestCase.assertTrue(batch.label.length * nrep == labels.length); - for (int i = 0; i < labels.length; ++i) { TestCase.assertTrue(labelall.get(i) == labels[i]); }