From 78694405a6123b12c62531958c4c7a655e7f0bc0 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 3 Jun 2022 11:09:48 +0800 Subject: [PATCH] [jvm-packages] add jni for setting feature name and type (#7966) --- .../dmlc/xgboost4j/gpu/java/DMatrixTest.java | 14 ++++- .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 62 ++++++++++++++++++- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 14 +++++ .../xgboost4j/src/native/xgboost4j.cpp | 60 +++++++++++++++++- jvm-packages/xgboost4j/src/native/xgboost4j.h | 16 +++++ .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 27 +++++++- 6 files changed, 189 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java index b0c96a828..341064571 100644 --- a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,8 @@ import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix; import ml.dmlc.xgboost4j.java.ColumnBatch; import ml.dmlc.xgboost4j.java.XGBoostError; +import static org.junit.Assert.assertArrayEquals; + /** * Test suite for DMatrix based on GPU */ @@ -60,6 +62,16 @@ public class DMatrixTest { dMatrix.setWeight(weightColumn); dMatrix.setBaseMargin(baseMarginColumn); + String[] featureNames = new String[]{"f1"}; + dMatrix.setFeatureNames(featureNames); + String[] retFeatureNames = dMatrix.getFeatureNames(); + assertArrayEquals(featureNames, retFeatureNames); + + String[] featureTypes = new String[]{"i"}; + dMatrix.setFeatureTypes(featureTypes); + String[] retFeatureTypes = dMatrix.getFeatureTypes(); + assertArrayEquals(featureTypes, retFeatureTypes); + float[] anchor = convertFloatTofloat(labelFloats); float[] label = dMatrix.getLabel(); float[] weight = dMatrix.getWeight(); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 37263eae4..68934ad1e 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -236,6 +236,66 @@ public class DMatrix { XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json)); } + private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError { + if (type == null || type.isEmpty()) { + throw new XGBoostError("Found empty type"); + } + if (values == null || values.length == 0) { + throw new XGBoostError("Found empty values"); + } + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(handle, type, values)); + } + + private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError { + if (type == null || type.isEmpty()) { + throw new XGBoostError("Found empty type"); + } + long[] outLen = new long[1]; + String[][] outValue = new String[1][]; + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(handle, type, outLen, outValue)); + + if (outLen[0] != outValue[0].length) { + throw new RuntimeException("Failed to get " + type); + } + return outValue[0]; + } + + /** + * Set feature names + * @param values feature names to be set + * @throws XGBoostError + */ + public void setFeatureNames(String[] values) throws XGBoostError { + setXGBDMatrixFeatureInfo("feature_name", values); + } + + /** + * Get feature names + * @return an array of feature names to be returned + * @throws XGBoostError + */ + public String[] getFeatureNames() throws XGBoostError { + return getXGBDMatrixFeatureInfo("feature_name"); + } + + /** + * Set feature types + * @param values feature types to be set + * @throws XGBoostError + */ + public void setFeatureTypes(String[] values) throws XGBoostError { + setXGBDMatrixFeatureInfo("feature_type", values); + } + + /** + * Get feature types + * @return an array of feature types to be returned + * @throws XGBoostError + */ + public String[] getFeatureTypes() throws XGBoostError { + return getXGBDMatrixFeatureInfo("feature_type"); + } + /** * set label of dmatrix * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 22b3155fe..d2285af90 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -82,6 +82,19 @@ class XGBoostJNI { public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info); + /** + * Set the feature information + * @param handle the DMatrix native address + * @param field "feature_names" or "feature_types" + * @param values an array of string + * @return 0 when success, -1 when failure happens + */ + public final static native int XGDMatrixSetStrFeatureInfo(long handle, String field, + String[] values); + + public final static native int XGDMatrixGetStrFeatureInfo(long handle, String field, + long[] outLength, String[][] outValues); + public final static native int XGDMatrixNumRow(long handle, long[] row); public final static native int XGBoosterCreate(long[] handles, long[] out); @@ -143,4 +156,5 @@ class XGBoostJNI { public final static native int XGDMatrixCreateFromArrayInterfaceColumns( String featureJson, float missing, int nthread, long[] out); + } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 4fd6131ac..630040731 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -1044,3 +1044,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro setHandle(jenv, jout, result); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo + (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) { + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* field = jenv->GetStringUTFChars(jfield, 0); + int size = jenv->GetArrayLength(jvalues); + + // tmp storage for java strings + std::vector values; + for (int i = 0; i < size; i++) { + jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i)); + const char *value = jenv->GetStringUTFChars(jstr, 0); + values.emplace_back(value); + if (value) jenv->ReleaseStringUTFChars(jstr, value); + } + + std::vector c_values; + c_values.resize(size); + std::transform(values.cbegin(), values.cend(), + c_values.begin(), + [](auto const &str) { return str.c_str(); }); + + int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size); + JVM_CHECK_CALL(ret); + + if (field) jenv->ReleaseStringUTFChars(jfield, field); + return ret; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixGetStrFeatureInfo + * Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo + (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray, + jobjectArray joutValueArray) { + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char *field = jenv->GetStringUTFChars(jfield, 0); + + bst_ulong out_len = 0; + char const **c_out_features; + int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features); + + jlong jlen = (jlong) out_len; + jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen); + + jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), + jenv->NewStringUTF("")); + for (int i = 0; i < jlen; i++) { + jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i])); + } + jenv->SetObjectArrayElement(joutValueArray, 0, jinfos); + + JVM_CHECK_CALL(ret); + if (field) jenv->ReleaseStringUTFChars(jfield, field); + return ret; +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 16ef166fc..2db64a169 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -359,6 +359,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns (JNIEnv *, jclass, jstring, jfloat, jint, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixSetStrFeatureInfo + * Signature: (JLjava/lang/String;[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo + (JNIEnv *, jclass, jlong, jstring, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixGetStrFeatureInfo + * Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo + (JNIEnv *, jclass, jlong, jstring, jlongArray, jobjectArray); + #ifdef __cplusplus } #endif diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 721b9a25f..7ea1604c3 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -403,4 +403,29 @@ public class DMatrixTest { //check TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup())); } + + @Test + public void testSetAndGetFeatureInfo() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + int nrow = 10; + int ncol = 5; + float[] data = new float[nrow * ncol]; + //put random nums + Random random = new Random(); + for (int i = 0; i < nrow * ncol; i++) { + data[i] = random.nextInt(); + } + + DMatrix dmat = new DMatrix(data, nrow, ncol, Float.NaN); + + String[] featureNames = new String[]{"f1", "f2", "f3", "f4", "f5"}; + dmat.setFeatureNames(featureNames); + String[] retFeatureNames = dmat.getFeatureNames(); + assertArrayEquals(featureNames, retFeatureNames); + + String[] featureTypes = new String[]{"i", "q", "c", "i", "q"}; + dmat.setFeatureTypes(featureTypes); + String[] retFeatureTypes = dmat.getFeatureTypes(); + assertArrayEquals(featureTypes, retFeatureTypes); + } }