[jvm-packages] Robust dmatrix creation (#1613)
* add back train method but mark as deprecated * robust matrix creation in jvm
This commit is contained in:
@@ -91,6 +91,78 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSREx() throws XGBoostError {
|
||||
//create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
|
||||
//check row num
|
||||
TestCase.assertTrue(dmat1.rowNum() == 3);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSC() throws XGBoostError {
|
||||
//create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSCEx() throws XGBoostError {
|
||||
//create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrix() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
|
||||
@@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite {
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSREx") {
|
||||
// create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
|
||||
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val rowHeaders = List[Long](0, 3, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5)
|
||||
assert(dmat1.rowNum === 3)
|
||||
val label1 = List[Float](1, 0, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSC") {
|
||||
// create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC)
|
||||
assert(dmat1.rowNum === 5)
|
||||
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSCEx") {
|
||||
// create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5)
|
||||
assert(dmat1.rowNum === 5)
|
||||
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from DenseMatrix") {
|
||||
val nrow = 10
|
||||
val ncol = 5
|
||||
|
||||
Reference in New Issue
Block a user