diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 99da3dee2..8b0d0a71e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -19,14 +19,14 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.immutable.HashMap import com.typesafe.config.Config -import org.apache.spark.SparkContext +import org.apache.spark.{TaskContext, SparkContext} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} -object XGBoost { +object XGBoost extends Serializable { implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { new XGBoostModel(booster) @@ -38,28 +38,43 @@ object XGBoost { numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ val sc = trainingData.sparkContext - val dataUtilsBroadcast = sc.broadcast(DataUtils) - trainingData.repartition(numWorkers).mapPartitions { - trainingSamples => - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - Iterator(SXGBoost.train(xgBoostConfMap, dMatrix, round, - watches = new HashMap[String, DMatrix], obj, eval)) - }.cache() + val tracker = new RabitTracker(numWorkers) + if (tracker.start()) { + trainingData.repartition(numWorkers).mapPartitions { + trainingSamples => + Rabit.init(new java.util.HashMap[String, String]() { + put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + }) + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) + val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, + watches = new HashMap[String, DMatrix], obj, eval) + Rabit.shutdown() + Iterator(booster) + }.cache() + } else { + null + } } def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, - eval: EvalTrait = null): XGBoostModel = { + eval: EvalTrait = null): Option[XGBoostModel] = { import DataUtils._ val numWorkers = config.getInt("numWorkers") val round = config.getInt("round") val sc = trainingData.sparkContext - // TODO: build configuration map from config - val xgBoostConfigMap = new HashMap[String, AnyRef]() - val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, - obj, eval) - // force the job - sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) - // TODO: how to choose best model - boosters.first() + val tracker = new RabitTracker(numWorkers) + if (tracker.start()) { + // TODO: build configuration map from config + val xgBoostConfigMap = new HashMap[String, AnyRef]() + val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, + obj, eval) + // force the job + sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) + tracker.waitFor() + // TODO: how to choose best model + Some(boosters.first()) + } else { + None + } } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index d09e43969..20ad68cfe 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -33,4 +33,8 @@ class XGBoostModel(booster: Booster) extends Serializable { Iterator(broadcastBooster.value.predict(dMatrix)) } } + + def predict(testSet: DMatrix): Array[Array[Float]] = { + booster.predict(testSet) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 98946ee63..23c9924d1 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -20,7 +20,11 @@ import java.io.File import scala.collection.mutable.ListBuffer import scala.io.Source +import scala.tools.reflect.Eval +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError} +import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import org.apache.commons.logging.LogFactory import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -32,6 +36,48 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { private var sc: SparkContext = null private val numWorker = 4 + private class EvalError extends EvalTrait { + + val logger = LogFactory.getLog(classOf[EvalError]) + + private[xgboost4j] var evalMetric: String = "custom_error" + + /** + * get evaluate metric + * + * @return evalMetric + */ + override def getMetric: String = evalMetric + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { + var error: Float = 0f + var labels: Array[Float] = null + try { + labels = dmat.getLabel + } catch { + case ex: XGBoostError => + logger.error(ex) + return -1f + } + val nrow: Int = predicts.length + for (i <- 0 until nrow) { + if (labels(i) == 0.0 && predicts(i)(0) > 0) { + error += 1 + } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) { + error += 1 + } + } + error / labels.length + } + } + override def beforeAll(): Unit = { // build SparkContext val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") @@ -56,28 +102,41 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { LabeledPoint(label, new DenseVector(denseFeature)) } - private def buildRDD(filePath: String): RDD[LabeledPoint] = { + private def readFile(filePath: String): List[LabeledPoint] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[LabeledPoint] for (sample <- file.getLines()) { sampleList += fromSVMStringToLabeledPoint(sample) } + sampleList.toList + } + + private def buildRDD(filePath: String): RDD[LabeledPoint] = { + val sampleList = readFile(filePath) sc.parallelize(sampleList, numWorker) } - private def buildTrainingAndTestRDD(): (RDD[LabeledPoint], RDD[LabeledPoint]) = { + private def buildTrainingRDD(): RDD[LabeledPoint] = { val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile) - val testRDD = buildRDD(getClass.getResource("/agaricus.txt.test").getFile) - (trainRDD, testRDD) + trainRDD } test("build RDD containing boosters") { - val (trainingRDD, testRDD) = buildTrainingAndTestRDD() + val trainingRDD = buildTrainingRDD() + val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator + import DataUtils._ + val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) val boosterRDD = XGBoost.buildDistributedBoosters( trainingRDD, - Map[String, AnyRef](), - numWorker, 4, null, null) + List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap, + numWorker, 2, null, null) val boosterCount = boosterRDD.count() assert(boosterCount === numWorker) + val boosters = boosterRDD.collect() + for (booster <- boosters) { + val predicts = booster.predict(testSetDMatrix, true) + assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1) + } } }