[JVM-PKG] Update JNI to include rabit codes

This commit is contained in:
tqchen 2016-03-02 22:12:17 -08:00
parent ced6d45e01
commit 5c9e50148a
8 changed files with 609 additions and 280 deletions

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import java.io.IOException; import java.io.IOException;
/** /**
* DMatrix for xgboost, similar to the python wrapper xgboost.py * DMatrix for xgboost.
* *
* @author hzx * @author hzx
*/ */

View File

@ -48,4 +48,5 @@ class JNIErrorHandle {
throw new XGBoostError(XgboostJNI.XGBGetLastError()); throw new XGBoostError(XgboostJNI.XGBGetLastError());
} }
} }
} }

View File

@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/** /**
* transfer DMatrix array to handle array (used for native functions) * transfer DMatrix array to handle array (used for native functions)
* *

View File

@ -0,0 +1,93 @@
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.Map;
import java.io.IOException;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@ -72,15 +72,21 @@ public class XGBoost {
} }
//initialize booster //initialize booster
Booster booster = new JavaBoosterImpl(params, allMats); JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
int version = booster.loadRabitCheckpoint();
//begin to train //begin to train
for (int iter = 0; iter < round; iter++) { for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation //evaluation
if (evalMats != null && evalMats.length > 0) { if (evalMats != null && evalMats.length > 0) {
@ -92,6 +98,8 @@ public class XGBoost {
} }
logger.info(evalInfo); logger.info(evalInfo);
} }
booster.saveRabitCheckpoint();
version += 1;
} }
return booster; return booster;
} }

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* *
* @author hzx * @author hzx
@ -80,4 +80,17 @@ class XgboostJNI {
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings); String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
} }

View File

@ -10,25 +10,28 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "xgboost/c_api.h" #include <xgboost/c_api.h>
#include "xgboost4j.h" #include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector>
#include <string>
//helper functions //helper functions
//set handle //set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out[1]; long out = (long) handle;
out[0] = (long) handle; jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
} }
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ; jstring jresult = 0;
const char* result = XGBGetLastError(); const char* result = XGBGetLastError();
if (result) jresult = jenv->NewStringUTF(result); if (result != NULL) {
jresult = jenv->NewStringUTF(result);
}
return jresult; return jresult;
} }
@ -37,7 +40,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
DMatrixHandle result; DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
}
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -271,12 +276,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles; DMatrixHandle* handles = NULL;
bst_ulong len = 0; bst_ulong len = 0;
jlong* cjhandles = 0; jlong* cjhandles = 0;
BoosterHandle result; BoosterHandle result;
if(jhandles) { if (jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles); len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new DMatrixHandle[len]; handles = new DMatrixHandle[len];
//put handle from jhandles to chandles //put handle from jhandles to chandles
@ -288,12 +293,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
int ret = XGBoosterCreate(handles, len, &result); int ret = XGBoosterCreate(handles, len, &result);
//release //release
if(jhandles) { if (jhandles) {
delete[] handles; delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
} }
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -480,7 +484,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
char *result; char *result;
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
if (result){ if (result) {
jstring jinfo = jenv->NewStringUTF((const char *) result); jstring jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo); jenv->SetObjectArrayElement(jout, 0, jinfo);
} }
@ -511,3 +515,112 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return ret; return ret;
} }
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
jenv->SetIntArrayRegion(jout, 0, 1, &version);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0),
jenv->GetStringLength(arg));
if (s.length() != 0) args.push_back(s);
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
RabitTrackerPrint(str.c_str());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}

View File

@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif