diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala index 9c517da94..851cffea9 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala @@ -49,7 +49,7 @@ object SparkWithRDD { "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic").toMap - val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt, + val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true) xgboostModel.booster.predict(new DMatrix(testSet)) // save model to HDFS path diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 05cbee80d..a4eb4d81f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,7 +17,6 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable -import scala.collection.mutable.ListBuffer import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker @@ -30,7 +29,6 @@ import org.apache.spark.ml.linalg.SparseVector import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.{SparkContext, TaskContext} -import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS} object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") @@ -53,97 +51,86 @@ case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String) object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") - private def convertBoosterToXGBoostModel(booster: Booster, isClassification: Boolean): - XGBoostModel = { - if (!isClassification) { - new XGBoostRegressionModel(booster) - } else { - new XGBoostClassificationModel(booster) - } - } - private def fromDenseToSparseLabeledPoints( denseLabeledPoints: Iterator[MLLabeledPoint], missing: Float): Iterator[MLLabeledPoint] = { if (!missing.isNaN) { - val sparseLabeledPoints = new ListBuffer[MLLabeledPoint] - for (labelPoint <- denseLabeledPoints) { - val dVector = labelPoint.features.toDense - val indices = new ListBuffer[Int] - val values = new ListBuffer[Double] - for (i <- dVector.values.indices) { - if (dVector.values(i) != missing) { + denseLabeledPoints.map { case MLLabeledPoint(label, features) => + val dFeatures = features.toDense + val indices = new mutable.ArrayBuilder.ofInt() + val values = new mutable.ArrayBuilder.ofDouble() + for (i <- dFeatures.values.indices) { + if (dFeatures.values(i) != missing) { indices += i - values += dVector.values(i) + values += dFeatures.values(i) } } - val sparseVector = new SparseVector(dVector.values.length, indices.toArray, - values.toArray) - sparseLabeledPoints += MLLabeledPoint(labelPoint.label, sparseVector) + val sFeatures = new SparseVector(dFeatures.values.length, indices.result(), + values.result()) + MLLabeledPoint(label, sFeatures) } - sparseLabeledPoints.iterator } else { denseLabeledPoints } } - private def repartitionData(trainingData: RDD[MLLabeledPoint], numWorkers: Int): - RDD[MLLabeledPoint] = { - if (numWorkers != trainingData.partitions.length) { - logger.info(s"repartitioning training set to $numWorkers partitions") - trainingData.repartition(numWorkers) - } else { - trainingData - } - } - private[spark] def buildDistributedBoosters( trainingSet: RDD[MLLabeledPoint], - xgBoostConfMap: Map[String, Any], + params: Map[String, Any], rabitEnv: java.util.Map[String, String], - numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, - useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = { + numWorkers: Int, + round: Int, + obj: ObjectiveTrait, + eval: EvalTrait, + useExternalMemory: Boolean, + missing: Float, + baseMargin: RDD[Float]): RDD[Booster] = { import DataUtils._ - val partitionedTrainingSet = repartitionData(trainingSet, numWorkers) + + val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) { + logger.info(s"repartitioning training set to $numWorkers partitions") + trainingSet.repartition(numWorkers) + } else { + trainingSet + } + val partitionedBaseMargin = Option(baseMargin) + .getOrElse(trainingSet.sparkContext.emptyRDD) + .repartition(partitionedTrainingSet.getNumPartitions) val appName = partitionedTrainingSet.context.appName // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) - partitionedTrainingSet.mapPartitions { - trainingSamples => - rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv) - var booster: Booster = null - if (trainingSamples.hasNext) { - val cacheFileName: String = { - if (useExternalMemory) { - s"$appName-${TaskContext.get().stageId()}-" + - s"dtrain_cache-${TaskContext.getPartitionId()}" - } else { - null - } - } - val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) - val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) - try { - if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) { - trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]]( - TaskContext.getPartitionId()).toArray) - } - booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, - watches = new mutable.HashMap[String, DMatrix] { - put("train", trainingSet) - }.toMap, obj, eval) - Rabit.shutdown() - } finally { - trainingSet.delete() - } - } else { - Rabit.shutdown() - throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" + - s" ${TaskContext.getPartitionId().toString}") + partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingSamples, baseMargin) => + if (trainingSamples.isEmpty) { + throw new XGBoostError( + s"detected an empty partition in the training data, partition ID:" + + s" ${TaskContext.getPartitionId()}") + } + val cacheFileName = if (useExternalMemory) { + s"$appName-${TaskContext.get().stageId()}-" + + s"dtrain_cache-${TaskContext.getPartitionId()}" + } else { + null + } + rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv) + val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) + val trainingMatrix = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) + try { + if (params.contains("groupData") && params("groupData") != null) { + trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( + TaskContext.getPartitionId()).toArray) } + if (baseMargin.nonEmpty) { + trainingMatrix.setBaseMargin(baseMargin.toArray) + } + val booster = SXGBoost.train(trainingMatrix, params, round, + watches = Map("train" -> trainingMatrix), obj, eval) Iterator(booster) + } finally { + Rabit.shutdown() + trainingMatrix.delete() + } }.cache() } @@ -191,8 +178,8 @@ object XGBoost extends Serializable { fit(trainingData) } - private[spark] def isClassificationTask(paramsMap: Map[String, Any]): Boolean = { - val objective = paramsMap.getOrElse("objective", paramsMap.getOrElse("obj_type", null)) + private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = { + val objective = params.getOrElse("objective", params.getOrElse("obj_type", null)) objective != null && { val objStr = objective.toString objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" && @@ -212,18 +199,26 @@ object XGBoost extends Serializable { * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark * @param missing the value represented the missing value in the dataset + * @param baseMargin initial prediction for boosting. * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed * @return XGBoostModel when successful training */ + @deprecated("Use XGBoost.trainWithRDD instead.") def train( - trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int, - nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { - require(nWorkers > 0, "you must specify more than 0 workers") - trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing) + trainingData: RDD[MLLabeledPoint], + params: Map[String, Any], + round: Int, + nWorkers: Int, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, + useExternalMemory: Boolean = false, + missing: Float = Float.NaN, + baseMargin: RDD[Float] = null): XGBoostModel = { + trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, + missing, baseMargin) } - private def overrideParamMapAccordingtoTaskCPUs( + private def overrideParamsAccordingToTaskCPUs( params: Map[String, Any], sc: SparkContext): Map[String, Any] = { val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt @@ -262,14 +257,21 @@ object XGBoost extends Serializable { * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark * @param missing the value represented the missing value in the dataset + * @param baseMargin initial prediction for boosting. * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed * @return XGBoostModel when successful training */ @throws(classOf[XGBoostError]) def trainWithRDD( - trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int, - nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { + trainingData: RDD[MLLabeledPoint], + params: Map[String, Any], + round: Int, + nWorkers: Int, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, + useExternalMemory: Boolean = false, + missing: Float = Float.NaN, + baseMargin: RDD[Float] = null): XGBoostModel = { if (params.contains("tree_method")) { require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + " for now") @@ -288,9 +290,10 @@ object XGBoost extends Serializable { } val tracker = startTracker(nWorkers, trackerConf) try { - val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext) - val boosters = buildDistributedBoosters(trainingData, overridedConfMap, - tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) + val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext) + val boosters = buildDistributedBoosters(trainingData, overriddenParams, + tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing, + baseMargin) val sparkJobThread = new Thread() { override def run() { // force the job @@ -302,7 +305,7 @@ object XGBoost extends Serializable { val isClsTask = isClassificationTask(params) val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") - postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread, + postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread, isClsTask) } finally { tracker.stop() @@ -311,11 +314,10 @@ object XGBoost extends Serializable { private def postTrackerReturnProcessing( trackerReturnVal: Int, distributedBoosters: RDD[Booster], - configMap: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean): + params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean): XGBoostModel = { if (trackerReturnVal == 0) { - val xgboostModel = convertBoosterToXGBoostModel(distributedBoosters.first(), - isClassificationTask) + val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask) distributedBoosters.unpersist(false) xgboostModel } else { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index b4d405364..a2ea44443 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -125,16 +125,15 @@ abstract class XGBoostModel(protected var _booster: Booster) case (null, _) => { val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) val Array(evName, predNumeric) = predStr.split(":") - Rabit.shutdown() Iterator(Some(evName, predNumeric.toFloat)) } case _ => { val predictions = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) } } } finally { + Rabit.shutdown() dMatrix.delete() } } else { @@ -170,10 +169,9 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) try { - val res = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(res) + Iterator(broadcastBooster.value.predict(dMatrix)) } finally { + Rabit.shutdown() dMatrix.delete() } } @@ -185,13 +183,16 @@ abstract class XGBoostModel(protected var _booster: Booster) * * @param testSet test set represented as RDD * @param useExternalCache whether to use external cache for the test set + * @param outputMargin whether to output raw untransformed margin value */ - def predict(testSet: RDD[MLVector], useExternalCache: Boolean = false): - RDD[Array[Array[Float]]] = { + def predict( + testSet: RDD[MLVector], + useExternalCache: Boolean = false, + outputMargin: Boolean = false): RDD[Array[Array[Float]]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) val appName = testSet.context.appName testSet.mapPartitions { testSamples => - if (testSamples.hasNext) { + if (testSamples.nonEmpty) { import DataUtils._ val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap Rabit.init(rabitEnv.asJava) @@ -204,10 +205,9 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) try { - val res = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(res) + Iterator(broadcastBooster.value.predict(dMatrix)) } finally { + Rabit.shutdown() dMatrix.delete() } } else { @@ -334,6 +334,13 @@ abstract class XGBoostModel(protected var _booster: Booster) } object XGBoostModel extends MLReadable[XGBoostModel] { + private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = { + if (!isClassification) { + new XGBoostRegressionModel(booster) + } else { + new XGBoostClassificationModel(booster) + } + } override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties b/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties new file mode 100644 index 000000000..dcd02d2c8 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/log4j.properties @@ -0,0 +1 @@ +log4j.logger.org.apache.spark=ERROR \ No newline at end of file diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala index 94ad1bf04..5e46966e1 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala @@ -22,19 +22,22 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} trait SharedSparkContext extends FunSuite with BeforeAndAfter with BeforeAndAfterAll with Serializable { - @transient protected implicit var sc: SparkContext = null + @transient protected implicit var sc: SparkContext = _ override def beforeAll() { - // build SparkContext - val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite"). - set("spark.driver.memory", "512m") + val sparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("XGBoostSuite") + .set("spark.driver.memory", "512m") + .set("spark.ui.enabled", "false") + sc = new SparkContext(sparkConf) - sc.setLogLevel("ERROR") } override def afterAll() { if (sc != null) { sc.stop() + sc = null } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 0f2ed94ef..d4007401b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -17,17 +17,15 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files -import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.LinkedBlockingDeque import scala.collection.mutable.ListBuffer import scala.io.Source import scala.util.Random -import scala.concurrent.duration._ -import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker -import org.scalatest.Ignore import org.apache.spark.SparkContext import org.apache.spark.ml.feature.LabeledPoint @@ -83,7 +81,8 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap, new java.util.HashMap[String, String](), - numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true) + numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true, + missing = Float.NaN, baseMargin = null) val boosterCount = boosterRDD.count() assert(boosterCount === 2) cleanExternalCache("XGBoostSuite") @@ -390,4 +389,30 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val predResult1: Array[Array[Float]] = predRDD.collect()(0) assert(testRDD.count() === predResult1.length) } + + test("test use base margin") { + val trainSet = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile) + val trainRDD = sc.parallelize(trainSet, numSlices = 1) + + val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile) + val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "rank:pairwise") + + val trainMargin = { + XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2) + .predict(trainRDD.map(_.features), outputMargin = true) + .flatMap { _.flatten.iterator } + } + + val xgBoostModel = XGBoost.trainWithRDD( + trainRDD, + paramMap, + round = 1, + nWorkers = 2, + baseMargin = trainMargin) + + assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length) + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index e0fc1247c..4d8a49dba 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -171,26 +171,26 @@ public class DMatrix { } /** - * if specified, xgboost will start from this init margin - * can be used to specify initial prediction to boost from + * Set base margin (initial prediction). * - * @param baseMargin base margin - * @throws XGBoostError native error + * The margin must have the same number of elements as the number of + * rows in this matrix. */ public void setBaseMargin(float[] baseMargin) throws XGBoostError { + if (baseMargin.length != rowNum()) { + throw new IllegalArgumentException(String.format( + "base margin must have exactly %s elements, got %s", + rowNum(), baseMargin.length)); + } + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); } /** - * if specified, xgboost will start from this init margin - * can be used to specify initial prediction to boost from - * - * @param baseMargin base margin - * @throws XGBoostError native error + * Set base margin (initial prediction). */ public void setBaseMargin(float[][] baseMargin) throws XGBoostError { - float[] flattenMargin = flatten(baseMargin); - setBaseMargin(flattenMargin); + setBaseMargin(flatten(baseMargin)); } /** @@ -236,10 +236,7 @@ public class DMatrix { } /** - * get base margin of the DMatrix - * - * @return base margin - * @throws XGBoostError native error + * Get base margin of the DMatrix. */ public float[] getBaseMargin() throws XGBoostError { return getFloatInfo("base_margin"); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index d2ff3b612..09c74b2d3 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -1,7 +1,8 @@ package ml.dmlc.xgboost4j.java; -import java.io.Serializable; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import ml.dmlc.xgboost4j.LabeledPoint; @@ -13,20 +14,18 @@ import ml.dmlc.xgboost4j.LabeledPoint; */ class DataBatch { /** The offset of each rows in the sparse matrix */ - long[] rowOffset = null; + final long[] rowOffset; /** weight of each data point, can be null */ - float[] weight = null; + final float[] weight; /** label of each data point, can be null */ - float[] label = null; + final float[] label; /** index of each feature(column) in the sparse matrix */ - int[] featureIndex = null; + final int[] featureIndex; /** value of each non-missing entry in the sparse matrix */ - float[] featureValue = null; + final float[] featureValue ; - public DataBatch() {} - - public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex, - float[] featureValue) { + DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex, + float[] featureValue) { this.rowOffset = rowOffset; this.weight = weight; this.label = label; @@ -34,80 +33,62 @@ class DataBatch { this.featureValue = featureValue; } - - /** - * Get number of rows in the data batch. - * @return Number of rows in the data batch. - */ - public int numRows() { - return rowOffset.length - 1; - } - - /** - * Shallow copy a DataBatch - * @return a copy of the batch - */ - public DataBatch shallowCopy() { - DataBatch b = new DataBatch(); - b.rowOffset = this.rowOffset; - b.weight = this.weight; - b.label = this.label; - b.featureIndex = this.featureIndex; - b.featureValue = this.featureValue; - return b; - } - static class BatchIterator implements Iterator { - private Iterator base; - private int batchSize; + private final Iterator base; + private final int batchSize; - BatchIterator(java.util.Iterator base, int batchSize) { + BatchIterator(Iterator base, int batchSize) { this.base = base; this.batchSize = batchSize; } + @Override public boolean hasNext() { return base.hasNext(); } + @Override public DataBatch next() { - int num_rows = 0, num_elem = 0; - java.util.List batch = new java.util.ArrayList(); - for (int i = 0; i < this.batchSize; ++i) { - if (!base.hasNext()) break; - LabeledPoint inst = base.next(); - batch.add(inst); - num_elem += inst.values.length; - ++num_rows; + int numRows = 0; + int numElem = 0; + List batch = new ArrayList<>(batchSize); + while (base.hasNext() && batch.size() < batchSize) { + LabeledPoint labeledPoint = base.next(); + batch.add(labeledPoint); + numElem += labeledPoint.values.length; + numRows++; } - DataBatch ret = new DataBatch(); - // label - ret.rowOffset = new long[num_rows + 1]; - ret.label = new float[num_rows]; - ret.featureIndex = new int[num_elem]; - ret.featureValue = new float[num_elem]; - // current offset + + long[] rowOffset = new long[numRows + 1]; + float[] label = new float[numRows]; + int[] featureIndex = new int[numElem]; + float[] featureValue = new float[numElem]; + int offset = 0; - for (int i = 0; i < batch.size(); ++i) { - LabeledPoint inst = batch.get(i); - ret.rowOffset[i] = offset; - ret.label[i] = inst.label; - if (inst.indices != null) { - System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length); - } else{ - for (int j = 0; j < inst.values.length; ++j) { - ret.featureIndex[offset + j] = j; + for (int i = 0; i < batch.size(); i++) { + LabeledPoint labeledPoint = batch.get(i); + rowOffset[i] = offset; + label[i] = labeledPoint.label; + if (labeledPoint.indices != null) { + System.arraycopy(labeledPoint.indices, 0, featureIndex, offset, + labeledPoint.indices.length); + } else { + for (int j = 0; j < labeledPoint.values.length; j++) { + featureIndex[offset + j] = j; } } - System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length); - offset += inst.values.length; + + System.arraycopy(labeledPoint.values, 0, featureValue, offset, labeledPoint.values.length); + offset += labeledPoint.values.length; } - ret.rowOffset[batch.size()] = offset; - return ret; + + rowOffset[batch.size()] = offset; + return new DataBatch(rowOffset, null, label, featureIndex, featureValue); } + @Override public void remove() { - throw new Error("not implemented"); + throw new UnsupportedOperationException("DataBatch.BatchIterator.remove"); } } }