Add public group getter for java and scala (#4838)
* Add public group getter for java and scala * Remove unnecessary param from javadoc * Fix typo * Fix another typo * Add semicolon * Fix javadoc return statement * Fix missing return statement * Add a unit test
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
f90e7f9aa8
commit
0fc7dcfe6c
@@ -204,6 +204,16 @@ public class DMatrix {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix
|
||||
*
|
||||
* @throws XGBoostError native error
|
||||
* @return group size as array
|
||||
*/
|
||||
public int[] getGroup() throws XGBoostError {
|
||||
return getIntInfo("group_ptr");
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) throws XGBoostError {
|
||||
float[][] infos = new float[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
|
||||
@@ -149,6 +149,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setGroup(group)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix (used for ranking)
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getGroup(): Array[Int] = {
|
||||
jDMatrix.getGroup()
|
||||
}
|
||||
|
||||
/**
|
||||
* get label values
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user