[jvm-packages] add jni for setting feature name and type (#7966)

This commit is contained in:
Bobby Wang
2022-06-03 11:09:48 +08:00
committed by GitHub
parent 6426449c8b
commit 78694405a6
6 changed files with 189 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@@ -1044,3 +1044,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
setHandle(jenv, jout, result);
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
int size = jenv->GetArrayLength(jvalues);
// tmp storage for java strings
std::vector<std::string> values;
for (int i = 0; i < size; i++) {
jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i));
const char *value = jenv->GetStringUTFChars(jstr, 0);
values.emplace_back(value);
if (value) jenv->ReleaseStringUTFChars(jstr, value);
}
std::vector<char const*> c_values;
c_values.resize(size);
std::transform(values.cbegin(), values.cend(),
c_values.begin(),
[](auto const &str) { return str.c_str(); });
int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetStrFeatureInfo
* Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray,
jobjectArray joutValueArray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char *field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong out_len = 0;
char const **c_out_features;
int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features);
jlong jlen = (jlong) out_len;
jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen);
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"),
jenv->NewStringUTF(""));
for (int i = 0; i < jlen; i++) {
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i]));
}
jenv->SetObjectArrayElement(joutValueArray, 0, jinfos);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}