Merge pull request #907 from tqchen/master
[DIST] Enable multiple thread make rabit and xgboost threadsafe
This commit is contained in:
commit
770b3451ca
1
.gitignore
vendored
1
.gitignore
vendored
@ -78,5 +78,4 @@ tags
|
|||||||
*.iml
|
*.iml
|
||||||
*.class
|
*.class
|
||||||
target
|
target
|
||||||
|
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
5
NEWS.md
5
NEWS.md
@ -22,6 +22,11 @@ This file records the changes in xgboost library in reverse chronological order.
|
|||||||
- The windows version is still blocked due to Rtools do not support ```std::thread```.
|
- The windows version is still blocked due to Rtools do not support ```std::thread```.
|
||||||
* rabit and dmlc-core are maintained through git submodule
|
* rabit and dmlc-core are maintained through git submodule
|
||||||
- Anyone can open PR to update these dependencies now.
|
- Anyone can open PR to update these dependencies now.
|
||||||
|
* Improvements
|
||||||
|
- Rabit and xgboost libs are not thread-safe and use thread local PRNGs
|
||||||
|
- This could fix some of the previous problem which runs xgboost on multiple threads.
|
||||||
|
* JVM Package
|
||||||
|
- Enable xgboost4j for java and scala
|
||||||
|
|
||||||
## v0.47 (2016.01.14)
|
## v0.47 (2016.01.14)
|
||||||
|
|
||||||
|
|||||||
2
jvm-packages/.gitignore
vendored
Normal file
2
jvm-packages/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
tracker.py
|
||||||
|
build.sh
|
||||||
@ -27,6 +27,8 @@ fi
|
|||||||
|
|
||||||
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
|
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
|
||||||
mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
|
mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
|
||||||
|
# copy python to native resources
|
||||||
|
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
|
||||||
|
|
||||||
popd > /dev/null
|
popd > /dev/null
|
||||||
echo "complete"
|
echo "complete"
|
||||||
|
|||||||
2
jvm-packages/test_distributed.sh
Normal file → Executable file
2
jvm-packages/test_distributed.sh
Normal file → Executable file
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Simple script to test distributed version, to be deleted later.
|
# Simple script to test distributed version, to be deleted later.
|
||||||
cd xgboost4j-demo
|
cd xgboost4j-demo
|
||||||
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
|
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
|
||||||
cd ..
|
cd ..
|
||||||
|
|||||||
@ -2,19 +2,37 @@ package ml.dmlc.xgboost4j.demo;
|
|||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.*;
|
import ml.dmlc.xgboost4j.*;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Distributed training example, used to quick test distributed training.
|
* Distributed training example, used to quick test distributed training.
|
||||||
*
|
*
|
||||||
* @author tqchen
|
* @author tqchen
|
||||||
*/
|
*/
|
||||||
public class DistTrain {
|
public class DistTrain {
|
||||||
|
private static final Log logger = LogFactory.getLog(DistTrain.class);
|
||||||
|
private Map<String, String> envs = null;
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, XGBoostError {
|
private class Worker implements Runnable {
|
||||||
|
private final int workerId;
|
||||||
|
|
||||||
|
Worker(int workerId) {
|
||||||
|
this.workerId = workerId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
||||||
|
|
||||||
|
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
||||||
// always initialize rabit module before training.
|
// always initialize rabit module before training.
|
||||||
Rabit.init(new HashMap<String, String>());
|
Rabit.init(worker_env);
|
||||||
|
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
@ -24,9 +42,9 @@ public class DistTrain {
|
|||||||
params.put("eta", 1.0);
|
params.put("eta", 1.0);
|
||||||
params.put("max_depth", 2);
|
params.put("max_depth", 2);
|
||||||
params.put("silent", 1);
|
params.put("silent", 1);
|
||||||
|
params.put("nthread", 2);
|
||||||
params.put("objective", "binary:logistic");
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
watches.put("train", trainMat);
|
watches.put("train", trainMat);
|
||||||
watches.put("test", testMat);
|
watches.put("test", testMat);
|
||||||
@ -39,5 +57,24 @@ public class DistTrain {
|
|||||||
|
|
||||||
// always shutdown rabit module after training.
|
// always shutdown rabit module after training.
|
||||||
Rabit.shutdown();
|
Rabit.shutdown();
|
||||||
|
} catch (Exception ex){
|
||||||
|
logger.error(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
||||||
|
RabitTracker tracker = new RabitTracker(nWorkers);
|
||||||
|
if (tracker.start()) {
|
||||||
|
envs = tracker.getWorkerEnvs();
|
||||||
|
for (int i = 0; i < nWorkers; ++i) {
|
||||||
|
new Thread(new Worker(i)).start();
|
||||||
|
}
|
||||||
|
tracker.waitFor();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
||||||
|
new DistTrain().start(Integer.parseInt(args[0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1,9 +1,10 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public interface Booster {
|
public interface Booster extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameter
|
* set parameter
|
||||||
@ -109,12 +110,25 @@ public interface Booster {
|
|||||||
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* save model to modelPath
|
* save model to modelPath, the model path support depends on the path support
|
||||||
*
|
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
|
||||||
|
* compiled with HDFS support.
|
||||||
|
* See also toByteArray
|
||||||
* @param modelPath model path
|
* @param modelPath model path
|
||||||
*/
|
*/
|
||||||
void saveModel(String modelPath) throws XGBoostError;
|
void saveModel(String modelPath) throws XGBoostError;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as byte array representation.
|
||||||
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
|
*
|
||||||
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||||
|
*
|
||||||
|
* @return the saved byte array.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
byte[] toByteArray() throws XGBoostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dump model into a text file.
|
* Dump model into a text file.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
|
|||||||
setParams(params);
|
setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* load model from modelPath
|
* load model from modelPath
|
||||||
*
|
*
|
||||||
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
|
|||||||
return featureScore;
|
return featureScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model as byte array representation.
|
||||||
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
|
*
|
||||||
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||||
|
*
|
||||||
|
* @return the saved byte array.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
|
byte[][] bytes = new byte[1][];
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||||
|
return bytes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the booster model from thread-local rabit checkpoint.
|
* Load the booster model from thread-local rabit checkpoint.
|
||||||
* This is only used in distributed training.
|
* This is only used in distributed training.
|
||||||
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
|
|||||||
return handles;
|
return handles;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// making Booster serializable
|
||||||
|
private void writeObject(java.io.ObjectOutputStream out)
|
||||||
|
throws IOException {
|
||||||
|
try {
|
||||||
|
out.writeObject(this.toByteArray());
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
throw new IOException(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readObject(java.io.ObjectInputStream in)
|
||||||
|
throws IOException, ClassNotFoundException {
|
||||||
|
try {
|
||||||
|
this.init(null);
|
||||||
|
byte[] bytes = (byte[])in.readObject();
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
throw new IOException(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finalize() throws Throwable {
|
protected void finalize() throws Throwable {
|
||||||
super.finalize();
|
super.finalize();
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import java.lang.reflect.Field;
|
|||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* class to load native library
|
* class to load native library
|
||||||
*
|
*
|
||||||
@ -61,12 +62,32 @@ class NativeLibLoader {
|
|||||||
* three characters
|
* three characters
|
||||||
*/
|
*/
|
||||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||||
|
String temp = createTempFileFromResource(path);
|
||||||
|
// Finally, load the library
|
||||||
|
System.load(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a temp file that copies the resource from current JAR archive
|
||||||
|
* <p/>
|
||||||
|
* The file from JAR is copied into system temp file.
|
||||||
|
* The temporary file is deleted after exiting.
|
||||||
|
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||||
|
* <p/>
|
||||||
|
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||||
|
* {@code path}.
|
||||||
|
* @param path Path to the resources in the jar
|
||||||
|
* @return The created temp file.
|
||||||
|
* @throws IOException
|
||||||
|
* @throws IllegalArgumentException
|
||||||
|
*/
|
||||||
|
static String createTempFileFromResource(String path) throws
|
||||||
|
IOException, IllegalArgumentException {
|
||||||
|
// Obtain filename from path
|
||||||
if (!path.startsWith("/")) {
|
if (!path.startsWith("/")) {
|
||||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Obtain filename from path
|
|
||||||
String[] parts = path.split("/");
|
String[] parts = path.split("/");
|
||||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||||
|
|
||||||
@ -83,7 +104,6 @@ class NativeLibLoader {
|
|||||||
if (filename == null || prefix.length() < 3) {
|
if (filename == null || prefix.length() < 3) {
|
||||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare temporary file
|
// Prepare temporary file
|
||||||
File temp = File.createTempFile(prefix, suffix);
|
File temp = File.createTempFile(prefix, suffix);
|
||||||
temp.deleteOnExit();
|
temp.deleteOnExit();
|
||||||
@ -113,9 +133,7 @@ class NativeLibLoader {
|
|||||||
os.close();
|
os.close();
|
||||||
is.close();
|
is.close();
|
||||||
}
|
}
|
||||||
|
return temp.getAbsolutePath();
|
||||||
// Finally, load the library
|
|
||||||
System.load(temp.getAbsolutePath());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -133,8 +151,9 @@ class NativeLibLoader {
|
|||||||
try {
|
try {
|
||||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||||
loadLibraryFromJar(libraryFromJar);
|
loadLibraryFromJar(libraryFromJar);
|
||||||
} catch (IOException e1) {
|
} catch (IOException ioe) {
|
||||||
throw e1;
|
logger.error("failed to load library from both native path and jar");
|
||||||
|
throw ioe;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,144 @@
|
|||||||
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distributed RabitTracker, need to be started on driver code before running distributed jobs.
|
||||||
|
*/
|
||||||
|
public class RabitTracker {
|
||||||
|
// Maybe per tracker logger?
|
||||||
|
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||||
|
// tracker python file.
|
||||||
|
private static String tracker_py = null;
|
||||||
|
// environment variable to be pased.
|
||||||
|
private Map<String, String> envs = new HashMap<String, String>();
|
||||||
|
// number of workers to be submitted.
|
||||||
|
private int num_workers;
|
||||||
|
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
initTrackerPy();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
logger.error("load tracker library failed.");
|
||||||
|
logger.error(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tracker logger that logs output from tracker.
|
||||||
|
*/
|
||||||
|
private class TrackerProcessLogger implements Runnable {
|
||||||
|
public void run() {
|
||||||
|
|
||||||
|
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||||
|
trackerProcess.get().getErrorStream()));
|
||||||
|
String line;
|
||||||
|
try {
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
trackerProcessLogger.info(line);
|
||||||
|
}
|
||||||
|
} catch (IOException ex) {
|
||||||
|
trackerProcessLogger.error(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void initTrackerPy() throws IOException {
|
||||||
|
try {
|
||||||
|
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
logger.trace("cannot access tracker python script");
|
||||||
|
throw ioe;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public RabitTracker(int num_workers) {
|
||||||
|
this.num_workers = num_workers;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get environments that can be used to pass to worker.
|
||||||
|
* @return The environment settings.
|
||||||
|
*/
|
||||||
|
public Map<String, String> getWorkerEnvs() {
|
||||||
|
return envs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadEnvs(InputStream ins) throws IOException {
|
||||||
|
try {
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||||
|
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
String[] sep = line.split("=");
|
||||||
|
if (sep.length == 2) {
|
||||||
|
envs.put(sep[0], sep[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reader.close();
|
||||||
|
} catch (IOException ioe){
|
||||||
|
logger.error("cannot get runtime configuration from tracker process");
|
||||||
|
ioe.printStackTrace();
|
||||||
|
throw ioe;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean startTrackerProcess() {
|
||||||
|
try {
|
||||||
|
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||||
|
" --num-workers=" + String.valueOf(num_workers)));
|
||||||
|
loadEnvs(trackerProcess.get().getInputStream());
|
||||||
|
return true;
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
ioe.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void stop() {
|
||||||
|
if (trackerProcess.get() != null) {
|
||||||
|
trackerProcess.get().destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean start() {
|
||||||
|
if (startTrackerProcess()) {
|
||||||
|
logger.debug("Tracker started, with env=" + envs.toString());
|
||||||
|
// also start a tracker logger
|
||||||
|
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||||
|
logger_thread.setDaemon(true);
|
||||||
|
logger_thread.start();
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
logger.error("FAULT: failed to start tracker process");
|
||||||
|
stop();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void waitFor() {
|
||||||
|
try {
|
||||||
|
trackerProcess.get().waitFor();
|
||||||
|
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
||||||
|
stop();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
|
e.printStackTrace();
|
||||||
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -74,9 +74,9 @@ class XgboostJNI {
|
|||||||
|
|
||||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||||
|
|
||||||
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
|
||||||
|
|
||||||
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
|
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
|
||||||
|
|
||||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
||||||
String[][] out_strings);
|
String[][] out_strings);
|
||||||
|
|||||||
@ -23,7 +23,7 @@ import scala.collection.mutable
|
|||||||
import ml.dmlc.xgboost4j.XGBoostError
|
import ml.dmlc.xgboost4j.XGBoostError
|
||||||
|
|
||||||
|
|
||||||
trait Booster {
|
trait Booster extends Serializable {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -13,6 +13,8 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <xgboost/c_api.h>
|
#include <xgboost/c_api.h>
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
#include "./xgboost4j.h"
|
#include "./xgboost4j.h"
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -276,27 +278,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||||
DMatrixHandle* handles = NULL;
|
std::vector<DMatrixHandle> handles;
|
||||||
bst_ulong len = 0;
|
if (jhandles != nullptr) {
|
||||||
jlong* cjhandles = 0;
|
size_t len = jenv->GetArrayLength(jhandles);
|
||||||
BoosterHandle result;
|
jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
if (jhandles) {
|
handles.push_back((DMatrixHandle) cjhandles[i]);
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
|
||||||
handles = new DMatrixHandle[len];
|
|
||||||
//put handle from jhandles to chandles
|
|
||||||
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
|
||||||
handles[i] = (DMatrixHandle) cjhandles[i];
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
int ret = XGBoosterCreate(handles, len, &result);
|
|
||||||
//release
|
|
||||||
if (jhandles) {
|
|
||||||
delete[] handles;
|
|
||||||
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
||||||
}
|
}
|
||||||
|
BoosterHandle result;
|
||||||
|
int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -369,43 +361,34 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle* dmats = 0;
|
std::vector<DMatrixHandle> dmats;
|
||||||
char **evnames = 0;
|
std::vector<std::string> evnames;
|
||||||
char *result = 0;
|
std::vector<const char*> evchars;
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
|
||||||
if(len > 0) {
|
size_t len = static_cast<size_t>(jenv->GetArrayLength(jdmats));
|
||||||
dmats = new DMatrixHandle[len];
|
// put handle from jhandles to chandles
|
||||||
evnames = new char*[len];
|
|
||||||
}
|
|
||||||
//put handle from jhandles to chandles
|
|
||||||
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
|
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
dmats[i] = (DMatrixHandle) cjdmats[i];
|
dmats.push_back((DMatrixHandle) cjdmats[i]);
|
||||||
}
|
|
||||||
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
|
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
|
||||||
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
||||||
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
|
const char *s =jenv->GetStringUTFChars(jevname, 0);
|
||||||
evnames[i] = new char[jenv->GetStringLength(jevname)];
|
evnames.push_back(std::string(s, jenv->GetStringLength(jevname)));
|
||||||
strcpy(evnames[i], cevname);
|
if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s);
|
||||||
jenv->ReleaseStringUTFChars(jevname, cevname);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
|
|
||||||
if(len > 0) {
|
|
||||||
delete[] dmats;
|
|
||||||
//release string chars
|
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
|
||||||
delete[] evnames[i];
|
|
||||||
}
|
|
||||||
delete[] evnames;
|
|
||||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
evchars.push_back(evnames[i].c_str());
|
||||||
|
}
|
||||||
|
const char* result;
|
||||||
|
int ret = XGBoosterEvalOneIter(handle, jiter,
|
||||||
|
dmlc::BeginPtr(dmats),
|
||||||
|
dmlc::BeginPtr(evchars),
|
||||||
|
len, &result);
|
||||||
|
jstring jinfo = nullptr;
|
||||||
|
if (result != nullptr) {
|
||||||
|
jinfo = jenv->NewStringUTF(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
jstring jinfo = 0;
|
|
||||||
if (result) jinfo = jenv->NewStringUTF((const char *) result);
|
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,37 +439,40 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
|||||||
|
|
||||||
int ret = XGBoosterSaveModel(handle, fname);
|
int ret = XGBoosterSaveModel(handle, fname);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
* Method: XGBoosterLoadModelFromBuffer
|
* Method: XGBoosterLoadModelFromBuffer
|
||||||
* Signature: (JJJ)V
|
* Signature: (J[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
void *buf = (void*) jbuf;
|
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
|
||||||
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
|
int ret = XGBoosterLoadModelFromBuffer(
|
||||||
|
handle, buffer, jenv->GetArrayLength(jbytes));
|
||||||
|
jenv->ReleaseByteArrayElements(jbytes, buffer, 0);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterGetModelRaw
|
||||||
* Signature: (J)Ljava/lang/String;
|
* Signature: (J[[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
char *result;
|
const char* result;
|
||||||
|
int ret = XGBoosterGetModelRaw(handle, &len, &result);
|
||||||
|
|
||||||
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
|
|
||||||
if (result) {
|
if (result) {
|
||||||
jstring jinfo = jenv->NewStringUTF((const char *) result);
|
jbyteArray jarray = jenv->NewByteArray(len);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -553,15 +539,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
|||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||||
for (bst_ulong i = 0; i < len; ++i) {
|
for (bst_ulong i = 0; i < len; ++i) {
|
||||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||||
std::string s(jenv->GetStringUTFChars(arg, 0),
|
const char *s = jenv->GetStringUTFChars(arg, 0);
|
||||||
jenv->GetStringLength(arg));
|
args.push_back(std::string(s, jenv->GetStringLength(arg)));
|
||||||
if (s.length() != 0) args.push_back(s);
|
if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
|
||||||
|
if (args.back().length() == 0) args.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
argv.push_back(&args[i][0]);
|
argv.push_back(&args[i][0]);
|
||||||
}
|
}
|
||||||
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
|
|
||||||
|
RabitInit(args.size(), dmlc::BeginPtr(argv));
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -194,15 +194,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
|||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
* Method: XGBoosterLoadModelFromBuffer
|
* Method: XGBoosterLoadModelFromBuffer
|
||||||
* Signature: (JJJ)I
|
* Signature: (J[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
(JNIEnv *, jclass, jlong, jbyteArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterGetModelRaw
|
||||||
* Signature: (J[Ljava/lang/String;)I
|
* Signature: (J[[B)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv *, jclass, jlong, jobjectArray);
|
(JNIEnv *, jclass, jlong, jobjectArray);
|
||||||
|
|||||||
@ -4,12 +4,20 @@
|
|||||||
* \brief Enable all kinds of global variables in common.
|
* \brief Enable all kinds of global variables in common.
|
||||||
*/
|
*/
|
||||||
#include "./random.h"
|
#include "./random.h"
|
||||||
|
#include "./thread_local.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
/*! \brief thread local entry for random. */
|
||||||
|
struct RandomThreadLocalEntry {
|
||||||
|
/*! \brief the random engine instance. */
|
||||||
|
GlobalRandomEngine engine;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
|
||||||
|
|
||||||
GlobalRandomEngine& GlobalRandom() {
|
GlobalRandomEngine& GlobalRandom() {
|
||||||
static GlobalRandomEngine inst;
|
return RandomThreadLocalStore::Get()->engine;
|
||||||
return inst;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -61,7 +61,8 @@ typedef RandomEngine GlobalRandomEngine;
|
|||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief global singleton of a random engine.
|
* \brief global singleton of a random engine.
|
||||||
* Only use this engine when necessary, not thread-safe.
|
* This random engine is thread-local and
|
||||||
|
* only visible to current thread.
|
||||||
*/
|
*/
|
||||||
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@
|
|||||||
#ifndef XGBOOST_COMMON_THREAD_LOCAL_H_
|
#ifndef XGBOOST_COMMON_THREAD_LOCAL_H_
|
||||||
#define XGBOOST_COMMON_THREAD_LOCAL_H_
|
#define XGBOOST_COMMON_THREAD_LOCAL_H_
|
||||||
|
|
||||||
|
#include <dmlc/base.h>
|
||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace gbm {
|
namespace gbm {
|
||||||
@ -265,13 +266,11 @@ class GBTree : public GradientBooster {
|
|||||||
inline void InitUpdater() {
|
inline void InitUpdater() {
|
||||||
if (updaters.size() != 0) return;
|
if (updaters.size() != 0) return;
|
||||||
std::string tval = tparam.updater_seq;
|
std::string tval = tparam.updater_seq;
|
||||||
char *pstr;
|
std::vector<std::string> ups = common::Split(tval, ',');
|
||||||
pstr = std::strtok(&tval[0], ",");
|
for (const std::string& pstr : ups) {
|
||||||
while (pstr != nullptr) {
|
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str()));
|
||||||
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr));
|
|
||||||
up->Init(this->cfg);
|
up->Init(this->cfg);
|
||||||
updaters.push_back(std::move(up));
|
updaters.push_back(std::move(up));
|
||||||
pstr = std::strtok(nullptr, ",");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// do group specific group
|
// do group specific group
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user