[DIST] Enable multiple thread and tracker, make rabit and xgboost more thread-safe by using thread local variables.
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
*/
|
||||
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include "./xgboost4j.h"
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
@@ -276,27 +278,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||
DMatrixHandle* handles = NULL;
|
||||
bst_ulong len = 0;
|
||||
jlong* cjhandles = 0;
|
||||
BoosterHandle result;
|
||||
|
||||
if (jhandles) {
|
||||
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
||||
handles = new DMatrixHandle[len];
|
||||
//put handle from jhandles to chandles
|
||||
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
||||
for(bst_ulong i=0; i<len; i++) {
|
||||
handles[i] = (DMatrixHandle) cjhandles[i];
|
||||
std::vector<DMatrixHandle> handles;
|
||||
if (jhandles != nullptr) {
|
||||
size_t len = jenv->GetArrayLength(jhandles);
|
||||
jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
handles.push_back((DMatrixHandle) cjhandles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
int ret = XGBoosterCreate(handles, len, &result);
|
||||
//release
|
||||
if (jhandles) {
|
||||
delete[] handles;
|
||||
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
||||
}
|
||||
BoosterHandle result;
|
||||
int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
@@ -369,43 +361,34 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle* dmats = 0;
|
||||
char **evnames = 0;
|
||||
char *result = 0;
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
||||
if(len > 0) {
|
||||
dmats = new DMatrixHandle[len];
|
||||
evnames = new char*[len];
|
||||
}
|
||||
//put handle from jhandles to chandles
|
||||
std::vector<DMatrixHandle> dmats;
|
||||
std::vector<std::string> evnames;
|
||||
std::vector<const char*> evchars;
|
||||
|
||||
size_t len = static_cast<size_t>(jenv->GetArrayLength(jdmats));
|
||||
// put handle from jhandles to chandles
|
||||
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
|
||||
for(bst_ulong i=0; i<len; i++) {
|
||||
dmats[i] = (DMatrixHandle) cjdmats[i];
|
||||
}
|
||||
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
|
||||
for(bst_ulong i=0; i<len; i++) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
dmats.push_back((DMatrixHandle) cjdmats[i]);
|
||||
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
||||
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
|
||||
evnames[i] = new char[jenv->GetStringLength(jevname)];
|
||||
strcpy(evnames[i], cevname);
|
||||
jenv->ReleaseStringUTFChars(jevname, cevname);
|
||||
const char *s =jenv->GetStringUTFChars(jevname, 0);
|
||||
evnames.push_back(std::string(s, jenv->GetStringLength(jevname)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s);
|
||||
}
|
||||
|
||||
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
|
||||
if(len > 0) {
|
||||
delete[] dmats;
|
||||
//release string chars
|
||||
for(bst_ulong i=0; i<len; i++) {
|
||||
delete[] evnames[i];
|
||||
}
|
||||
delete[] evnames;
|
||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
evchars.push_back(evnames[i].c_str());
|
||||
}
|
||||
const char* result;
|
||||
int ret = XGBoosterEvalOneIter(handle, jiter,
|
||||
dmlc::BeginPtr(dmats),
|
||||
dmlc::BeginPtr(evchars),
|
||||
len, &result);
|
||||
jstring jinfo = nullptr;
|
||||
if (result != nullptr) {
|
||||
jinfo = jenv->NewStringUTF(result);
|
||||
}
|
||||
|
||||
jstring jinfo = 0;
|
||||
if (result) jinfo = jenv->NewStringUTF((const char *) result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -456,37 +439,40 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
||||
|
||||
int ret = XGBoosterSaveModel(handle, fname);
|
||||
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterLoadModelFromBuffer
|
||||
* Signature: (JJJ)V
|
||||
* Signature: (J[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
void *buf = (void*) jbuf;
|
||||
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
|
||||
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
|
||||
int ret = XGBoosterLoadModelFromBuffer(
|
||||
handle, buffer, jenv->GetArrayLength(jbytes));
|
||||
jenv->ReleaseByteArrayElements(jbytes, buffer, 0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterGetModelRaw
|
||||
* Signature: (J)Ljava/lang/String;
|
||||
* Signature: (J[[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
bst_ulong len = 0;
|
||||
char *result;
|
||||
const char* result;
|
||||
int ret = XGBoosterGetModelRaw(handle, &len, &result);
|
||||
|
||||
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
|
||||
if (result) {
|
||||
jstring jinfo = jenv->NewStringUTF((const char *) result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||
jbyteArray jarray = jenv->NewByteArray(len);
|
||||
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -553,15 +539,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||
std::string s(jenv->GetStringUTFChars(arg, 0),
|
||||
jenv->GetStringLength(arg));
|
||||
if (s.length() != 0) args.push_back(s);
|
||||
const char *s = jenv->GetStringUTFChars(arg, 0);
|
||||
args.push_back(std::string(s, jenv->GetStringLength(arg)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
|
||||
if (args.back().length() == 0) args.pop_back();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
argv.push_back(&args[i][0]);
|
||||
}
|
||||
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
|
||||
|
||||
RabitInit(args.size(), dmlc::BeginPtr(argv));
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user