diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index c62814717..02dc5e58f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -204,6 +204,16 @@ public class DMatrix { XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); } + /** + * Get group sizes of DMatrix + * + * @throws XGBoostError native error + * @return group size as array + */ + public int[] getGroup() throws XGBoostError { + return getIntInfo("group_ptr"); + } + private float[] getFloatInfo(String field) throws XGBoostError { float[][] infos = new float[1][]; XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index bf2952ec5..629a39dbf 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -149,6 +149,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { jDMatrix.setGroup(group) } + /** + * Get group sizes of DMatrix (used for ranking) + */ + @throws(classOf[XGBoostError]) + def getGroup(): Array[Int] = { + jDMatrix.getGroup() + } + /** * get label values * diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 24c783987..b121bb887 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -223,4 +223,33 @@ public class DMatrixTest { TestCase.assertTrue(dmat0.rowNum() == 10); TestCase.assertTrue(dmat0.getLabel().length == 10); } + + @Test + public void testSetAndGetGroup() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + int nrow = 10; + int ncol = 5; + float[] data0 = new float[nrow * ncol]; + //put random nums + Random random = new Random(); + for (int i = 0; i < nrow * ncol; i++) { + data0[i] = random.nextFloat(); + } + + //create label + float[] label0 = new float[nrow]; + for (int i = 0; i < nrow; i++) { + label0[i] = random.nextFloat(); + } + + //create two groups + int[] groups = new int[]{5, 5}; + + DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f); + dmat0.setLabel(label0); + dmat0.setGroup(groups); + + //check + TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup())); + } }