framework of xgboost-spark

iterator

return java iterator and recover test
This commit is contained in:
CodingCat
2016-03-04 23:26:45 -05:00
parent 1540773340
commit b2d705ffb0
15 changed files with 194 additions and 156 deletions

View File

@@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
private long handle = 0;
protected long handle = 0;
//load native library
static {

View File

@@ -4,8 +4,6 @@ package ml.dmlc.xgboost4j;
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
@@ -19,6 +17,19 @@ public class DataBatch {
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
public DataBatch() {}
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
this.featureIndex = featureIndex;
this.featureValue = featureValue;
}
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.

View File

@@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
}
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out)
throws IOException {
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {

View File

@@ -27,7 +27,8 @@ class XgboostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);

View File

@@ -16,7 +16,9 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, DataBatch, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@@ -43,6 +45,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(headers, indices, data, st))
}
private[xgboost4j] def this(dataBatch: DataBatch) {
this(new JDMatrix(List(dataBatch).asJava.iterator, null))
}
/**
* create DMatrix from dense matrix
*