framework of xgboost-spark
iterator return java iterator and recover test
This commit is contained in:
@@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
*/
|
||||
public class DMatrix {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
private long handle = 0;
|
||||
protected long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
|
||||
@@ -4,8 +4,6 @@ package ml.dmlc.xgboost4j;
|
||||
* A mini-batch of data that can be converted to DMatrix.
|
||||
* The data is in sparse matrix CSR format.
|
||||
*
|
||||
* Usually this object is not needed.
|
||||
*
|
||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||
*/
|
||||
public class DataBatch {
|
||||
@@ -19,6 +17,19 @@ public class DataBatch {
|
||||
int[] featureIndex = null;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
float[] featureValue = null;
|
||||
|
||||
public DataBatch() {}
|
||||
|
||||
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
this.rowOffset = rowOffset;
|
||||
this.weight = weight;
|
||||
this.label = label;
|
||||
this.featureIndex = featureIndex;
|
||||
this.featureValue = featureValue;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get number of rows in the data batch.
|
||||
* @return Number of rows in the data batch.
|
||||
|
||||
@@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
|
||||
}
|
||||
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out)
|
||||
throws IOException {
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
|
||||
@@ -27,7 +27,8 @@ class XgboostJNI {
|
||||
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
||||
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
@@ -43,6 +45,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
this(new JDMatrix(headers, indices, data, st))
|
||||
}
|
||||
|
||||
private[xgboost4j] def this(dataBatch: DataBatch) {
|
||||
this(new JDMatrix(List(dataBatch).asJava.iterator, null))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user