[DIST] Enable multiple thread and tracker, make rabit and xgboost more thread-safe by using thread local variables.

This commit is contained in:
tqchen
2016-03-03 11:36:34 -08:00
parent 12dc92f7e0
commit e80d3db64b
17 changed files with 323 additions and 153 deletions

View File

@@ -13,6 +13,8 @@
*/
#include <xgboost/c_api.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include "./xgboost4j.h"
#include <cstring>
#include <vector>
@@ -276,27 +278,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles = NULL;
bst_ulong len = 0;
jlong* cjhandles = 0;
BoosterHandle result;
if (jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new DMatrixHandle[len];
//put handle from jhandles to chandles
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for(bst_ulong i=0; i<len; i++) {
handles[i] = (DMatrixHandle) cjhandles[i];
std::vector<DMatrixHandle> handles;
if (jhandles != nullptr) {
size_t len = jenv->GetArrayLength(jhandles);
jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for (size_t i = 0; i < len; ++i) {
handles.push_back((DMatrixHandle) cjhandles[i]);
}
}
int ret = XGBoosterCreate(handles, len, &result);
//release
if (jhandles) {
delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
}
BoosterHandle result;
int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result);
setHandle(jenv, jout, result);
return ret;
}
@@ -369,43 +361,34 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle* dmats = 0;
char **evnames = 0;
char *result = 0;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
if(len > 0) {
dmats = new DMatrixHandle[len];
evnames = new char*[len];
}
//put handle from jhandles to chandles
std::vector<DMatrixHandle> dmats;
std::vector<std::string> evnames;
std::vector<const char*> evchars;
size_t len = static_cast<size_t>(jenv->GetArrayLength(jdmats));
// put handle from jhandles to chandles
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
for(bst_ulong i=0; i<len; i++) {
dmats[i] = (DMatrixHandle) cjdmats[i];
}
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
for(bst_ulong i=0; i<len; i++) {
for (size_t i = 0; i < len; ++i) {
dmats.push_back((DMatrixHandle) cjdmats[i]);
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
evnames[i] = new char[jenv->GetStringLength(jevname)];
strcpy(evnames[i], cevname);
jenv->ReleaseStringUTFChars(jevname, cevname);
const char *s =jenv->GetStringUTFChars(jevname, 0);
evnames.push_back(std::string(s, jenv->GetStringLength(jevname)));
if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s);
}
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
if(len > 0) {
delete[] dmats;
//release string chars
for(bst_ulong i=0; i<len; i++) {
delete[] evnames[i];
}
delete[] evnames;
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
for (size_t i = 0; i < len; ++i) {
evchars.push_back(evnames[i].c_str());
}
const char* result;
int ret = XGBoosterEvalOneIter(handle, jiter,
dmlc::BeginPtr(dmats),
dmlc::BeginPtr(evchars),
len, &result);
jstring jinfo = nullptr;
if (result != nullptr) {
jinfo = jenv->NewStringUTF(result);
}
jstring jinfo = 0;
if (result) jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo);
return ret;
}
@@ -456,37 +439,40 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
int ret = XGBoosterSaveModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V
* Signature: (J[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
(JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
BoosterHandle handle = (BoosterHandle) jhandle;
void *buf = (void*) jbuf;
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
int ret = XGBoosterLoadModelFromBuffer(
handle, buffer, jenv->GetArrayLength(jbytes));
jenv->ReleaseByteArrayElements(jbytes, buffer, 0);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String;
* Signature: (J[[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0;
char *result;
const char* result;
int ret = XGBoosterGetModelRaw(handle, &len, &result);
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
if (result) {
jstring jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo);
jbyteArray jarray = jenv->NewByteArray(len);
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result);
jenv->SetObjectArrayElement(jout, 0, jarray);
}
return ret;
}
@@ -553,15 +539,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0),
jenv->GetStringLength(arg));
if (s.length() != 0) args.push_back(s);
const char *s = jenv->GetStringUTFChars(arg, 0);
args.push_back(std::string(s, jenv->GetStringLength(arg)));
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
if (args.back().length() == 0) args.pop_back();
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
RabitInit(args.size(), dmlc::BeginPtr(argv));
return 0;
}