[jvm-packages] create dmatrix with specified missing value (#1272)
* create dmatrix with specified missing value * update dmlc-core * support for predict method in spark package repartitioning work around * add more elements to work around training set empty partition issue
This commit is contained in:
@@ -118,6 +118,19 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
*/
|
||||
|
||||
@@ -67,6 +67,19 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
this(new JDMatrix(data, nrow, ncol))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(data: Array[Float], nrow: Int, ncol: Int, missing: Float) {
|
||||
this(new JDMatrix(data, nrow, ncol, missing))
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
|
||||
@@ -125,4 +125,34 @@ public class DMatrixTest {
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
int nrow = 10;
|
||||
int ncol = 5;
|
||||
float[] data0 = new float[nrow * ncol];
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nrow * ncol; i++) {
|
||||
if (i % 10 == 0) {
|
||||
data0[i] = -0.1f;
|
||||
} else {
|
||||
data0[i] = random.nextFloat();
|
||||
}
|
||||
}
|
||||
|
||||
//create label
|
||||
float[] label0 = new float[nrow];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||
dmat0.setLabel(label0);
|
||||
|
||||
//check
|
||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,4 +82,28 @@ class DMatrixSuite extends FunSuite {
|
||||
dmat0.setWeight(weights)
|
||||
assert(weights === dmat0.getWeight)
|
||||
}
|
||||
|
||||
test("create DMatrix from DenseMatrix with missing value") {
|
||||
val nrow = 10
|
||||
val ncol = 5
|
||||
val data0 = new Array[Float](nrow * ncol)
|
||||
// put random nums
|
||||
for (i <- data0.indices) {
|
||||
if (i % 10 == 0) {
|
||||
data0(i) = -0.1f
|
||||
} else {
|
||||
data0(i) = Random.nextFloat()
|
||||
}
|
||||
}
|
||||
// create label
|
||||
val label0 = new Array[Float](nrow)
|
||||
for (i <- label0.indices) {
|
||||
label0(i) = Random.nextFloat()
|
||||
}
|
||||
val dmat0 = new DMatrix(data0, nrow, ncol, -0.1f)
|
||||
dmat0.setLabel(label0)
|
||||
// check
|
||||
assert(dmat0.rowNum === 10)
|
||||
assert(dmat0.getLabel.length === 10)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user