[jvm-packages] add jni for setting feature name and type (#7966)
This commit is contained in:
parent
6426449c8b
commit
78694405a6
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -33,6 +33,8 @@ import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
|||||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test suite for DMatrix based on GPU
|
* Test suite for DMatrix based on GPU
|
||||||
*/
|
*/
|
||||||
@ -60,6 +62,16 @@ public class DMatrixTest {
|
|||||||
dMatrix.setWeight(weightColumn);
|
dMatrix.setWeight(weightColumn);
|
||||||
dMatrix.setBaseMargin(baseMarginColumn);
|
dMatrix.setBaseMargin(baseMarginColumn);
|
||||||
|
|
||||||
|
String[] featureNames = new String[]{"f1"};
|
||||||
|
dMatrix.setFeatureNames(featureNames);
|
||||||
|
String[] retFeatureNames = dMatrix.getFeatureNames();
|
||||||
|
assertArrayEquals(featureNames, retFeatureNames);
|
||||||
|
|
||||||
|
String[] featureTypes = new String[]{"i"};
|
||||||
|
dMatrix.setFeatureTypes(featureTypes);
|
||||||
|
String[] retFeatureTypes = dMatrix.getFeatureTypes();
|
||||||
|
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||||
|
|
||||||
float[] anchor = convertFloatTofloat(labelFloats);
|
float[] anchor = convertFloatTofloat(labelFloats);
|
||||||
float[] label = dMatrix.getLabel();
|
float[] label = dMatrix.getLabel();
|
||||||
float[] weight = dMatrix.getWeight();
|
float[] weight = dMatrix.getWeight();
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -236,6 +236,66 @@ public class DMatrix {
|
|||||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError {
|
||||||
|
if (type == null || type.isEmpty()) {
|
||||||
|
throw new XGBoostError("Found empty type");
|
||||||
|
}
|
||||||
|
if (values == null || values.length == 0) {
|
||||||
|
throw new XGBoostError("Found empty values");
|
||||||
|
}
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(handle, type, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError {
|
||||||
|
if (type == null || type.isEmpty()) {
|
||||||
|
throw new XGBoostError("Found empty type");
|
||||||
|
}
|
||||||
|
long[] outLen = new long[1];
|
||||||
|
String[][] outValue = new String[1][];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(handle, type, outLen, outValue));
|
||||||
|
|
||||||
|
if (outLen[0] != outValue[0].length) {
|
||||||
|
throw new RuntimeException("Failed to get " + type);
|
||||||
|
}
|
||||||
|
return outValue[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set feature names
|
||||||
|
* @param values feature names to be set
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public void setFeatureNames(String[] values) throws XGBoostError {
|
||||||
|
setXGBDMatrixFeatureInfo("feature_name", values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get feature names
|
||||||
|
* @return an array of feature names to be returned
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public String[] getFeatureNames() throws XGBoostError {
|
||||||
|
return getXGBDMatrixFeatureInfo("feature_name");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set feature types
|
||||||
|
* @param values feature types to be set
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public void setFeatureTypes(String[] values) throws XGBoostError {
|
||||||
|
setXGBDMatrixFeatureInfo("feature_type", values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get feature types
|
||||||
|
* @return an array of feature types to be returned
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public String[] getFeatureTypes() throws XGBoostError {
|
||||||
|
return getXGBDMatrixFeatureInfo("feature_type");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set label of dmatrix
|
* set label of dmatrix
|
||||||
*
|
*
|
||||||
|
|||||||
@ -82,6 +82,19 @@ class XGBoostJNI {
|
|||||||
|
|
||||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the feature information
|
||||||
|
* @param handle the DMatrix native address
|
||||||
|
* @param field "feature_names" or "feature_types"
|
||||||
|
* @param values an array of string
|
||||||
|
* @return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
public final static native int XGDMatrixSetStrFeatureInfo(long handle, String field,
|
||||||
|
String[] values);
|
||||||
|
|
||||||
|
public final static native int XGDMatrixGetStrFeatureInfo(long handle, String field,
|
||||||
|
long[] outLength, String[][] outValues);
|
||||||
|
|
||||||
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
||||||
|
|
||||||
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
||||||
@ -143,4 +156,5 @@ class XGBoostJNI {
|
|||||||
|
|
||||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
String featureJson, float missing, int nthread, long[] out);
|
String featureJson, float missing, int nthread, long[] out);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
@ -1044,3 +1044,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
|
||||||
|
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) {
|
||||||
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
|
int size = jenv->GetArrayLength(jvalues);
|
||||||
|
|
||||||
|
// tmp storage for java strings
|
||||||
|
std::vector<std::string> values;
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i));
|
||||||
|
const char *value = jenv->GetStringUTFChars(jstr, 0);
|
||||||
|
values.emplace_back(value);
|
||||||
|
if (value) jenv->ReleaseStringUTFChars(jstr, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char const*> c_values;
|
||||||
|
c_values.resize(size);
|
||||||
|
std::transform(values.cbegin(), values.cend(),
|
||||||
|
c_values.begin(),
|
||||||
|
[](auto const &str) { return str.c_str(); });
|
||||||
|
|
||||||
|
int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
|
||||||
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixGetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
|
||||||
|
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray,
|
||||||
|
jobjectArray joutValueArray) {
|
||||||
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
|
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
|
|
||||||
|
bst_ulong out_len = 0;
|
||||||
|
char const **c_out_features;
|
||||||
|
int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features);
|
||||||
|
|
||||||
|
jlong jlen = (jlong) out_len;
|
||||||
|
jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen);
|
||||||
|
|
||||||
|
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"),
|
||||||
|
jenv->NewStringUTF(""));
|
||||||
|
for (int i = 0; i < jlen; i++) {
|
||||||
|
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i]));
|
||||||
|
}
|
||||||
|
jenv->SetObjectArrayElement(joutValueArray, 0, jinfos);
|
||||||
|
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|||||||
@ -359,6 +359,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixSetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixGetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jlongArray, jobjectArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -403,4 +403,29 @@ public class DMatrixTest {
|
|||||||
//check
|
//check
|
||||||
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
|
TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dmat0.getGroup()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSetAndGetFeatureInfo() throws XGBoostError {
|
||||||
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
int nrow = 10;
|
||||||
|
int ncol = 5;
|
||||||
|
float[] data = new float[nrow * ncol];
|
||||||
|
//put random nums
|
||||||
|
Random random = new Random();
|
||||||
|
for (int i = 0; i < nrow * ncol; i++) {
|
||||||
|
data[i] = random.nextInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
DMatrix dmat = new DMatrix(data, nrow, ncol, Float.NaN);
|
||||||
|
|
||||||
|
String[] featureNames = new String[]{"f1", "f2", "f3", "f4", "f5"};
|
||||||
|
dmat.setFeatureNames(featureNames);
|
||||||
|
String[] retFeatureNames = dmat.getFeatureNames();
|
||||||
|
assertArrayEquals(featureNames, retFeatureNames);
|
||||||
|
|
||||||
|
String[] featureTypes = new String[]{"i", "q", "c", "i", "q"};
|
||||||
|
dmat.setFeatureTypes(featureTypes);
|
||||||
|
String[] retFeatureTypes = dmat.getFeatureTypes();
|
||||||
|
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user