> params) throws XGBoostError {
- dmats = new DMatrix[] {dtrain, dtest};
- booster = new Booster(params, dmats);
- names = new String[] {"train", "test"};
- this.dtrain = dtrain;
- this.dtest = dtest;
- }
-
- /**
- * update one iteration
- * @param iter iteration num
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public void update(int iter) throws XGBoostError {
- booster.update(dtrain, iter);
- }
-
- /**
- * update one iteration
- * @param iter iteration num
- * @param obj customized objective
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public void update(int iter, IObjective obj) throws XGBoostError {
- booster.update(dtrain, iter, obj);
- }
-
- /**
- * evaluation
- * @param iter iteration num
- * @return evaluation
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public String eval(int iter) throws XGBoostError {
- return booster.evalSet(dmats, names, iter);
- }
-
- /**
- * evaluation
- * @param iter iteration num
- * @param eval customized eval
- * @return evaluation
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public String eval(int iter, IEvaluation eval) throws XGBoostError {
- return booster.evalSet(dmats, names, iter, eval);
- }
-}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java
deleted file mode 100644
index aad9f6174..000000000
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j.util;
-
-import java.io.IOException;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.dmlc.xgboost4j.wrapper.XgboostJNI;
-
-/**
- * Error handle for Xgboost.
- */
-public class ErrorHandle {
- private static final Log logger = LogFactory.getLog(ErrorHandle.class);
-
- //load native library
- static {
- try {
- Initializer.InitXgboost();
- } catch (IOException ex) {
- logger.error("load native library failed.");
- logger.error(ex);
- }
- }
-
- /**
- * Check the return value of C API.
- * @param ret return valud of xgboostJNI C API call
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public static void checkCall(int ret) throws XGBoostError {
- if(ret != 0) {
- throw new XGBoostError(XgboostJNI.XGBGetLastError());
- }
- }
-}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java
deleted file mode 100644
index 5dbbe4b28..000000000
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j.util;
-
-import java.io.IOException;
-import java.lang.reflect.Field;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-/**
- * class to load native library
- * @author hzx
- */
-public class Initializer {
- private static final Log logger = LogFactory.getLog(Initializer.class);
-
- static boolean initialized = false;
- public static final String nativePath = "./lib";
- public static final String nativeResourcePath = "/lib/";
- public static final String[] libNames = new String[] {"xgboost4j"};
-
- public static synchronized void InitXgboost() throws IOException {
- if(initialized == false) {
- for(String libName: libNames) {
- smartLoad(libName);
- }
- initialized = true;
- }
- }
-
- /**
- * load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
- * @param libName library path
- * @throws IOException exception
- */
- private static void smartLoad(String libName) throws IOException {
- addNativeDir(nativePath);
- try {
- System.loadLibrary(libName);
- }
- catch (UnsatisfiedLinkError e) {
- try {
- NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName));
- }
- catch (IOException e1) {
- throw e1;
- }
- }
- }
-
- /**
- * Add libPath to java.library.path, then native library in libPath would be load properly
- * @param libPath library path
- * @throws IOException exception
- */
- public static void addNativeDir(String libPath) throws IOException {
- try {
- Field field = ClassLoader.class.getDeclaredField("usr_paths");
- field.setAccessible(true);
- String[] paths = (String[]) field.get(null);
- for (String path : paths) {
- if (libPath.equals(path)) {
- return;
- }
- }
- String[] tmp = new String[paths.length+1];
- System.arraycopy(paths,0,tmp,0,paths.length);
- tmp[paths.length] = libPath;
- field.set(null, tmp);
- } catch (IllegalAccessException e) {
- logger.error(e.getMessage());
- throw new IOException("Failed to get permissions to set library path");
- } catch (NoSuchFieldException e) {
- logger.error(e.getMessage());
- throw new IOException("Failed to get field handle to set library path");
- }
- }
-}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java
deleted file mode 100644
index 77e299fa2..000000000
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j.util;
-
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-
-
-/**
- * Simple library class for working with JNI (Java Native Interface)
- *
- * See
- * http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar
- *
- * Author Adam Heirnich <adam@adamh.cz>, http://www.adamh.cz
- */
-public class NativeUtils {
-
- /**
- * Private constructor - this class will never be instanced
- */
- private NativeUtils() {
- }
-
- /**
- * Loads library from current JAR archive
- *
- * The file from JAR is copied into system temporary directory and then loaded.
- * The temporary file is deleted after exiting.
- * Method uses String as filename because the pathname is "abstract", not system-dependent.
- *
- * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
- *
- * @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
- * @throws IOException If temporary file creation or read/write operation fails
- * @throws IllegalArgumentException If source file (param path) does not exist
- * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
- */
- public static void loadLibraryFromJar(String path) throws IOException {
-
- if (!path.startsWith("/")) {
- throw new IllegalArgumentException("The path has to be absolute (start with '/').");
- }
-
- // Obtain filename from path
- String[] parts = path.split("/");
- String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
-
- // Split filename to prexif and suffix (extension)
- String prefix = "";
- String suffix = null;
- if (filename != null) {
- parts = filename.split("\\.", 2);
- prefix = parts[0];
- suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-)
- }
-
- // Check if the filename is okay
- if (filename == null || prefix.length() < 3) {
- throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
- }
-
- // Prepare temporary file
- File temp = File.createTempFile(prefix, suffix);
- temp.deleteOnExit();
-
- if (!temp.exists()) {
- throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
- }
-
- // Prepare buffer for data copying
- byte[] buffer = new byte[1024];
- int readBytes;
-
- // Open and check input stream
- InputStream is = NativeUtils.class.getResourceAsStream(path);
- if (is == null) {
- throw new FileNotFoundException("File " + path + " was not found inside JAR.");
- }
-
- // Open output stream and copy data between source file in JAR and the temporary file
- OutputStream os = new FileOutputStream(temp);
- try {
- while ((readBytes = is.read(buffer)) != -1) {
- os.write(buffer, 0, readBytes);
- }
- } finally {
- // If read/write fails, close streams safely before throwing an exception
- os.close();
- is.close();
- }
-
- // Finally, load the library
- System.load(temp.getAbsolutePath());
- }
-}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java
deleted file mode 100644
index 994a8b4ac..000000000
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java
+++ /dev/null
@@ -1,238 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j.util;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.dmlc.xgboost4j.IEvaluation;
-import org.dmlc.xgboost4j.Booster;
-import org.dmlc.xgboost4j.DMatrix;
-import org.dmlc.xgboost4j.IObjective;
-
-
-/**
- * trainer for xgboost
- * @author hzx
- */
-public class Trainer {
- private static final Log logger = LogFactory.getLog(Trainer.class);
-
- /**
- * Train a booster with given parameters.
- * @param params Booster params.
- * @param dtrain Data to be trained.
- * @param round Number of boosting iterations.
- * @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
- * @param obj customized objective (set to null if not used)
- * @param eval customized evaluation (set to null if not used)
- * @return trained booster
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public static Booster train(Iterable> params, DMatrix dtrain, int round,
- Iterable> watchs, IObjective obj, IEvaluation eval) throws XGBoostError {
-
- //collect eval matrixs
- String[] evalNames;
- DMatrix[] evalMats;
- List names = new ArrayList<>();
- List mats = new ArrayList<>();
-
- for(Entry evalEntry : watchs) {
- names.add(evalEntry.getKey());
- mats.add(evalEntry.getValue());
- }
-
- evalNames = names.toArray(new String[names.size()]);
- evalMats = mats.toArray(new DMatrix[mats.size()]);
-
- //collect all data matrixs
- DMatrix[] allMats;
- if(evalMats!=null && evalMats.length>0) {
- allMats = new DMatrix[evalMats.length+1];
- allMats[0] = dtrain;
- System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
- }
- else {
- allMats = new DMatrix[1];
- allMats[0] = dtrain;
- }
-
- //initialize booster
- Booster booster = new Booster(params, allMats);
-
- //begin to train
- for(int iter=0; iter0) {
- String evalInfo;
- if(eval != null) {
- evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
- }
- else {
- evalInfo = booster.evalSet(evalMats, evalNames, iter);
- }
- logger.info(evalInfo);
- }
- }
- return booster;
- }
-
- /**
- * Cross-validation with given paramaters.
- * @param params Booster params.
- * @param data Data to be trained.
- * @param round Number of boosting iterations.
- * @param nfold Number of folds in CV.
- * @param metrics Evaluation metrics to be watched in CV.
- * @param obj customized objective (set to null if not used)
- * @param eval customized evaluation (set to null if not used)
- * @return evaluation history
- * @throws org.dmlc.xgboost4j.util.XGBoostError native error
- */
- public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XGBoostError {
- CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
- String[] evalHist = new String[round];
- String[] results = new String[cvPacks.length];
- for(int i=0; i> params, String[] evalMetrics) throws XGBoostError {
- List samples = genRandPermutationNums(0, (int) data.rowNum());
- int step = samples.size()/nfold;
- int[] testSlice = new int[step];
- int[] trainSlice = new int[samples.size()-step];
- int testid, trainid;
- CVPack[] cvPacks = new CVPack[nfold];
- for(int i=0; i(i*step) && j<(i*step+step) && testid genRandPermutationNums(int start, int end) {
- List samples = new ArrayList<>();
- for(int i=start; i > cvMap = new HashMap<>();
- String aggResult = results[0].split("\t")[0];
- for(String result : results) {
- String[] items = result.split("\t");
- for(int i=1; i());
- }
- cvMap.get(key).add(value);
- }
- }
-
- for(String key : cvMap.keySet()) {
- float value = 0f;
- for(Float tvalue : cvMap.get(key)) {
- value += tvalue;
- }
- value /= cvMap.get(key).size();
- aggResult += String.format("\tcv-%s:%f", key, value);
- }
-
- return aggResult;
- }
-}
diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
deleted file mode 100644
index 20c64b316..000000000
--- a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j;
-
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-import junit.framework.TestCase;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.dmlc.xgboost4j.util.Trainer;
-import org.dmlc.xgboost4j.util.XGBoostError;
-import org.junit.Test;
-
-/**
- * test cases for Booster
- * @author hzx
- */
-public class BoosterTest {
- public static class EvalError implements IEvaluation {
- private static final Log logger = LogFactory.getLog(EvalError.class);
-
- String evalMetric = "custom_error";
-
- public EvalError() {
- }
-
- @Override
- public String getMetric() {
- return evalMetric;
- }
-
- @Override
- public float eval(float[][] predicts, DMatrix dmat) {
- float error = 0f;
- float[] labels;
- try {
- labels = dmat.getLabel();
- } catch (XGBoostError ex) {
- logger.error(ex);
- return -1f;
- }
- int nrow = predicts.length;
- for(int i=0; i0) {
- error++;
- }
- else if(labels[i]==1f && predicts[i][0]<=0) {
- error++;
- }
- }
-
- return error/labels.length;
- }
- }
-
- @Test
- public void testBoosterBasic() throws XGBoostError {
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
-
- //set params
- Map paramMap = new HashMap() {
- {
- put("eta", 1.0);
- put("max_depth", 2);
- put("silent", 1);
- put("objective", "binary:logistic");
- }
- };
- Iterable> param = paramMap.entrySet();
-
- //set watchList
- List> watchs = new ArrayList<>();
- watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
- watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
-
- //set round
- int round = 2;
-
- //train a boost model
- Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
-
- //predict raw output
- float[][] predicts = booster.predict(testMat, true);
-
- //eval
- IEvaluation eval = new EvalError();
- //error must be less than 0.1
- TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
-
- //test dump model
-
- }
-
- /**
- * test cross valiation
- * @throws XGBoostError
- */
- @Test
- public void testCV() throws XGBoostError {
- //load train mat
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
-
- //set params
- Map param= new HashMap() {
- {
- put("eta", 1.0);
- put("max_depth", 3);
- put("silent", 1);
- put("nthread", 6);
- put("objective", "binary:logistic");
- put("gamma", 1.0);
- put("eval_metric", "error");
- }
- };
-
- //do 5-fold cross validation
- int round = 2;
- int nfold = 5;
- //set additional eval_metrics
- String[] metrics = null;
-
- String[] evalHist = Trainer.crossValiation(param.entrySet(), trainMat, round, nfold, metrics, null, null);
- }
-}
diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
deleted file mode 100644
index 343dd3ed9..000000000
--- a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-package org.dmlc.xgboost4j;
-
-import java.util.Arrays;
-import java.util.Random;
-import junit.framework.TestCase;
-import org.dmlc.xgboost4j.util.XGBoostError;
-import org.junit.Test;
-
-/**
- * test cases for DMatrix
- * @author hzx
- */
-public class DMatrixTest {
-
- @Test
- public void testCreateFromFile() throws XGBoostError {
- //create DMatrix from file
- DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
- //get label
- float[] labels = dmat.getLabel();
- //check length
- TestCase.assertTrue(dmat.rowNum()==labels.length);
- //set weights
- float[] weights = Arrays.copyOf(labels, labels.length);
- dmat.setWeight(weights);
- float[] dweights = dmat.getWeight();
- TestCase.assertTrue(Arrays.equals(weights, dweights));
- }
-
- @Test
- public void testCreateFromCSR() throws XGBoostError {
- //create Matrix from csr format sparse Matrix and labels
- /**
- * sparse matrix
- * 1 0 2 3 0
- * 4 0 2 3 5
- * 3 1 2 5 0
- */
- float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
- int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
- long[] rowHeaders = new long[] {0, 3, 7, 11};
- DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
- //check row num
- System.out.println(dmat1.rowNum());
- TestCase.assertTrue(dmat1.rowNum()==3);
- //test set label
- float[] label1 = new float[] {1, 0, 1};
- dmat1.setLabel(label1);
- float[] label2 = dmat1.getLabel();
- TestCase.assertTrue(Arrays.equals(label1, label2));
- }
-
- @Test
- public void testCreateFromDenseMatrix() throws XGBoostError {
- //create DMatrix from 10*5 dense matrix
- int nrow = 10;
- int ncol = 5;
- float[] data0 = new float[nrow*ncol];
- //put random nums
- Random random = new Random();
- for(int i=0; i
-/* Header for class org_dmlc_xgboost4j_wrapper_XgboostJNI */
-
-#ifndef _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
-#define _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
-#ifdef __cplusplus
-extern "C" {
-#endif
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBGetLastError
- * Signature: ()Ljava/lang/String;
- */
-JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
- (JNIEnv *, jclass);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixCreateFromFile
- * Signature: (Ljava/lang/String;I[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
- (JNIEnv *, jclass, jstring, jint, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixCreateFromCSR
- * Signature: ([J[I[F[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
- (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixCreateFromCSC
- * Signature: ([J[I[F[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
- (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixCreateFromMat
- * Signature: ([FIIF[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
- (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixSliceDMatrix
- * Signature: (J[I[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
- (JNIEnv *, jclass, jlong, jintArray, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixFree
- * Signature: (J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
- (JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixSaveBinary
- * Signature: (JLjava/lang/String;I)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
- (JNIEnv *, jclass, jlong, jstring, jint);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixSetFloatInfo
- * Signature: (JLjava/lang/String;[F)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
- (JNIEnv *, jclass, jlong, jstring, jfloatArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixSetUIntInfo
- * Signature: (JLjava/lang/String;[I)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
- (JNIEnv *, jclass, jlong, jstring, jintArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixSetGroup
- * Signature: (J[I)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
- (JNIEnv *, jclass, jlong, jintArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixGetFloatInfo
- * Signature: (JLjava/lang/String;[[F)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
- (JNIEnv *, jclass, jlong, jstring, jobjectArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixGetUIntInfo
- * Signature: (JLjava/lang/String;[[I)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
- (JNIEnv *, jclass, jlong, jstring, jobjectArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGDMatrixNumRow
- * Signature: (J[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
- (JNIEnv *, jclass, jlong, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterCreate
- * Signature: ([J[J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
- (JNIEnv *, jclass, jlongArray, jlongArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterFree
- * Signature: (J)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
- (JNIEnv *, jclass, jlong);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterSetParam
- * Signature: (JLjava/lang/String;Ljava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
- (JNIEnv *, jclass, jlong, jstring, jstring);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterUpdateOneIter
- * Signature: (JIJ)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
- (JNIEnv *, jclass, jlong, jint, jlong);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterBoostOneIter
- * Signature: (JJ[F[F)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
- (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterEvalOneIter
- * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
- (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterPredict
- * Signature: (JJIJ[[F)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
- (JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterLoadModel
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
- (JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterSaveModel
- * Signature: (JLjava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
- (JNIEnv *, jclass, jlong, jstring);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterLoadModelFromBuffer
- * Signature: (JJJ)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
- (JNIEnv *, jclass, jlong, jlong, jlong);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterGetModelRaw
- * Signature: (J[Ljava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
- (JNIEnv *, jclass, jlong, jobjectArray);
-
-/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
- * Method: XGBoosterDumpModel
- * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
- */
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
- (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
-
-#ifdef __cplusplus
-}
-#endif
-#endif
diff --git a/java/README.md b/jvm-packages/README.md
similarity index 100%
rename from java/README.md
rename to jvm-packages/README.md
diff --git a/jvm-packages/checkstyle-suppressions.xml b/jvm-packages/checkstyle-suppressions.xml
new file mode 100644
index 000000000..21550e139
--- /dev/null
+++ b/jvm-packages/checkstyle-suppressions.xml
@@ -0,0 +1,33 @@
+
+
+
+
+
+
+
+
+
diff --git a/jvm-packages/checkstyle.xml b/jvm-packages/checkstyle.xml
new file mode 100644
index 000000000..9583ec282
--- /dev/null
+++ b/jvm-packages/checkstyle.xml
@@ -0,0 +1,169 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/java/create_wrap.bat b/jvm-packages/create_jni.bat
similarity index 98%
rename from java/create_wrap.bat
rename to jvm-packages/create_jni.bat
index ce4d99327..cbc0681c1 100644
--- a/java/create_wrap.bat
+++ b/jvm-packages/create_jni.bat
@@ -17,4 +17,4 @@ exit
:end
echo "source library not found, please build it first from ..\windows\xgboost.sln"
pause
- exit
\ No newline at end of file
+ exit
diff --git a/java/create_wrap.sh b/jvm-packages/create_jni.sh
similarity index 81%
rename from java/create_wrap.sh
rename to jvm-packages/create_jni.sh
index fb3b1f149..13e6a8556 100755
--- a/java/create_wrap.sh
+++ b/jvm-packages/create_jni.sh
@@ -16,8 +16,8 @@ if [ $(uname) == "Darwin" ]; then
fi
cd ..
-make java no_omp=${dis_omp}
-cd java
+make jvm no_omp=${dis_omp}
+cd jvm-packages
echo "move native lib"
libPath="xgboost4j/src/main/resources/lib"
@@ -26,7 +26,7 @@ if [ ! -d "$libPath" ]; then
fi
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
-mv libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
+mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
popd > /dev/null
echo "complete"
diff --git a/java/doc/xgboost4j.md b/jvm-packages/doc/xgboost4j.md
similarity index 100%
rename from java/doc/xgboost4j.md
rename to jvm-packages/doc/xgboost4j.md
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
new file mode 100644
index 000000000..5ec221175
--- /dev/null
+++ b/jvm-packages/pom.xml
@@ -0,0 +1,117 @@
+
+
+ 4.0.0
+
+ org.dmlc
+ xgboostjvm
+ 0.1
+ pom
+
+ UTF-8
+ UTF-8
+ 1.7
+ 1.7
+ 3.3.9
+ 2.11.7
+ 2.11
+
+
+ xgboost4j
+ xgboost4j-demo
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+ 0.8.0
+
+ false
+ true
+ true
+ ${basedir}/src/main/scala
+ ${basedir}/src/test/scala
+ scalastyle-config.xml
+ UTF-8
+
+
+
+ checkstyle
+ validate
+
+ check
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+ 2.17
+
+ checkstyle.xml
+ true
+
+
+
+ checkstyle
+ validate
+
+ check
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.2
+
+
+ compile
+
+ compile
+
+ compile
+
+
+ test-compile
+
+ testCompile
+
+ test-compile
+
+
+ process-resources
+
+ compile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.19.1
+
+ -Djava.library.path=lib/
+
+
+
+
+
+
+ commons-logging
+ commons-logging
+ 1.2
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ 2.2.6
+ test
+
+
+
diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml
new file mode 100644
index 000000000..27bb4fa8a
--- /dev/null
+++ b/jvm-packages/scalastyle-config.xml
@@ -0,0 +1,291 @@
+
+
+
+
+ Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+ ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+
+
+
+ ^FunSuite[A-Za-z]*$
+ Tests must extend org.apache.spark.SparkFunSuite instead.
+
+
+
+
+ ^println$
+
+
+
+
+ @VisibleForTesting
+
+
+
+
+ Runtime\.getRuntime\.addShutdownHook
+
+
+
+
+ mutable\.SynchronizedBuffer
+
+
+
+
+ Class\.forName
+
+
+
+
+
+ JavaConversions
+ Instead of importing implicits in scala.collection.JavaConversions._, import
+ scala.collection.JavaConverters._ and use .asScala / .asJava methods
+
+
+
+
+ java,scala,3rdParty,spark
+ javax?\..*
+ scala\..*
+ (?!org\.apache\.spark\.).*
+ org\.apache\.spark\..*
+
+
+
+
+
+ COMMA
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 800>
+
+
+
+
+ 30
+
+
+
+
+ 10
+
+
+
+
+ 50
+
+
+
+
+
+
+
+
+
+
+ -1,0,1,2,3
+
+
+
diff --git a/java/xgboost4j-demo/LICENSE b/jvm-packages/xgboost4j-demo/LICENSE
similarity index 100%
rename from java/xgboost4j-demo/LICENSE
rename to jvm-packages/xgboost4j-demo/LICENSE
diff --git a/java/xgboost4j-demo/README.md b/jvm-packages/xgboost4j-demo/README.md
similarity index 100%
rename from java/xgboost4j-demo/README.md
rename to jvm-packages/xgboost4j-demo/README.md
diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml
new file mode 100644
index 000000000..d8e679b78
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/pom.xml
@@ -0,0 +1,26 @@
+
+
+ 4.0.0
+
+ org.dmlc
+ xgboostjvm
+ 0.1
+
+ xgboost4j-demo
+ 0.1
+ jar
+
+
+ org.dmlc
+ xgboost4j
+ 0.1
+
+
+ org.apache.commons
+ commons-lang3
+ 3.4
+
+
+
\ No newline at end of file
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java
new file mode 100644
index 000000000..af5dd8a86
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java
@@ -0,0 +1,120 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+import ml.dmlc.xgboost4j.demo.util.DataLoader;
+
+/**
+ * a simple example of java wrapper for xgboost
+ *
+ * @author hzx
+ */
+public class BasicWalkThrough {
+ public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
+ if (fPredicts.length != sPredicts.length) {
+ return false;
+ }
+
+ for (int i = 0; i < fPredicts.length; i++) {
+ if (!Arrays.equals(fPredicts[i], sPredicts[i])) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+
+ public static void main(String[] args) throws IOException, XGBoostError {
+ // load file from text file, also binary buffer generated by xgboost4j
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ HashMap params = new HashMap();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+
+
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //set round
+ int round = 2;
+
+ //train a boost model
+ Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
+
+ //predict
+ float[][] predicts = booster.predict(testMat);
+
+ //save model to modelPath
+ File file = new File("./model");
+ if (!file.exists()) {
+ file.mkdirs();
+ }
+
+ String modelPath = "./model/xgb.model";
+ booster.saveModel(modelPath);
+
+ //dump model
+ booster.dumpModel("./model/dump.raw.txt", false);
+
+ //dump model with feature map
+ booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
+
+ //save dmatrix into binary buffer
+ testMat.saveBinary("./model/dtest.buffer");
+
+ //reload model and data
+ Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
+ DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
+ float[][] predicts2 = booster2.predict(testMat2);
+
+
+ //check the two predicts
+ System.out.println(checkPredicts(predicts, predicts2));
+
+ System.out.println("start build dmatrix from csr sparse data ...");
+ //build dmatrix from CSR Sparse Matrix
+ DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
+
+ DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
+ DMatrix.SparseType.CSR);
+ trainMat2.setLabel(spData.labels);
+
+ //specify watchList
+ HashMap watches2 = new HashMap();
+ watches2.put("train", trainMat2);
+ watches2.put("test", testMat2);
+ Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
+ float[][] predicts3 = booster3.predict(testMat2);
+
+ //check predicts
+ System.out.println(checkPredicts(predicts, predicts3));
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java
new file mode 100644
index 000000000..335efc2d7
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java
@@ -0,0 +1,62 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+
+/**
+ * example for start from a initial base prediction
+ *
+ * @author hzx
+ */
+public class BoostFromPrediction {
+ public static void main(String[] args) throws XGBoostError {
+ System.out.println("start running example to start from a initial prediction");
+
+ // load file from text file, also binary buffer generated by xgboost4j
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ //specify parameters
+ HashMap params = new HashMap();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //train xgboost for 1 round
+ Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
+
+ float[][] trainPred = booster.predict(trainMat, true);
+ float[][] testPred = booster.predict(testMat, true);
+
+ trainMat.setBaseMargin(trainPred);
+ testMat.setBaseMargin(testPred);
+
+ System.out.println("result of running from initial prediction");
+ Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java
new file mode 100644
index 000000000..115b1dc5b
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java
@@ -0,0 +1,54 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+
+/**
+ * an example of cross validation
+ *
+ * @author hzx
+ */
+public class CrossValidation {
+ public static void main(String[] args) throws IOException, XGBoostError {
+ //load train mat
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+
+ //set params
+ HashMap params = new HashMap();
+
+ params.put("eta", 1.0);
+ params.put("max_depth", 3);
+ params.put("silent", 1);
+ params.put("nthread", 6);
+ params.put("objective", "binary:logistic");
+ params.put("gamma", 1.0);
+ params.put("eval_metric", "error");
+
+ //do 5-fold cross validation
+ int round = 2;
+ int nfold = 5;
+ //set additional eval_metrics
+ String[] metrics = null;
+
+ String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java
new file mode 100644
index 000000000..be09fd701
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java
@@ -0,0 +1,167 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import ml.dmlc.xgboost4j.*;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+/**
+ * an example user define objective and eval
+ * NOTE: when you do customized loss function, the default prediction value is margin
+ * this may make buildin evalution metric not function properly
+ * for example, we are doing logistic loss, the prediction is score before logistic transformation
+ * he buildin evaluation error assumes input is after logistic transformation
+ * Take this in mind when you use the customization, and maybe you need write customized evaluation
+ * function
+ *
+ * @author hzx
+ */
+public class CustomObjective {
+ /**
+ * loglikelihoode loss obj function
+ */
+ public static class LogRegObj implements IObjective {
+ private static final Log logger = LogFactory.getLog(LogRegObj.class);
+
+ /**
+ * simple sigmoid func
+ *
+ * @param input
+ * @return Note: this func is not concern about numerical stability, only used as example
+ */
+ public float sigmoid(float input) {
+ float val = (float) (1 / (1 + Math.exp(-input)));
+ return val;
+ }
+
+ public float[][] transform(float[][] predicts) {
+ int nrow = predicts.length;
+ float[][] transPredicts = new float[nrow][1];
+
+ for (int i = 0; i < nrow; i++) {
+ transPredicts[i][0] = sigmoid(predicts[i][0]);
+ }
+
+ return transPredicts;
+ }
+
+ @Override
+ public List getGradient(float[][] predicts, DMatrix dtrain) {
+ int nrow = predicts.length;
+ List gradients = new ArrayList();
+ float[] labels;
+ try {
+ labels = dtrain.getLabel();
+ } catch (XGBoostError ex) {
+ logger.error(ex);
+ return null;
+ }
+ float[] grad = new float[nrow];
+ float[] hess = new float[nrow];
+
+ float[][] transPredicts = transform(predicts);
+
+ for (int i = 0; i < nrow; i++) {
+ float predict = transPredicts[i][0];
+ grad[i] = predict - labels[i];
+ hess[i] = predict * (1 - predict);
+ }
+
+ gradients.add(grad);
+ gradients.add(hess);
+ return gradients;
+ }
+ }
+
+ /**
+ * user defined eval function.
+ * NOTE: when you do customized loss function, the default prediction value is margin
+ * this may make buildin evalution metric not function properly
+ * for example, we are doing logistic loss, the prediction is score before logistic transformation
+ * the buildin evaluation error assumes input is after logistic transformation
+ * Take this in mind when you use the customization, and maybe you need write customized
+ * evaluation function
+ */
+ public static class EvalError implements IEvaluation {
+ private static final Log logger = LogFactory.getLog(EvalError.class);
+
+ String evalMetric = "custom_error";
+
+ public EvalError() {
+ }
+
+ @Override
+ public String getMetric() {
+ return evalMetric;
+ }
+
+ @Override
+ public float eval(float[][] predicts, DMatrix dmat) {
+ float error = 0f;
+ float[] labels;
+ try {
+ labels = dmat.getLabel();
+ } catch (XGBoostError ex) {
+ logger.error(ex);
+ return -1f;
+ }
+ int nrow = predicts.length;
+ for (int i = 0; i < nrow; i++) {
+ if (labels[i] == 0f && predicts[i][0] > 0) {
+ error++;
+ } else if (labels[i] == 1f && predicts[i][0] <= 0) {
+ error++;
+ }
+ }
+
+ return error / labels.length;
+ }
+ }
+
+ public static void main(String[] args) throws XGBoostError {
+ //load train mat (svmlight format)
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ //load valid mat (svmlight format)
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ HashMap params = new HashMap();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+
+
+ //set round
+ int round = 2;
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //user define obj and eval
+ IObjective obj = new LogRegObj();
+ IEvaluation eval = new EvalError();
+
+ //train a booster
+ System.out.println("begin to train the booster model");
+ Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java
new file mode 100644
index 000000000..095382953
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java
@@ -0,0 +1,61 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+
+/**
+ * simple example for using external memory version
+ *
+ * @author hzx
+ */
+public class ExternalMemory {
+ public static void main(String[] args) throws XGBoostError {
+ //this is the only difference, add a # followed by a cache prefix name
+ //several cache file with the prefix will be generated
+ //currently only support convert from libsvm file
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
+
+ //specify parameters
+ HashMap params = new HashMap();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+
+ //performance notice: set nthread to be the number of your real cpu
+ //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
+ // set nthread=4
+ //param.put("nthread", num_real_cpu);
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //set round
+ int round = 2;
+
+ //train a boost model
+ Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java
new file mode 100644
index 000000000..8fae69032
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java
@@ -0,0 +1,70 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+import ml.dmlc.xgboost4j.demo.util.CustomEval;
+
+import java.util.HashMap;
+
+/**
+ * this is an example of fit generalized linear model in xgboost
+ * basically, we are using linear model, instead of tree for our boosters
+ *
+ * @author hzx
+ */
+public class GeneralizedLinearModel {
+ public static void main(String[] args) throws XGBoostError {
+ // load file from text file, also binary buffer generated by xgboost4j
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ //specify parameters
+ //change booster to gblinear, so that we are fitting a linear model
+ // alpha is the L1 regularizer
+ //lambda is the L2 regularizer
+ //you can also set lambda_bias which is L2 regularizer on the bias term
+ HashMap params = new HashMap();
+ params.put("alpha", 0.0001);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+ params.put("booster", "gblinear");
+
+ //normally, you do not need to set eta (step_size)
+ //XGBoost uses a parallel coordinate descent algorithm (shotgun),
+ //there could be affection on convergence with parallelization on certain cases
+ //setting eta to be smaller value, e.g 0.5 can make the optimization more stable
+ //param.put("eta", "0.5");
+
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //train a booster
+ int round = 4;
+ Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
+
+ float[][] predicts = booster.predict(testMat);
+
+ CustomEval eval = new CustomEval();
+ System.out.println("error=" + eval.eval(predicts, testMat));
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java
new file mode 100644
index 000000000..defa437d3
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java
@@ -0,0 +1,66 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+import ml.dmlc.xgboost4j.demo.util.CustomEval;
+
+/**
+ * predict first ntree
+ *
+ * @author hzx
+ */
+public class PredictFirstNtree {
+ public static void main(String[] args) throws XGBoostError {
+ // load file from text file, also binary buffer generated by xgboost4j
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ //specify parameters
+ HashMap params = new HashMap();
+
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+
+ //train a booster
+ int round = 3;
+ Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
+
+ //predict use 1 tree
+ float[][] predicts1 = booster.predict(testMat, false, 1);
+ //by default all trees are used to do predict
+ float[][] predicts2 = booster.predict(testMat);
+
+ //use a simple evaluation class to check error result
+ CustomEval eval = new CustomEval();
+ System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
+ System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java
new file mode 100644
index 000000000..d18987292
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java
@@ -0,0 +1,66 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+import ml.dmlc.xgboost4j.Booster;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.XGBoost;
+import ml.dmlc.xgboost4j.XGBoostError;
+
+/**
+ * predict leaf indices
+ *
+ * @author hzx
+ */
+public class PredictLeafIndices {
+ public static void main(String[] args) throws XGBoostError {
+ // load file from text file, also binary buffer generated by xgboost4j
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ //specify parameters
+ HashMap params = new HashMap();
+ params.put("eta", 1.0);
+ params.put("max_depth", 2);
+ params.put("silent", 1);
+ params.put("objective", "binary:logistic");
+
+ //specify watchList
+ HashMap watches = new HashMap();
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+
+ //train a booster
+ int round = 3;
+ Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
+
+ //predict using first 2 tree
+ float[][] leafindex = booster.predict(testMat, 2, true);
+ for (float[] leafs : leafindex) {
+ System.out.println(Arrays.toString(leafs));
+ }
+
+ //predict all trees
+ leafindex = booster.predict(testMat, 0, true);
+ for (float[] leafs : leafindex) {
+ System.out.println(Arrays.toString(leafs));
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java
new file mode 100644
index 000000000..31e841b03
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java
@@ -0,0 +1,60 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo.util;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import ml.dmlc.xgboost4j.DMatrix;
+import ml.dmlc.xgboost4j.IEvaluation;
+import ml.dmlc.xgboost4j.XGBoostError;
+
+/**
+ * a util evaluation class for examples
+ *
+ * @author hzx
+ */
+public class CustomEval implements IEvaluation {
+ private static final Log logger = LogFactory.getLog(CustomEval.class);
+
+ String evalMetric = "custom_error";
+
+ @Override
+ public String getMetric() {
+ return evalMetric;
+ }
+
+ @Override
+ public float eval(float[][] predicts, DMatrix dmat) {
+ float error = 0f;
+ float[] labels;
+ try {
+ labels = dmat.getLabel();
+ } catch (XGBoostError ex) {
+ logger.error(ex);
+ return -1f;
+ }
+ int nrow = predicts.length;
+ for (int i = 0; i < nrow; i++) {
+ if (labels[i] == 0f && predicts[i][0] > 0.5) {
+ error++;
+ } else if (labels[i] == 1f && predicts[i][0] <= 0.5) {
+ error++;
+ }
+ }
+
+ return error / labels.length;
+ }
+}
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java
new file mode 100644
index 000000000..0dcaca8c2
--- /dev/null
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java
@@ -0,0 +1,123 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.demo.util;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.*;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * util class for loading data
+ *
+ * @author hzx
+ */
+public class DataLoader {
+ public static class DenseData {
+ public float[] labels;
+ public float[] data;
+ public int nrow;
+ public int ncol;
+ }
+
+ public static class CSRSparseData {
+ public float[] labels;
+ public float[] data;
+ public long[] rowHeaders;
+ public int[] colIndex;
+ }
+
+ public static DenseData loadCSVFile(String filePath) throws IOException {
+ DenseData denseData = new DenseData();
+
+ File f = new File(filePath);
+ FileInputStream in = new FileInputStream(f);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
+
+ denseData.nrow = 0;
+ denseData.ncol = -1;
+ String line;
+ List tlabels = new ArrayList<>();
+ List tdata = new ArrayList<>();
+
+ while ((line = reader.readLine()) != null) {
+ String[] items = line.trim().split(",");
+ if (items.length == 0) {
+ continue;
+ }
+ denseData.nrow++;
+ if (denseData.ncol == -1) {
+ denseData.ncol = items.length - 1;
+ }
+
+ tlabels.add(Float.valueOf(items[items.length - 1]));
+ for (int i = 0; i < items.length - 1; i++) {
+ tdata.add(Float.valueOf(items[i]));
+ }
+ }
+
+ reader.close();
+ in.close();
+
+ denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
+ denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
+
+ return denseData;
+ }
+
+ public static CSRSparseData loadSVMFile(String filePath) throws IOException {
+ CSRSparseData spData = new CSRSparseData();
+
+ List tlabels = new ArrayList<>();
+ List tdata = new ArrayList<>();
+ List theaders = new ArrayList<>();
+ List tindex = new ArrayList<>();
+
+ File f = new File(filePath);
+ FileInputStream in = new FileInputStream(f);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
+
+ String line;
+ long rowheader = 0;
+ theaders.add(rowheader);
+ while ((line = reader.readLine()) != null) {
+ String[] items = line.trim().split(" ");
+ if (items.length == 0) {
+ continue;
+ }
+
+ rowheader += items.length - 1;
+ theaders.add(rowheader);
+ tlabels.add(Float.valueOf(items[0]));
+
+ for (int i = 1; i < items.length; i++) {
+ String[] tup = items[i].split(":");
+ assert tup.length == 2;
+
+ tdata.add(Float.valueOf(tup[1]));
+ tindex.add(Integer.valueOf(tup[0]));
+ }
+ }
+
+ spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
+ spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
+ spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
+ spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
+
+ return spData;
+ }
+}
diff --git a/java/xgboost4j/LICENSE b/jvm-packages/xgboost4j/LICENSE
similarity index 100%
rename from java/xgboost4j/LICENSE
rename to jvm-packages/xgboost4j/LICENSE
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
new file mode 100644
index 000000000..271d918f9
--- /dev/null
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -0,0 +1,35 @@
+
+
+ 4.0.0
+
+ org.dmlc
+ xgboostjvm
+ 0.1
+
+ xgboost4j
+ 0.1
+ jar
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 2.10.3
+
+ protected
+ true
+
+
+
+
+
+
+ junit
+ junit
+ 4.11
+ test
+
+
+
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java
new file mode 100644
index 000000000..e234fef60
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java
@@ -0,0 +1,153 @@
+package ml.dmlc.xgboost4j;
+
+import java.io.IOException;
+import java.util.Map;
+
+public interface Booster {
+
+ /**
+ * set parameter
+ *
+ * @param key param name
+ * @param value param value
+ */
+ void setParam(String key, String value) throws XGBoostError;
+
+ /**
+ * set parameters
+ *
+ * @param params parameters key-value map
+ */
+ void setParams(Map params) throws XGBoostError;
+
+ /**
+ * Update (one iteration)
+ *
+ * @param dtrain training data
+ * @param iter current iteration number
+ */
+ void update(DMatrix dtrain, int iter) throws XGBoostError;
+
+ /**
+ * update with customize obj func
+ *
+ * @param dtrain training data
+ * @param obj customized objective class
+ */
+ void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
+
+ /**
+ * update with give grad and hess
+ *
+ * @param dtrain training data
+ * @param grad first order of gradient
+ * @param hess seconde order of gradient
+ */
+ void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
+
+ /**
+ * evaluate with given dmatrixs.
+ *
+ * @param evalMatrixs dmatrixs for evaluation
+ * @param evalNames name for eval dmatrixs, used for check results
+ * @param iter current eval iteration
+ * @return eval information
+ */
+ String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
+
+ /**
+ * evaluate with given customized Evaluation class
+ *
+ * @param evalMatrixs evaluation matrix
+ * @param evalNames evaluation names
+ * @param eval custom evaluator
+ * @return eval information
+ */
+ String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @return predict result
+ */
+ float[][] predict(DMatrix data) throws XGBoostError;
+
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @return predict result
+ */
+ float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
+
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @return predict result
+ */
+ float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
+
+
+ /**
+ * Predict with data
+ * @param data dmatrix storing the input
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
+ * nsample = data.numRow with each record indicating the predicted leaf index of
+ * each sample in each tree. Note that the leaf index of a tree is unique per
+ * tree, so you may find leaf 1 in both tree 1 and tree 0.
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
+
+ /**
+ * save model to modelPath
+ *
+ * @param modelPath model path
+ */
+ void saveModel(String modelPath) throws XGBoostError;
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param withStats bool Controls whether the split statistics are output.
+ */
+ void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param featureMap featureMap file
+ * @param withStats bool
+ * Controls whether the split statistics are output.
+ */
+ void dumpModel(String modelPath, String featureMap, boolean withStats)
+ throws IOException, XGBoostError;
+
+ /**
+ * get importance of each feature
+ *
+ * @return featureMap key: feature index, value: feature importance score
+ */
+ Map getFeatureScore() throws XGBoostError ;
+
+ /**
+ * get importance of each feature
+ *
+ * @param featureMap file to save dumped model info
+ * @return featureMap key: feature index, value: feature importance score
+ */
+ Map getFeatureScore(String featureMap) throws XGBoostError;
+
+ void dispose();
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java
new file mode 100644
index 000000000..4b498caf1
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java
@@ -0,0 +1,256 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.io.IOException;
+
+/**
+ * DMatrix for xgboost, similar to the python wrapper xgboost.py
+ *
+ * @author hzx
+ */
+public class DMatrix {
+ private static final Log logger = LogFactory.getLog(DMatrix.class);
+ private long handle = 0;
+
+ //load native library
+ static {
+ try {
+ NativeLibLoader.initXgBoost();
+ } catch (IOException ex) {
+ logger.error("load native library failed.");
+ logger.error(ex);
+ }
+ }
+
+ /**
+ * sparse matrix type (CSR or CSC)
+ */
+ public static enum SparseType {
+ CSR,
+ CSC;
+ }
+
+ public DMatrix(String dataPath) throws XGBoostError {
+ if (dataPath == null) {
+ throw new NullPointerException("dataPath: null");
+ }
+ long[] out = new long[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
+ handle = out[0];
+ }
+
+ public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
+ long[] out = new long[1];
+ if (st == SparseType.CSR) {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
+ } else if (st == SparseType.CSC) {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
+ } else {
+ throw new UnknownError("unknow sparsetype");
+ }
+ handle = out[0];
+ }
+
+ /**
+ * create DMatrix from dense matrix
+ *
+ * @param data data values
+ * @param nrow number of rows
+ * @param ncol number of columns
+ * @throws XGBoostError native error
+ */
+ public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
+ long[] out = new long[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
+ handle = out[0];
+ }
+
+ /**
+ * used for DMatrix slice
+ */
+ protected DMatrix(long handle) {
+ this.handle = handle;
+ }
+
+
+ /**
+ * set label of dmatrix
+ *
+ * @param labels labels
+ * @throws XGBoostError native error
+ */
+ public void setLabel(float[] labels) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
+ }
+
+ /**
+ * set weight of each instance
+ *
+ * @param weights weights
+ * @throws XGBoostError native error
+ */
+ public void setWeight(float[] weights) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
+ }
+
+ /**
+ * if specified, xgboost will start from this init margin
+ * can be used to specify initial prediction to boost from
+ *
+ * @param baseMargin base margin
+ * @throws XGBoostError native error
+ */
+ public void setBaseMargin(float[] baseMargin) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
+ }
+
+ /**
+ * if specified, xgboost will start from this init margin
+ * can be used to specify initial prediction to boost from
+ *
+ * @param baseMargin base margin
+ * @throws XGBoostError native error
+ */
+ public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
+ float[] flattenMargin = flatten(baseMargin);
+ setBaseMargin(flattenMargin);
+ }
+
+ /**
+ * Set group sizes of DMatrix (used for ranking)
+ *
+ * @param group group size as array
+ * @throws XGBoostError native error
+ */
+ public void setGroup(int[] group) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
+ }
+
+ private float[] getFloatInfo(String field) throws XGBoostError {
+ float[][] infos = new float[1][];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
+ return infos[0];
+ }
+
+ private int[] getIntInfo(String field) throws XGBoostError {
+ int[][] infos = new int[1][];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
+ return infos[0];
+ }
+
+ /**
+ * get label values
+ *
+ * @return label
+ * @throws XGBoostError native error
+ */
+ public float[] getLabel() throws XGBoostError {
+ return getFloatInfo("label");
+ }
+
+ /**
+ * get weight of the DMatrix
+ *
+ * @return weights
+ * @throws XGBoostError native error
+ */
+ public float[] getWeight() throws XGBoostError {
+ return getFloatInfo("weight");
+ }
+
+ /**
+ * get base margin of the DMatrix
+ *
+ * @return base margin
+ * @throws XGBoostError native error
+ */
+ public float[] getBaseMargin() throws XGBoostError {
+ return getFloatInfo("base_margin");
+ }
+
+ /**
+ * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
+ *
+ * @param rowIndex row index
+ * @return sliced new DMatrix
+ * @throws XGBoostError native error
+ */
+ public DMatrix slice(int[] rowIndex) throws XGBoostError {
+ long[] out = new long[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
+ long sHandle = out[0];
+ DMatrix sMatrix = new DMatrix(sHandle);
+ return sMatrix;
+ }
+
+ /**
+ * get the row number of DMatrix
+ *
+ * @return number of rows
+ * @throws XGBoostError native error
+ */
+ public long rowNum() throws XGBoostError {
+ long[] rowNum = new long[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum));
+ return rowNum[0];
+ }
+
+ /**
+ * save DMatrix to filePath
+ */
+ public void saveBinary(String filePath) {
+ XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
+ }
+
+ /**
+ * Get the handle
+ */
+ public long getHandle() {
+ return handle;
+ }
+
+ /**
+ * flatten a mat to array
+ */
+ private static float[] flatten(float[][] mat) {
+ int size = 0;
+ for (float[] array : mat) size += array.length;
+ float[] result = new float[size];
+ int pos = 0;
+ for (float[] ar : mat) {
+ System.arraycopy(ar, 0, result, pos, ar.length);
+ pos += ar.length;
+ }
+
+ return result;
+ }
+
+ @Override
+ protected void finalize() {
+ dispose();
+ }
+
+ public synchronized void dispose() {
+ if (handle != 0) {
+ XgboostJNI.XGDMatrixFree(handle);
+ handle = 0;
+ }
+ }
+}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java
similarity index 59%
rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java
rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java
index 3793bff41..079cd057e 100644
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java
@@ -1,10 +1,10 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -13,29 +13,27 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
-package org.dmlc.xgboost4j;
+package ml.dmlc.xgboost4j;
/**
* interface for customized evaluation
- *
+ *
* @author hzx
*/
public interface IEvaluation {
- /**
- * get evaluate metric
- *
- * @return evalMetric
- */
- public abstract String getMetric();
+ /**
+ * get evaluate metric
+ *
+ * @return evalMetric
+ */
+ String getMetric();
- /**
- * evaluate with predicts and data
- *
- * @param predicts
- * predictions as array
- * @param dmat
- * data matrix to evaluate
- * @return result of the metric
- */
- public abstract float eval(float[][] predicts, DMatrix dmat);
+ /**
+ * evaluate with predicts and data
+ *
+ * @param predicts predictions as array
+ * @param dmat data matrix to evaluate
+ * @return result of the metric
+ */
+ float eval(float[][] predicts, DMatrix dmat);
}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
similarity index 60%
rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java
rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
index 640f46e6d..97ef9aed4 100644
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
@@ -1,10 +1,10 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -13,20 +13,22 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
-package org.dmlc.xgboost4j;
+package ml.dmlc.xgboost4j;
import java.util.List;
/**
* interface for customize Object function
+ *
* @author hzx
*/
public interface IObjective {
- /**
- * user define objective function, return gradient and second order gradient
- * @param predicts untransformed margin predicts
- * @param dtrain training data
- * @return List with two float array, correspond to first order grad and second order grad
- */
- public abstract List getGradient(float[][] predicts, DMatrix dtrain);
+ /**
+ * user define objective function, return gradient and second order gradient
+ *
+ * @param predicts untransformed margin predicts
+ * @param dtrain training data
+ * @return List with two float array, correspond to first order grad and second order grad
+ */
+ List getGradient(float[][] predicts, DMatrix dtrain);
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java
new file mode 100644
index 000000000..06474dbb4
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java
@@ -0,0 +1,51 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.io.IOException;
+
+/**
+ * Error handle for Xgboost.
+ */
+class JNIErrorHandle {
+
+ private static final Log logger = LogFactory.getLog(DMatrix.class);
+
+ //load native library
+ static {
+ try {
+ NativeLibLoader.initXgBoost();
+ } catch (IOException ex) {
+ logger.error("load native library failed.");
+ logger.error(ex);
+ }
+ }
+
+ /**
+ * Check the return value of C API.
+ *
+ * @param ret return valud of xgboostJNI C API call
+ * @throws XGBoostError native error
+ */
+ static void checkCall(int ret) throws XGBoostError {
+ if (ret != 0) {
+ throw new XGBoostError(XgboostJNI.XGBGetLastError());
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java
new file mode 100644
index 000000000..321b7fead
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java
@@ -0,0 +1,470 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.io.*;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+
+/**
+ * Booster for xgboost, similar to the python wrapper xgboost.py
+ * but custom obj function and eval function not supported at present.
+ *
+ * @author hzx
+ */
+class JavaBoosterImpl implements Booster {
+ private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
+
+ long handle = 0;
+
+ //load native library
+ static {
+ try {
+ NativeLibLoader.initXgBoost();
+ } catch (IOException ex) {
+ logger.error("load native library failed.");
+ logger.error(ex);
+ }
+ }
+
+ /**
+ * init Booster from dMatrixs
+ *
+ * @param params parameters
+ * @param dMatrixs DMatrix array
+ * @throws XGBoostError native error
+ */
+ JavaBoosterImpl(Map params, DMatrix[] dMatrixs) throws XGBoostError {
+ init(dMatrixs);
+ setParam("seed", "0");
+ setParams(params);
+ }
+
+
+ /**
+ * load model from modelPath
+ *
+ * @param params parameters
+ * @param modelPath booster modelPath (model generated by booster.saveModel)
+ * @throws XGBoostError native error
+ */
+ JavaBoosterImpl(Map params, String modelPath) throws XGBoostError {
+ init(null);
+ if (modelPath == null) {
+ throw new NullPointerException("modelPath : null");
+ }
+ loadModel(modelPath);
+ setParam("seed", "0");
+ setParams(params);
+ }
+
+
+ private void init(DMatrix[] dMatrixs) throws XGBoostError {
+ long[] handles = null;
+ if (dMatrixs != null) {
+ handles = dmatrixsToHandles(dMatrixs);
+ }
+ long[] out = new long[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
+
+ handle = out[0];
+ }
+
+ /**
+ * set parameter
+ *
+ * @param key param name
+ * @param value param value
+ * @throws XGBoostError native error
+ */
+ public final void setParam(String key, String value) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
+ }
+
+ /**
+ * set parameters
+ *
+ * @param params parameters key-value map
+ * @throws XGBoostError native error
+ */
+ public void setParams(Map params) throws XGBoostError {
+ if (params != null) {
+ for (Map.Entry entry : params.entrySet()) {
+ setParam(entry.getKey(), entry.getValue().toString());
+ }
+ }
+ }
+
+
+ /**
+ * Update (one iteration)
+ *
+ * @param dtrain training data
+ * @param iter current iteration number
+ * @throws XGBoostError native error
+ */
+ public void update(DMatrix dtrain, int iter) throws XGBoostError {
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
+ }
+
+ /**
+ * update with customize obj func
+ *
+ * @param dtrain training data
+ * @param obj customized objective class
+ * @throws XGBoostError native error
+ */
+ public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
+ float[][] predicts = predict(dtrain, true);
+ List gradients = obj.getGradient(predicts, dtrain);
+ boost(dtrain, gradients.get(0), gradients.get(1));
+ }
+
+ /**
+ * update with give grad and hess
+ *
+ * @param dtrain training data
+ * @param grad first order of gradient
+ * @param hess seconde order of gradient
+ * @throws XGBoostError native error
+ */
+ public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
+ if (grad.length != hess.length) {
+ throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
+ hess.length));
+ }
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
+ hess));
+ }
+
+ /**
+ * evaluate with given dmatrixs.
+ *
+ * @param evalMatrixs dmatrixs for evaluation
+ * @param evalNames name for eval dmatrixs, used for check results
+ * @param iter current eval iteration
+ * @return eval information
+ * @throws XGBoostError native error
+ */
+ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
+ long[] handles = dmatrixsToHandles(evalMatrixs);
+ String[] evalInfo = new String[1];
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
+ evalInfo));
+ return evalInfo[0];
+ }
+
+ /**
+ * evaluate with given customized Evaluation class
+ *
+ * @param evalMatrixs evaluation matrix
+ * @param evalNames evaluation names
+ * @param eval custom evaluator
+ * @return eval information
+ * @throws XGBoostError native error
+ */
+ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
+ throws XGBoostError {
+ String evalInfo = "";
+ for (int i = 0; i < evalNames.length; i++) {
+ String evalName = evalNames[i];
+ DMatrix evalMat = evalMatrixs[i];
+ float evalResult = eval.eval(predict(evalMat), evalMat);
+ String evalMetric = eval.getMetric();
+ evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
+ }
+ return evalInfo;
+ }
+
+ /**
+ * base function for Predict
+ *
+ * @param data data
+ * @param outPutMargin output margin
+ * @param treeLimit limit number of trees
+ * @param predLeaf prediction minimum to keep leafs
+ * @return predict results
+ */
+ private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
+ boolean predLeaf) throws XGBoostError {
+ int optionMask = 0;
+ if (outPutMargin) {
+ optionMask = 1;
+ }
+ if (predLeaf) {
+ optionMask = 2;
+ }
+ float[][] rawPredicts = new float[1][];
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
+ treeLimit, rawPredicts));
+ int row = (int) data.rowNum();
+ int col = rawPredicts[0].length / row;
+ float[][] predicts = new float[row][col];
+ int r, c;
+ for (int i = 0; i < rawPredicts[0].length; i++) {
+ r = i / col;
+ c = i % col;
+ predicts[r][c] = rawPredicts[0][i];
+ }
+ return predicts;
+ }
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ public float[][] predict(DMatrix data) throws XGBoostError {
+ return pred(data, false, 0, false);
+ }
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
+ return pred(data, outPutMargin, 0, false);
+ }
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
+ return pred(data, outPutMargin, treeLimit, false);
+ }
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
+ * nsample = data.numRow with each record indicating the predicted leaf index
+ * of each sample in each tree.
+ * Note that the leaf index of a tree is unique per tree, so you may find leaf 1
+ * in both tree 1 and tree 0.
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
+ return pred(data, false, treeLimit, predLeaf);
+ }
+
+ /**
+ * save model to modelPath
+ *
+ * @param modelPath model path
+ */
+ public void saveModel(String modelPath) throws XGBoostError{
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
+ }
+
+ private void loadModel(String modelPath) {
+ XgboostJNI.XGBoosterLoadModel(handle, modelPath);
+ }
+
+ /**
+ * get the dump of the model as a string array
+ *
+ * @param withStats Controls whether the split statistics are output.
+ * @return dumped model information
+ * @throws XGBoostError native error
+ */
+ private String[] getDumpInfo(boolean withStats) throws XGBoostError {
+ int statsFlag = 0;
+ if (withStats) {
+ statsFlag = 1;
+ }
+ String[][] modelInfos = new String[1][];
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
+ return modelInfos[0];
+ }
+
+ /**
+ * get the dump of the model as a string array
+ *
+ * @param featureMap featureMap file
+ * @param withStats Controls whether the split statistics are output.
+ * @return dumped model information
+ * @throws XGBoostError native error
+ */
+ private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
+ int statsFlag = 0;
+ if (withStats) {
+ statsFlag = 1;
+ }
+ String[][] modelInfos = new String[1][];
+ JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
+ modelInfos));
+ return modelInfos[0];
+ }
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param withStats bool
+ * Controls whether the split statistics are output.
+ * @throws FileNotFoundException file not found
+ * @throws UnsupportedEncodingException unsupported feature
+ * @throws IOException error with model writing
+ * @throws XGBoostError native error
+ */
+ public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
+ File tf = new File(modelPath);
+ FileOutputStream out = new FileOutputStream(tf);
+ BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
+ String[] modelInfos = getDumpInfo(withStats);
+
+ for (int i = 0; i < modelInfos.length; i++) {
+ writer.write("booster [" + i + "]:\n");
+ writer.write(modelInfos[i]);
+ }
+
+ writer.close();
+ out.close();
+ }
+
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param featureMap featureMap file
+ * @param withStats bool
+ * Controls whether the split statistics are output.
+ * @throws FileNotFoundException exception
+ * @throws UnsupportedEncodingException exception
+ * @throws IOException exception
+ * @throws XGBoostError native error
+ */
+ public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
+ IOException, XGBoostError {
+ File tf = new File(modelPath);
+ FileOutputStream out = new FileOutputStream(tf);
+ BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
+ String[] modelInfos = getDumpInfo(featureMap, withStats);
+
+ for (int i = 0; i < modelInfos.length; i++) {
+ writer.write("booster [" + i + "]:\n");
+ writer.write(modelInfos[i]);
+ }
+
+ writer.close();
+ out.close();
+ }
+
+
+ /**
+ * get importance of each feature
+ *
+ * @return featureMap key: feature index, value: feature importance score
+ * @throws XGBoostError native error
+ */
+ public Map getFeatureScore() throws XGBoostError {
+ String[] modelInfos = getDumpInfo(false);
+ Map featureScore = new HashMap();
+ for (String tree : modelInfos) {
+ for (String node : tree.split("\n")) {
+ String[] array = node.split("\\[");
+ if (array.length == 1) {
+ continue;
+ }
+ String fid = array[1].split("\\]")[0];
+ fid = fid.split("<")[0];
+ if (featureScore.containsKey(fid)) {
+ featureScore.put(fid, 1 + featureScore.get(fid));
+ } else {
+ featureScore.put(fid, 1);
+ }
+ }
+ }
+ return featureScore;
+ }
+
+
+ /**
+ * get importance of each feature
+ *
+ * @param featureMap file to save dumped model info
+ * @return featureMap key: feature index, value: feature importance score
+ * @throws XGBoostError native error
+ */
+ public Map getFeatureScore(String featureMap) throws XGBoostError {
+ String[] modelInfos = getDumpInfo(featureMap, false);
+ Map featureScore = new HashMap();
+ for (String tree : modelInfos) {
+ for (String node : tree.split("\n")) {
+ String[] array = node.split("\\[");
+ if (array.length == 1) {
+ continue;
+ }
+ String fid = array[1].split("\\]")[0];
+ fid = fid.split("<")[0];
+ if (featureScore.containsKey(fid)) {
+ featureScore.put(fid, 1 + featureScore.get(fid));
+ } else {
+ featureScore.put(fid, 1);
+ }
+ }
+ }
+ return featureScore;
+ }
+
+ /**
+ * transfer DMatrix array to handle array (used for native functions)
+ *
+ * @param dmatrixs
+ * @return handle array for input dmatrixs
+ */
+ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
+ long[] handles = new long[dmatrixs.length];
+ for (int i = 0; i < dmatrixs.length; i++) {
+ handles[i] = dmatrixs[i].getHandle();
+ }
+ return handles;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ super.finalize();
+ dispose();
+ }
+
+ public synchronized void dispose() {
+ if (handle != 0L) {
+ XgboostJNI.XGBoosterFree(handle);
+ handle = 0;
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java
new file mode 100644
index 000000000..85e60b3ef
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java
@@ -0,0 +1,170 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.io.*;
+import java.lang.reflect.Field;
+
+/**
+ * class to load native library
+ *
+ * @author hzx
+ */
+class NativeLibLoader {
+ private static final Log logger = LogFactory.getLog(NativeLibLoader.class);
+
+ private static boolean initialized = false;
+ private static final String nativePath = "../lib/";
+ private static final String nativeResourcePath = "/lib/";
+ private static final String[] libNames = new String[]{"xgboost4j"};
+
+ public static synchronized void initXgBoost() throws IOException {
+ if (!initialized) {
+ for (String libName : libNames) {
+ smartLoad(libName);
+ }
+ initialized = true;
+ }
+ }
+
+ /**
+ * Loads library from current JAR archive
+ *
+ * The file from JAR is copied into system temporary directory and then loaded.
+ * The temporary file is deleted after exiting.
+ * Method uses String as filename because the pathname is "abstract", not system-dependent.
+ *
+ * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
+ * {@code path}.
+ *
+ * @param path The filename inside JAR as absolute path (beginning with '/'),
+ * e.g. /package/File.ext
+ * @throws IOException If temporary file creation or read/write operation fails
+ * @throws IllegalArgumentException If source file (param path) does not exist
+ * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
+ * three characters
+ */
+ private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
+
+ if (!path.startsWith("/")) {
+ throw new IllegalArgumentException("The path has to be absolute (start with '/').");
+ }
+
+ // Obtain filename from path
+ String[] parts = path.split("/");
+ String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
+
+ // Split filename to prexif and suffix (extension)
+ String prefix = "";
+ String suffix = null;
+ if (filename != null) {
+ parts = filename.split("\\.", 2);
+ prefix = parts[0];
+ suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
+ }
+
+ // Check if the filename is okay
+ if (filename == null || prefix.length() < 3) {
+ throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
+ }
+
+ // Prepare temporary file
+ File temp = File.createTempFile(prefix, suffix);
+ temp.deleteOnExit();
+
+ if (!temp.exists()) {
+ throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
+ }
+
+ // Prepare buffer for data copying
+ byte[] buffer = new byte[1024];
+ int readBytes;
+
+ // Open and check input stream
+ InputStream is = NativeLibLoader.class.getResourceAsStream(path);
+ if (is == null) {
+ throw new FileNotFoundException("File " + path + " was not found inside JAR.");
+ }
+
+ // Open output stream and copy data between source file in JAR and the temporary file
+ OutputStream os = new FileOutputStream(temp);
+ try {
+ while ((readBytes = is.read(buffer)) != -1) {
+ os.write(buffer, 0, readBytes);
+ }
+ } finally {
+ // If read/write fails, close streams safely before throwing an exception
+ os.close();
+ is.close();
+ }
+
+ // Finally, load the library
+ System.load(temp.getAbsolutePath());
+ }
+
+ /**
+ * load native library, this method will first try to load library from java.library.path, then
+ * try to load library in jar package.
+ *
+ * @param libName library path
+ * @throws IOException exception
+ */
+ private static void smartLoad(String libName) throws IOException {
+ addNativeDir(nativePath);
+ try {
+ System.loadLibrary(libName);
+ } catch (UnsatisfiedLinkError e) {
+ try {
+ String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
+ loadLibraryFromJar(libraryFromJar);
+ } catch (IOException e1) {
+ throw e1;
+ }
+ }
+ }
+
+ /**
+ * Add libPath to java.library.path, then native library in libPath would be load properly
+ *
+ * @param libPath library path
+ * @throws IOException exception
+ */
+ private static void addNativeDir(String libPath) throws IOException {
+ try {
+ Field field = ClassLoader.class.getDeclaredField("usr_paths");
+ field.setAccessible(true);
+ String[] paths = (String[]) field.get(null);
+ for (String path : paths) {
+ if (libPath.equals(path)) {
+ return;
+ }
+ }
+ String[] tmp = new String[paths.length + 1];
+ System.arraycopy(paths, 0, tmp, 0, paths.length);
+ tmp[paths.length] = libPath;
+ field.set(null, tmp);
+ } catch (IllegalAccessException e) {
+ logger.error(e.getMessage());
+ throw new IOException("Failed to get permissions to set library path");
+ } catch (NoSuchFieldException e) {
+ logger.error(e.getMessage());
+ throw new IOException("Failed to get field handle to set library path");
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java
new file mode 100644
index 000000000..cea4ae5bf
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java
@@ -0,0 +1,336 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.util.*;
+
+
+/**
+ * trainer for xgboost
+ *
+ * @author hzx
+ */
+public class XGBoost {
+ private static final Log logger = LogFactory.getLog(XGBoost.class);
+
+ /**
+ * Train a booster with given parameters.
+ *
+ * @param params Booster params.
+ * @param dtrain Data to be trained.
+ * @param round Number of boosting iterations.
+ * @param watches a group of items to be evaluated during training, this allows user to watch
+ * performance on the validation set.
+ * @param obj customized objective (set to null if not used)
+ * @param eval customized evaluation (set to null if not used)
+ * @return trained booster
+ * @throws XGBoostError native error
+ */
+ public static Booster train(Map params, DMatrix dtrain, int round,
+ Map watches, IObjective obj,
+ IEvaluation eval) throws XGBoostError {
+
+ //collect eval matrixs
+ String[] evalNames;
+ DMatrix[] evalMats;
+ List names = new ArrayList();
+ List mats = new ArrayList();
+
+ for (Map.Entry evalEntry : watches.entrySet()) {
+ names.add(evalEntry.getKey());
+ mats.add(evalEntry.getValue());
+ }
+
+ evalNames = names.toArray(new String[names.size()]);
+ evalMats = mats.toArray(new DMatrix[mats.size()]);
+
+ //collect all data matrixs
+ DMatrix[] allMats;
+ if (evalMats != null && evalMats.length > 0) {
+ allMats = new DMatrix[evalMats.length + 1];
+ allMats[0] = dtrain;
+ System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
+ } else {
+ allMats = new DMatrix[1];
+ allMats[0] = dtrain;
+ }
+
+ //initialize booster
+ Booster booster = new JavaBoosterImpl(params, allMats);
+
+ //begin to train
+ for (int iter = 0; iter < round; iter++) {
+ if (obj != null) {
+ booster.update(dtrain, obj);
+ } else {
+ booster.update(dtrain, iter);
+ }
+
+ //evaluation
+ if (evalMats != null && evalMats.length > 0) {
+ String evalInfo;
+ if (eval != null) {
+ evalInfo = booster.evalSet(evalMats, evalNames, eval);
+ } else {
+ evalInfo = booster.evalSet(evalMats, evalNames, iter);
+ }
+ logger.info(evalInfo);
+ }
+ }
+ return booster;
+ }
+
+ /**
+ * init Booster from dMatrixs
+ *
+ * @param params parameters
+ * @param dMatrixs DMatrix array
+ * @throws XGBoostError native error
+ */
+ public static Booster initBoostingModel(
+ Map params,
+ DMatrix[] dMatrixs) throws XGBoostError {
+ return new JavaBoosterImpl(params, dMatrixs);
+ }
+
+ /**
+ * load model from modelPath
+ *
+ * @param params parameters
+ * @param modelPath booster modelPath (model generated by booster.saveModel)
+ * @throws XGBoostError native error
+ */
+ public static Booster loadBoostModel(Map params, String modelPath)
+ throws XGBoostError {
+ return new JavaBoosterImpl(params, modelPath);
+ }
+
+ /**
+ * Cross-validation with given paramaters.
+ *
+ * @param params Booster params.
+ * @param data Data to be trained.
+ * @param round Number of boosting iterations.
+ * @param nfold Number of folds in CV.
+ * @param metrics Evaluation metrics to be watched in CV.
+ * @param obj customized objective (set to null if not used)
+ * @param eval customized evaluation (set to null if not used)
+ * @return evaluation history
+ * @throws XGBoostError native error
+ */
+ public static String[] crossValiation(
+ Map params,
+ DMatrix data,
+ int round,
+ int nfold,
+ String[] metrics,
+ IObjective obj,
+ IEvaluation eval) throws XGBoostError {
+ CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
+ String[] evalHist = new String[round];
+ String[] results = new String[cvPacks.length];
+ for (int i = 0; i < round; i++) {
+ for (CVPack cvPack : cvPacks) {
+ if (obj != null) {
+ cvPack.update(obj);
+ } else {
+ cvPack.update(i);
+ }
+ }
+
+ for (int j = 0; j < cvPacks.length; j++) {
+ if (eval != null) {
+ results[j] = cvPacks[j].eval(eval);
+ } else {
+ results[j] = cvPacks[j].eval(i);
+ }
+ }
+
+ evalHist[i] = aggCVResults(results);
+ logger.info(evalHist[i]);
+ }
+ return evalHist;
+ }
+
+ /**
+ * make an n-fold array of CVPack from random indices
+ *
+ * @param data original data
+ * @param nfold num of folds
+ * @param params booster parameters
+ * @param evalMetrics Evaluation metrics
+ * @return CV package array
+ * @throws XGBoostError native error
+ */
+ private static CVPack[] makeNFold(DMatrix data, int nfold, Map params,
+ String[] evalMetrics) throws XGBoostError {
+ List samples = genRandPermutationNums(0, (int) data.rowNum());
+ int step = samples.size() / nfold;
+ int[] testSlice = new int[step];
+ int[] trainSlice = new int[samples.size() - step];
+ int testid, trainid;
+ CVPack[] cvPacks = new CVPack[nfold];
+ for (int i = 0; i < nfold; i++) {
+ testid = 0;
+ trainid = 0;
+ for (int j = 0; j < samples.size(); j++) {
+ if (j > (i * step) && j < (i * step + step) && testid < step) {
+ testSlice[testid] = samples.get(j);
+ testid++;
+ } else {
+ if (trainid < samples.size() - step) {
+ trainSlice[trainid] = samples.get(j);
+ trainid++;
+ } else {
+ testSlice[testid] = samples.get(j);
+ testid++;
+ }
+ }
+ }
+
+ DMatrix dtrain = data.slice(trainSlice);
+ DMatrix dtest = data.slice(testSlice);
+ CVPack cvPack = new CVPack(dtrain, dtest, params);
+ //set eval types
+ if (evalMetrics != null) {
+ for (String type : evalMetrics) {
+ cvPack.booster.setParam("eval_metric", type);
+ }
+ }
+ cvPacks[i] = cvPack;
+ }
+
+ return cvPacks;
+ }
+
+ private static List genRandPermutationNums(int start, int end) {
+ List samples = new ArrayList();
+ for (int i = start; i < end; i++) {
+ samples.add(i);
+ }
+ Collections.shuffle(samples);
+ return samples;
+ }
+
+ /**
+ * Aggregate cross-validation results.
+ *
+ * @param results eval info from each data sample
+ * @return cross-validation eval info
+ */
+ private static String aggCVResults(String[] results) {
+ Map> cvMap = new HashMap>();
+ String aggResult = results[0].split("\t")[0];
+ for (String result : results) {
+ String[] items = result.split("\t");
+ for (int i = 1; i < items.length; i++) {
+ String[] tup = items[i].split(":");
+ String key = tup[0];
+ Float value = Float.valueOf(tup[1]);
+ if (!cvMap.containsKey(key)) {
+ cvMap.put(key, new ArrayList());
+ }
+ cvMap.get(key).add(value);
+ }
+ }
+
+ for (String key : cvMap.keySet()) {
+ float value = 0f;
+ for (Float tvalue : cvMap.get(key)) {
+ value += tvalue;
+ }
+ value /= cvMap.get(key).size();
+ aggResult += String.format("\tcv-%s:%f", key, value);
+ }
+
+ return aggResult;
+ }
+
+ /**
+ * cross validation package for xgb
+ *
+ * @author hzx
+ */
+ private static class CVPack {
+ DMatrix dtrain;
+ DMatrix dtest;
+ DMatrix[] dmats;
+ String[] names;
+ Booster booster;
+
+ /**
+ * create an cross validation package
+ *
+ * @param dtrain train data
+ * @param dtest test data
+ * @param params parameters
+ * @throws XGBoostError native error
+ */
+ public CVPack(DMatrix dtrain, DMatrix dtest, Map params)
+ throws XGBoostError {
+ dmats = new DMatrix[]{dtrain, dtest};
+ booster = XGBoost.initBoostingModel(params, dmats);
+ names = new String[]{"train", "test"};
+ this.dtrain = dtrain;
+ this.dtest = dtest;
+ }
+
+ /**
+ * update one iteration
+ *
+ * @param iter iteration num
+ * @throws XGBoostError native error
+ */
+ public void update(int iter) throws XGBoostError {
+ booster.update(dtrain, iter);
+ }
+
+ /**
+ * update one iteration
+ *
+ * @param obj customized objective
+ * @throws XGBoostError native error
+ */
+ public void update(IObjective obj) throws XGBoostError {
+ booster.update(dtrain, obj);
+ }
+
+ /**
+ * evaluation
+ *
+ * @param iter iteration num
+ * @return evaluation
+ * @throws XGBoostError native error
+ */
+ public String eval(int iter) throws XGBoostError {
+ return booster.evalSet(dmats, names, iter);
+ }
+
+ /**
+ * evaluation
+ *
+ * @param eval customized eval
+ * @return evaluation
+ * @throws XGBoostError native error
+ */
+ public String eval(IEvaluation eval) throws XGBoostError {
+ return booster.evalSet(dmats, names, eval);
+ }
+ }
+}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java
similarity index 75%
rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java
rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java
index dc7a9a0b2..1f62b22fc 100644
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java
@@ -1,10 +1,10 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -13,14 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
-package org.dmlc.xgboost4j.util;
+package ml.dmlc.xgboost4j;
/**
* custom error class for xgboost
+ *
* @author hzx
*/
-public class XGBoostError extends Exception{
- public XGBoostError(String message) {
- super(message);
- }
+public class XGBoostError extends Exception {
+ public XGBoostError(String message) {
+ super(message);
+ }
}
diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
similarity index 77%
rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java
rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
index 11cab988c..10ba1802b 100644
--- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
@@ -13,38 +13,71 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
-package org.dmlc.xgboost4j.wrapper;
+package ml.dmlc.xgboost4j;
/**
* xgboost jni wrapper functions for xgboost_wrapper.h
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
+ *
* @author hzx
*/
-public class XgboostJNI {
+class XgboostJNI {
public final static native String XGBGetLastError();
+
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
- public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
- public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out);
- public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out);
+
+ public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
+ long[] out);
+
+ public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
+ long[] out);
+
+ public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
+ float missing, long[] out);
+
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
+
public final static native int XGDMatrixFree(long handle);
+
public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
+
public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
+
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
+
public final static native int XGDMatrixSetGroup(long handle, int[] group);
+
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
+
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
+
public final static native int XGDMatrixNumRow(long handle, long[] row);
+
public final static native int XGBoosterCreate(long[] handles, long[] out);
+
public final static native int XGBoosterFree(long handle);
+
public final static native int XGBoosterSetParam(long handle, String name, String value);
+
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
- public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
- public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
- public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts);
+
+ public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
+ float[] hess);
+
+ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,
+ String[] evnames, String[] eval_info);
+
+ public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
+ int ntree_limit, float[][] predicts);
+
public final static native int XGBoosterLoadModel(long handle, String fname);
+
public final static native int XGBoosterSaveModel(long handle, String fname);
+
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
+
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
- public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
+
+ public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
+ String[][] out_strings);
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala
new file mode 100644
index 000000000..5d5cd5619
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala
@@ -0,0 +1,189 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import java.io.IOException
+
+import scala.collection.mutable
+
+import ml.dmlc.xgboost4j.XGBoostError
+
+
+trait Booster {
+
+
+ /**
+ * set parameter
+ *
+ * @param key param name
+ * @param value param value
+ */
+ @throws(classOf[XGBoostError])
+ def setParam(key: String, value: String)
+
+ /**
+ * set parameters
+ *
+ * @param params parameters key-value map
+ */
+ @throws(classOf[XGBoostError])
+ def setParams(params: Map[String, AnyRef])
+
+ /**
+ * Update (one iteration)
+ *
+ * @param dtrain training data
+ * @param iter current iteration number
+ */
+ @throws(classOf[XGBoostError])
+ def update(dtrain: DMatrix, iter: Int)
+
+ /**
+ * update with customize obj func
+ *
+ * @param dtrain training data
+ * @param obj customized objective class
+ */
+ @throws(classOf[XGBoostError])
+ def update(dtrain: DMatrix, obj: ObjectiveTrait)
+
+ /**
+ * update with give grad and hess
+ *
+ * @param dtrain training data
+ * @param grad first order of gradient
+ * @param hess seconde order of gradient
+ */
+ @throws(classOf[XGBoostError])
+ def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
+
+ /**
+ * evaluate with given dmatrixs.
+ *
+ * @param evalMatrixs dmatrixs for evaluation
+ * @param evalNames name for eval dmatrixs, used for check results
+ * @param iter current eval iteration
+ * @return eval information
+ */
+ @throws(classOf[XGBoostError])
+ def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
+
+ /**
+ * evaluate with given customized Evaluation class
+ *
+ * @param evalMatrixs evaluation matrix
+ * @param evalNames evaluation names
+ * @param eval custom evaluator
+ * @return eval information
+ */
+ @throws(classOf[XGBoostError])
+ def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @return predict result
+ */
+ @throws(classOf[XGBoostError])
+ def predict(data: DMatrix): Array[Array[Float]]
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @return predict result
+ */
+ @throws(classOf[XGBoostError])
+ def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param outPutMargin Whether to output the raw untransformed margin value.
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @return predict result
+ */
+ @throws(classOf[XGBoostError])
+ def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
+
+ /**
+ * Predict with data
+ *
+ * @param data dmatrix storing the input
+ * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
+ * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
+ * nsample = data.numRow with each record indicating the predicted leaf index of
+ * each sample in each tree. Note that the leaf index of a tree is unique per
+ * tree, so you may find leaf 1 in both tree 1 and tree 0.
+ * @return predict result
+ * @throws XGBoostError native error
+ */
+ @throws(classOf[XGBoostError])
+ def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
+
+ /**
+ * save model to modelPath
+ *
+ * @param modelPath model path
+ */
+ @throws(classOf[XGBoostError])
+ def saveModel(modelPath: String)
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param withStats bool Controls whether the split statistics are output.
+ */
+ @throws(classOf[IOException])
+ @throws(classOf[XGBoostError])
+ def dumpModel(modelPath: String, withStats: Boolean)
+
+ /**
+ * Dump model into a text file.
+ *
+ * @param modelPath file to save dumped model info
+ * @param featureMap featureMap file
+ * @param withStats bool
+ * Controls whether the split statistics are output.
+ */
+ @throws(classOf[IOException])
+ @throws(classOf[XGBoostError])
+ def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
+
+ /**
+ * get importance of each feature
+ *
+ * @return featureMap key: feature index, value: feature importance score
+ */
+ @throws(classOf[XGBoostError])
+ def getFeatureScore: mutable.Map[String, Integer]
+
+ /**
+ * get importance of each feature
+ *
+ * @param featureMap file to save dumped model info
+ * @return featureMap key: feature index, value: feature importance score
+ */
+ @throws(classOf[XGBoostError])
+ def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
+
+ def dispose
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
new file mode 100644
index 000000000..73fafc7f0
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
@@ -0,0 +1,177 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
+
+class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
+
+ /**
+ * init DMatrix from file (svmlight format)
+ *
+ * @param dataPath path of data file
+ * @throws XGBoostError native error
+ */
+ def this(dataPath: String) {
+ this(new JDMatrix(dataPath))
+ }
+
+ /**
+ * create DMatrix from sparse matrix
+ *
+ * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
+ * @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
+ * @param data non zero values (sequence by row for CSR or by col for CSC)
+ * @param st sparse matrix type (CSR or CSC)
+ */
+ @throws(classOf[XGBoostError])
+ def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
+ this(new JDMatrix(headers, indices, data, st))
+ }
+
+ /**
+ * create DMatrix from dense matrix
+ *
+ * @param data data values
+ * @param nrow number of rows
+ * @param ncol number of columns
+ */
+ @throws(classOf[XGBoostError])
+ def this(data: Array[Float], nrow: Int, ncol: Int) {
+ this(new JDMatrix(data, nrow, ncol))
+ }
+
+ /**
+ * set label of dmatrix
+ *
+ * @param labels labels
+ */
+ @throws(classOf[XGBoostError])
+ def setLabel(labels: Array[Float]): Unit = {
+ jDMatrix.setLabel(labels)
+ }
+
+ /**
+ * set weight of each instance
+ *
+ * @param weights weights
+ */
+ @throws(classOf[XGBoostError])
+ def setWeight(weights: Array[Float]): Unit = {
+ jDMatrix.setWeight(weights)
+ }
+
+ /**
+ * if specified, xgboost will start from this init margin
+ * can be used to specify initial prediction to boost from
+ *
+ * @param baseMargin base margin
+ */
+ @throws(classOf[XGBoostError])
+ def setBaseMargin(baseMargin: Array[Float]): Unit = {
+ jDMatrix.setBaseMargin(baseMargin)
+ }
+
+ /**
+ * if specified, xgboost will start from this init margin
+ * can be used to specify initial prediction to boost from
+ *
+ * @param baseMargin base margin
+ */
+ @throws(classOf[XGBoostError])
+ def setBaseMargin(baseMargin: Array[Array[Float]]): Unit = {
+ jDMatrix.setBaseMargin(baseMargin)
+ }
+
+ /**
+ * Set group sizes of DMatrix (used for ranking)
+ *
+ * @param group group size as array
+ */
+ @throws(classOf[XGBoostError])
+ def setGroup(group: Array[Int]): Unit = {
+ jDMatrix.setGroup(group)
+ }
+
+ /**
+ * get label values
+ *
+ * @return label
+ */
+ @throws(classOf[XGBoostError])
+ def getLabel: Array[Float] = {
+ jDMatrix.getLabel
+ }
+
+ /**
+ * get weight of the DMatrix
+ *
+ * @return weights
+ */
+ @throws(classOf[XGBoostError])
+ def getWeight: Array[Float] = {
+ jDMatrix.getWeight
+ }
+
+ /**
+ * get base margin of the DMatrix
+ *
+ * @return base margin
+ */
+ @throws(classOf[XGBoostError])
+ def getBaseMargin: Array[Float] = {
+ jDMatrix.getBaseMargin
+ }
+
+ /**
+ * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
+ *
+ * @param rowIndex row index
+ * @return sliced new DMatrix
+ */
+ @throws(classOf[XGBoostError])
+ def slice(rowIndex: Array[Int]): DMatrix = {
+ new DMatrix(jDMatrix.slice(rowIndex))
+ }
+
+ /**
+ * get the row number of DMatrix
+ *
+ * @return number of rows
+ */
+ @throws(classOf[XGBoostError])
+ def rowNum: Long = {
+ jDMatrix.rowNum
+ }
+
+ /**
+ * save DMatrix to filePath
+ *
+ * @param filePath file path
+ */
+ def saveBinary(filePath: String): Unit = {
+ jDMatrix.saveBinary(filePath)
+ }
+
+ def getHandle: Long = {
+ jDMatrix.getHandle
+ }
+
+ def delete(): Unit = {
+ jDMatrix.dispose()
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala
new file mode 100644
index 000000000..461f515a1
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala
@@ -0,0 +1,38 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import ml.dmlc.xgboost4j.IEvaluation
+
+trait EvalTrait extends IEvaluation {
+
+ /**
+ * get evaluate metric
+ *
+ * @return evalMetric
+ */
+ def getMetric: String
+
+ /**
+ * evaluate with predicts and data
+ *
+ * @param predicts predictions as array
+ * @param dmat data matrix to evaluate
+ * @return result of the metric
+ */
+ def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala
new file mode 100644
index 000000000..c5df8aead
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala
@@ -0,0 +1,30 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import ml.dmlc.xgboost4j.IObjective
+
+trait ObjectiveTrait extends IObjective {
+ /**
+ * user define objective function, return gradient and second order gradient
+ *
+ * @param predicts untransformed margin predicts
+ * @param dtrain training data
+ * @return List with two float array, correspond to first order grad and second order grad
+ */
+ def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala
new file mode 100644
index 000000000..06af4541b
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala
@@ -0,0 +1,100 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import ml.dmlc.xgboost4j.{Booster => JBooster, IEvaluation, IObjective}
+
+private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster {
+
+ override def setParam(key: String, value: String): Unit = {
+ booster.setParam(key, value)
+ }
+
+ override def update(dtrain: DMatrix, iter: Int): Unit = {
+ booster.update(dtrain.jDMatrix, iter)
+ }
+
+ override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
+ booster.update(dtrain.jDMatrix, obj)
+ }
+
+ override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
+ booster.dumpModel(modelPath, withStats)
+ }
+
+ override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
+ booster.dumpModel(modelPath, featureMap, withStats)
+ }
+
+ override def setParams(params: Map[String, AnyRef]): Unit = {
+ booster.setParams(params.asJava)
+ }
+
+ override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
+ booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
+ }
+
+ override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
+ String = {
+ booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
+ }
+
+ override def dispose: Unit = {
+ booster.dispose()
+ }
+
+ override def predict(data: DMatrix): Array[Array[Float]] = {
+ booster.predict(data.jDMatrix)
+ }
+
+ override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
+ booster.predict(data.jDMatrix, outPutMargin)
+ }
+
+ override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
+ Array[Array[Float]] = {
+ booster.predict(data.jDMatrix, outPutMargin, treeLimit)
+ }
+
+ override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
+ booster.predict(data.jDMatrix, treeLimit, predLeaf)
+ }
+
+ override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
+ booster.boost(dtrain.jDMatrix, grad, hess)
+ }
+
+ override def getFeatureScore: mutable.Map[String, Integer] = {
+ booster.getFeatureScore.asScala
+ }
+
+ override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
+ booster.getFeatureScore(featureMap).asScala
+ }
+
+ override def saveModel(modelPath: String): Unit = {
+ booster.saveModel(modelPath)
+ }
+
+ override def finalize(): Unit = {
+ super.finalize()
+ dispose
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
new file mode 100644
index 000000000..737e4765d
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
@@ -0,0 +1,52 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import _root_.scala.collection.JavaConverters._
+import ml.dmlc.xgboost4j.{XGBoost => JXGBoost}
+
+object XGBoost {
+
+ def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
+ watches: Map[String, DMatrix], obj: ObjectiveTrait, eval: EvalTrait): Booster = {
+ val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
+ val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
+ obj, eval)
+ new ScalaBoosterImpl(xgboostInJava)
+ }
+
+ def crossValiation(
+ params: Map[String, AnyRef],
+ data: DMatrix,
+ round: Int,
+ nfold: Int,
+ metrics: Array[String],
+ obj: ObjectiveTrait,
+ eval: EvalTrait): Array[String] = {
+ JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
+ }
+
+ def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
+ val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
+ new ScalaBoosterImpl(xgboostInJava)
+ }
+
+ def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
+ val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
+ new ScalaBoosterImpl(xgboostInJava)
+ }
+}
diff --git a/java/xgboost4j_wrapper.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
similarity index 81%
rename from java/xgboost4j_wrapper.cpp
rename to jvm-packages/xgboost4j/src/native/xgboost4j.cpp
index 865426752..0d976a33f 100644
--- a/java/xgboost4j_wrapper.cpp
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
@@ -13,7 +13,7 @@
*/
#include "xgboost/c_api.h"
-#include "xgboost4j_wrapper.h"
+#include "xgboost4j.h"
#include
//helper functions
@@ -24,7 +24,7 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
}
-JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
+JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ;
const char* result = XGBGetLastError();
@@ -32,7 +32,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastE
return jresult;
}
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@@ -43,11 +43,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
@@ -65,11 +65,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
@@ -89,11 +89,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
@@ -107,11 +107,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
DMatrixHandle result;
DMatrixHandle handle = (DMatrixHandle) jhandle;
@@ -128,11 +128,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlice
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixFree
* Signature: (J)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
int ret = XGDMatrixFree(handle);
@@ -140,11 +140,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@@ -154,11 +154,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@@ -173,11 +173,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@@ -192,11 +192,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
@@ -208,11 +208,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@@ -230,11 +230,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@@ -251,11 +251,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixNumRow
* Signature: (J)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1];
@@ -265,11 +265,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRo
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterCreate
* Signature: ([J)J
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles;
bst_ulong len = 0;
@@ -298,11 +298,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterFree
* Signature: (J)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterFree(handle);
@@ -310,11 +310,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* name = jenv->GetStringUTFChars(jname, 0);
@@ -327,11 +327,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@@ -339,11 +339,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@@ -358,11 +358,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle* dmats = 0;
@@ -406,11 +406,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterPredict
* Signature: (JJIJ)[F
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dmat = (DMatrixHandle) jdmat;
@@ -426,11 +426,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredi
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@@ -441,11 +441,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@@ -457,11 +457,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
BoosterHandle handle = (BoosterHandle) jhandle;
void *buf = (void*) jbuf;
@@ -469,11 +469,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String;
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0;
@@ -488,11 +488,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetMo
}
/*
- * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/
-JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
@@ -510,4 +510,4 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpM
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return ret;
-}
\ No newline at end of file
+}
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h
new file mode 100644
index 000000000..d93da0ee6
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h
@@ -0,0 +1,221 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class ml_dmlc_xgboost4j_XgboostJNI */
+
+#ifndef _Included_ml_dmlc_xgboost4j_XgboostJNI
+#define _Included_ml_dmlc_xgboost4j_XgboostJNI
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBGetLastError
+ * Signature: ()Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
+ (JNIEnv *, jclass);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixCreateFromFile
+ * Signature: (Ljava/lang/String;I[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
+ (JNIEnv *, jclass, jstring, jint, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixCreateFromCSR
+ * Signature: ([J[I[F[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
+ (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixCreateFromCSC
+ * Signature: ([J[I[F[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
+ (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixCreateFromMat
+ * Signature: ([FIIF[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
+ (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixSliceDMatrix
+ * Signature: (J[I[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
+ (JNIEnv *, jclass, jlong, jintArray, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixFree
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
+ (JNIEnv *, jclass, jlong);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixSaveBinary
+ * Signature: (JLjava/lang/String;I)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
+ (JNIEnv *, jclass, jlong, jstring, jint);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixSetFloatInfo
+ * Signature: (JLjava/lang/String;[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
+ (JNIEnv *, jclass, jlong, jstring, jfloatArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixSetUIntInfo
+ * Signature: (JLjava/lang/String;[I)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
+ (JNIEnv *, jclass, jlong, jstring, jintArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixSetGroup
+ * Signature: (J[I)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
+ (JNIEnv *, jclass, jlong, jintArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixGetFloatInfo
+ * Signature: (JLjava/lang/String;[[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
+ (JNIEnv *, jclass, jlong, jstring, jobjectArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixGetUIntInfo
+ * Signature: (JLjava/lang/String;[[I)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
+ (JNIEnv *, jclass, jlong, jstring, jobjectArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGDMatrixNumRow
+ * Signature: (J[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
+ (JNIEnv *, jclass, jlong, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterCreate
+ * Signature: ([J[J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
+ (JNIEnv *, jclass, jlongArray, jlongArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterFree
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
+ (JNIEnv *, jclass, jlong);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterSetParam
+ * Signature: (JLjava/lang/String;Ljava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
+ (JNIEnv *, jclass, jlong, jstring, jstring);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterUpdateOneIter
+ * Signature: (JIJ)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
+ (JNIEnv *, jclass, jlong, jint, jlong);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterBoostOneIter
+ * Signature: (JJ[F[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
+ (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterEvalOneIter
+ * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
+ (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterPredict
+ * Signature: (JJII[[F)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
+ (JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterLoadModel
+ * Signature: (JLjava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
+ (JNIEnv *, jclass, jlong, jstring);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterSaveModel
+ * Signature: (JLjava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
+ (JNIEnv *, jclass, jlong, jstring);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterLoadModelFromBuffer
+ * Signature: (JJJ)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
+ (JNIEnv *, jclass, jlong, jlong, jlong);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterGetModelRaw
+ * Signature: (J[Ljava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
+ (JNIEnv *, jclass, jlong, jobjectArray);
+
+/*
+ * Class: ml_dmlc_xgboost4j_XgboostJNI
+ * Method: XGBoosterDumpModel
+ * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
+ (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
new file mode 100644
index 000000000..e44bc95bc
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
@@ -0,0 +1,138 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import junit.framework.TestCase;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.junit.Test;
+
+import java.util.*;
+
+/**
+ * test cases for Booster
+ *
+ * @author hzx
+ */
+public class BoosterImplTest {
+ public static class EvalError implements IEvaluation {
+ private static final Log logger = LogFactory.getLog(EvalError.class);
+
+ String evalMetric = "custom_error";
+
+ public EvalError() {
+ }
+
+ @Override
+ public String getMetric() {
+ return evalMetric;
+ }
+
+ @Override
+ public float eval(float[][] predicts, DMatrix dmat) {
+ float error = 0f;
+ float[] labels;
+ try {
+ labels = dmat.getLabel();
+ } catch (XGBoostError ex) {
+ logger.error(ex);
+ return -1f;
+ }
+ int nrow = predicts.length;
+ for (int i = 0; i < nrow; i++) {
+ if (labels[i] == 0f && predicts[i][0] > 0) {
+ error++;
+ } else if (labels[i] == 1f && predicts[i][0] <= 0) {
+ error++;
+ }
+ }
+
+ return error / labels.length;
+ }
+ }
+
+ @Test
+ public void testBoosterBasic() throws XGBoostError {
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+
+ //set params
+ Map paramMap = new HashMap() {
+ {
+ put("eta", 1.0);
+ put("max_depth", 2);
+ put("silent", 1);
+ put("objective", "binary:logistic");
+ }
+ };
+
+ //set watchList
+ HashMap watches = new HashMap<>();
+
+ watches.put("train", trainMat);
+ watches.put("test", testMat);
+
+ //set round
+ int round = 2;
+
+ //train a boost model
+ Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
+
+ //predict raw output
+ float[][] predicts = booster.predict(testMat, true);
+
+ //eval
+ IEvaluation eval = new EvalError();
+ //error must be less than 0.1
+ TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
+
+ //test dump model
+
+ }
+
+ /**
+ * test cross valiation
+ *
+ * @throws XGBoostError
+ */
+ @Test
+ public void testCV() throws XGBoostError {
+ //load train mat
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+
+ //set params
+ Map param = new HashMap() {
+ {
+ put("eta", 1.0);
+ put("max_depth", 3);
+ put("silent", 1);
+ put("nthread", 6);
+ put("objective", "binary:logistic");
+ put("gamma", 1.0);
+ put("eval_metric", "error");
+ }
+ };
+
+ //do 5-fold cross validation
+ int round = 2;
+ int nfold = 5;
+ //set additional eval_metrics
+ String[] metrics = null;
+
+ String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
+ null, null);
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java
new file mode 100644
index 000000000..9b3a8b860
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java
@@ -0,0 +1,103 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j;
+
+import junit.framework.TestCase;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * test cases for DMatrix
+ *
+ * @author hzx
+ */
+public class DMatrixTest {
+
+ @Test
+ public void testCreateFromFile() throws XGBoostError {
+ //create DMatrix from file
+ DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
+ //get label
+ float[] labels = dmat.getLabel();
+ //check length
+ TestCase.assertTrue(dmat.rowNum() == labels.length);
+ //set weights
+ float[] weights = Arrays.copyOf(labels, labels.length);
+ dmat.setWeight(weights);
+ float[] dweights = dmat.getWeight();
+ TestCase.assertTrue(Arrays.equals(weights, dweights));
+ }
+
+ @Test
+ public void testCreateFromCSR() throws XGBoostError {
+ //create Matrix from csr format sparse Matrix and labels
+ /**
+ * sparse matrix
+ * 1 0 2 3 0
+ * 4 0 2 3 5
+ * 3 1 2 5 0
+ */
+ float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
+ int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
+ long[] rowHeaders = new long[]{0, 3, 7, 11};
+ DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
+ //check row num
+ System.out.println(dmat1.rowNum());
+ TestCase.assertTrue(dmat1.rowNum() == 3);
+ //test set label
+ float[] label1 = new float[]{1, 0, 1};
+ dmat1.setLabel(label1);
+ float[] label2 = dmat1.getLabel();
+ TestCase.assertTrue(Arrays.equals(label1, label2));
+ }
+
+ @Test
+ public void testCreateFromDenseMatrix() throws XGBoostError {
+ //create DMatrix from 10*5 dense matrix
+ int nrow = 10;
+ int ncol = 5;
+ float[] data0 = new float[nrow * ncol];
+ //put random nums
+ Random random = new Random();
+ for (int i = 0; i < nrow * ncol; i++) {
+ data0[i] = random.nextFloat();
+ }
+
+ //create label
+ float[] label0 = new float[nrow];
+ for (int i = 0; i < nrow; i++) {
+ label0[i] = random.nextFloat();
+ }
+
+ DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
+ dmat0.setLabel(label0);
+
+ //check
+ TestCase.assertTrue(dmat0.rowNum() == 10);
+ TestCase.assertTrue(dmat0.getLabel().length == 10);
+
+ //set weights for each instance
+ float[] weights = new float[nrow];
+ for (int i = 0; i < nrow; i++) {
+ weights[i] = random.nextFloat();
+ }
+ dmat0.setWeight(weights);
+
+ TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
+ }
+}
diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh
index 5795d89ff..bf3a781e7 100755
--- a/tests/travis/run_test.sh
+++ b/tests/travis/run_test.sh
@@ -73,10 +73,9 @@ fi
if [ ${TASK} == "java_test" ]; then
set -e
- make java
- cd java
- ./create_wrap.sh
- cd xgboost4j
+ make jvm-packages
+ cd jvm-packages
+ ./create_jni.sh
mvn clean install -DskipTests=true
mvn test
fi