[JVM-PKG] Update JNI to include rabit codes

This commit is contained in:
tqchen
2016-03-02 22:12:17 -08:00
parent ced6d45e01
commit 5c9e50148a
8 changed files with 609 additions and 280 deletions

View File

@@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import java.io.IOException;
/**
* DMatrix for xgboost, similar to the python wrapper xgboost.py
* DMatrix for xgboost.
*
* @author hzx
*/

View File

@@ -48,4 +48,5 @@ class JNIErrorHandle {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
}

View File

@@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
return featureScore;
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
* transfer DMatrix array to handle array (used for native functions)
*

View File

@@ -0,0 +1,93 @@
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.Map;
import java.io.IOException;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@@ -72,14 +72,20 @@ public class XGBoost {
}
//initialize booster
Booster booster = new JavaBoosterImpl(params, allMats);
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
int version = booster.loadRabitCheckpoint();
//begin to train
for (int iter = 0; iter < round; iter++) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation
@@ -92,6 +98,8 @@ public class XGBoost {
}
logger.info(evalInfo);
}
booster.saveRabitCheckpoint();
version += 1;
}
return booster;
}

View File

@@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j;
/**
* xgboost jni wrapper functions for xgboost_wrapper.h
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
*
* @author hzx
@@ -80,4 +80,17 @@ class XgboostJNI {
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
}