diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 7bfb35a7f..43f602df6 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -20,6 +20,7 @@
xgboost4j
xgboost4j-demo
+ xgboost4j-spark
xgboost4j-flink
@@ -118,6 +119,19 @@
maven-surefire-plugin
2.19.1
+
+ org.scalatest
+ scalatest-maven-plugin
+ 1.0
+
+
+ test
+
+ test
+
+
+
+
@@ -150,7 +164,7 @@
com.typesafe
config
- 1.3.0
+ 1.2.1
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 49417e2c6..99da3dee2 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -28,27 +28,35 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost {
- private var _sc: Option[SparkContext] = None
-
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
+ private[spark] def buildDistributedBoosters(
+ trainingData: RDD[LabeledPoint],
+ xgBoostConfMap: Map[String, AnyRef],
+ numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
+ import DataUtils._
+ val sc = trainingData.sparkContext
+ val dataUtilsBroadcast = sc.broadcast(DataUtils)
+ trainingData.repartition(numWorkers).mapPartitions {
+ trainingSamples =>
+ val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
+ Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round,
+ watches = new HashMap[String, DMatrix], obj, eval))
+ }.cache()
+ }
+
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
import DataUtils._
- val sc = trainingData.sparkContext
- val dataUtilsBroadcast = sc.broadcast(DataUtils)
- val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
+ val sc = trainingData.sparkContext
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
- val boosters = trainingData.repartition(numWorkers).mapPartitions {
- trainingSamples =>
- val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
- Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
- }.cache()
+ val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
+ obj, eval)
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
// TODO: how to choose best model
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
new file mode 100644
index 000000000..98946ee63
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
@@ -0,0 +1,83 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.File
+
+import scala.collection.mutable.ListBuffer
+import scala.io.Source
+
+import org.apache.spark.mllib.linalg.DenseVector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
+
+ private var sc: SparkContext = null
+ private val numWorker = 4
+
+ override def beforeAll(): Unit = {
+ // build SparkContext
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
+ sc = new SparkContext(sparkConf)
+ }
+
+ override def afterAll(): Unit = {
+ if (sc != null) {
+ sc.stop()
+ }
+ }
+
+ private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
+ val labelAndFeatures = line.split(" ")
+ val label = labelAndFeatures(0).toInt
+ val features = labelAndFeatures.tail
+ val denseFeature = new Array[Double](129)
+ for (feature <- features) {
+ val idAndValue = feature.split(":")
+ denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
+ }
+ LabeledPoint(label, new DenseVector(denseFeature))
+ }
+
+ private def buildRDD(filePath: String): RDD[LabeledPoint] = {
+ val file = Source.fromFile(new File(filePath))
+ val sampleList = new ListBuffer[LabeledPoint]
+ for (sample <- file.getLines()) {
+ sampleList += fromSVMStringToLabeledPoint(sample)
+ }
+ sc.parallelize(sampleList, numWorker)
+ }
+
+ private def buildTrainingAndTestRDD(): (RDD[LabeledPoint], RDD[LabeledPoint]) = {
+ val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile)
+ val testRDD = buildRDD(getClass.getResource("/agaricus.txt.test").getFile)
+ (trainRDD, testRDD)
+ }
+
+ test("build RDD containing boosters") {
+ val (trainingRDD, testRDD) = buildTrainingAndTestRDD()
+ val boosterRDD = XGBoost.buildDistributedBoosters(
+ trainingRDD,
+ Map[String, AnyRef](),
+ numWorker, 4, null, null)
+ val boosterCount = boosterRDD.count()
+ assert(boosterCount === numWorker)
+ }
+}
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 4ab4414a1..6e1d733a4 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -29,19 +29,6 @@
false
-
- org.scalatest
- scalatest-maven-plugin
- 1.0
-
-
- test
-
- test
-
-
-
-