[DOC-JVM] Refactor JVM docs

This commit is contained in:
tqchen 2016-03-06 20:37:10 -08:00
parent 79f9fceb6b
commit c05c5bc7bc
27 changed files with 194 additions and 128 deletions

View File

@ -27,6 +27,8 @@ This file records the changes in xgboost library in reverse chronological order.
- This could fix some of the previous problem which runs xgboost on multiple threads. - This could fix some of the previous problem which runs xgboost on multiple threads.
* JVM Package * JVM Package
- Enable xgboost4j for java and scala - Enable xgboost4j for java and scala
- XGBoost distributed now runs on Flink and Spark.
## v0.47 (2016.01.14) ## v0.47 (2016.01.14)

20
doc/jvm/index.md Normal file
View File

@ -0,0 +1,20 @@
XGBoost JVM Package
===================
[![Build Status](https://travis-ci.org/dmlc/xgboost.svg?branch=master)](https://travis-ci.org/dmlc/xgboost)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](../LICENSE)
You have find XGBoost JVM Package!
Installation
------------
To build XGBoost4J contains two steps.
- First type the following command to build JNI library.
```bash
./create_jni.sh
```
- Then package the libary. you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
Contents
--------
* [Java Overview Tutorial](java_intro.md)
* [Code Examples](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example)

View File

@ -1,23 +1,8 @@
xgboost4j : java wrapper for xgboost XGBoost4J Java API
==== ==================
This tutorial introduces
This page will introduce xgboost4j, the java wrapper for xgboost, including: ## Data Interface
* [Building](#build-xgboost4j)
* [Data Interface](#data-interface)
* [Setting Parameters](#setting-parameters)
* [Train Model](#training-model)
* [Prediction](#prediction)
=
#### Build xgboost4j
* Build native library
first make sure you have installed jdk and `JAVA_HOME` has been setted properly, then simply run `./create_wrap.sh`.
* Package xgboost4j
to package xgboost4j, you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
=
#### Data Interface
Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported. Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported.
* To import ```DMatrix``` : * To import ```DMatrix``` :
@ -30,11 +15,11 @@ import org.dmlc.xgboost4j.DMatrix;
DMatrix dmat = new DMatrix("train.svm.txt"); DMatrix dmat = new DMatrix("train.svm.txt");
``` ```
* To load sparse matrix in CSR/CSC format is a little complicated, the usage is like : * To load sparse matrix in CSR/CSC format is a little complicated, the usage is like :
suppose a sparse matrix : suppose a sparse matrix :
1 0 2 0 1 0 2 0
4 0 0 3 4 0 0 3
3 1 2 0 3 1 2 0
for CSR format for CSR format
```java ```java
@ -52,12 +37,12 @@ int[] rowIndex = new int[] {0,1,2,2,0,2,1};
DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC); DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
``` ```
* To load 3*2 dense matrix, the usage is like : * To load 3*2 dense matrix, the usage is like :
suppose a matrix : suppose a matrix :
1 2 1 2
3 4 3 4
5 6 5 6
```java ```java
float[] data = new float[] {1f,2f,3f,4f,5f,6f}; float[] data = new float[] {1f,2f,3f,4f,5f,6f};
int nrow = 3; int nrow = 3;
@ -72,7 +57,7 @@ float[] weights = new float[] {1f,2f,1f};
dmat.setWeight(weights); dmat.setWeight(weights);
``` ```
#### Setting Parameters ## Setting Parameters
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters. * in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map: * to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
@ -100,7 +85,7 @@ List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
}; };
``` ```
#### Training Model ## Training Model
With parameters and data, you are able to train a booster model. With parameters and data, you are able to train a booster model.
* Import ```Trainer``` and ```Booster``` : * Import ```Trainer``` and ```Booster``` :
```java ```java
@ -145,7 +130,7 @@ Params param = new Params() {
Booster booster = new Booster(param, "model.bin"); Booster booster = new Booster(param, "model.bin");
``` ```
####Prediction ## Prediction
after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees) after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees)
```java ```java
DMatrix dtest = new DMatrix("test.svm.txt"); DMatrix dtest = new DMatrix("test.svm.txt");

View File

@ -1,33 +1,73 @@
# xgboost4j # XGBoost4J: Distributed XGBoost for Scala/Java
this is a java wrapper for xgboost [![Build Status](https://travis-ci.org/dmlc/xgboost.svg?branch=master)](https://travis-ci.org/dmlc/xgboost)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](../LICENSE)
the structure of this wrapper is almost the same as the official python wrapper. [Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
[Resources](../demo/README.md) |
[Release Notes](../NEWS.md)
core of this wrapper is two classes: XGBoost4J is the JVM package of xgboost. It brings all the optimizations
and power xgboost into JVM ecosystem.
* DMatrix: for handling data - Train XGBoost models on scala and java with easy customizations.
- Run distributed xgboost natively on jvm frameworks such as Flink and Spark.
* Booster: for train and predict You can find more about XGBoost on [Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) and [Resource Page](../demo/README.md).
## usage: ## Hello World
please refer to [xgboost4j.md](doc/xgboost4j.md) for more information. ### XGBoost Scala
```scala
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
besides, simple examples could be found in [xgboost4j-demo](xgboost4j-demo/README.md) object XGBoostScalaExample {
def main(args: Array[String]) {
// read trainining data, available at xgboost/demo/data
val trainData =
new DMatrix("/path/to/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(paramMap, trainData, round)
// run prediction
val predTrain = model.predict(trainData)
// save model to the file.
model.saveModel("/local/path/to/model")
}
}
```
## build native library ### XGBoost Flink
```scala
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.MLUtils
for windows: open the xgboost.sln in "../windows" folder, you will found the xgboost4j project, you should do the following steps to build wrapper library: object DistTrainWithFlink {
* Select x64/win32 and Release in build def main(args: Array[String]) {
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboost4j project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path. val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
* rebuild all // read trainining data
* double click "create_wrap.bat" to set library to proper place val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(paramMap, trainData, round)
val predTrain = model.predict(trainData.map{x => x.vector})
model.saveModelToHadoop("file:///path/to/xgboost.model")
}
}
```
for linux: ### XGBoost Spark
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
* run "create_wrap.sh"
for osx:
* make sure you have installed jdk
* for single thread xgboost, simply run "create_wrap.sh"
* for build with openMP, please refer to [build.md](../doc/build.md) to get openmp supported compiler first, and change the line "dis_omp=1" to "dis_omp=0" in "create_wrap.sh", then run "create_wrap.sh"

View File

@ -5,8 +5,8 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboost-jvm</artifactId>
<version>0.1</version> <version>0.5</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@ -19,7 +19,7 @@
</properties> </properties>
<modules> <modules>
<module>xgboost4j</module> <module>xgboost4j</module>
<module>xgboost4j-demo</module> <module>xgboost4j-example</module>
<module>xgboost4j-spark</module> <module>xgboost4j-spark</module>
<module>xgboost4j-flink</module> <module>xgboost4j-flink</module>
</modules> </modules>

View File

@ -1,10 +0,0 @@
xgboost4j examples
====
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)

View File

@ -0,0 +1,18 @@
XGBoost4J Code Examples
=======================
## Java API
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
* [Cutomize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
## Spark API
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala)
## Flink API
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/flink/example/DistTrainWithFlink.scala)

View File

@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboost-jvm</artifactId>
<version>0.1</version> <version>0.5</version>
</parent> </parent>
<artifactId>xgboost4j-demo</artifactId> <artifactId>xgboost4j-example</artifactId>
<version>0.1</version> <version>0.5</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<build> <build>
<plugins> <plugins>
@ -26,7 +26,12 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId> <artifactId>xgboost4j-spark</artifactId>
<version>0.1</version> <version>0.5</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink</artifactId>
<version>0.5</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -24,7 +24,7 @@ import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError; import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.DataLoader; import ml.dmlc.xgboost4j.java.example.util.DataLoader;
/** /**
* a simple example of java wrapper for xgboost * a simple example of java wrapper for xgboost

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.HashMap; import java.util.HashMap;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.HashMap; import java.util.HashMap;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.HashMap; import java.util.HashMap;
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError; import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.CustomEval; import ml.dmlc.xgboost4j.java.example.util.CustomEval;
/** /**
* this is an example of fit generalized linear model in xgboost * this is an example of fit generalized linear model in xgboost

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.HashMap; import java.util.HashMap;
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError; import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.CustomEval; import ml.dmlc.xgboost4j.java.example.util.CustomEval;
/** /**
* predict first ntree * predict first ntree

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo; package ml.dmlc.xgboost4j.java.example;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo.util; package ml.dmlc.xgboost4j.java.example.util;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.java.demo.util; package ml.dmlc.xgboost4j.java.example.util;
import java.io.*; import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -13,33 +13,29 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.scala.flink.example
package ml.dmlc.xgboost4j.flink import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala._ import org.apache.flink.api.scala._
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala.ExecutionEnvironment import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.MLUtils import org.apache.flink.ml.MLUtils
import org.apache.flink.util.Collector
object DistTrainWithFlink {
object Test {
val log = LogFactory.getLog(this.getClass)
def main(args: Array[String]) { def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train") // read trainining data
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap "objective" -> "binary:logistic").toMap
// number of iterations
val round = 2 val round = 2
val model = XGBoost.train(paramMap, data, round) // train the model
val model = XGBoost.train(paramMap, trainData, round)
val predTrain = model.predict(trainData.map{x => x.vector})
log.info(model) model.saveModelToHadoop("file:///path/to/xgboost.model")
} }
} }

View File

@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.scala.spark.demo package ml.dmlc.xgboost4j.scala.spark.example
import java.io.File import java.io.File

View File

@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboost-jvm</artifactId>
<version>0.1</version> <version>0.5</version>
</parent> </parent>
<artifactId>xgboost4j-flink</artifactId> <artifactId>xgboost4j-flink</artifactId>
<version>0.1</version> <version>0.5</version>
<build> <build>
<plugins> <plugins>
<plugin> <plugin>
@ -26,7 +26,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j</artifactId>
<version>0.1</version> <version>0.5</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>

View File

@ -14,7 +14,8 @@
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.flink package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter; import scala.collection.JavaConverters.asScalaIteratorConverter;
import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit} import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
@ -35,7 +36,7 @@ object XGBoost {
* *
* @param workerEnvs * @param workerEnvs
*/ */
private class MapFunction(paramMap: Map[String, AnyRef], private class MapFunction(paramMap: Map[String, Any],
round: Int, round: Int,
workerEnvs: java.util.Map[String, String]) workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] { extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
@ -69,7 +70,7 @@ object XGBoost {
* @param modelPath The path that is accessible by hadoop filesystem API. * @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model * @return The loaded model
*/ */
def loadModel(modelPath: String) : XGBoostModel = { def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
new XGBoostModel( new XGBoostModel(
XGBoostScala.loadModel( XGBoostScala.loadModel(
FileSystem FileSystem
@ -84,7 +85,7 @@ object XGBoost {
* @param dtrain The training data. * @param dtrain The training data.
* @param round Number of rounds to train. * @param round Number of rounds to train.
*/ */
def train(params: Map[String, AnyRef], def train(params: Map[String, Any],
dtrain: DataSet[LabeledVector], dtrain: DataSet[LabeledVector],
round: Int): XGBoostModel = { round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism) val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)

View File

@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
package ml.dmlc.xgboost4j.flink package ml.dmlc.xgboost4j.scala.flink
import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
@ -31,7 +31,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
* *
* @param modelPath The model path as in Hadoop path. * @param modelPath The model path as in Hadoop path.
*/ */
def saveModel(modelPath: String): Unit = { def saveModelToHadoop(modelPath: String): Unit = {
booster.saveModel(FileSystem booster.saveModel(FileSystem
.get(new Configuration) .get(new Configuration)
.create(new Path(modelPath))) .create(new Path(modelPath)))

View File

@ -5,8 +5,8 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboost-jvm</artifactId>
<version>0.1</version> <version>0.5</version>
</parent> </parent>
<artifactId>xgboost4j-spark</artifactId> <artifactId>xgboost4j-spark</artifactId>
<build> <build>
@ -24,7 +24,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j</artifactId>
<version>0.1</version> <version>0.5</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>

View File

@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboost-jvm</artifactId>
<version>0.1</version> <version>0.5</version>
</parent> </parent>
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j</artifactId>
<version>0.1</version> <version>0.5</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<build> <build>
<plugins> <plugins>

View File

@ -39,14 +39,20 @@ object XGBoost {
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def train( def train(
params: Map[String, AnyRef], params: Map[String, Any],
dtrain: DMatrix, dtrain: DMatrix,
round: Int, round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](), watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
eval: EvalTrait = null): Booster = { eval: EvalTrait = null): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, val xgboostInJava = JXGBoost.train(
params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
dtrain.jDMatrix, round, jWatches.asJava,
obj, eval) obj, eval)
new Booster(xgboostInJava) new Booster(xgboostInJava)
} }
@ -65,14 +71,17 @@ object XGBoost {
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def crossValidation( def crossValidation(
params: Map[String, AnyRef], params: Map[String, Any],
data: DMatrix, data: DMatrix,
round: Int, round: Int,
nfold: Int = 5, nfold: Int = 5,
metrics: Array[String] = null, metrics: Array[String] = null,
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = { eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval) JXGBoost.crossValidation(params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
data.jDMatrix, round, nfold, metrics, obj, eval)
} }
/** /**