[jvm-packages] Robust dmatrix creation (#1613)

* add back train method but mark as deprecated

* robust matrix creation in jvm
This commit is contained in:
Nan Zhu
2016-09-26 13:35:04 -04:00
committed by GitHub
parent 915ac0b8fe
commit 37bc122c90
7 changed files with 197 additions and 20 deletions

View File

@@ -91,6 +91,78 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSREx() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
//check row num
TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label
float[] label1 = new float[]{1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSC() throws XGBoostError {
//create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
long[] colHeaders = new long[]{0, 4, 7, 11};
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 5);
//test set label
float[] label1 = new float[]{1, 0, 1, 1, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSCEx() throws XGBoostError {
//create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
long[] colHeaders = new long[]{0, 4, 7, 11};
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 5);
//test set label
float[] label1 = new float[]{1, 0, 1, 1, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix

View File

@@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite {
assert(label2 === label1)
}
test("create DMatrix from CSREx") {
// create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
val rowHeaders = List[Long](0, 3, 7, 11).toArray
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5)
assert(dmat1.rowNum === 3)
val label1 = List[Float](1, 0, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from CSC") {
// create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
val colHeaders = List[Long](0, 4, 7, 11).toArray
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC)
assert(dmat1.rowNum === 5)
val label1 = List[Float](1, 0, 1, 1, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from CSCEx") {
// create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
val colHeaders = List[Long](0, 4, 7, 11).toArray
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5)
assert(dmat1.rowNum === 5)
val label1 = List[Float](1, 0, 1, 1, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from DenseMatrix") {
val nrow = 10
val ncol = 5