Add public group getter for java and scala (#4838)

* Add public group getter for java and scala

* Remove unnecessary param from javadoc

* Fix typo

* Fix another typo

* Add semicolon

* Fix javadoc return statement

* Fix missing return statement

* Add a unit test
This commit is contained in:
Stephanie Yang 2019-09-09 13:07:48 -04:00 committed by Philip Hyunsu Cho
parent f90e7f9aa8
commit 0fc7dcfe6c
3 changed files with 47 additions and 0 deletions

View File

@ -204,6 +204,16 @@ public class DMatrix {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
} }
/**
* Get group sizes of DMatrix
*
* @throws XGBoostError native error
* @return group size as array
*/
public int[] getGroup() throws XGBoostError {
return getIntInfo("group_ptr");
}
private float[] getFloatInfo(String field) throws XGBoostError { private float[] getFloatInfo(String field) throws XGBoostError {
float[][] infos = new float[1][]; float[][] infos = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));

View File

@ -149,6 +149,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.setGroup(group) jDMatrix.setGroup(group)
} }
/**
* Get group sizes of DMatrix (used for ranking)
*/
@throws(classOf[XGBoostError])
def getGroup(): Array[Int] = {
jDMatrix.getGroup()
}
/** /**
* get label values * get label values
* *

View File

@ -223,4 +223,33 @@ public class DMatrixTest {
TestCase.assertTrue(dmat0.rowNum() == 10); TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10); TestCase.assertTrue(dmat0.getLabel().length == 10);
} }
@Test
public void testSetAndGetGroup() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
//create two groups
int[] groups = new int[]{5, 5};
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
dmat0.setGroup(groups);
//check
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
}
} }