[jvm-packages] add jni for setting feature name and type (#7966)

This commit is contained in:
Bobby Wang 2022-06-03 11:09:48 +08:00 committed by GitHub
parent 6426449c8b
commit 78694405a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 189 additions and 4 deletions

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -33,6 +33,8 @@ import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
import static org.junit.Assert.assertArrayEquals;
/**
* Test suite for DMatrix based on GPU
*/
@ -60,6 +62,16 @@ public class DMatrixTest {
dMatrix.setWeight(weightColumn);
dMatrix.setBaseMargin(baseMarginColumn);
String[] featureNames = new String[]{"f1"};
dMatrix.setFeatureNames(featureNames);
String[] retFeatureNames = dMatrix.getFeatureNames();
assertArrayEquals(featureNames, retFeatureNames);
String[] featureTypes = new String[]{"i"};
dMatrix.setFeatureTypes(featureTypes);
String[] retFeatureTypes = dMatrix.getFeatureTypes();
assertArrayEquals(featureTypes, retFeatureTypes);
float[] anchor = convertFloatTofloat(labelFloats);
float[] label = dMatrix.getLabel();
float[] weight = dMatrix.getWeight();

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -236,6 +236,66 @@ public class DMatrix {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
}
private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError {
if (type == null || type.isEmpty()) {
throw new XGBoostError("Found empty type");
}
if (values == null || values.length == 0) {
throw new XGBoostError("Found empty values");
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(handle, type, values));
}
private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError {
if (type == null || type.isEmpty()) {
throw new XGBoostError("Found empty type");
}
long[] outLen = new long[1];
String[][] outValue = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(handle, type, outLen, outValue));
if (outLen[0] != outValue[0].length) {
throw new RuntimeException("Failed to get " + type);
}
return outValue[0];
}
/**
* Set feature names
* @param values feature names to be set
* @throws XGBoostError
*/
public void setFeatureNames(String[] values) throws XGBoostError {
setXGBDMatrixFeatureInfo("feature_name", values);
}
/**
* Get feature names
* @return an array of feature names to be returned
* @throws XGBoostError
*/
public String[] getFeatureNames() throws XGBoostError {
return getXGBDMatrixFeatureInfo("feature_name");
}
/**
* Set feature types
* @param values feature types to be set
* @throws XGBoostError
*/
public void setFeatureTypes(String[] values) throws XGBoostError {
setXGBDMatrixFeatureInfo("feature_type", values);
}
/**
* Get feature types
* @return an array of feature types to be returned
* @throws XGBoostError
*/
public String[] getFeatureTypes() throws XGBoostError {
return getXGBDMatrixFeatureInfo("feature_type");
}
/**
* set label of dmatrix
*

View File

@ -82,6 +82,19 @@ class XGBoostJNI {
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
/**
* Set the feature information
* @param handle the DMatrix native address
* @param field "feature_names" or "feature_types"
* @param values an array of string
* @return 0 when success, -1 when failure happens
*/
public final static native int XGDMatrixSetStrFeatureInfo(long handle, String field,
String[] values);
public final static native int XGDMatrixGetStrFeatureInfo(long handle, String field,
long[] outLength, String[][] outValues);
public final static native int XGDMatrixNumRow(long handle, long[] row);
public final static native int XGBoosterCreate(long[] handles, long[] out);
@ -143,4 +156,5 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@ -1044,3 +1044,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
setHandle(jenv, jout, result);
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
int size = jenv->GetArrayLength(jvalues);
// tmp storage for java strings
std::vector<std::string> values;
for (int i = 0; i < size; i++) {
jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i));
const char *value = jenv->GetStringUTFChars(jstr, 0);
values.emplace_back(value);
if (value) jenv->ReleaseStringUTFChars(jstr, value);
}
std::vector<char const*> c_values;
c_values.resize(size);
std::transform(values.cbegin(), values.cend(),
c_values.begin(),
[](auto const &str) { return str.c_str(); });
int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetStrFeatureInfo
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray,
jobjectArray joutValueArray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char *field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong out_len = 0;
char const **c_out_features;
int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features);
jlong jlen = (jlong) out_len;
jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen);
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"),
jenv->NewStringUTF(""));
for (int i = 0; i < jlen; i++) {
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i]));
}
jenv->SetObjectArrayElement(joutValueArray, 0, jinfos);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}

View File

@ -359,6 +359,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetStrFeatureInfo
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jlongArray, jobjectArray);
#ifdef __cplusplus
}
#endif

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -403,4 +403,29 @@ public class DMatrixTest {
//check
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
}
@Test
public void testSetAndGetFeatureInfo() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data[i] = random.nextInt();
}
DMatrix dmat = new DMatrix(data, nrow, ncol, Float.NaN);
String[] featureNames = new String[]{"f1", "f2", "f3", "f4", "f5"};
dmat.setFeatureNames(featureNames);
String[] retFeatureNames = dmat.getFeatureNames();
assertArrayEquals(featureNames, retFeatureNames);
String[] featureTypes = new String[]{"i", "q", "c", "i", "q"};
dmat.setFeatureTypes(featureTypes);
String[] retFeatureTypes = dmat.getFeatureTypes();
assertArrayEquals(featureTypes, retFeatureTypes);
}
}