distributed in RDD

This commit is contained in:
CodingCat 2016-03-05 17:50:40 -05:00
parent fb41e4e673
commit 5c1af13f84
4 changed files with 116 additions and 24 deletions

View File

@ -20,6 +20,7 @@
<modules>
<module>xgboost4j</module>
<module>xgboost4j-demo</module>
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
</modules>
<build>
@ -118,6 +119,19 @@
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
@ -150,7 +164,7 @@
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>1.3.0</version>
<version>1.2.1</version>
</dependency>
</dependencies>
</project>

View File

@ -28,27 +28,35 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost {
private var _sc: Option[SparkContext] = None
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, AnyRef],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
import DataUtils._
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round,
watches = new HashMap[String, DMatrix], obj, eval))
}.cache()
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
import DataUtils._
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
val sc = trainingData.sparkContext
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}.cache()
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
obj, eval)
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
// TODO: how to choose best model

View File

@ -0,0 +1,83 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
private var sc: SparkContext = null
private val numWorker = 4
override def beforeAll(): Unit = {
// build SparkContext
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
sc = new SparkContext(sparkConf)
}
override def afterAll(): Unit = {
if (sc != null) {
sc.stop()
}
}
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toInt
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
LabeledPoint(label, new DenseVector(denseFeature))
}
private def buildRDD(filePath: String): RDD[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
}
sc.parallelize(sampleList, numWorker)
}
private def buildTrainingAndTestRDD(): (RDD[LabeledPoint], RDD[LabeledPoint]) = {
val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile)
val testRDD = buildRDD(getClass.getResource("/agaricus.txt.test").getFile)
(trainRDD, testRDD)
}
test("build RDD containing boosters") {
val (trainingRDD, testRDD) = buildTrainingAndTestRDD()
val boosterRDD = XGBoost.buildDistributedBoosters(
trainingRDD,
Map[String, AnyRef](),
numWorker, 4, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)
}
}

View File

@ -29,19 +29,6 @@
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>