[JVM-PKG] Update JNI to include rabit codes

This commit is contained in:
tqchen 2016-03-02 22:12:17 -08:00
parent ced6d45e01
commit 5c9e50148a
8 changed files with 609 additions and 280 deletions

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import java.io.IOException; import java.io.IOException;
/** /**
* DMatrix for xgboost, similar to the python wrapper xgboost.py * DMatrix for xgboost.
* *
* @author hzx * @author hzx
*/ */

View File

@ -48,4 +48,5 @@ class JNIErrorHandle {
throw new XGBoostError(XgboostJNI.XGBGetLastError()); throw new XGBoostError(XgboostJNI.XGBGetLastError());
} }
} }
} }

View File

@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/** /**
* transfer DMatrix array to handle array (used for native functions) * transfer DMatrix array to handle array (used for native functions)
* *

View File

@ -0,0 +1,93 @@
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.Map;
import java.io.IOException;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@ -72,15 +72,21 @@ public class XGBoost {
} }
//initialize booster //initialize booster
Booster booster = new JavaBoosterImpl(params, allMats); JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
int version = booster.loadRabitCheckpoint();
//begin to train //begin to train
for (int iter = 0; iter < round; iter++) { for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation //evaluation
if (evalMats != null && evalMats.length > 0) { if (evalMats != null && evalMats.length > 0) {
@ -92,6 +98,8 @@ public class XGBoost {
} }
logger.info(evalInfo); logger.info(evalInfo);
} }
booster.saveRabitCheckpoint();
version += 1;
} }
return booster; return booster;
} }

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* *
* @author hzx * @author hzx
@ -80,4 +80,17 @@ class XgboostJNI {
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings); String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
} }

View File

@ -12,23 +12,26 @@
limitations under the License. limitations under the License.
*/ */
#include "xgboost/c_api.h" #include <xgboost/c_api.h>
#include "xgboost4j.h" #include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector>
#include <string>
//helper functions //helper functions
//set handle //set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out[1]; long out = (long) handle;
out[0] = (long) handle; jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
} }
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
jstring jresult = 0; jstring jresult = 0;
const char* result = XGBGetLastError(); const char* result = XGBGetLastError();
if (result) jresult = jenv->NewStringUTF(result); if (result != NULL) {
jresult = jenv->NewStringUTF(result);
}
return jresult; return jresult;
} }
@ -37,7 +40,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
DMatrixHandle result; DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
}
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -271,7 +276,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles; DMatrixHandle* handles = NULL;
bst_ulong len = 0; bst_ulong len = 0;
jlong* cjhandles = 0; jlong* cjhandles = 0;
BoosterHandle result; BoosterHandle result;
@ -293,7 +298,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
} }
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -511,3 +515,112 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return ret; return ret;
} }
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
jenv->SetIntArrayRegion(jout, 0, 1, &version);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0),
jenv->GetStringLength(arg));
if (s.length() != 0) args.push_back(s);
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
RabitTrackerPrint(str.c_str());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}

View File

@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif