[DOC-JVM] Refactor JVM docs
This commit is contained in:
parent
79f9fceb6b
commit
c05c5bc7bc
2
NEWS.md
2
NEWS.md
@ -27,6 +27,8 @@ This file records the changes in xgboost library in reverse chronological order.
|
||||
- This could fix some of the previous problem which runs xgboost on multiple threads.
|
||||
* JVM Package
|
||||
- Enable xgboost4j for java and scala
|
||||
- XGBoost distributed now runs on Flink and Spark.
|
||||
|
||||
|
||||
## v0.47 (2016.01.14)
|
||||
|
||||
|
||||
20
doc/jvm/index.md
Normal file
20
doc/jvm/index.md
Normal file
@ -0,0 +1,20 @@
|
||||
XGBoost JVM Package
|
||||
===================
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](../LICENSE)
|
||||
|
||||
You have find XGBoost JVM Package!
|
||||
|
||||
Installation
|
||||
------------
|
||||
To build XGBoost4J contains two steps.
|
||||
- First type the following command to build JNI library.
|
||||
```bash
|
||||
./create_jni.sh
|
||||
```
|
||||
- Then package the libary. you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
|
||||
|
||||
Contents
|
||||
--------
|
||||
* [Java Overview Tutorial](java_intro.md)
|
||||
* [Code Examples](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example)
|
||||
@ -1,23 +1,8 @@
|
||||
xgboost4j : java wrapper for xgboost
|
||||
====
|
||||
XGBoost4J Java API
|
||||
==================
|
||||
This tutorial introduces
|
||||
|
||||
This page will introduce xgboost4j, the java wrapper for xgboost, including:
|
||||
* [Building](#build-xgboost4j)
|
||||
* [Data Interface](#data-interface)
|
||||
* [Setting Parameters](#setting-parameters)
|
||||
* [Train Model](#training-model)
|
||||
* [Prediction](#prediction)
|
||||
|
||||
=
|
||||
#### Build xgboost4j
|
||||
* Build native library
|
||||
first make sure you have installed jdk and `JAVA_HOME` has been setted properly, then simply run `./create_wrap.sh`.
|
||||
|
||||
* Package xgboost4j
|
||||
to package xgboost4j, you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
|
||||
|
||||
=
|
||||
#### Data Interface
|
||||
## Data Interface
|
||||
Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported.
|
||||
|
||||
* To import ```DMatrix``` :
|
||||
@ -30,11 +15,11 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
DMatrix dmat = new DMatrix("train.svm.txt");
|
||||
```
|
||||
|
||||
* To load sparse matrix in CSR/CSC format is a little complicated, the usage is like :
|
||||
suppose a sparse matrix :
|
||||
1 0 2 0
|
||||
4 0 0 3
|
||||
3 1 2 0
|
||||
* To load sparse matrix in CSR/CSC format is a little complicated, the usage is like :
|
||||
suppose a sparse matrix :
|
||||
1 0 2 0
|
||||
4 0 0 3
|
||||
3 1 2 0
|
||||
|
||||
for CSR format
|
||||
```java
|
||||
@ -52,12 +37,12 @@ int[] rowIndex = new int[] {0,1,2,2,0,2,1};
|
||||
DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
|
||||
```
|
||||
|
||||
* To load 3*2 dense matrix, the usage is like :
|
||||
suppose a matrix :
|
||||
1 2
|
||||
3 4
|
||||
5 6
|
||||
|
||||
* To load 3*2 dense matrix, the usage is like :
|
||||
suppose a matrix :
|
||||
1 2
|
||||
3 4
|
||||
5 6
|
||||
|
||||
```java
|
||||
float[] data = new float[] {1f,2f,3f,4f,5f,6f};
|
||||
int nrow = 3;
|
||||
@ -72,7 +57,7 @@ float[] weights = new float[] {1f,2f,1f};
|
||||
dmat.setWeight(weights);
|
||||
```
|
||||
|
||||
#### Setting Parameters
|
||||
## Setting Parameters
|
||||
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
|
||||
|
||||
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
|
||||
@ -100,7 +85,7 @@ List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
|
||||
};
|
||||
```
|
||||
|
||||
#### Training Model
|
||||
## Training Model
|
||||
With parameters and data, you are able to train a booster model.
|
||||
* Import ```Trainer``` and ```Booster``` :
|
||||
```java
|
||||
@ -145,7 +130,7 @@ Params param = new Params() {
|
||||
Booster booster = new Booster(param, "model.bin");
|
||||
```
|
||||
|
||||
####Prediction
|
||||
## Prediction
|
||||
after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees)
|
||||
```java
|
||||
DMatrix dtest = new DMatrix("test.svm.txt");
|
||||
@ -1,33 +1,73 @@
|
||||
# xgboost4j
|
||||
this is a java wrapper for xgboost
|
||||
# XGBoost4J: Distributed XGBoost for Scala/Java
|
||||
[](https://travis-ci.org/dmlc/xgboost)
|
||||
[](../LICENSE)
|
||||
|
||||
the structure of this wrapper is almost the same as the official python wrapper.
|
||||
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||
[Resources](../demo/README.md) |
|
||||
[Release Notes](../NEWS.md)
|
||||
|
||||
core of this wrapper is two classes:
|
||||
XGBoost4J is the JVM package of xgboost. It brings all the optimizations
|
||||
and power xgboost into JVM ecosystem.
|
||||
|
||||
* DMatrix: for handling data
|
||||
- Train XGBoost models on scala and java with easy customizations.
|
||||
- Run distributed xgboost natively on jvm frameworks such as Flink and Spark.
|
||||
|
||||
* Booster: for train and predict
|
||||
You can find more about XGBoost on [Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) and [Resource Page](../demo/README.md).
|
||||
|
||||
## usage:
|
||||
please refer to [xgboost4j.md](doc/xgboost4j.md) for more information.
|
||||
## Hello World
|
||||
### XGBoost Scala
|
||||
```scala
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.XGBoost
|
||||
|
||||
besides, simple examples could be found in [xgboost4j-demo](xgboost4j-demo/README.md)
|
||||
|
||||
object XGBoostScalaExample {
|
||||
def main(args: Array[String]) {
|
||||
// read trainining data, available at xgboost/demo/data
|
||||
val trainData =
|
||||
new DMatrix("/path/to/agaricus.txt.train")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(paramMap, trainData, round)
|
||||
// run prediction
|
||||
val predTrain = model.predict(trainData)
|
||||
// save model to the file.
|
||||
model.saveModel("/local/path/to/model")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## build native library
|
||||
### XGBoost Flink
|
||||
```scala
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||
import org.apache.flink.ml.MLUtils
|
||||
|
||||
for windows: open the xgboost.sln in "../windows" folder, you will found the xgboost4j project, you should do the following steps to build wrapper library:
|
||||
* Select x64/win32 and Release in build
|
||||
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboost4j project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
|
||||
* rebuild all
|
||||
* double click "create_wrap.bat" to set library to proper place
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(paramMap, trainData, round)
|
||||
val predTrain = model.predict(trainData.map{x => x.vector})
|
||||
model.saveModelToHadoop("file:///path/to/xgboost.model")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
for linux:
|
||||
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
|
||||
* run "create_wrap.sh"
|
||||
|
||||
for osx:
|
||||
* make sure you have installed jdk
|
||||
* for single thread xgboost, simply run "create_wrap.sh"
|
||||
* for build with openMP, please refer to [build.md](../doc/build.md) to get openmp supported compiler first, and change the line "dis_omp=1" to "dis_omp=0" in "create_wrap.sh", then run "create_wrap.sh"
|
||||
### XGBoost Spark
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
<packaging>pom</packaging>
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
@ -19,7 +19,7 @@
|
||||
</properties>
|
||||
<modules>
|
||||
<module>xgboost4j</module>
|
||||
<module>xgboost4j-demo</module>
|
||||
<module>xgboost4j-example</module>
|
||||
<module>xgboost4j-spark</module>
|
||||
<module>xgboost4j-flink</module>
|
||||
</modules>
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
xgboost4j examples
|
||||
====
|
||||
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
|
||||
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
|
||||
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
|
||||
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
|
||||
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
|
||||
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
|
||||
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
|
||||
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)
|
||||
18
jvm-packages/xgboost4j-example/README.md
Normal file
18
jvm-packages/xgboost4j-example/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
XGBoost4J Code Examples
|
||||
=======================
|
||||
|
||||
## Java API
|
||||
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
|
||||
* [Cutomize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
|
||||
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
|
||||
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
|
||||
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
|
||||
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||
|
||||
## Spark API
|
||||
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala)
|
||||
|
||||
## Flink API
|
||||
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/flink/example/DistTrainWithFlink.scala)
|
||||
@ -5,11 +5,11 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-demo</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost4j-example</artifactId>
|
||||
<version>0.5</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
@ -26,7 +26,12 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>0.1</version>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink</artifactId>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
@ -24,7 +24,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.DataLoader;
|
||||
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
||||
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
||||
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo.util;
|
||||
package ml.dmlc.xgboost4j.java.example.util;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.demo.util;
|
||||
package ml.dmlc.xgboost4j.java.example.util;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
@ -13,33 +13,29 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.flink.example
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
|
||||
import org.apache.commons.logging.Log
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.api.scala.DataSet
|
||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import org.apache.flink.util.Collector
|
||||
|
||||
|
||||
|
||||
object Test {
|
||||
val log = LogFactory.getLog(this.getClass)
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
// number of iterations
|
||||
val round = 2
|
||||
val model = XGBoost.train(paramMap, data, round)
|
||||
|
||||
|
||||
log.info(model)
|
||||
// train the model
|
||||
val model = XGBoost.train(paramMap, trainData, round)
|
||||
val predTrain = model.predict(trainData.map{x => x.vector})
|
||||
model.saveModelToHadoop("file:///path/to/xgboost.model")
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.demo
|
||||
package ml.dmlc.xgboost4j.scala.spark.example
|
||||
|
||||
import java.io.File
|
||||
|
||||
@ -5,11 +5,11 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink</artifactId>
|
||||
<version>0.1</version>
|
||||
<version>0.5</version>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
|
||||
@ -14,7 +14,8 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
|
||||
@ -35,7 +36,7 @@ object XGBoost {
|
||||
*
|
||||
* @param workerEnvs
|
||||
*/
|
||||
private class MapFunction(paramMap: Map[String, AnyRef],
|
||||
private class MapFunction(paramMap: Map[String, Any],
|
||||
round: Int,
|
||||
workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
||||
@ -69,7 +70,7 @@ object XGBoost {
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModel(modelPath: String) : XGBoostModel = {
|
||||
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(
|
||||
FileSystem
|
||||
@ -84,7 +85,7 @@ object XGBoost {
|
||||
* @param dtrain The training data.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(params: Map[String, AnyRef],
|
||||
def train(params: Map[String, Any],
|
||||
dtrain: DataSet[LabeledVector],
|
||||
round: Int): XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
@ -31,7 +31,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModel(modelPath: String): Unit = {
|
||||
def saveModelToHadoop(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
@ -5,8 +5,8 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<build>
|
||||
@ -24,7 +24,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
||||
@ -5,11 +5,11 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
<version>0.5</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
@ -39,14 +39,20 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
params: Map[String, AnyRef],
|
||||
params: Map[String, Any],
|
||||
dtrain: DMatrix,
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
|
||||
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
dtrain.jDMatrix, round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
@ -65,14 +71,17 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def crossValidation(
|
||||
params: Map[String, AnyRef],
|
||||
params: Map[String, Any],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int = 5,
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user