framework of xgboost-spark

iterator

return java iterator and recover test
This commit is contained in:
CodingCat 2016-03-04 23:26:45 -05:00
parent 1540773340
commit b2d705ffb0
15 changed files with 194 additions and 156 deletions

@ -1 +1 @@
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0

View File

@ -20,7 +20,7 @@
<modules> <modules>
<module>xgboost4j</module> <module>xgboost4j</module>
<module>xgboost4j-demo</module> <module>xgboost4j-demo</module>
<module>xgboost4jspark</module> <module>xgboost4j-spark</module>
</modules> </modules>
<build> <build>
<plugins> <plugins>

View File

@ -17,7 +17,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId> <artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version> <version>1.6.0</version>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.util.{Iterator => JIterator}
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.DataBatch
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
private[spark] object DataUtils extends Serializable {
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
// TODO: support weight
var samplePos = 0
// TODO: change hard value
val loadingBatchSize = 100
val rowOffset = new ListBuffer[Long]
val label = new ListBuffer[Float]
val featureIndices = new ListBuffer[Int]
val featureValues = new ListBuffer[Float]
val dataBatches = new ListBuffer[DataBatch]
for (point <- points) {
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
rowOffset(samplePos) = rowOffset.size
label(samplePos) = point.label.toFloat
for (i <- nonZeroIndices.indices) {
featureIndices += nonZeroIndices(i)
featureValues += nonZeroValues(i)
}
samplePos += 1
if (samplePos % loadingBatchSize == 0) {
// create a data batch
dataBatches += new DataBatch(
rowOffset.toArray.clone(),
null, label.toArray.clone(), featureIndices.toArray.clone(),
featureValues.toArray.clone())
rowOffset.clear()
label.clear()
featureIndices.clear()
featureValues.clear()
}
}
dataBatches.iterator.asJava
}
}

View File

@ -0,0 +1,55 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.JavaConverters._
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
object XGBoost {
private var _sc: Option[SparkContext] = None
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}
// TODO: how to choose best model
boosters.first()
}
}

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
class XGBoostModel(booster: Booster) extends Serializable {
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples =>
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(broadcastBooster.value.predict(dMatrix))
}
}
}

View File

@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
*/ */
public class DMatrix { public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class); private static final Log logger = LogFactory.getLog(DMatrix.class);
private long handle = 0; protected long handle = 0;
//load native library //load native library
static { static {

View File

@ -4,8 +4,6 @@ package ml.dmlc.xgboost4j;
* A mini-batch of data that can be converted to DMatrix. * A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format. * The data is in sparse matrix CSR format.
* *
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch, * This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/ */
public class DataBatch { public class DataBatch {
@ -19,6 +17,19 @@ public class DataBatch {
int[] featureIndex = null; int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */ /** value of each non-missing entry in the sparse matrix */
float[] featureValue = null; float[] featureValue = null;
public DataBatch() {}
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
this.featureIndex = featureIndex;
this.featureValue = featureValue;
}
/** /**
* Get number of rows in the data batch. * Get number of rows in the data batch.
* @return Number of rows in the data batch. * @return Number of rows in the data batch.

View File

@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
} }
// making Booster serializable // making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) private void writeObject(java.io.ObjectOutputStream out) throws IOException {
throws IOException {
try { try {
out.writeObject(this.toByteArray()); out.writeObject(this.toByteArray());
} catch (XGBoostError ex) { } catch (XGBoostError ex) {

View File

@ -27,7 +27,8 @@ class XgboostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out); final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out); long[] out);

View File

@ -16,7 +16,9 @@
package ml.dmlc.xgboost4j.scala package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError} import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, DataBatch, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@ -43,6 +45,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(headers, indices, data, st)) this(new JDMatrix(headers, indices, data, st))
} }
private[xgboost4j] def this(dataBatch: DataBatch) {
this(new JDMatrix(List(dataBatch).asJava.iterator, null))
}
/** /**
* create DMatrix from dense matrix * create DMatrix from dense matrix
* *

View File

@ -1,32 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import java.io.DataInputStream
private[xgboost4j] object DMatrixBuilder extends Serializable {
def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
// TODO: currently it is random statement for making compiler happy
new DMatrix(new Array[Float](1), 1, 1)
}
def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
// TODO: currently it is random statement for making compiler happy
new DMatrix(new Array[Float](1), 1, 1)
}
}

View File

@ -1,47 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.rdd.RDD
class Boosters(boosters: RDD[Booster]) {
def save(path: String): Unit = {
}
def chooseBestBooster(boosters: RDD[Booster]): Booster = {
// TODO:
null
}
}
object Boosters {
implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
new Boosters(boosterRDD)
}
// load booster from path
def apply(path: String): RDD[Booster] = {
// TODO
null
}
}

View File

@ -1,66 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.mutable.ListBuffer
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
object XGBoost {
private var _sc: Option[SparkContext] = None
private def buildSparkContext(config: Config): SparkContext = {
if (_sc.isEmpty) {
// TODO:build SparkContext with the user configuration (cores per task, and cores per executor
// (or total cores)
// NOTE: currently Spark has limited support of configuration of core number in executors
}
_sc.get
}
def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
val sc = buildSparkContext(config)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
sc.binaryFiles(filePath, numWorkers).mapPartitions {
trainingFiles =>
val boosters = new ListBuffer[Booster]
// we assume one file per DMatrix
for ((_, fileInStream) <- trainingFiles) {
// TODO:
// step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
// from a fileInStream.open() (which returns a DataInputStream)
val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
// step2: build a Booster
// TODO: how to build watches list???
boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
}
// TODO
boosters.iterator
}
}
}

2
rabit

@ -1 +1 @@
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0