[DOC-JVM] Refactor JVM docs
This commit is contained in:
parent
79f9fceb6b
commit
c05c5bc7bc
2
NEWS.md
2
NEWS.md
@ -27,6 +27,8 @@ This file records the changes in xgboost library in reverse chronological order.
|
|||||||
- This could fix some of the previous problem which runs xgboost on multiple threads.
|
- This could fix some of the previous problem which runs xgboost on multiple threads.
|
||||||
* JVM Package
|
* JVM Package
|
||||||
- Enable xgboost4j for java and scala
|
- Enable xgboost4j for java and scala
|
||||||
|
- XGBoost distributed now runs on Flink and Spark.
|
||||||
|
|
||||||
|
|
||||||
## v0.47 (2016.01.14)
|
## v0.47 (2016.01.14)
|
||||||
|
|
||||||
|
|||||||
20
doc/jvm/index.md
Normal file
20
doc/jvm/index.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
XGBoost JVM Package
|
||||||
|
===================
|
||||||
|
[](https://travis-ci.org/dmlc/xgboost)
|
||||||
|
[](../LICENSE)
|
||||||
|
|
||||||
|
You have find XGBoost JVM Package!
|
||||||
|
|
||||||
|
Installation
|
||||||
|
------------
|
||||||
|
To build XGBoost4J contains two steps.
|
||||||
|
- First type the following command to build JNI library.
|
||||||
|
```bash
|
||||||
|
./create_jni.sh
|
||||||
|
```
|
||||||
|
- Then package the libary. you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
|
||||||
|
|
||||||
|
Contents
|
||||||
|
--------
|
||||||
|
* [Java Overview Tutorial](java_intro.md)
|
||||||
|
* [Code Examples](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example)
|
||||||
@ -1,23 +1,8 @@
|
|||||||
xgboost4j : java wrapper for xgboost
|
XGBoost4J Java API
|
||||||
====
|
==================
|
||||||
|
This tutorial introduces
|
||||||
|
|
||||||
This page will introduce xgboost4j, the java wrapper for xgboost, including:
|
## Data Interface
|
||||||
* [Building](#build-xgboost4j)
|
|
||||||
* [Data Interface](#data-interface)
|
|
||||||
* [Setting Parameters](#setting-parameters)
|
|
||||||
* [Train Model](#training-model)
|
|
||||||
* [Prediction](#prediction)
|
|
||||||
|
|
||||||
=
|
|
||||||
#### Build xgboost4j
|
|
||||||
* Build native library
|
|
||||||
first make sure you have installed jdk and `JAVA_HOME` has been setted properly, then simply run `./create_wrap.sh`.
|
|
||||||
|
|
||||||
* Package xgboost4j
|
|
||||||
to package xgboost4j, you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
|
|
||||||
|
|
||||||
=
|
|
||||||
#### Data Interface
|
|
||||||
Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported.
|
Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported.
|
||||||
|
|
||||||
* To import ```DMatrix``` :
|
* To import ```DMatrix``` :
|
||||||
@ -72,7 +57,7 @@ float[] weights = new float[] {1f,2f,1f};
|
|||||||
dmat.setWeight(weights);
|
dmat.setWeight(weights);
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Setting Parameters
|
## Setting Parameters
|
||||||
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
|
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
|
||||||
|
|
||||||
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
|
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
|
||||||
@ -100,7 +85,7 @@ List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
|
|||||||
};
|
};
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Training Model
|
## Training Model
|
||||||
With parameters and data, you are able to train a booster model.
|
With parameters and data, you are able to train a booster model.
|
||||||
* Import ```Trainer``` and ```Booster``` :
|
* Import ```Trainer``` and ```Booster``` :
|
||||||
```java
|
```java
|
||||||
@ -145,7 +130,7 @@ Params param = new Params() {
|
|||||||
Booster booster = new Booster(param, "model.bin");
|
Booster booster = new Booster(param, "model.bin");
|
||||||
```
|
```
|
||||||
|
|
||||||
####Prediction
|
## Prediction
|
||||||
after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees)
|
after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees)
|
||||||
```java
|
```java
|
||||||
DMatrix dtest = new DMatrix("test.svm.txt");
|
DMatrix dtest = new DMatrix("test.svm.txt");
|
||||||
@ -1,33 +1,73 @@
|
|||||||
# xgboost4j
|
# XGBoost4J: Distributed XGBoost for Scala/Java
|
||||||
this is a java wrapper for xgboost
|
[](https://travis-ci.org/dmlc/xgboost)
|
||||||
|
[](../LICENSE)
|
||||||
|
|
||||||
the structure of this wrapper is almost the same as the official python wrapper.
|
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||||
|
[Resources](../demo/README.md) |
|
||||||
|
[Release Notes](../NEWS.md)
|
||||||
|
|
||||||
core of this wrapper is two classes:
|
XGBoost4J is the JVM package of xgboost. It brings all the optimizations
|
||||||
|
and power xgboost into JVM ecosystem.
|
||||||
|
|
||||||
* DMatrix: for handling data
|
- Train XGBoost models on scala and java with easy customizations.
|
||||||
|
- Run distributed xgboost natively on jvm frameworks such as Flink and Spark.
|
||||||
|
|
||||||
* Booster: for train and predict
|
You can find more about XGBoost on [Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) and [Resource Page](../demo/README.md).
|
||||||
|
|
||||||
## usage:
|
## Hello World
|
||||||
please refer to [xgboost4j.md](doc/xgboost4j.md) for more information.
|
### XGBoost Scala
|
||||||
|
```scala
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.XGBoost
|
||||||
|
|
||||||
besides, simple examples could be found in [xgboost4j-demo](xgboost4j-demo/README.md)
|
object XGBoostScalaExample {
|
||||||
|
def main(args: Array[String]) {
|
||||||
|
// read trainining data, available at xgboost/demo/data
|
||||||
|
val trainData =
|
||||||
|
new DMatrix("/path/to/agaricus.txt.train")
|
||||||
|
// define parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1,
|
||||||
|
"max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
// number of iterations
|
||||||
|
val round = 2
|
||||||
|
// train the model
|
||||||
|
val model = XGBoost.train(paramMap, trainData, round)
|
||||||
|
// run prediction
|
||||||
|
val predTrain = model.predict(trainData)
|
||||||
|
// save model to the file.
|
||||||
|
model.saveModel("/local/path/to/model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### XGBoost Flink
|
||||||
|
```scala
|
||||||
|
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||||
|
import org.apache.flink.api.scala._
|
||||||
|
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||||
|
import org.apache.flink.ml.MLUtils
|
||||||
|
|
||||||
## build native library
|
object DistTrainWithFlink {
|
||||||
|
def main(args: Array[String]) {
|
||||||
|
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||||
|
// read trainining data
|
||||||
|
val trainData =
|
||||||
|
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||||
|
// define parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1,
|
||||||
|
"max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
// number of iterations
|
||||||
|
val round = 2
|
||||||
|
// train the model
|
||||||
|
val model = XGBoost.train(paramMap, trainData, round)
|
||||||
|
val predTrain = model.predict(trainData.map{x => x.vector})
|
||||||
|
model.saveModelToHadoop("file:///path/to/xgboost.model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
for windows: open the xgboost.sln in "../windows" folder, you will found the xgboost4j project, you should do the following steps to build wrapper library:
|
### XGBoost Spark
|
||||||
* Select x64/win32 and Release in build
|
|
||||||
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboost4j project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
|
|
||||||
* rebuild all
|
|
||||||
* double click "create_wrap.bat" to set library to proper place
|
|
||||||
|
|
||||||
for linux:
|
|
||||||
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
|
|
||||||
* run "create_wrap.sh"
|
|
||||||
|
|
||||||
for osx:
|
|
||||||
* make sure you have installed jdk
|
|
||||||
* for single thread xgboost, simply run "create_wrap.sh"
|
|
||||||
* for build with openMP, please refer to [build.md](../doc/build.md) to get openmp supported compiler first, and change the line "dis_omp=1" to "dis_omp=0" in "create_wrap.sh", then run "create_wrap.sh"
|
|
||||||
|
|||||||
@ -5,8 +5,8 @@
|
|||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboost-jvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
@ -19,7 +19,7 @@
|
|||||||
</properties>
|
</properties>
|
||||||
<modules>
|
<modules>
|
||||||
<module>xgboost4j</module>
|
<module>xgboost4j</module>
|
||||||
<module>xgboost4j-demo</module>
|
<module>xgboost4j-example</module>
|
||||||
<module>xgboost4j-spark</module>
|
<module>xgboost4j-spark</module>
|
||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
xgboost4j examples
|
|
||||||
====
|
|
||||||
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
|
|
||||||
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
|
|
||||||
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
|
|
||||||
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
|
|
||||||
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
|
|
||||||
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
|
|
||||||
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
|
|
||||||
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)
|
|
||||||
18
jvm-packages/xgboost4j-example/README.md
Normal file
18
jvm-packages/xgboost4j-example/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
XGBoost4J Code Examples
|
||||||
|
=======================
|
||||||
|
|
||||||
|
## Java API
|
||||||
|
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
|
||||||
|
* [Cutomize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
|
||||||
|
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
|
||||||
|
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
|
||||||
|
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
|
||||||
|
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||||
|
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||||
|
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||||
|
|
||||||
|
## Spark API
|
||||||
|
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala)
|
||||||
|
|
||||||
|
## Flink API
|
||||||
|
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/flink/example/DistTrainWithFlink.scala)
|
||||||
@ -5,11 +5,11 @@
|
|||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboost-jvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-demo</artifactId>
|
<artifactId>xgboost4j-example</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
@ -26,7 +26,12 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-spark</artifactId>
|
<artifactId>xgboost4j-spark</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ml.dmlc</groupId>
|
||||||
|
<artifactId>xgboost4j-flink</artifactId>
|
||||||
|
<version>0.5</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@ -24,7 +24,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
|||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
import ml.dmlc.xgboost4j.java.demo.util.DataLoader;
|
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* a simple example of java wrapper for xgboost
|
* a simple example of java wrapper for xgboost
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
|||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this is an example of fit generalized linear model in xgboost
|
* this is an example of fit generalized linear model in xgboost
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
|||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict first ntree
|
* predict first ntree
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo;
|
package ml.dmlc.xgboost4j.java.example;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo.util;
|
package ml.dmlc.xgboost4j.java.example.util;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java.demo.util;
|
package ml.dmlc.xgboost4j.java.example.util;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@ -13,33 +13,29 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.flink.example
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.flink
|
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||||
|
|
||||||
import org.apache.commons.logging.Log
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
|
||||||
import org.apache.flink.api.scala._
|
import org.apache.flink.api.scala._
|
||||||
import org.apache.flink.api.scala.DataSet
|
|
||||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||||
import org.apache.flink.ml.common.LabeledVector
|
|
||||||
import org.apache.flink.ml.MLUtils
|
import org.apache.flink.ml.MLUtils
|
||||||
import org.apache.flink.util.Collector
|
|
||||||
|
|
||||||
|
object DistTrainWithFlink {
|
||||||
|
|
||||||
object Test {
|
|
||||||
val log = LogFactory.getLog(this.getClass)
|
|
||||||
def main(args: Array[String]) {
|
def main(args: Array[String]) {
|
||||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||||
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
// read trainining data
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val trainData =
|
||||||
|
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||||
|
// define parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1,
|
||||||
|
"max_depth" -> 2,
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
|
// number of iterations
|
||||||
val round = 2
|
val round = 2
|
||||||
val model = XGBoost.train(paramMap, data, round)
|
// train the model
|
||||||
|
val model = XGBoost.train(paramMap, trainData, round)
|
||||||
|
val predTrain = model.predict(trainData.map{x => x.vector})
|
||||||
log.info(model)
|
model.saveModelToHadoop("file:///path/to/xgboost.model")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.demo
|
package ml.dmlc.xgboost4j.scala.spark.example
|
||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
@ -5,11 +5,11 @@
|
|||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboost-jvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-flink</artifactId>
|
<artifactId>xgboost4j-flink</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
@ -26,7 +26,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
|
|||||||
@ -14,7 +14,8 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.flink
|
package ml.dmlc.xgboost4j.scala.flink
|
||||||
|
|
||||||
import scala.collection.JavaConverters.asScalaIteratorConverter;
|
import scala.collection.JavaConverters.asScalaIteratorConverter;
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
|
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
|
||||||
@ -35,7 +36,7 @@ object XGBoost {
|
|||||||
*
|
*
|
||||||
* @param workerEnvs
|
* @param workerEnvs
|
||||||
*/
|
*/
|
||||||
private class MapFunction(paramMap: Map[String, AnyRef],
|
private class MapFunction(paramMap: Map[String, Any],
|
||||||
round: Int,
|
round: Int,
|
||||||
workerEnvs: java.util.Map[String, String])
|
workerEnvs: java.util.Map[String, String])
|
||||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
||||||
@ -69,7 +70,7 @@ object XGBoost {
|
|||||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||||
* @return The loaded model
|
* @return The loaded model
|
||||||
*/
|
*/
|
||||||
def loadModel(modelPath: String) : XGBoostModel = {
|
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||||
new XGBoostModel(
|
new XGBoostModel(
|
||||||
XGBoostScala.loadModel(
|
XGBoostScala.loadModel(
|
||||||
FileSystem
|
FileSystem
|
||||||
@ -84,7 +85,7 @@ object XGBoost {
|
|||||||
* @param dtrain The training data.
|
* @param dtrain The training data.
|
||||||
* @param round Number of rounds to train.
|
* @param round Number of rounds to train.
|
||||||
*/
|
*/
|
||||||
def train(params: Map[String, AnyRef],
|
def train(params: Map[String, Any],
|
||||||
dtrain: DataSet[LabeledVector],
|
dtrain: DataSet[LabeledVector],
|
||||||
round: Int): XGBoostModel = {
|
round: Int): XGBoostModel = {
|
||||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||||
@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.flink
|
package ml.dmlc.xgboost4j.scala.flink
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||||
@ -31,7 +31,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
|
|||||||
*
|
*
|
||||||
* @param modelPath The model path as in Hadoop path.
|
* @param modelPath The model path as in Hadoop path.
|
||||||
*/
|
*/
|
||||||
def saveModel(modelPath: String): Unit = {
|
def saveModelToHadoop(modelPath: String): Unit = {
|
||||||
booster.saveModel(FileSystem
|
booster.saveModel(FileSystem
|
||||||
.get(new Configuration)
|
.get(new Configuration)
|
||||||
.create(new Path(modelPath)))
|
.create(new Path(modelPath)))
|
||||||
@ -5,8 +5,8 @@
|
|||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboost-jvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-spark</artifactId>
|
<artifactId>xgboost4j-spark</artifactId>
|
||||||
<build>
|
<build>
|
||||||
@ -24,7 +24,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
|
|||||||
@ -5,11 +5,11 @@
|
|||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboost-jvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.5</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|||||||
@ -39,14 +39,20 @@ object XGBoost {
|
|||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def train(
|
def train(
|
||||||
params: Map[String, AnyRef],
|
params: Map[String, Any],
|
||||||
dtrain: DMatrix,
|
dtrain: DMatrix,
|
||||||
round: Int,
|
round: Int,
|
||||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null): Booster = {
|
eval: EvalTrait = null): Booster = {
|
||||||
|
|
||||||
|
|
||||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
val xgboostInJava = JXGBoost.train(
|
||||||
|
params.map{
|
||||||
|
case (key: String, value) => (key, value.toString)
|
||||||
|
}.toMap[String, AnyRef].asJava,
|
||||||
|
dtrain.jDMatrix, round, jWatches.asJava,
|
||||||
obj, eval)
|
obj, eval)
|
||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
@ -65,14 +71,17 @@ object XGBoost {
|
|||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def crossValidation(
|
def crossValidation(
|
||||||
params: Map[String, AnyRef],
|
params: Map[String, Any],
|
||||||
data: DMatrix,
|
data: DMatrix,
|
||||||
round: Int,
|
round: Int,
|
||||||
nfold: Int = 5,
|
nfold: Int = 5,
|
||||||
metrics: Array[String] = null,
|
metrics: Array[String] = null,
|
||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null): Array[String] = {
|
eval: EvalTrait = null): Array[String] = {
|
||||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
JXGBoost.crossValidation(params.map{
|
||||||
|
case (key: String, value) => (key, value.toString)
|
||||||
|
}.toMap[String, AnyRef].asJava,
|
||||||
|
data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user