From 5c1af13f84db1cf9d4b0b318ae6a687b19b6d22a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sat, 5 Mar 2016 17:50:40 -0500 Subject: [PATCH] distributed in RDD --- jvm-packages/pom.xml | 16 +++- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 28 ++++--- .../xgboost4j/scala/spark/XGBoostSuite.scala | 83 +++++++++++++++++++ jvm-packages/xgboost4j/pom.xml | 13 --- 4 files changed, 116 insertions(+), 24 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 7bfb35a7f..43f602df6 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -20,6 +20,7 @@ xgboost4j xgboost4j-demo + xgboost4j-spark xgboost4j-flink @@ -118,6 +119,19 @@ maven-surefire-plugin 2.19.1 + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + @@ -150,7 +164,7 @@ com.typesafe config - 1.3.0 + 1.2.1 diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 49417e2c6..99da3dee2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -28,27 +28,35 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} object XGBoost { - private var _sc: Option[SparkContext] = None - implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { new XGBoostModel(booster) } + private[spark] def buildDistributedBoosters( + trainingData: RDD[LabeledPoint], + xgBoostConfMap: Map[String, AnyRef], + numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { + import DataUtils._ + val sc = trainingData.sparkContext + val dataUtilsBroadcast = sc.broadcast(DataUtils) + trainingData.repartition(numWorkers).mapPartitions { + trainingSamples => + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) + Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round, + watches = new HashMap[String, DMatrix], obj, eval)) + }.cache() + } + def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { import DataUtils._ - val sc = trainingData.sparkContext - val dataUtilsBroadcast = sc.broadcast(DataUtils) - val filePath = config.getString("inputPath") // configuration entry name to be fixed val numWorkers = config.getInt("numWorkers") val round = config.getInt("round") + val sc = trainingData.sparkContext // TODO: build configuration map from config val xgBoostConfigMap = new HashMap[String, AnyRef]() - val boosters = trainingData.repartition(numWorkers).mapPartitions { - trainingSamples => - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)) - }.cache() + val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, + obj, eval) // force the job sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) // TODO: how to choose best model diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala new file mode 100644 index 000000000..98946ee63 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -0,0 +1,83 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import java.io.File + +import scala.collection.mutable.ListBuffer +import scala.io.Source + +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class XGBoostSuite extends FunSuite with BeforeAndAfterAll { + + private var sc: SparkContext = null + private val numWorker = 4 + + override def beforeAll(): Unit = { + // build SparkContext + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") + sc = new SparkContext(sparkConf) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = { + val labelAndFeatures = line.split(" ") + val label = labelAndFeatures(0).toInt + val features = labelAndFeatures.tail + val denseFeature = new Array[Double](129) + for (feature <- features) { + val idAndValue = feature.split(":") + denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble + } + LabeledPoint(label, new DenseVector(denseFeature)) + } + + private def buildRDD(filePath: String): RDD[LabeledPoint] = { + val file = Source.fromFile(new File(filePath)) + val sampleList = new ListBuffer[LabeledPoint] + for (sample <- file.getLines()) { + sampleList += fromSVMStringToLabeledPoint(sample) + } + sc.parallelize(sampleList, numWorker) + } + + private def buildTrainingAndTestRDD(): (RDD[LabeledPoint], RDD[LabeledPoint]) = { + val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile) + val testRDD = buildRDD(getClass.getResource("/agaricus.txt.test").getFile) + (trainRDD, testRDD) + } + + test("build RDD containing boosters") { + val (trainingRDD, testRDD) = buildTrainingAndTestRDD() + val boosterRDD = XGBoost.buildDistributedBoosters( + trainingRDD, + Map[String, AnyRef](), + numWorker, 4, null, null) + val boosterCount = boosterRDD.count() + assert(boosterCount === numWorker) + } +} diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 4ab4414a1..6e1d733a4 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -29,19 +29,6 @@ false - - org.scalatest - scalatest-maven-plugin - 1.0 - - - test - - test - - - -