[jvm-packages] create dmatrix with specified missing value (#1272)

* create dmatrix with specified missing value

* update dmlc-core

* support for predict method in spark package

repartitioning

work around

* add more elements to work around training set empty partition issue
This commit is contained in:
Nan Zhu
2016-06-21 17:35:17 -04:00
committed by GitHub
parent c9a73fe2a9
commit bd5b07873e
6 changed files with 143 additions and 2 deletions

View File

@@ -118,6 +118,19 @@ public class DMatrix {
handle = out[0];
}
/**
* create DMatrix from dense matrix
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
* @param missing the specified value to represent the missing value
*/
public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
handle = out[0];
}
/**
* used for DMatrix slice
*/

View File

@@ -67,6 +67,19 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(data, nrow, ncol))
}
/**
* create DMatrix from dense matrix
*
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
* @param missing the specified value to represent the missing value
*/
@throws(classOf[XGBoostError])
def this(data: Array[Float], nrow: Int, ncol: Int, missing: Float) {
this(new JDMatrix(data, nrow, ncol, missing))
}
/**
* set label of dmatrix
*

View File

@@ -125,4 +125,34 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
@Test
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
if (i % 10 == 0) {
data0[i] = -0.1f;
} else {
data0[i] = random.nextFloat();
}
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
}
}

View File

@@ -82,4 +82,28 @@ class DMatrixSuite extends FunSuite {
dmat0.setWeight(weights)
assert(weights === dmat0.getWeight)
}
test("create DMatrix from DenseMatrix with missing value") {
val nrow = 10
val ncol = 5
val data0 = new Array[Float](nrow * ncol)
// put random nums
for (i <- data0.indices) {
if (i % 10 == 0) {
data0(i) = -0.1f
} else {
data0(i) = Random.nextFloat()
}
}
// create label
val label0 = new Array[Float](nrow)
for (i <- label0.indices) {
label0(i) = Random.nextFloat()
}
val dmat0 = new DMatrix(data0, nrow, ncol, -0.1f)
dmat0.setLabel(label0)
// check
assert(dmat0.rowNum === 10)
assert(dmat0.getLabel.length === 10)
}
}