Merge pull request #922 from CodingCat/label
spark with new labeledpoint
This commit is contained in:
commit
3ddddfce79
@ -20,6 +20,7 @@
|
||||
<modules>
|
||||
<module>xgboost4j</module>
|
||||
<module>xgboost4j-demo</module>
|
||||
<module>xgboost4j-spark</module>
|
||||
<module>xgboost4j-flink</module>
|
||||
</modules>
|
||||
<build>
|
||||
@ -118,6 +119,19 @@
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>2.19.1</version>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest-maven-plugin</artifactId>
|
||||
<version>1.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test</id>
|
||||
<goals>
|
||||
<goal>test</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
@ -150,7 +164,7 @@
|
||||
<dependency>
|
||||
<groupId>com.typesafe</groupId>
|
||||
<artifactId>config</artifactId>
|
||||
<version>1.3.0</version>
|
||||
<version>1.2.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -16,17 +16,28 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.util.{Iterator => JIterator}
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.DataBatch
|
||||
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
private[spark] object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
java.util.Iterator[LabeledPoint] = {
|
||||
(for (p <- sps) yield {
|
||||
p.features match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||
case sparseFeature: SparseVector =>
|
||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}).asJava
|
||||
}
|
||||
|
||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
||||
}
|
||||
@ -37,38 +48,4 @@ private[spark] object DataUtils extends Serializable {
|
||||
case sparseFeature: SparseVector =>
|
||||
fetchUpdateFromSparseVector(sparseFeature)
|
||||
}
|
||||
|
||||
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
|
||||
// TODO: support weight
|
||||
var samplePos = 0
|
||||
// TODO: change hard value
|
||||
val loadingBatchSize = 100
|
||||
val rowOffset = new ListBuffer[Long]
|
||||
val label = new ListBuffer[Float]
|
||||
val featureIndices = new ListBuffer[Int]
|
||||
val featureValues = new ListBuffer[Float]
|
||||
val dataBatches = new ListBuffer[DataBatch]
|
||||
for (point <- points) {
|
||||
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
|
||||
rowOffset(samplePos) = rowOffset.size
|
||||
label(samplePos) = point.label.toFloat
|
||||
for (i <- nonZeroIndices.indices) {
|
||||
featureIndices += nonZeroIndices(i)
|
||||
featureValues += nonZeroValues(i)
|
||||
}
|
||||
samplePos += 1
|
||||
if (samplePos % loadingBatchSize == 0) {
|
||||
// create a data batch
|
||||
dataBatches += new DataBatch(
|
||||
rowOffset.toArray.clone(),
|
||||
null, label.toArray.clone(), featureIndices.toArray.clone(),
|
||||
featureValues.toArray.clone())
|
||||
rowOffset.clear()
|
||||
label.clear()
|
||||
featureIndices.clear()
|
||||
featureValues.clear()
|
||||
}
|
||||
}
|
||||
dataBatches.iterator.asJava
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,41 +17,64 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.{TaskContext, SparkContext}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
object XGBoost {
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
private var _sc: Option[SparkContext] = None
|
||||
object XGBoost extends Serializable {
|
||||
|
||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
|
||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): XGBoostModel = {
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, AnyRef],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val sc = trainingData.sparkContext
|
||||
val dataUtilsBroadcast = sc.broadcast(DataUtils)
|
||||
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
if (tracker.start()) {
|
||||
trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
Rabit.init(new java.util.HashMap[String, String]() {
|
||||
put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
})
|
||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
||||
watches = new HashMap[String, DMatrix], obj, eval)
|
||||
Rabit.shutdown()
|
||||
Iterator(booster)
|
||||
}.cache()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||
import DataUtils._
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
val boosters = trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
|
||||
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
||||
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
|
||||
}.cache()
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
// TODO: how to choose best model
|
||||
boosters.first()
|
||||
val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
if (tracker.start()) {
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
|
||||
obj, eval)
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
tracker.waitFor()
|
||||
// TODO: how to choose best model
|
||||
Some(boosters.first())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,22 +16,25 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class XGBoostModel(booster: Booster) extends Serializable {
|
||||
|
||||
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
|
||||
def predict(testSet: RDD[SparkLabeledPoint]): RDD[Array[Array[Float]]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
|
||||
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
}
|
||||
}
|
||||
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet)
|
||||
}
|
||||
}
|
||||
|
||||
1611
jvm-packages/xgboost4j-spark/src/test/resources/agaricus.txt.test
Normal file
1611
jvm-packages/xgboost4j-spark/src/test/resources/agaricus.txt.test
Normal file
File diff suppressed because it is too large
Load Diff
6513
jvm-packages/xgboost4j-spark/src/test/resources/agaricus.txt.train
Normal file
6513
jvm-packages/xgboost4j-spark/src/test/resources/agaricus.txt.train
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,142 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
import scala.tools.reflect.Eval
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
|
||||
private var sc: SparkContext = null
|
||||
private val numWorker = 4
|
||||
|
||||
private class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
// build SparkContext
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(sparkConf)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
|
||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toInt
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
LabeledPoint(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def buildRDD(filePath: String): RDD[LabeledPoint] = {
|
||||
val sampleList = readFile(filePath)
|
||||
sc.parallelize(sampleList, numWorker)
|
||||
}
|
||||
|
||||
private def buildTrainingRDD(): RDD[LabeledPoint] = {
|
||||
val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile)
|
||||
trainRDD
|
||||
}
|
||||
|
||||
test("build RDD containing boosters") {
|
||||
val trainingRDD = buildTrainingRDD()
|
||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
import DataUtils._
|
||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
numWorker, 2, null, null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === numWorker)
|
||||
val boosters = boosterRDD.collect()
|
||||
for (booster <- boosters) {
|
||||
val predicts = booster.predict(testSetDMatrix, true)
|
||||
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -29,19 +29,6 @@
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest-maven-plugin</artifactId>
|
||||
<version>1.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test</id>
|
||||
<goals>
|
||||
<goal>test</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user