[jvm-packages] add jni for setting feature name and type (#7966)
This commit is contained in:
parent
6426449c8b
commit
78694405a6
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -33,6 +33,8 @@ import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
/**
|
||||
* Test suite for DMatrix based on GPU
|
||||
*/
|
||||
@ -60,6 +62,16 @@ public class DMatrixTest {
|
||||
dMatrix.setWeight(weightColumn);
|
||||
dMatrix.setBaseMargin(baseMarginColumn);
|
||||
|
||||
String[] featureNames = new String[]{"f1"};
|
||||
dMatrix.setFeatureNames(featureNames);
|
||||
String[] retFeatureNames = dMatrix.getFeatureNames();
|
||||
assertArrayEquals(featureNames, retFeatureNames);
|
||||
|
||||
String[] featureTypes = new String[]{"i"};
|
||||
dMatrix.setFeatureTypes(featureTypes);
|
||||
String[] retFeatureTypes = dMatrix.getFeatureTypes();
|
||||
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||
|
||||
float[] anchor = convertFloatTofloat(labelFloats);
|
||||
float[] label = dMatrix.getLabel();
|
||||
float[] weight = dMatrix.getWeight();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -236,6 +236,66 @@ public class DMatrix {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||
}
|
||||
|
||||
private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError {
|
||||
if (type == null || type.isEmpty()) {
|
||||
throw new XGBoostError("Found empty type");
|
||||
}
|
||||
if (values == null || values.length == 0) {
|
||||
throw new XGBoostError("Found empty values");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(handle, type, values));
|
||||
}
|
||||
|
||||
private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError {
|
||||
if (type == null || type.isEmpty()) {
|
||||
throw new XGBoostError("Found empty type");
|
||||
}
|
||||
long[] outLen = new long[1];
|
||||
String[][] outValue = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(handle, type, outLen, outValue));
|
||||
|
||||
if (outLen[0] != outValue[0].length) {
|
||||
throw new RuntimeException("Failed to get " + type);
|
||||
}
|
||||
return outValue[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature names
|
||||
* @param values feature names to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureNames(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_name", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature names
|
||||
* @return an array of feature names to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public String[] getFeatureNames() throws XGBoostError {
|
||||
return getXGBDMatrixFeatureInfo("feature_name");
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature types
|
||||
* @param values feature types to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_type", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature types
|
||||
* @return an array of feature types to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public String[] getFeatureTypes() throws XGBoostError {
|
||||
return getXGBDMatrixFeatureInfo("feature_type");
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
|
||||
@ -82,6 +82,19 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||
|
||||
/**
|
||||
* Set the feature information
|
||||
* @param handle the DMatrix native address
|
||||
* @param field "feature_names" or "feature_types"
|
||||
* @param values an array of string
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
public final static native int XGDMatrixSetStrFeatureInfo(long handle, String field,
|
||||
String[] values);
|
||||
|
||||
public final static native int XGDMatrixGetStrFeatureInfo(long handle, String field,
|
||||
long[] outLength, String[][] outValues);
|
||||
|
||||
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
||||
|
||||
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
||||
@ -143,4 +156,5 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
@ -1044,3 +1044,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
|
||||
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
int size = jenv->GetArrayLength(jvalues);
|
||||
|
||||
// tmp storage for java strings
|
||||
std::vector<std::string> values;
|
||||
for (int i = 0; i < size; i++) {
|
||||
jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i));
|
||||
const char *value = jenv->GetStringUTFChars(jstr, 0);
|
||||
values.emplace_back(value);
|
||||
if (value) jenv->ReleaseStringUTFChars(jstr, value);
|
||||
}
|
||||
|
||||
std::vector<char const*> c_values;
|
||||
c_values.resize(size);
|
||||
std::transform(values.cbegin(), values.cend(),
|
||||
c_values.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
|
||||
int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size);
|
||||
JVM_CHECK_CALL(ret);
|
||||
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
|
||||
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray,
|
||||
jobjectArray joutValueArray) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||
|
||||
bst_ulong out_len = 0;
|
||||
char const **c_out_features;
|
||||
int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features);
|
||||
|
||||
jlong jlen = (jlong) out_len;
|
||||
jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen);
|
||||
|
||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"),
|
||||
jenv->NewStringUTF(""));
|
||||
for (int i = 0; i < jlen; i++) {
|
||||
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i]));
|
||||
}
|
||||
jenv->SetObjectArrayElement(joutValueArray, 0, jinfos);
|
||||
|
||||
JVM_CHECK_CALL(ret);
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -359,6 +359,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jlongArray, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -403,4 +403,29 @@ public class DMatrixTest {
|
||||
//check
|
||||
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetAndGetFeatureInfo() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
int nrow = 10;
|
||||
int ncol = 5;
|
||||
float[] data = new float[nrow * ncol];
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nrow * ncol; i++) {
|
||||
data[i] = random.nextInt();
|
||||
}
|
||||
|
||||
DMatrix dmat = new DMatrix(data, nrow, ncol, Float.NaN);
|
||||
|
||||
String[] featureNames = new String[]{"f1", "f2", "f3", "f4", "f5"};
|
||||
dmat.setFeatureNames(featureNames);
|
||||
String[] retFeatureNames = dmat.getFeatureNames();
|
||||
assertArrayEquals(featureNames, retFeatureNames);
|
||||
|
||||
String[] featureTypes = new String[]{"i", "q", "c", "i", "q"};
|
||||
dmat.setFeatureTypes(featureTypes);
|
||||
String[] retFeatureTypes = dmat.getFeatureTypes();
|
||||
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user