Set feature_names and feature_types in jvm-packages (#9364)

* 1. Add parameters to set feature names and feature types
2. Save feature names and feature types to native json model

* Change serialization and deserialization format to ubj.
This commit is contained in:
jinmfeng001
2023-07-12 15:18:46 +08:00
committed by GitHub
parent 3632242e0b
commit a1367ea1f8
12 changed files with 295 additions and 8 deletions

View File

@@ -1148,3 +1148,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
jobjectArray jfeatures) {
BoosterHandle handle = (BoosterHandle)jhandle;
const char *field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);
std::vector<std::string> features;
std::vector<char const*> features_char;
for (bst_ulong i = 0; i < feature_num; ++i) {
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
const char *s = jenv->GetStringUTFChars(jfeature, 0);
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
}
for (size_t i = 0; i < features.size(); ++i) {
features_char.push_back(features[i].c_str());
}
int ret = XGBoosterSetStrFeatureInfo(
handle, field, dmlc::BeginPtr(features_char), feature_num);
JVM_CHECK_CALL(ret);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetGtrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
jobjectArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;
const char *field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);
const char **features;
std::vector<char *> features_char;
int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
(const char ***)&features);
JVM_CHECK_CALL(ret);
for (bst_ulong i = 0; i < feature_num; i++) {
jstring jfeature = jenv->NewStringUTF(features[i]);
jenv->SetObjectArrayElement(jout, i, jfeature);
}
return ret;
}