framework of xgboost-spark

iterator

return java iterator and recover test
This commit is contained in:
CodingCat
2016-03-04 23:26:45 -05:00
parent 1540773340
commit b2d705ffb0
15 changed files with 194 additions and 156 deletions

View File

@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4jspark</artifactId>
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.util.{Iterator => JIterator}
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.DataBatch
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
private[spark] object DataUtils extends Serializable {
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
// TODO: support weight
var samplePos = 0
// TODO: change hard value
val loadingBatchSize = 100
val rowOffset = new ListBuffer[Long]
val label = new ListBuffer[Float]
val featureIndices = new ListBuffer[Int]
val featureValues = new ListBuffer[Float]
val dataBatches = new ListBuffer[DataBatch]
for (point <- points) {
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
rowOffset(samplePos) = rowOffset.size
label(samplePos) = point.label.toFloat
for (i <- nonZeroIndices.indices) {
featureIndices += nonZeroIndices(i)
featureValues += nonZeroValues(i)
}
samplePos += 1
if (samplePos % loadingBatchSize == 0) {
// create a data batch
dataBatches += new DataBatch(
rowOffset.toArray.clone(),
null, label.toArray.clone(), featureIndices.toArray.clone(),
featureValues.toArray.clone())
rowOffset.clear()
label.clear()
featureIndices.clear()
featureValues.clear()
}
}
dataBatches.iterator.asJava
}
}

View File

@@ -0,0 +1,55 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.JavaConverters._
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
object XGBoost {
private var _sc: Option[SparkContext] = None
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}
// TODO: how to choose best model
boosters.first()
}
}

View File

@@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
class XGBoostModel(booster: Booster) extends Serializable {
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples =>
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(broadcastBooster.value.predict(dMatrix))
}
}
}