adjust the API signature as well as the docs

This commit is contained in:
CodingCat 2016-03-11 15:22:44 -05:00
parent 97e4dcde98
commit 400b1faecc
23 changed files with 58 additions and 52 deletions

View File

@ -24,7 +24,7 @@ Many of these machine learning libraries(e.g. [XGBoost](https://github.com/dmlc/
requires new computation abstraction and native support(e.g. C++ for GPU computing). requires new computation abstraction and native support(e.g. C++ for GPU computing).
They are also often [much more efficient](http://arxiv.org/abs/1603.02754). They are also often [much more efficient](http://arxiv.org/abs/1603.02754).
The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Flink/Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. While such process won't hurt performance as much in data processing case(because machine learning takes a lot of time compared to data loading), it create a bit inconvenience for the users. The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Flink/Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. While such process won't hurt performance as much in data processing case(because machine learning takes a lot of time compared to data loading), it creates a bit inconvenience for the users.
We want best of both worlds, so we can use the data processing frameworks like Flink and Spark toghether with We want best of both worlds, so we can use the data processing frameworks like Flink and Spark toghether with
the best distributed machine learning solutions. the best distributed machine learning solutions.
@ -37,7 +37,7 @@ XGBoost and XGBoost4J adopts Unix Philosophy.
XGBoost **does its best in one thing -- tree boosting** and is **being designed to work with other systems**. XGBoost **does its best in one thing -- tree boosting** and is **being designed to work with other systems**.
We strongly believe that machine learning solution should not be restricted to certain language or certain platform. We strongly believe that machine learning solution should not be restricted to certain language or certain platform.
Specifically, users will be able to use distributed XGBoost in both Flink and Spark. Specifically, users will be able to use distributed XGBoost in both Flink and Spark, and possibly more frameworks in Future.
We have made the API in a portable way so it **can be easily ported to other Dataflow frameworks provided by the Cloud**. We have made the API in a portable way so it **can be easily ported to other Dataflow frameworks provided by the Cloud**.
XGBoost4J shares its core with other XGBoost libraries, which means data scientists can use R/python XGBoost4J shares its core with other XGBoost libraries, which means data scientists can use R/python
read and visualize the model trained distributedly. read and visualize the model trained distributedly.
@ -85,10 +85,10 @@ watches += "test" -> testMax
val round = 2 val round = 2
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap) val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap)
``` ```
In Scala: We then evaluate our model:
```scala ```scala
val predicts = booster.predict(testMax) val predicts = booster.predict(testMax)
@ -111,7 +111,7 @@ In Spark, the dataset is represented as the [Resilient Distributed Dataset (RDD)
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt) val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
``` ```
We move forward to train the models, in Spark: We move forward to train the models:
```scala ```scala
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound) val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound)
@ -169,6 +169,8 @@ xgboostModel.predict(testData.map{x => x.vector})
It is the first release of XGBoost4J package, we are actively move forward for more charming features in the next release. You can watch our progress in [XGBoost4J Road Map](https://github.com/dmlc/xgboost/issues/935). It is the first release of XGBoost4J package, we are actively move forward for more charming features in the next release. You can watch our progress in [XGBoost4J Road Map](https://github.com/dmlc/xgboost/issues/935).
While we are trying our best to keep the minimum changes to the APIs, it is still subject to the incompatible changes.
## Further Readings ## Further Readings
If you are interested in knowing more about XGBoost, you can find rich resources in If you are interested in knowing more about XGBoost, you can find rich resources in

View File

@ -67,7 +67,7 @@ public class BasicWalkThrough {
int round = 2; int round = 2;
//train a boost model //train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict //predict
float[][] predicts = booster.predict(testMat); float[][] predicts = booster.predict(testMat);
@ -111,7 +111,7 @@ public class BasicWalkThrough {
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>(); HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
watches2.put("train", trainMat2); watches2.put("train", trainMat2);
watches2.put("test", testMat2); watches2.put("test", testMat2);
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null); Booster booster3 = XGBoost.train(trainMat2, params, round, watches2, null, null);
float[][] predicts3 = booster3.predict(testMat2); float[][] predicts3 = booster3.predict(testMat2);
//check predicts //check predicts

View File

@ -48,7 +48,7 @@ public class BoostFromPrediction {
watches.put("test", testMat); watches.put("test", testMat);
//train xgboost for 1 round //train xgboost for 1 round
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null); Booster booster = XGBoost.train(trainMat, params, 1, watches, null, null);
float[][] trainPred = booster.predict(trainMat, true); float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true); float[][] testPred = booster.predict(testMat, true);
@ -57,6 +57,6 @@ public class BoostFromPrediction {
testMat.setBaseMargin(testPred); testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction"); System.out.println("result of running from initial prediction");
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null); Booster booster2 = XGBoost.train(trainMat, params, 1, watches, null, null);
} }
} }

View File

@ -49,7 +49,7 @@ public class CrossValidation {
//set additional eval_metrics //set additional eval_metrics
String[] metrics = null; String[] metrics = null;
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null, String[] evalHist = XGBoost.crossValidation(trainMat, params, round, nfold, metrics, null,
null); null);
} }
} }

View File

@ -163,6 +163,6 @@ public class CustomObjective {
//train a booster //train a booster
System.out.println("begin to train the booster model"); System.out.println("begin to train the booster model");
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval); Booster booster = XGBoost.train(trainMat, params, round, watches, obj, eval);
} }
} }

View File

@ -56,6 +56,6 @@ public class ExternalMemory {
int round = 2; int round = 2;
//train a boost model //train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
} }
} }

View File

@ -60,7 +60,7 @@ public class GeneralizedLinearModel {
//train a booster //train a booster
int round = 4; int round = 4;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
float[][] predicts = booster.predict(testMat); float[][] predicts = booster.predict(testMat);

View File

@ -51,7 +51,7 @@ public class PredictFirstNtree {
//train a booster //train a booster
int round = 3; int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict use 1 tree //predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1); float[][] predicts1 = booster.predict(testMat, false, 1);

View File

@ -49,7 +49,7 @@ public class PredictLeafIndices {
//train a booster //train a booster
int round = 3; int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
//predict using first 2 tree //predict using first 2 tree
float[][] leafindex = booster.predictLeaf(testMat, 2); float[][] leafindex = booster.predictLeaf(testMat, 2);

View File

@ -43,7 +43,7 @@ class BasicWalkThrough {
val round = 2 val round = 2
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap) val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap)
// predict // predict
val predicts = booster.predict(testMax) val predicts = booster.predict(testMax)
// save model to model path // save model to model path
@ -78,7 +78,7 @@ class BasicWalkThrough {
val watches2 = new mutable.HashMap[String, DMatrix] val watches2 = new mutable.HashMap[String, DMatrix]
watches2 += "train" -> trainMax2 watches2 += "train" -> trainMax2
watches2 += "test" -> testMax2 watches2 += "test" -> testMax2
val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null) val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null)
val predicts3 = booster3.predict(testMax2) val predicts3 = booster3.predict(testMax2)
println(checkPredicts(predicts, predicts3)) println(checkPredicts(predicts, predicts3))
} }

View File

@ -39,7 +39,7 @@ class BoostFromPrediction {
val round = 2 val round = 2
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
val trainPred = booster.predict(trainMat, true) val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true) val testPred = booster.predict(testMat, true)
@ -48,6 +48,6 @@ class BoostFromPrediction {
testMat.setBaseMargin(testPred) testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction") System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
} }
} }

View File

@ -41,6 +41,6 @@ class CrossValidation {
val metrics: Array[String] = null val metrics: Array[String] = null
val evalHist: Array[String] = val evalHist: Array[String] =
XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null) XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null)
} }
} }

View File

@ -150,8 +150,8 @@ class CustomObjective {
val round = 2 val round = 2
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError) XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError)
} }
} }

View File

@ -45,7 +45,7 @@ class ExternalMemory {
val round = 2 val round = 2
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
val trainPred = booster.predict(trainMat, true) val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true) val testPred = booster.predict(testMat, true)
@ -54,6 +54,6 @@ class ExternalMemory {
testMat.setBaseMargin(testPred) testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction") System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
} }
} }

View File

@ -52,7 +52,7 @@ class GeneralizedLinearModel {
watches += "test" -> testMat watches += "test" -> testMat
val round = 4 val round = 4
val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
val predicts = booster.predict(testMat) val predicts = booster.predict(testMat)
val eval = new CustomEval val eval = new CustomEval
println(s"error=${eval.eval(predicts, testMat)}") println(s"error=${eval.eval(predicts, testMat)}")

View File

@ -38,7 +38,7 @@ class PredictFirstNTree {
val round = 3 val round = 3
// train a model // train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
// predict use 1 tree // predict use 1 tree
val predicts1 = booster.predict(testMat, false, 1) val predicts1 = booster.predict(testMat, false, 1)

View File

@ -39,7 +39,7 @@ class PredictLeafIndices {
watches += "test" -> testMat watches += "test" -> testMat
val round = 3 val round = 3
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
// predict using first 2 tree // predict using first 2 tree
val leafIndex = booster.predictLeaf(testMat, 2) val leafIndex = booster.predictLeaf(testMat, 2)

View File

@ -56,7 +56,7 @@ object XGBoost {
val trainMat = new DMatrix(dataIter, null) val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap val watches = List("train" -> trainMat).toMap
val round = 2 val round = 2
val booster = XGBoostScala.train(paramMap, trainMat, round, watches, null, null) val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
Rabit.shutdown() Rabit.shutdown()
collector.collect(new XGBoostModel(booster)) collector.collect(new XGBoostModel(booster))
} }

View File

@ -56,9 +56,10 @@ object XGBoost extends Serializable {
trainingSamples => trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null))
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
obj, eval)
Rabit.shutdown() Rabit.shutdown()
Iterator(booster) Iterator(booster)
}.cache() }.cache()

View File

@ -60,8 +60,8 @@ public class XGBoost {
/** /**
* Train a booster with given parameters. * Train a booster with given parameters.
* *
* @param params Booster params.
* @param dtrain Data to be trained. * @param dtrain Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations. * @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch * @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set. * performance on the validation set.
@ -70,8 +70,10 @@ public class XGBoost {
* @return trained booster * @return trained booster
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public static Booster train(Map<String, Object> params, public static Booster train(
DMatrix dtrain, int round, DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches, Map<String, DMatrix> watches,
IObjective obj, IObjective obj,
IEvaluation eval) throws XGBoostError { IEvaluation eval) throws XGBoostError {
@ -139,8 +141,8 @@ public class XGBoost {
/** /**
* Cross-validation with given parameters. * Cross-validation with given parameters.
* *
* @param params Booster params.
* @param data Data to be trained. * @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations. * @param round Number of boosting iterations.
* @param nfold Number of folds in CV. * @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV. * @param metrics Evaluation metrics to be watched in CV.
@ -150,8 +152,8 @@ public class XGBoost {
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public static String[] crossValidation( public static String[] crossValidation(
Map<String, Object> params,
DMatrix data, DMatrix data,
Map<String, Object> params,
int round, int round,
int nfold, int nfold,
String[] metrics, String[] metrics,

View File

@ -28,8 +28,8 @@ object XGBoost {
/** /**
* Train a booster given parameters. * Train a booster given parameters.
* *
* @param params Parameters.
* @param dtrain Data to be trained. * @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations. * @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch * @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set. * performance on the validation set.
@ -39,8 +39,8 @@ object XGBoost {
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def train( def train(
params: Map[String, Any],
dtrain: DMatrix, dtrain: DMatrix,
params: Map[String, Any],
round: Int, round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](), watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
@ -49,10 +49,11 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train( val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
params.map{ params.map{
case (key: String, value) => (key, value.toString) case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava, }.toMap[String, AnyRef].asJava,
dtrain.jDMatrix, round, jWatches.asJava, round, jWatches.asJava,
obj, eval) obj, eval)
new Booster(xgboostInJava) new Booster(xgboostInJava)
} }
@ -60,8 +61,8 @@ object XGBoost {
/** /**
* Cross-validation with given parameters. * Cross-validation with given parameters.
* *
* @param params Booster params.
* @param data Data to be trained. * @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations. * @param round Number of boosting iterations.
* @param nfold Number of folds in CV. * @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV. * @param metrics Evaluation metrics to be watched in CV.
@ -71,17 +72,17 @@ object XGBoost {
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def crossValidation( def crossValidation(
params: Map[String, Any],
data: DMatrix, data: DMatrix,
params: Map[String, Any],
round: Int, round: Int,
nfold: Int = 5, nfold: Int = 5,
metrics: Array[String] = null, metrics: Array[String] = null,
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = { eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValidation(params.map{ JXGBoost.crossValidation(
case (key: String, value) => (key, value.toString) data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
}.toMap[String, AnyRef].asJava, toMap[String, AnyRef].asJava,
data.jDMatrix, round, nfold, metrics, obj, eval) round, nfold, metrics, obj, eval)
} }
/** /**

View File

@ -94,7 +94,7 @@ public class BoosterImplTest {
int round = 5; int round = 5;
//train a boost model //train a boost model
return XGBoost.train(paramMap, trainMat, round, watches, null, null); return XGBoost.train(trainMat, paramMap, round, watches, null, null);
} }
@Test @Test
@ -177,6 +177,6 @@ public class BoosterImplTest {
//do 5-fold cross validation //do 5-fold cross validation
int round = 2; int round = 2;
int nfold = 5; int nfold = 5;
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null); String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
} }
} }

View File

@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
val watches = List("train" -> trainMat, "test" -> testMat).toMap val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2 val round = 2
XGBoost.train(paramMap, trainMat, round, watches, null, null) XGBoost.train(trainMat, paramMap, round, watches, null, null)
} }
test("basic operation of booster") { test("basic operation of booster") {
@ -126,6 +126,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2 val round = 2
val nfold = 5 val nfold = 5
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null) XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
} }
} }