diff --git a/doc/jvm/xgboost4j-intro.md b/doc/jvm/xgboost4j-intro.md index bc8509d79..09d5b29c7 100644 --- a/doc/jvm/xgboost4j-intro.md +++ b/doc/jvm/xgboost4j-intro.md @@ -24,7 +24,7 @@ Many of these machine learning libraries(e.g. [XGBoost](https://github.com/dmlc/ requires new computation abstraction and native support(e.g. C++ for GPU computing). They are also often [much more efficient](http://arxiv.org/abs/1603.02754). -The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Flink/Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. While such process won't hurt performance as much in data processing case(because machine learning takes a lot of time compared to data loading), it create a bit inconvenience for the users. +The gap between the implementation fundamentals of the general data processing frameworks and the more specific machine learning libraries/systems prohibits the smooth connection between these two types of systems, thus brings unnecessary inconvenience to the end user. The common workflow to the user is to utilize the systems like Flink/Spark to preprocess/clean data, pass the results to machine learning systems like [XGBoost](https://github.com/dmlc/xgboost)/[MxNet](https://github.com/dmlc/mxnet)) via the file system and then conduct the following machine learning phase. While such process won't hurt performance as much in data processing case(because machine learning takes a lot of time compared to data loading), it creates a bit inconvenience for the users. We want best of both worlds, so we can use the data processing frameworks like Flink and Spark toghether with the best distributed machine learning solutions. @@ -37,7 +37,7 @@ XGBoost and XGBoost4J adopts Unix Philosophy. XGBoost **does its best in one thing -- tree boosting** and is **being designed to work with other systems**. We strongly believe that machine learning solution should not be restricted to certain language or certain platform. -Specifically, users will be able to use distributed XGBoost in both Flink and Spark. +Specifically, users will be able to use distributed XGBoost in both Flink and Spark, and possibly more frameworks in Future. We have made the API in a portable way so it **can be easily ported to other Dataflow frameworks provided by the Cloud**. XGBoost4J shares its core with other XGBoost libraries, which means data scientists can use R/python read and visualize the model trained distributedly. @@ -85,10 +85,10 @@ watches += "test" -> testMax val round = 2 // train a model -val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap) +val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap) ``` -In Scala: +We then evaluate our model: ```scala val predicts = booster.predict(testMax) @@ -111,7 +111,7 @@ In Spark, the dataset is represented as the [Resilient Distributed Dataset (RDD) val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt) ``` -We move forward to train the models, in Spark: +We move forward to train the models: ```scala val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound) @@ -169,6 +169,8 @@ xgboostModel.predict(testData.map{x => x.vector}) It is the first release of XGBoost4J package, we are actively move forward for more charming features in the next release. You can watch our progress in [XGBoost4J Road Map](https://github.com/dmlc/xgboost/issues/935). +While we are trying our best to keep the minimum changes to the APIs, it is still subject to the incompatible changes. + ## Further Readings If you are interested in knowing more about XGBoost, you can find rich resources in diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java index 2fe829b41..7a74852f4 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java @@ -67,7 +67,7 @@ public class BasicWalkThrough { int round = 2; //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); //predict float[][] predicts = booster.predict(testMat); @@ -111,7 +111,7 @@ public class BasicWalkThrough { HashMap watches2 = new HashMap(); watches2.put("train", trainMat2); watches2.put("test", testMat2); - Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null); + Booster booster3 = XGBoost.train(trainMat2, params, round, watches2, null, null); float[][] predicts3 = booster3.predict(testMat2); //check predicts diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java index 0649de2ae..7eb9e99f0 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java @@ -48,7 +48,7 @@ public class BoostFromPrediction { watches.put("test", testMat); //train xgboost for 1 round - Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, 1, watches, null, null); float[][] trainPred = booster.predict(trainMat, true); float[][] testPred = booster.predict(testMat, true); @@ -57,6 +57,6 @@ public class BoostFromPrediction { testMat.setBaseMargin(testPred); System.out.println("result of running from initial prediction"); - Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null); + Booster booster2 = XGBoost.train(trainMat, params, 1, watches, null, null); } } diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java index dbaa4ff0f..dbe5f368c 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java @@ -49,7 +49,7 @@ public class CrossValidation { //set additional eval_metrics String[] metrics = null; - String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null, + String[] evalHist = XGBoost.crossValidation(trainMat, params, round, nfold, metrics, null, null); } } diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java index 5fc132fa6..6d529974c 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java @@ -163,6 +163,6 @@ public class CustomObjective { //train a booster System.out.println("begin to train the booster model"); - Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval); + Booster booster = XGBoost.train(trainMat, params, round, watches, obj, eval); } } diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java index 26434a1a4..349098ae1 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java @@ -56,6 +56,6 @@ public class ExternalMemory { int round = 2; //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); } } diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java index 37c4e756a..422cdea6a 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java @@ -60,7 +60,7 @@ public class GeneralizedLinearModel { //train a booster int round = 4; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); float[][] predicts = booster.predict(testMat); diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java index 9b3f3e27a..c98534a93 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java @@ -51,7 +51,7 @@ public class PredictFirstNtree { //train a booster int round = 3; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); //predict use 1 tree float[][] predicts1 = booster.predict(testMat, false, 1); diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java index c063df368..0fcfb39de 100644 --- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java +++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java @@ -49,7 +49,7 @@ public class PredictLeafIndices { //train a booster int round = 3; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); //predict using first 2 tree float[][] leafindex = booster.predictLeaf(testMat, 2); diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala index fdfb50c94..0f6cd1c04 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala @@ -43,7 +43,7 @@ class BasicWalkThrough { val round = 2 // train a model - val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap) + val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap) // predict val predicts = booster.predict(testMax) // save model to model path @@ -78,7 +78,7 @@ class BasicWalkThrough { val watches2 = new mutable.HashMap[String, DMatrix] watches2 += "train" -> trainMax2 watches2 += "test" -> testMax2 - val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null) + val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null) val predicts3 = booster3.predict(testMax2) println(checkPredicts(predicts, predicts3)) } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala index a68d479c1..3163ea47e 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala @@ -39,7 +39,7 @@ class BoostFromPrediction { val round = 2 // train a model - val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) val trainPred = booster.predict(trainMat, true) val testPred = booster.predict(testMat, true) @@ -48,6 +48,6 @@ class BoostFromPrediction { testMat.setBaseMargin(testPred) System.out.println("result of running from initial prediction") - val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala index 493aa2e62..14dfa8dfd 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala @@ -41,6 +41,6 @@ class CrossValidation { val metrics: Array[String] = null val evalHist: Array[String] = - XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null) + XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala index 3f27e9031..c0b6914c4 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala @@ -150,8 +150,8 @@ class CustomObjective { val round = 2 // train a model - val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) - XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError) + val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) + XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala index 61faf3293..b1bd02347 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala @@ -45,7 +45,7 @@ class ExternalMemory { val round = 2 // train a model - val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) val trainPred = booster.predict(trainMat, true) val testPred = booster.predict(testMat, true) @@ -54,6 +54,6 @@ class ExternalMemory { testMat.setBaseMargin(testPred) System.out.println("result of running from initial prediction") - val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala index 580f8351a..11238c0f6 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala @@ -52,7 +52,7 @@ class GeneralizedLinearModel { watches += "test" -> testMat val round = 4 - val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null) val predicts = booster.predict(testMat) val eval = new CustomEval println(s"error=${eval.eval(predicts, testMat)}") diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala index 8dd83e6c7..f62c518a7 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala @@ -38,7 +38,7 @@ class PredictFirstNTree { val round = 3 // train a model - val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) // predict use 1 tree val predicts1 = booster.predict(testMat, false, 1) diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala index d7194c73f..32a75eeec 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala @@ -39,7 +39,7 @@ class PredictLeafIndices { watches += "test" -> testMat val round = 3 - val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) // predict using first 2 tree val leafIndex = booster.predictLeaf(testMat, 2) diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index 8d00ec9c1..3577ebcc1 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -56,7 +56,7 @@ object XGBoost { val trainMat = new DMatrix(dataIter, null) val watches = List("train" -> trainMat).toMap val round = 2 - val booster = XGBoostScala.train(paramMap, trainMat, round, watches, null, null) + val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null) Rabit.shutdown() collector.collect(new XGBoostModel(booster)) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index a0551fba5..4b4da36cb 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -56,9 +56,10 @@ object XGBoost extends Serializable { trainingSamples => rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, - watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) + val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null)) + val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, + watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap, + obj, eval) Rabit.shutdown() Iterator(booster) }.cache() diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 09159e53f..f2ce989f6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -60,8 +60,8 @@ public class XGBoost { /** * Train a booster with given parameters. * - * @param params Booster params. * @param dtrain Data to be trained. + * @param params Booster params. * @param round Number of boosting iterations. * @param watches a group of items to be evaluated during training, this allows user to watch * performance on the validation set. @@ -70,11 +70,13 @@ public class XGBoost { * @return trained booster * @throws XGBoostError native error */ - public static Booster train(Map params, - DMatrix dtrain, int round, - Map watches, - IObjective obj, - IEvaluation eval) throws XGBoostError { + public static Booster train( + DMatrix dtrain, + Map params, + int round, + Map watches, + IObjective obj, + IEvaluation eval) throws XGBoostError { //collect eval matrixs String[] evalNames; @@ -139,8 +141,8 @@ public class XGBoost { /** * Cross-validation with given parameters. * - * @param params Booster params. * @param data Data to be trained. + * @param params Booster params. * @param round Number of boosting iterations. * @param nfold Number of folds in CV. * @param metrics Evaluation metrics to be watched in CV. @@ -150,8 +152,8 @@ public class XGBoost { * @throws XGBoostError native error */ public static String[] crossValidation( - Map params, DMatrix data, + Map params, int round, int nfold, String[] metrics, diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 6ed9cfb62..15f16be51 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -28,8 +28,8 @@ object XGBoost { /** * Train a booster given parameters. * - * @param params Parameters. * @param dtrain Data to be trained. + * @param params Parameters. * @param round Number of boosting iterations. * @param watches a group of items to be evaluated during training, this allows user to watch * performance on the validation set. @@ -39,8 +39,8 @@ object XGBoost { */ @throws(classOf[XGBoostError]) def train( - params: Map[String, Any], dtrain: DMatrix, + params: Map[String, Any], round: Int, watches: Map[String, DMatrix] = Map[String, DMatrix](), obj: ObjectiveTrait = null, @@ -49,10 +49,11 @@ object XGBoost { val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val xgboostInJava = JXGBoost.train( + dtrain.jDMatrix, params.map{ case (key: String, value) => (key, value.toString) }.toMap[String, AnyRef].asJava, - dtrain.jDMatrix, round, jWatches.asJava, + round, jWatches.asJava, obj, eval) new Booster(xgboostInJava) } @@ -60,8 +61,8 @@ object XGBoost { /** * Cross-validation with given parameters. * - * @param params Booster params. * @param data Data to be trained. + * @param params Booster params. * @param round Number of boosting iterations. * @param nfold Number of folds in CV. * @param metrics Evaluation metrics to be watched in CV. @@ -71,17 +72,17 @@ object XGBoost { */ @throws(classOf[XGBoostError]) def crossValidation( - params: Map[String, Any], data: DMatrix, + params: Map[String, Any], round: Int, nfold: Int = 5, metrics: Array[String] = null, obj: ObjectiveTrait = null, eval: EvalTrait = null): Array[String] = { - JXGBoost.crossValidation(params.map{ - case (key: String, value) => (key, value.toString) - }.toMap[String, AnyRef].asJava, - data.jDMatrix, round, nfold, metrics, obj, eval) + JXGBoost.crossValidation( + data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}. + toMap[String, AnyRef].asJava, + round, nfold, metrics, obj, eval) } /** diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 848d42407..d8cb1d505 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -94,7 +94,7 @@ public class BoosterImplTest { int round = 5; //train a boost model - return XGBoost.train(paramMap, trainMat, round, watches, null, null); + return XGBoost.train(trainMat, paramMap, round, watches, null, null); } @Test @@ -177,6 +177,6 @@ public class BoosterImplTest { //do 5-fold cross validation int round = 2; int nfold = 5; - String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null); + String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null); } } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index b9e6443ca..147a486c5 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite { val watches = List("train" -> trainMat, "test" -> testMat).toMap val round = 2 - XGBoost.train(paramMap, trainMat, round, watches, null, null) + XGBoost.train(trainMat, paramMap, round, watches, null, null) } test("basic operation of booster") { @@ -126,6 +126,6 @@ class ScalaBoosterImplSuite extends FunSuite { "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap val round = 2 val nfold = 5 - XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null) + XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null) } }