[JVM-PKG] Update JNI to include rabit codes
This commit is contained in:
@@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* DMatrix for xgboost, similar to the python wrapper xgboost.py
|
||||
* DMatrix for xgboost.
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
|
||||
@@ -48,4 +48,5 @@ class JNIErrorHandle {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
*
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import java.util.Map;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the rabit library on current working thread.
|
||||
* @param envs The additional environment variables to pass to rabit.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
String[] args = new String[envs.size()];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
checkCall(XgboostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the rabit engine in current working thread, equals to finalize.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Print the message on rabit tracker.
|
||||
* @param msg
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get version number of current stored model in the thread.
|
||||
* which means how many calls to CheckPoint we made so far.
|
||||
* @return version Number.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get rank of current thread.
|
||||
* @return the rank.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get world size of current job.
|
||||
* @return the worldsize
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
}
|
||||
@@ -72,14 +72,20 @@ public class XGBoost {
|
||||
}
|
||||
|
||||
//initialize booster
|
||||
Booster booster = new JavaBoosterImpl(params, allMats);
|
||||
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
|
||||
|
||||
int version = booster.loadRabitCheckpoint();
|
||||
|
||||
//begin to train
|
||||
for (int iter = 0; iter < round; iter++) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
for (int iter = version / 2; iter < round; iter++) {
|
||||
if (version % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
|
||||
//evaluation
|
||||
@@ -92,6 +98,8 @@ public class XGBoost {
|
||||
}
|
||||
logger.info(evalInfo);
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
/**
|
||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
||||
* xgboost JNI functions
|
||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||
*
|
||||
* @author hzx
|
||||
@@ -80,4 +80,17 @@ class XgboostJNI {
|
||||
|
||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
||||
String[][] out_strings);
|
||||
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
public final static native int RabitFinalize();
|
||||
public final static native int RabitTrackerPrint(String msg);
|
||||
public final static native int RabitGetRank(int[] out);
|
||||
public final static native int RabitGetWorldSize(int[] out);
|
||||
public final static native int RabitVersionNumber(int[] out);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user