Add public group getter for java and scala (#4838)
* Add public group getter for java and scala * Remove unnecessary param from javadoc * Fix typo * Fix another typo * Add semicolon * Fix javadoc return statement * Fix missing return statement * Add a unit test
This commit is contained in:
parent
f90e7f9aa8
commit
0fc7dcfe6c
@ -204,6 +204,16 @@ public class DMatrix {
|
|||||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get group sizes of DMatrix
|
||||||
|
*
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
* @return group size as array
|
||||||
|
*/
|
||||||
|
public int[] getGroup() throws XGBoostError {
|
||||||
|
return getIntInfo("group_ptr");
|
||||||
|
}
|
||||||
|
|
||||||
private float[] getFloatInfo(String field) throws XGBoostError {
|
private float[] getFloatInfo(String field) throws XGBoostError {
|
||||||
float[][] infos = new float[1][];
|
float[][] infos = new float[1][];
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||||
|
|||||||
@ -149,6 +149,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
jDMatrix.setGroup(group)
|
jDMatrix.setGroup(group)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get group sizes of DMatrix (used for ranking)
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getGroup(): Array[Int] = {
|
||||||
|
jDMatrix.getGroup()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get label values
|
* get label values
|
||||||
*
|
*
|
||||||
|
|||||||
@ -223,4 +223,33 @@ public class DMatrixTest {
|
|||||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSetAndGetGroup() throws XGBoostError {
|
||||||
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
int nrow = 10;
|
||||||
|
int ncol = 5;
|
||||||
|
float[] data0 = new float[nrow * ncol];
|
||||||
|
//put random nums
|
||||||
|
Random random = new Random();
|
||||||
|
for (int i = 0; i < nrow * ncol; i++) {
|
||||||
|
data0[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
|
||||||
|
//create label
|
||||||
|
float[] label0 = new float[nrow];
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
|
label0[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
|
||||||
|
//create two groups
|
||||||
|
int[] groups = new int[]{5, 5};
|
||||||
|
|
||||||
|
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
|
||||||
|
dmat0.setLabel(label0);
|
||||||
|
dmat0.setGroup(groups);
|
||||||
|
|
||||||
|
//check
|
||||||
|
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user