diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 8228d6712..7ef0f2b58 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
- org.dmlc
+ ml.dmlc
xgboostjvm
0.1
pom
@@ -14,12 +14,13 @@
1.7
1.7
3.3.9
- 2.11.7
- 2.11
+ 2.10.5
+ 2.10
xgboost4j
xgboost4j-demo
+ xgboost4jspark
diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml
index 204b72a20..fd4d22e9a 100644
--- a/jvm-packages/scalastyle-config.xml
+++ b/jvm-packages/scalastyle-config.xml
@@ -188,8 +188,8 @@ This file is divided into 3 sections:
java,scala,3rdParty,spark
javax?\..*
scala\..*
- (?!org\.apache\.spark\.).*
- org\.apache\.spark\..*
+ (?!ml\.dmlc\.xgboost4j\.).*
+ ml.dmlc.xgboost4j.*
diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml
index d8e679b78..4873a4f6d 100644
--- a/jvm-packages/xgboost4j-demo/pom.xml
+++ b/jvm-packages/xgboost4j-demo/pom.xml
@@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
- org.dmlc
+ ml.dmlc
xgboostjvm
0.1
@@ -13,7 +13,7 @@
jar
- org.dmlc
+ ml.dmlc
xgboost4j
0.1
diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java
index 115b1dc5b..c3e913fd2 100644
--- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java
+++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java
@@ -49,6 +49,7 @@ public class CrossValidation {
//set additional eval_metrics
String[] metrics = null;
- String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
+ String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
+ null);
}
}
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index fc6b45ccd..51fe13777 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
- org.dmlc
+ ml.dmlc
xgboostjvm
0.1
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
index 97ef9aed4..c95b18ff6 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java
@@ -15,6 +15,7 @@
*/
package ml.dmlc.xgboost4j;
+import java.io.Serializable;
import java.util.List;
/**
@@ -22,7 +23,7 @@ import java.util.List;
*
* @author hzx
*/
-public interface IObjective {
+public interface IObjective extends Serializable {
/**
* user define objective function, return gradient and second order gradient
*
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java
index 839b006c4..293ce6728 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java
@@ -143,7 +143,7 @@ public class XGBoost {
* @return evaluation history
* @throws XGBoostError native error
*/
- public static String[] crossValiation(
+ public static String[] crossValidation(
Map params,
DMatrix data,
int round,
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
index e32bd46a7..8e15f8174 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
@@ -35,7 +35,7 @@ object XGBoost {
new ScalaBoosterImpl(xgboostInJava)
}
- def crossValiation(
+ def crossValidation(
params: Map[String, AnyRef],
data: DMatrix,
round: Int,
@@ -43,7 +43,7 @@ object XGBoost {
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
- JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
+ JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
index 8f0f3a97e..4a0dd8e16 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
@@ -130,6 +130,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
- String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
+ String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
}
}
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
index e911ec985..ab805a70c 100644
--- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
@@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
- XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
+ XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
}
}
diff --git a/jvm-packages/xgboost4jspark/pom.xml b/jvm-packages/xgboost4jspark/pom.xml
new file mode 100644
index 000000000..b74adfb91
--- /dev/null
+++ b/jvm-packages/xgboost4jspark/pom.xml
@@ -0,0 +1,24 @@
+
+
+ 4.0.0
+
+ ml.dmlc
+ xgboostjvm
+ 0.1
+
+ xgboost4jspark
+
+
+ ml.dmlc
+ xgboost4j
+ 0.1
+
+
+ org.apache.spark
+ spark-core_2.10
+ 1.6.0
+
+
+
\ No newline at end of file
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala
new file mode 100644
index 000000000..04884ebcf
--- /dev/null
+++ b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala
@@ -0,0 +1,32 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import java.io.DataInputStream
+
+private[xgboost4j] object DMatrixBuilder extends Serializable {
+
+ def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
+ // TODO: currently it is random statement for making compiler happy
+ new DMatrix(new Array[Float](1), 1, 1)
+ }
+
+ def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
+ // TODO: currently it is random statement for making compiler happy
+ new DMatrix(new Array[Float](1), 1, 1)
+ }
+}
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala
new file mode 100644
index 000000000..1fec5a9db
--- /dev/null
+++ b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala
@@ -0,0 +1,47 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import ml.dmlc.xgboost4j.scala.Booster
+import org.apache.spark.rdd.RDD
+
+class Boosters(boosters: RDD[Booster]) {
+
+ def save(path: String): Unit = {
+
+ }
+
+ def chooseBestBooster(boosters: RDD[Booster]): Booster = {
+ // TODO:
+ null
+ }
+
+}
+
+object Boosters {
+
+ implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
+ new Boosters(boosterRDD)
+ }
+
+ // load booster from path
+ def apply(path: String): RDD[Booster] = {
+ // TODO
+ null
+ }
+}
+
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
new file mode 100644
index 000000000..6e276cc9d
--- /dev/null
+++ b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -0,0 +1,66 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import scala.collection.immutable.HashMap
+import scala.collection.mutable.ListBuffer
+
+import com.typesafe.config.Config
+import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+
+object XGBoost {
+
+ private var _sc: Option[SparkContext] = None
+
+ private def buildSparkContext(config: Config): SparkContext = {
+ if (_sc.isEmpty) {
+ // TODO:build SparkContext with the user configuration (cores per task, and cores per executor
+ // (or total cores)
+ // NOTE: currently Spark has limited support of configuration of core number in executors
+ }
+ _sc.get
+ }
+
+ def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
+ val sc = buildSparkContext(config)
+ val filePath = config.getString("inputPath") // configuration entry name to be fixed
+ val numWorkers = config.getInt("numWorkers")
+ val round = config.getInt("round")
+ // TODO: build configuration map from config
+ val xgBoostConfigMap = new HashMap[String, AnyRef]()
+ sc.binaryFiles(filePath, numWorkers).mapPartitions {
+ trainingFiles =>
+ val boosters = new ListBuffer[Booster]
+ // we assume one file per DMatrix
+ for ((_, fileInStream) <- trainingFiles) {
+ // TODO:
+ // step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
+ // from a fileInStream.open() (which returns a DataInputStream)
+ val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
+ // step2: build a Booster
+ // TODO: how to build watches list???
+ boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
+ }
+ // TODO
+ boosters.iterator
+ }
+ }
+
+
+}