diff --git a/dmlc-core b/dmlc-core
index 3f6ff43d3..71360023d 160000
--- a/dmlc-core
+++ b/dmlc-core
@@ -1 +1 @@
-Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
+Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 7ef0f2b58..cfc409ddb 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -20,7 +20,7 @@
xgboost4j
xgboost4j-demo
- xgboost4jspark
+ xgboost4j-spark
diff --git a/jvm-packages/xgboost4jspark/pom.xml b/jvm-packages/xgboost4j-spark/pom.xml
similarity index 93%
rename from jvm-packages/xgboost4jspark/pom.xml
rename to jvm-packages/xgboost4j-spark/pom.xml
index b74adfb91..d847ca2f4 100644
--- a/jvm-packages/xgboost4jspark/pom.xml
+++ b/jvm-packages/xgboost4j-spark/pom.xml
@@ -17,7 +17,7 @@
org.apache.spark
- spark-core_2.10
+ spark-mllib_2.10
1.6.0
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
new file mode 100644
index 000000000..fd336a9c2
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
@@ -0,0 +1,74 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.util.{Iterator => JIterator}
+
+import scala.collection.mutable.ListBuffer
+import scala.collection.JavaConverters._
+
+import ml.dmlc.xgboost4j.DataBatch
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+
+private[spark] object DataUtils extends Serializable {
+
+ private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
+ (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
+ }
+
+ private def fetchUpdateFromVector(feature: Vector) = feature match {
+ case denseFeature: DenseVector =>
+ fetchUpdateFromSparseVector(denseFeature.toSparse)
+ case sparseFeature: SparseVector =>
+ fetchUpdateFromSparseVector(sparseFeature)
+ }
+
+ def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
+ // TODO: support weight
+ var samplePos = 0
+ // TODO: change hard value
+ val loadingBatchSize = 100
+ val rowOffset = new ListBuffer[Long]
+ val label = new ListBuffer[Float]
+ val featureIndices = new ListBuffer[Int]
+ val featureValues = new ListBuffer[Float]
+ val dataBatches = new ListBuffer[DataBatch]
+ for (point <- points) {
+ val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
+ rowOffset(samplePos) = rowOffset.size
+ label(samplePos) = point.label.toFloat
+ for (i <- nonZeroIndices.indices) {
+ featureIndices += nonZeroIndices(i)
+ featureValues += nonZeroValues(i)
+ }
+ samplePos += 1
+ if (samplePos % loadingBatchSize == 0) {
+ // create a data batch
+ dataBatches += new DataBatch(
+ rowOffset.toArray.clone(),
+ null, label.toArray.clone(), featureIndices.toArray.clone(),
+ featureValues.toArray.clone())
+ rowOffset.clear()
+ label.clear()
+ featureIndices.clear()
+ featureValues.clear()
+ }
+ }
+ dataBatches.iterator.asJava
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
new file mode 100644
index 000000000..96a6210a7
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -0,0 +1,55 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import scala.collection.immutable.HashMap
+import scala.collection.JavaConverters._
+
+import com.typesafe.config.Config
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
+import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+object XGBoost {
+
+ private var _sc: Option[SparkContext] = None
+
+ implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
+ new XGBoostModel(booster)
+ }
+
+ def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
+ eval: EvalTrait = null): XGBoostModel = {
+ val sc = trainingData.sparkContext
+ val dataUtilsBroadcast = sc.broadcast(DataUtils)
+ val filePath = config.getString("inputPath") // configuration entry name to be fixed
+ val numWorkers = config.getInt("numWorkers")
+ val round = config.getInt("round")
+ // TODO: build configuration map from config
+ val xgBoostConfigMap = new HashMap[String, AnyRef]()
+ val boosters = trainingData.repartition(numWorkers).mapPartitions {
+ trainingSamples =>
+ val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
+ val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
+ Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
+ }
+ // TODO: how to choose best model
+ boosters.first()
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
new file mode 100644
index 000000000..849ad6168
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
@@ -0,0 +1,37 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import scala.collection.JavaConverters._
+
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
+import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+class XGBoostModel(booster: Booster) extends Serializable {
+
+ def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
+ val broadcastBooster = testSet.sparkContext.broadcast(booster)
+ val dataUtils = testSet.sparkContext.broadcast(DataUtils)
+ testSet.mapPartitions { testSamples =>
+ val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
+ val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
+ Iterator(broadcastBooster.value.predict(dMatrix))
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java
index e2b3ecc47..99f55055e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java
@@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
- private long handle = 0;
+ protected long handle = 0;
//load native library
static {
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java
index 2e48b02f5..3fd2427b8 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java
@@ -4,8 +4,6 @@ package ml.dmlc.xgboost4j;
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
- * Usually this object is not needed.
- *
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
@@ -19,6 +17,19 @@ public class DataBatch {
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
+
+ public DataBatch() {}
+
+ public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
+ float[] featureValue) {
+ this.rowOffset = rowOffset;
+ this.weight = weight;
+ this.label = label;
+ this.featureIndex = featureIndex;
+ this.featureValue = featureValue;
+ }
+
+
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java
index ae265a36d..2820f51b6 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java
@@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
}
// making Booster serializable
- private void writeObject(java.io.ObjectOutputStream out)
- throws IOException {
+ private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
index 160396df0..922f8bc8e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java
@@ -27,7 +27,8 @@ class XgboostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
- public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, String cache_info, long[] out);
+ final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter,
+ String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
index 634aef190..8f3d73e2b 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
@@ -16,7 +16,9 @@
package ml.dmlc.xgboost4j.scala
-import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
+import _root_.scala.collection.JavaConverters._
+
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, DataBatch, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@@ -43,6 +45,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(headers, indices, data, st))
}
+ private[xgboost4j] def this(dataBatch: DataBatch) {
+ this(new JDMatrix(List(dataBatch).asJava.iterator, null))
+ }
+
/**
* create DMatrix from dense matrix
*
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala
deleted file mode 100644
index 04884ebcf..000000000
--- a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrixBuilder.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala
-
-import java.io.DataInputStream
-
-private[xgboost4j] object DMatrixBuilder extends Serializable {
-
- def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
- // TODO: currently it is random statement for making compiler happy
- new DMatrix(new Array[Float](1), 1, 1)
- }
-
- def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
- // TODO: currently it is random statement for making compiler happy
- new DMatrix(new Array[Float](1), 1, 1)
- }
-}
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala
deleted file mode 100644
index 1fec5a9db..000000000
--- a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Boosters.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import ml.dmlc.xgboost4j.scala.Booster
-import org.apache.spark.rdd.RDD
-
-class Boosters(boosters: RDD[Booster]) {
-
- def save(path: String): Unit = {
-
- }
-
- def chooseBestBooster(boosters: RDD[Booster]): Booster = {
- // TODO:
- null
- }
-
-}
-
-object Boosters {
-
- implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
- new Boosters(boosterRDD)
- }
-
- // load booster from path
- def apply(path: String): RDD[Booster] = {
- // TODO
- null
- }
-}
-
diff --git a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
deleted file mode 100644
index 6e276cc9d..000000000
--- a/jvm-packages/xgboost4jspark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import scala.collection.immutable.HashMap
-import scala.collection.mutable.ListBuffer
-
-import com.typesafe.config.Config
-import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
-import org.apache.spark.SparkContext
-import org.apache.spark.rdd.RDD
-
-object XGBoost {
-
- private var _sc: Option[SparkContext] = None
-
- private def buildSparkContext(config: Config): SparkContext = {
- if (_sc.isEmpty) {
- // TODO:build SparkContext with the user configuration (cores per task, and cores per executor
- // (or total cores)
- // NOTE: currently Spark has limited support of configuration of core number in executors
- }
- _sc.get
- }
-
- def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
- val sc = buildSparkContext(config)
- val filePath = config.getString("inputPath") // configuration entry name to be fixed
- val numWorkers = config.getInt("numWorkers")
- val round = config.getInt("round")
- // TODO: build configuration map from config
- val xgBoostConfigMap = new HashMap[String, AnyRef]()
- sc.binaryFiles(filePath, numWorkers).mapPartitions {
- trainingFiles =>
- val boosters = new ListBuffer[Booster]
- // we assume one file per DMatrix
- for ((_, fileInStream) <- trainingFiles) {
- // TODO:
- // step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
- // from a fileInStream.open() (which returns a DataInputStream)
- val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
- // step2: build a Booster
- // TODO: how to build watches list???
- boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
- }
- // TODO
- boosters.iterator
- }
- }
-
-
-}
diff --git a/rabit b/rabit
index be50e7b63..1392e9f3d 160000
--- a/rabit
+++ b/rabit
@@ -1 +1 @@
-Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
+Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0