Allow JVM-Package to access inplace predict method (#9167)

---------

Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
Jon Yoquinto
2023-09-11 17:29:51 -06:00
committed by GitHub
parent 9027686cac
commit d05ea589fb
5 changed files with 384 additions and 18 deletions

View File

@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
jfloatArray jmargin, jobjectArray jout) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);
/**
* Create array interface.
*/
namespace linalg = xgboost::linalg;
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
xgboost::Context ctx;
auto t_data = linalg::MakeTensorView(
ctx.Device(),
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
num_features);
auto s_array = linalg::ArrayInterfaceStr(t_data);
/**
* Create configuration object.
*/
xgboost::Json config{xgboost::Object{}};
config["cache_id"] = xgboost::Integer{};
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
config["missing"] = xgboost::Number{static_cast<float>(missing)};
config["strict_shape"] = xgboost::Boolean{true};
std::string s_config;
xgboost::Json::Dump(config, &s_config);
/**
* Handle base margin
*/
BoosterHandle proxy{nullptr};
float *margin{nullptr};
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL(
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
}
bst_ulong const *out_shape;
bst_ulong out_dim;
float const *result;
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
&out_dim, &result);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
if (proxy) {
XGDMatrixFree(proxy);
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
}
if (ret != 0) {
return ret;
}
std::size_t n{1};
for (std::size_t i = 0; i < out_dim; ++i) {
n *= out_shape[i];
}
jfloatArray jarray = jenv->NewFloatArray(n);
jenv->SetFloatArrayRegion(jarray, 0, n, result);
jenv->SetObjectArrayElement(jout, 0, jarray);
API_END();
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel

View File

@@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
(JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
@@ -359,14 +367,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback