Allow JVM-Package to access inplace predict method (#9167)
--------- Co-authored-by: Stephan T. Lavavej <stl@nuwen.net> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredictFromDense
|
||||
* Signature: (J[FJJFIII[F[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
|
||||
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
|
||||
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
|
||||
jfloatArray jmargin, jobjectArray jout) {
|
||||
API_BEGIN();
|
||||
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);
|
||||
|
||||
/**
|
||||
* Create array interface.
|
||||
*/
|
||||
namespace linalg = xgboost::linalg;
|
||||
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
|
||||
xgboost::Context ctx;
|
||||
auto t_data = linalg::MakeTensorView(
|
||||
ctx.Device(),
|
||||
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
|
||||
num_features);
|
||||
auto s_array = linalg::ArrayInterfaceStr(t_data);
|
||||
|
||||
/**
|
||||
* Create configuration object.
|
||||
*/
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
config["cache_id"] = xgboost::Integer{};
|
||||
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
|
||||
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
|
||||
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
|
||||
config["missing"] = xgboost::Number{static_cast<float>(missing)};
|
||||
config["strict_shape"] = xgboost::Boolean{true};
|
||||
std::string s_config;
|
||||
xgboost::Json::Dump(config, &s_config);
|
||||
|
||||
/**
|
||||
* Handle base margin
|
||||
*/
|
||||
BoosterHandle proxy{nullptr};
|
||||
|
||||
float *margin{nullptr};
|
||||
if (jmargin) {
|
||||
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
|
||||
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
|
||||
JVM_CHECK_CALL(
|
||||
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
|
||||
}
|
||||
|
||||
bst_ulong const *out_shape;
|
||||
bst_ulong out_dim;
|
||||
float const *result;
|
||||
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
|
||||
&out_dim, &result);
|
||||
|
||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||
if (proxy) {
|
||||
XGDMatrixFree(proxy);
|
||||
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::size_t n{1};
|
||||
for (std::size_t i = 0; i < out_dim; ++i) {
|
||||
n *= out_shape[i];
|
||||
}
|
||||
|
||||
jfloatArray jarray = jenv->NewFloatArray(n);
|
||||
|
||||
jenv->SetFloatArrayRegion(jarray, 0, n, result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
|
||||
@@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredictFromDense
|
||||
* Signature: (J[FJJFIII[F[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
|
||||
(JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
@@ -359,14 +367,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
|
||||
Reference in New Issue
Block a user