Merge pull request #969 from tqchen/master

JVM API Update
This commit is contained in:
Tianqi Chen 2016-03-11 12:36:27 -08:00
commit cbabaeba0c
28 changed files with 188 additions and 189 deletions

View File

@ -1,92 +1,76 @@
---
layout: post
title: "XGBoost4J Package Released"
date: 2016-03-15 12:00:00
author: Nan Zhu Tianqi Chen
categories: rstats
title: XGBoost4J: Portable Distributed Tree Boosting in DataFlow
date: 2016-03-15 12:00:00
author: Nan Zhu, Tianqi Chen
comments: true
---
#Introduction#
[XGBoost](https://github.com/dmlc/xgboost) is a library designed and optimized for boosting trees algorithms. Gradient boosting trees model is originally proposed by Friedman et al. By embracing multi-threads and introducing regularization, XGBoost delivers higher computational power and more accurate prediction. We have witnessed that the more than half of the winning solutions in machine learning challenges hosted at Kaggle adopt XGBoost ([Incomplete list](https://github.com/dmlc/xgboost/tree/master/demo#machine-learning-challenge-winning-solutions)). Until three weeks ago, XGBoost has provided C++, R, python, Julia and Java interfaces to the various target user groups.
## Introduction
[XGBoost](https://github.com/dmlc/xgboost) is a library designed and optimized for tree boosting. Gradient boosting trees model is originally proposed by Friedman et al. By embracing multi-threads and introducing regularization, XGBoost delivers higher computational power and more accurate prediction. **More than half of the winning solutions in machine learning challenges** hosted at Kaggle adopt XGBoost ([Incomplete list](https://github.com/dmlc/xgboost/tree/master/demo#machine-learning-challenge-winning-solutions)).
XGBoost has provided native interfaces for C++, R, python, Julia and Java users.
It is used by both data exploration and [production pipeline](https://github.com/dmlc/xgboost/tree/master/demo#usecases) to solve real world machine learning problems.
We started the project of [xgboost4j](https://github.com/dmlc/xgboost/tree/master/jvm-packages) (XGBoost for JVM) three weeks ago, including the new design/implementation of Java/Scala interface and the integration with the dataflow frameworks. Today, we are happy to announce the availability of the first version of XGBoost4J. In this post, we would like to have a brief introduction to this new package of XGBoost.
The distributed XGBoost is described in the [recently published paper](http://arxiv.org/abs/1603.02754).
In short, the XGBoost system runs magnitudes faster than existing alternatives of distributed ML,
and uses far fewer resources. The reader is more than welcomed to refer to the paper for more details.
#Motivation#
Despite the great success, one of our goal is to make XGBoost even more available for all production scenario.
Programming languages and data processing/storage systems based on Java Virtual Machine (JVM) play the significant roles in the BigData ecosystem. [Hadoop](http://hadoop.apache.org/), [Spark](http://spark.apache.org/) and more recently introduced [Flink](http://flink.apache.org/) are very useful solutions to general large-scale data processing.
Programming languages and data processing/storage systems based on Java Virtual Machine (JVM) play the significant roles in the BigData ecosystem. [Hadoop](http://hadoop.apache.org/) and [Spark](http://spark.apache.org/) which take majority of the market share of the general large-scale data processing systems are both implemented by JVM-languages. On the other side, a lot of machine learning libraries/systems (e.g. [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) which exhibit the excellent performance in various scenarios are implemented by more "native" programming languages, e.g. C++.
On the other side, the emerging demands of machine learning and deep learning
inspires many excellent machine learning libraries.
Many of these machine learning libraries(e.g. [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet))
requires new computation abstraction and native support(e.g. C++ for GPU computing).
They are also often [much more efficient](http://arxiv.org/abs/1603.02754).
The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. In the case of data format changing or trying new features, the user has to walk into this process time and time again.
The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Flink/Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. While such process won't hurt performance as much in data processing case(because machine learning takes a lot of time compared to data loading), it creates a bit inconvenience for the users.
To resolve the situation, we introduce the new-brewed XGBoost4J, <b>XGBoost</b> for <b>J</b>VM Platform. We aim to provide the clean Java/Scala APIs and the integration with the most popular data processing systems developed in JVM-based languages.
We want best of both worlds, so we can use the data processing frameworks like Flink and Spark toghether with
the best distributed machine learning solutions.
To resolve the situation, we introduce the new-brewed [XGBoost4J](https://github.com/dmlc/xgboost/tree/master/jvm-packages),
<b>XGBoost</b> for <b>J</b>VM Platform. We aim to provide the clean Java/Scala APIs and the integration with the most popular data processing systems developed in JVM-based languages.
## Unix Philosophy in Machine Learning
#System Overview#
XGBoost and XGBoost4J adopts Unix Philosophy.
XGBoost **does its best in one thing -- tree boosting** and is **being designed to work with other systems**.
We strongly believe that machine learning solution should not be restricted to certain language or certain platform.
In the following Figure, we describe the overall architecture of XGBoost4J. XGBoost4J provides the Java/Scala API wrapping the core functionality of XGBoost library and most importantly, it not only supports the single-machine model training, but also provides an abstraction layer which masks the difference of the underlying data processing engines (they can be Spark, Flink, or just distributed servers across the cluster)
Specifically, users will be able to use distributed XGBoost in both Flink and Spark, and possibly more frameworks in Future.
We have made the API in a portable way so it **can be easily ported to other Dataflow frameworks provided by the Cloud**.
XGBoost4J shares its core with other XGBoost libraries, which means data scientists can use R/python
read and visualize the model trained distributedly.
It also means that user can start with single machine version for exploration,
which already can handle hundreds of million examples.
## System Overview
In the following Figure, we describe the overall architecture of XGBoost4J. XGBoost4J provides the Java/Scala API calling the core functionality of XGBoost library. Most importantly, it not only supports the single-machine model training, but also provides an abstraction layer which masks the difference of the underlying data processing engines (they can be Spark, Flink, or just distributed servers across the cluster)
![XGBoost4J Architecture](https://raw.githubusercontent.com/dmlc/web-data/master/xgboost/xgboost4j.png)
By calling the XGBoost4J API, users can scale the model training to the cluster. XGBoost4J wraps the running instance of XGBoost in Spark/Flink task and run them across the cluster. The communication among the distributed model training tasks and the XGBoost4J runtime environment go through [Rabit] (https://github.com/dmlc/rabit).
By calling the XGBoost4J API, users can scale the model training to the cluster. XGBoost4J calls the running instance of XGBoost worker in Spark/Flink task and run them across the cluster. The communication among the distributed model training tasks and the XGBoost4J runtime environment go through [Rabit] (https://github.com/dmlc/rabit).
With the abstraction of XGBoost4J, users can build an unified data analytic application ranging from Extract-Transform-Loading, data exploration, machine learning model training and the final data product service. The following figure illustrate an example application built on top of Apache Spark. The application seamlessly embeds XGBoost into the processing pipeline and exchange data with other Spark-based processing phase through Spark's distributed memory layer.
![XGBoost4J Architecture](https://raw.githubusercontent.com/dmlc/web-data/master/xgboost/unified_pipeline.png)
## Single-machine Training Walk-through
# Walk-through
In this section, we will work through the APIs of XGBoost4J by examples. We will cover the single-machine as well as distributed APIs.
####Single-machine Training
In this section, we will work through the APIs of XGBoost4J by examples.
We will be using scala for demonstration, but we also have a complete API for java users.
To start the model training and evaluation, we need to prepare the training and test set:
In Java, we do:
```java
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
```
Or in Scala:
```scala
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
```
After preparing the data, we can train our model:
In Java:
```java
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
```
In Scala
After preparing the data, we can train our model:
```scala
val params = new mutable.HashMap[String, Any]()
@ -101,18 +85,10 @@ watches += "test" -> testMax
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap)
val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap)
```
With the booster we got in either Java or Scala, we can evaluate it with our testset.
In Java:
```java
float[][] predicts = booster.predict(testMat);
```
In Scala:
We then evaluate our model:
```scala
val predicts = booster.predict(testMax)
@ -120,11 +96,14 @@ val predicts = booster.predict(testMax)
`predict` can output the output results and you can define a customized evaluation method to derive your own metrics (see the example in ([Customized Evaluation Metric in Java](https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java), [Customized Evaluation Metric in Scala] (https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala)).
####Distributed Model Training with Distributed Dataflow Framework##
## Distributed Model Training with Distributed Dataflow Frameworks
The most exciting part in this XGBoost4J release is the integration with the Distributed Dataflow Framework. The most popular data processing frameworks fall into this category, e.g. [Apache Spark](http://spark.apache.org/), [Apache Flink] (http://flink.apache.org/), etc. In this part, we will walk through the steps to build the unified data analytic applications containing data preprocessing and distributed model training with Spark and Flink. (currently, we only provide Scala API for the integration with Spark and Flink)
Similar to the single-machine training, we need to prepare the training and test dataset.
Similar to the single-machine training, we need to prepare the training and test dataset.
### Spark Example
In Spark, the dataset is represented as the [Resilient Distributed Dataset (RDD)](http://spark.apache.org/docs/latest/programming-guide.html#resilient-distributed-datasets-rdds), we can utilize the Spark-distributed tools to parse libSVM file and wrap it as the RDD:
@ -132,40 +111,23 @@ In Spark, the dataset is represented as the [Resilient Distributed Dataset (RDD)
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
```
In Flink, we do the similar stuffs and represent training data as Flink's [DataSet](https://ci.apache.org/projects/flink/flink-docs-master/apis/batch/index.html)
```scala
val trainData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
```
We move forward to train the models, in Spark:
We move forward to train the models:
```scala
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound)
```
and in Flink:
The next step is to evaluate the model, you can either predict in local side or in a distributed fashion
```scala
val xgboostModel = XGBoost.train(trainData, paramMap, round)
```
The next step is to evaluate the model, you can either predict in local side or in a distributed fashion
In Spark
```scala
// testSet is an RDD containing testset data represented as
// testSet is an RDD containing testset data represented as
// org.apache.spark.mllib.regression.LabeledPoint
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath)
// local prediction
// import methods in DataUtils to convert Iterator[org.apache.spark.mllib.regression.LabeledPoint]
// local prediction
// import methods in DataUtils to convert Iterator[org.apache.spark.mllib.regression.LabeledPoint]
// to Iterator[ml.dmlc.xgboost4j.LabeledPoint] in automatic
import DataUtils._
xgboostModel.predict(new DMatrix(testSet.collect().iterator)
@ -174,28 +136,44 @@ xgboostModel.predict(new DMatrix(testSet.collect().iterator)
xgboostModel.predict(testSet)
```
### Flink example
In Flink
In Flink, we represent training data as Flink's [DataSet](https://ci.apache.org/projects/flink/flink-docs-master/apis/batch/index.html)
```scala
// testData is a Dataset containing testset data represented as
val trainData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
```
Model Training can be done as follows
```scala
val xgboostModel = XGBoost.train(trainData, paramMap, round)
```
Training and prediction.
```scala
// testData is a Dataset containing testset data represented as
// org.apache.flink.ml.math.Vector.LabeledVector
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
// local prediction
// local prediction
xgboostModel.predict(testData.collect().iterator)
// distributed prediction
xgboostModel.predict(testData.map{x => x.vector})
```
#Road Map #
## Road Map
It is the first release of XGBoost4J package, we are actively move forward for more charming features in the next release. You can watch our progress in [XGBoost4J Road Map](https://github.com/dmlc/xgboost/issues/935).
#Further Readings#
While we are trying our best to keep the minimum changes to the APIs, it is still subject to the incompatible changes.
If you are interested in knowing more about XGBoost, you can find rich resources in
## Further Readings
If you are interested in knowing more about XGBoost, you can find rich resources in
- [The github repository of XGBoost](https://github.com/dmlc/xgboost)
- [The comprehensive documentation site for XGBoostl](http://xgboost.readthedocs.org/en/latest/index.html)
@ -204,6 +182,6 @@ If you are interested in knowing more about XGBoost, you can find rich resources
- [Introduction of the Parameters](http://xgboost.readthedocs.org/en/latest/parameter.html)
- [Awesome XGBoost, a curated list of examples, tutorials, blogs about XGBoost usecases](https://github.com/dmlc/xgboost/tree/master/demo)
#Acknowledgements#
## Acknowledgements
We would like to send many thanks to [Zixuan Huang](https://github.com/yanqingmen), the early developer of XGBoost for Java (XGBoost for Java).
We would like to send many thanks to [Zixuan Huang](https://github.com/yanqingmen), the early developer of XGBoost for Java (XGBoost for Java).

View File

@ -34,7 +34,7 @@ object XGBoostScalaExample {
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(paramMap, trainData, round)
val model = XGBoost.train(trainData, paramMap, round)
// run prediction
val predTrain = model.predict(trainData)
// save model to the file.
@ -43,34 +43,6 @@ object XGBoostScalaExample {
}
```
### XGBoost Flink
```scala
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.MLUtils
object DistTrainWithFlink {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
// read trainining data
val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(paramMap, trainData, round)
val predTrain = model.predict(trainData.map{x => x.vector})
model.saveModelToHadoop("file:///path/to/xgboost.model")
}
}
```
### XGBoost Spark
```scala
import org.apache.spark.SparkContext
@ -101,3 +73,33 @@ object DistTrainWithSpark {
}
}
```
### XGBoost Flink
```scala
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.MLUtils
object DistTrainWithFlink {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
// read trainining data
val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
val predTrain = model.predict(trainData.map{x => x.vector})
model.saveModelToHadoop("file:///path/to/xgboost.model")
}
}
```

View File

@ -67,7 +67,7 @@ public class BasicWalkThrough {
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict
float[][] predicts = booster.predict(testMat);
@ -111,7 +111,7 @@ public class BasicWalkThrough {
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
watches2.put("train", trainMat2);
watches2.put("test", testMat2);
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
Booster booster3 = XGBoost.train(trainMat2, params, round, watches2, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts

View File

@ -48,7 +48,7 @@ public class BoostFromPrediction {
watches.put("test", testMat);
//train xgboost for 1 round
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, 1, watches, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
@ -57,6 +57,6 @@ public class BoostFromPrediction {
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
Booster booster2 = XGBoost.train(trainMat, params, 1, watches, null, null);
}
}

View File

@ -49,7 +49,7 @@ public class CrossValidation {
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
String[] evalHist = XGBoost.crossValidation(trainMat, params, round, nfold, metrics, null,
null);
}
}

View File

@ -163,6 +163,6 @@ public class CustomObjective {
//train a booster
System.out.println("begin to train the booster model");
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
Booster booster = XGBoost.train(trainMat, params, round, watches, obj, eval);
}
}

View File

@ -56,6 +56,6 @@ public class ExternalMemory {
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
}
}

View File

@ -60,7 +60,7 @@ public class GeneralizedLinearModel {
//train a booster
int round = 4;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
float[][] predicts = booster.predict(testMat);

View File

@ -51,7 +51,7 @@ public class PredictFirstNtree {
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);

View File

@ -49,7 +49,7 @@ public class PredictLeafIndices {
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predictLeaf(testMat, 2);

View File

@ -43,7 +43,7 @@ class BasicWalkThrough {
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap)
val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap)
// predict
val predicts = booster.predict(testMax)
// save model to model path
@ -78,7 +78,7 @@ class BasicWalkThrough {
val watches2 = new mutable.HashMap[String, DMatrix]
watches2 += "train" -> trainMax2
watches2 += "test" -> testMax2
val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null)
val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null)
val predicts3 = booster3.predict(testMax2)
println(checkPredicts(predicts, predicts3))
}

View File

@ -39,7 +39,7 @@ class BoostFromPrediction {
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true)
@ -48,6 +48,6 @@ class BoostFromPrediction {
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
}
}

View File

@ -41,6 +41,6 @@ class CrossValidation {
val metrics: Array[String] = null
val evalHist: Array[String] =
XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null)
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null)
}
}

View File

@ -150,8 +150,8 @@ class CustomObjective {
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError)
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError)
}
}

View File

@ -45,7 +45,7 @@ class ExternalMemory {
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true)
@ -54,6 +54,6 @@ class ExternalMemory {
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
}
}

View File

@ -52,7 +52,7 @@ class GeneralizedLinearModel {
watches += "test" -> testMat
val round = 4
val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
val predicts = booster.predict(testMat)
val eval = new CustomEval
println(s"error=${eval.eval(predicts, testMat)}")

View File

@ -38,7 +38,7 @@ class PredictFirstNTree {
val round = 3
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
// predict use 1 tree
val predicts1 = booster.predict(testMat, false, 1)

View File

@ -39,7 +39,7 @@ class PredictLeafIndices {
watches += "test" -> testMat
val round = 3
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
// predict using first 2 tree
val leafIndex = booster.predictLeaf(testMat, 2)

View File

@ -34,7 +34,7 @@ object DistTrainWithFlink {
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(paramMap, trainData, round)
val model = XGBoost.train(trainData, paramMap, round)
val predTest = model.predict(testData.map{x => x.vector})
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
}

View File

@ -16,29 +16,34 @@
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
object DistTrainWithSpark {
def main(args: Array[String]): Unit = {
if (args.length != 4) {
if (args.length != 5) {
println(
"usage: program num_of_rounds num_workers training_path model_path")
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(2)
val outputModelPath = args(3)
val inputTestPath = args(3)
val outputModelPath = args(4)
// number of iterations
val numRound = args(0).toInt
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
import DataUtils._
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().iterator
// training parameters
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound)
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt)
xgboostModel.predict(new DMatrix(testSet))
// save model to HDFS path
xgboostModel.saveModelAsHadoopFile(outputModelPath)
}

View File

@ -56,7 +56,7 @@ object XGBoost {
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoostScala.train(paramMap, trainMat, round, watches, null, null)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
Rabit.shutdown()
collector.collect(new XGBoostModel(booster))
}
@ -81,13 +81,14 @@ object XGBoost {
/**
* Train a xgboost model with link.
*
* @param params The parameters to XGBoost.
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(params: Map[String, Any],
dtrain: DataSet[LabeledVector],
round: Int): XGBoostModel = {
def train(
dtrain: DataSet[LabeledVector],
params: Map[String, Any],
round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start()) {
dtrain

View File

@ -37,6 +37,15 @@ class XGBoostModel (booster: Booster) extends Serializable {
.create(new Path(modelPath)))
}
/**
* predict with the given DMatrix
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
}
/**
* Predict given vector dataset.
*
@ -44,7 +53,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip

View File

@ -56,9 +56,10 @@ object XGBoost extends Serializable {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null))
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
obj, eval)
Rabit.shutdown()
Iterator(booster)
}.cache()

View File

@ -60,8 +60,8 @@ public class XGBoost {
/**
* Train a booster with given parameters.
*
* @param params Booster params.
* @param dtrain Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@ -70,11 +70,13 @@ public class XGBoost {
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster train(Map<String, Object> params,
DMatrix dtrain, int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@ -139,8 +141,8 @@ public class XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@ -150,8 +152,8 @@ public class XGBoost {
* @throws XGBoostError native error
*/
public static String[] crossValidation(
Map<String, Object> params,
DMatrix data,
Map<String, Object> params,
int round,
int nfold,
String[] metrics,

View File

@ -35,10 +35,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
* init DMatrix from Iterator of LabeledPoint
*
* @param dataIter An iterator of LabeledPoint
* @param cacheInfo Cache path information, used for external memory setting, can be null.
* @param cacheInfo Cache path information, used for external memory setting, null by default.
* @throws XGBoostError native error
*/
def this(dataIter: Iterator[LabeledPoint], cacheInfo: String) {
def this(dataIter: Iterator[LabeledPoint], cacheInfo: String = null) {
this(new JDMatrix(dataIter.asJava, cacheInfo))
}

View File

@ -28,8 +28,8 @@ object XGBoost {
/**
* Train a booster given parameters.
*
* @param params Parameters.
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@ -39,8 +39,8 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def train(
params: Map[String, Any],
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
@ -49,10 +49,11 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
dtrain.jDMatrix, round, jWatches.asJava,
round, jWatches.asJava,
obj, eval)
new Booster(xgboostInJava)
}
@ -60,8 +61,8 @@ object XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@ -71,17 +72,17 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def crossValidation(
params: Map[String, Any],
data: DMatrix,
params: Map[String, Any],
round: Int,
nfold: Int = 5,
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValidation(params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
data.jDMatrix, round, nfold, metrics, obj, eval)
JXGBoost.crossValidation(
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
toMap[String, AnyRef].asJava,
round, nfold, metrics, obj, eval)
}
/**

View File

@ -94,7 +94,7 @@ public class BoosterImplTest {
int round = 5;
//train a boost model
return XGBoost.train(paramMap, trainMat, round, watches, null, null);
return XGBoost.train(trainMat, paramMap, round, watches, null, null);
}
@Test
@ -177,6 +177,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
}
}

View File

@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2
XGBoost.train(paramMap, trainMat, round, watches, null, null)
XGBoost.train(trainMat, paramMap, round, watches, null, null)
}
test("basic operation of booster") {
@ -126,6 +126,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
}
}