[jvm-packages] Exposed baseMargin (#2450)
* Disabled excessive Spark logging in tests * Fixed a singature of XGBoostModel.predict Prior to this commit XGBoostModel.predict produced an RDD with an array of predictions for each partition, effectively changing the shape wrt the input RDD. A more natural contract for prediction API is that given an RDD it returns a new RDD with the same number of elements. This allows the users to easily match inputs with predictions. This commit removes one layer of nesting in XGBoostModel.predict output. Even though the change is clearly non-backward compatible, I still think it is well justified. * Removed boxing in XGBoost.fromDenseToSparseLabeledPoints * Inlined XGBoost.repartitionData An if is more explicit than an opaque method name. * Moved XGBoost.convertBoosterToXGBoostModel to XGBoostModel * Check the input dimension in DMatrix.setBaseMargin Prior to this commit providing an array of incorrect dimensions would have resulted in memory corruption. Maybe backport this to C++? * Reduced nesting in XGBoost.buildDistributedBoosters * Ensured consistent naming of the params map * Cleaned up DataBatch to make it easier to comprehend * Made scalastyle happy * Added baseMargin to XGBoost.train and trainWithRDD * Deprecated XGBoost.train It is ambiguous and work only for RDDs. * Addressed review comments * Revert "Fixed a singature of XGBoostModel.predict" This reverts commit 06bd5dcae7780265dd57e93ed7d4135f4e78f9b4. * Addressed more review comments * Fixed NullPointerException in buildDistributedBoosters
This commit is contained in:
parent
6b287177c8
commit
d535340459
@ -49,7 +49,7 @@ object SparkWithRDD {
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
useExternalMemory = true)
|
||||
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||
// save model to HDFS path
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
@ -30,7 +29,6 @@ import org.apache.spark.ml.linalg.SparseVector
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS}
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
@ -53,97 +51,86 @@ case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private def convertBoosterToXGBoostModel(booster: Booster, isClassification: Boolean):
|
||||
XGBoostModel = {
|
||||
if (!isClassification) {
|
||||
new XGBoostRegressionModel(booster)
|
||||
} else {
|
||||
new XGBoostClassificationModel(booster)
|
||||
}
|
||||
}
|
||||
|
||||
private def fromDenseToSparseLabeledPoints(
|
||||
denseLabeledPoints: Iterator[MLLabeledPoint],
|
||||
missing: Float): Iterator[MLLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
val sparseLabeledPoints = new ListBuffer[MLLabeledPoint]
|
||||
for (labelPoint <- denseLabeledPoints) {
|
||||
val dVector = labelPoint.features.toDense
|
||||
val indices = new ListBuffer[Int]
|
||||
val values = new ListBuffer[Double]
|
||||
for (i <- dVector.values.indices) {
|
||||
if (dVector.values(i) != missing) {
|
||||
denseLabeledPoints.map { case MLLabeledPoint(label, features) =>
|
||||
val dFeatures = features.toDense
|
||||
val indices = new mutable.ArrayBuilder.ofInt()
|
||||
val values = new mutable.ArrayBuilder.ofDouble()
|
||||
for (i <- dFeatures.values.indices) {
|
||||
if (dFeatures.values(i) != missing) {
|
||||
indices += i
|
||||
values += dVector.values(i)
|
||||
values += dFeatures.values(i)
|
||||
}
|
||||
}
|
||||
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
|
||||
values.toArray)
|
||||
sparseLabeledPoints += MLLabeledPoint(labelPoint.label, sparseVector)
|
||||
val sFeatures = new SparseVector(dFeatures.values.length, indices.result(),
|
||||
values.result())
|
||||
MLLabeledPoint(label, sFeatures)
|
||||
}
|
||||
sparseLabeledPoints.iterator
|
||||
} else {
|
||||
denseLabeledPoints
|
||||
}
|
||||
}
|
||||
|
||||
private def repartitionData(trainingData: RDD[MLLabeledPoint], numWorkers: Int):
|
||||
RDD[MLLabeledPoint] = {
|
||||
if (numWorkers != trainingData.partitions.length) {
|
||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||
trainingData.repartition(numWorkers)
|
||||
} else {
|
||||
trainingData
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingSet: RDD[MLLabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
numWorkers: Int,
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
useExternalMemory: Boolean,
|
||||
missing: Float,
|
||||
baseMargin: RDD[Float]): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val partitionedTrainingSet = repartitionData(trainingSet, numWorkers)
|
||||
|
||||
val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) {
|
||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||
trainingSet.repartition(numWorkers)
|
||||
} else {
|
||||
trainingSet
|
||||
}
|
||||
val partitionedBaseMargin = Option(baseMargin)
|
||||
.getOrElse(trainingSet.sparkContext.emptyRDD)
|
||||
.repartition(partitionedTrainingSet.getNumPartitions)
|
||||
val appName = partitionedTrainingSet.context.appName
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
partitionedTrainingSet.mapPartitions {
|
||||
trainingSamples =>
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv)
|
||||
var booster: Booster = null
|
||||
if (trainingSamples.hasNext) {
|
||||
val cacheFileName: String = {
|
||||
if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-" +
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||
try {
|
||||
if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) {
|
||||
trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||
watches = new mutable.HashMap[String, DMatrix] {
|
||||
put("train", trainingSet)
|
||||
}.toMap, obj, eval)
|
||||
Rabit.shutdown()
|
||||
} finally {
|
||||
trainingSet.delete()
|
||||
}
|
||||
} else {
|
||||
Rabit.shutdown()
|
||||
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId().toString}")
|
||||
partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingSamples, baseMargin) =>
|
||||
if (trainingSamples.isEmpty) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val cacheFileName = if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-" +
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv)
|
||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||
val trainingMatrix = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||
try {
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
if (baseMargin.nonEmpty) {
|
||||
trainingMatrix.setBaseMargin(baseMargin.toArray)
|
||||
}
|
||||
val booster = SXGBoost.train(trainingMatrix, params, round,
|
||||
watches = Map("train" -> trainingMatrix), obj, eval)
|
||||
Iterator(booster)
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
trainingMatrix.delete()
|
||||
}
|
||||
}.cache()
|
||||
}
|
||||
|
||||
@ -191,8 +178,8 @@ object XGBoost extends Serializable {
|
||||
fit(trainingData)
|
||||
}
|
||||
|
||||
private[spark] def isClassificationTask(paramsMap: Map[String, Any]): Boolean = {
|
||||
val objective = paramsMap.getOrElse("objective", paramsMap.getOrElse("obj_type", null))
|
||||
private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = {
|
||||
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
|
||||
objective != null && {
|
||||
val objStr = objective.toString
|
||||
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
@ -212,18 +199,26 @@ object XGBoost extends Serializable {
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @param baseMargin initial prediction for boosting.
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@deprecated("Use XGBoost.trainWithRDD instead.")
|
||||
def train(
|
||||
trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN,
|
||||
baseMargin: RDD[Float] = null): XGBoostModel = {
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
|
||||
missing, baseMargin)
|
||||
}
|
||||
|
||||
private def overrideParamMapAccordingtoTaskCPUs(
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt
|
||||
@ -262,14 +257,21 @@ object XGBoost extends Serializable {
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @param baseMargin initial prediction for boosting.
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithRDD(
|
||||
trainingData: RDD[MLLabeledPoint], params: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN,
|
||||
baseMargin: RDD[Float] = null): XGBoostModel = {
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
@ -288,9 +290,10 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val overridedConfMap = overrideParamMapAccordingtoTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
|
||||
val boosters = buildDistributedBoosters(trainingData, overriddenParams,
|
||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing,
|
||||
baseMargin)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
@ -302,7 +305,7 @@ object XGBoost extends Serializable {
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overridedConfMap, sparkJobThread,
|
||||
postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, sparkJobThread,
|
||||
isClsTask)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
@ -311,11 +314,10 @@ object XGBoost extends Serializable {
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
|
||||
configMap: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
||||
params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
||||
XGBoostModel = {
|
||||
if (trackerReturnVal == 0) {
|
||||
val xgboostModel = convertBoosterToXGBoostModel(distributedBoosters.first(),
|
||||
isClassificationTask)
|
||||
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
||||
distributedBoosters.unpersist(false)
|
||||
xgboostModel
|
||||
} else {
|
||||
|
||||
@ -125,16 +125,15 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
case (null, _) => {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Rabit.shutdown()
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
case _ => {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
@ -170,10 +169,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
}
|
||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||
try {
|
||||
val res = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(res)
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
}
|
||||
@ -185,13 +183,16 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
* @param useExternalCache whether to use external cache for the test set
|
||||
* @param outputMargin whether to output raw untransformed margin value
|
||||
*/
|
||||
def predict(testSet: RDD[MLVector], useExternalCache: Boolean = false):
|
||||
RDD[Array[Array[Float]]] = {
|
||||
def predict(
|
||||
testSet: RDD[MLVector],
|
||||
useExternalCache: Boolean = false,
|
||||
outputMargin: Boolean = false): RDD[Array[Array[Float]]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
val appName = testSet.context.appName
|
||||
testSet.mapPartitions { testSamples =>
|
||||
if (testSamples.hasNext) {
|
||||
if (testSamples.nonEmpty) {
|
||||
import DataUtils._
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
@ -204,10 +205,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
}
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||
try {
|
||||
val res = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(res)
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
@ -334,6 +334,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
}
|
||||
|
||||
object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||
private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = {
|
||||
if (!isClassification) {
|
||||
new XGBoostRegressionModel(booster)
|
||||
} else {
|
||||
new XGBoostClassificationModel(booster)
|
||||
}
|
||||
}
|
||||
|
||||
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
log4j.logger.org.apache.spark=ERROR
|
||||
@ -22,19 +22,22 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
|
||||
trait SharedSparkContext extends FunSuite with BeforeAndAfter with BeforeAndAfterAll
|
||||
with Serializable {
|
||||
|
||||
@transient protected implicit var sc: SparkContext = null
|
||||
@transient protected implicit var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
// build SparkContext
|
||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||
set("spark.driver.memory", "512m")
|
||||
val sparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
.set("spark.driver.memory", "512m")
|
||||
.set("spark.ui.enabled", "false")
|
||||
|
||||
sc = new SparkContext(sparkConf)
|
||||
sc.setLogLevel("ERROR")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
sc = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,17 +17,15 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
import scala.util.Random
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import org.scalatest.Ignore
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
@ -83,7 +81,8 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new java.util.HashMap[String, String](),
|
||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true,
|
||||
missing = Float.NaN, baseMargin = null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
cleanExternalCache("XGBoostSuite")
|
||||
@ -390,4 +389,30 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
}
|
||||
|
||||
test("test use base margin") {
|
||||
val trainSet = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
|
||||
val trainRDD = sc.parallelize(trainSet, numSlices = 1)
|
||||
|
||||
val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise")
|
||||
|
||||
val trainMargin = {
|
||||
XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2)
|
||||
.predict(trainRDD.map(_.features), outputMargin = true)
|
||||
.flatMap { _.flatten.iterator }
|
||||
}
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(
|
||||
trainRDD,
|
||||
paramMap,
|
||||
round = 1,
|
||||
nWorkers = 2,
|
||||
baseMargin = trainMargin)
|
||||
|
||||
assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length)
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,26 +171,26 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
* Set base margin (initial prediction).
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
* @throws XGBoostError native error
|
||||
* The margin must have the same number of elements as the number of
|
||||
* rows in this matrix.
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
if (baseMargin.length != rowNum()) {
|
||||
throw new IllegalArgumentException(String.format(
|
||||
"base margin must have exactly %s elements, got %s",
|
||||
rowNum(), baseMargin.length));
|
||||
}
|
||||
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
* @throws XGBoostError native error
|
||||
* Set base margin (initial prediction).
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
float[] flattenMargin = flatten(baseMargin);
|
||||
setBaseMargin(flattenMargin);
|
||||
setBaseMargin(flatten(baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -236,10 +236,7 @@ public class DMatrix {
|
||||
}
|
||||
|
||||
/**
|
||||
* get base margin of the DMatrix
|
||||
*
|
||||
* @return base margin
|
||||
* @throws XGBoostError native error
|
||||
* Get base margin of the DMatrix.
|
||||
*/
|
||||
public float[] getBaseMargin() throws XGBoostError {
|
||||
return getFloatInfo("base_margin");
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
|
||||
@ -13,20 +14,18 @@ import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
*/
|
||||
class DataBatch {
|
||||
/** The offset of each rows in the sparse matrix */
|
||||
long[] rowOffset = null;
|
||||
final long[] rowOffset;
|
||||
/** weight of each data point, can be null */
|
||||
float[] weight = null;
|
||||
final float[] weight;
|
||||
/** label of each data point, can be null */
|
||||
float[] label = null;
|
||||
final float[] label;
|
||||
/** index of each feature(column) in the sparse matrix */
|
||||
int[] featureIndex = null;
|
||||
final int[] featureIndex;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
float[] featureValue = null;
|
||||
final float[] featureValue ;
|
||||
|
||||
public DataBatch() {}
|
||||
|
||||
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
this.rowOffset = rowOffset;
|
||||
this.weight = weight;
|
||||
this.label = label;
|
||||
@ -34,80 +33,62 @@ class DataBatch {
|
||||
this.featureValue = featureValue;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get number of rows in the data batch.
|
||||
* @return Number of rows in the data batch.
|
||||
*/
|
||||
public int numRows() {
|
||||
return rowOffset.length - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow copy a DataBatch
|
||||
* @return a copy of the batch
|
||||
*/
|
||||
public DataBatch shallowCopy() {
|
||||
DataBatch b = new DataBatch();
|
||||
b.rowOffset = this.rowOffset;
|
||||
b.weight = this.weight;
|
||||
b.label = this.label;
|
||||
b.featureIndex = this.featureIndex;
|
||||
b.featureValue = this.featureValue;
|
||||
return b;
|
||||
}
|
||||
|
||||
static class BatchIterator implements Iterator<DataBatch> {
|
||||
private Iterator<LabeledPoint> base;
|
||||
private int batchSize;
|
||||
private final Iterator<LabeledPoint> base;
|
||||
private final int batchSize;
|
||||
|
||||
BatchIterator(java.util.Iterator<LabeledPoint> base, int batchSize) {
|
||||
BatchIterator(Iterator<LabeledPoint> base, int batchSize) {
|
||||
this.base = base;
|
||||
this.batchSize = batchSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return base.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataBatch next() {
|
||||
int num_rows = 0, num_elem = 0;
|
||||
java.util.List<LabeledPoint> batch = new java.util.ArrayList<LabeledPoint>();
|
||||
for (int i = 0; i < this.batchSize; ++i) {
|
||||
if (!base.hasNext()) break;
|
||||
LabeledPoint inst = base.next();
|
||||
batch.add(inst);
|
||||
num_elem += inst.values.length;
|
||||
++num_rows;
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
while (base.hasNext() && batch.size() < batchSize) {
|
||||
LabeledPoint labeledPoint = base.next();
|
||||
batch.add(labeledPoint);
|
||||
numElem += labeledPoint.values.length;
|
||||
numRows++;
|
||||
}
|
||||
DataBatch ret = new DataBatch();
|
||||
// label
|
||||
ret.rowOffset = new long[num_rows + 1];
|
||||
ret.label = new float[num_rows];
|
||||
ret.featureIndex = new int[num_elem];
|
||||
ret.featureValue = new float[num_elem];
|
||||
// current offset
|
||||
|
||||
long[] rowOffset = new long[numRows + 1];
|
||||
float[] label = new float[numRows];
|
||||
int[] featureIndex = new int[numElem];
|
||||
float[] featureValue = new float[numElem];
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; i < batch.size(); ++i) {
|
||||
LabeledPoint inst = batch.get(i);
|
||||
ret.rowOffset[i] = offset;
|
||||
ret.label[i] = inst.label;
|
||||
if (inst.indices != null) {
|
||||
System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length);
|
||||
} else{
|
||||
for (int j = 0; j < inst.values.length; ++j) {
|
||||
ret.featureIndex[offset + j] = j;
|
||||
for (int i = 0; i < batch.size(); i++) {
|
||||
LabeledPoint labeledPoint = batch.get(i);
|
||||
rowOffset[i] = offset;
|
||||
label[i] = labeledPoint.label;
|
||||
if (labeledPoint.indices != null) {
|
||||
System.arraycopy(labeledPoint.indices, 0, featureIndex, offset,
|
||||
labeledPoint.indices.length);
|
||||
} else {
|
||||
for (int j = 0; j < labeledPoint.values.length; j++) {
|
||||
featureIndex[offset + j] = j;
|
||||
}
|
||||
}
|
||||
System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length);
|
||||
offset += inst.values.length;
|
||||
|
||||
System.arraycopy(labeledPoint.values, 0, featureValue, offset, labeledPoint.values.length);
|
||||
offset += labeledPoint.values.length;
|
||||
}
|
||||
ret.rowOffset[batch.size()] = offset;
|
||||
return ret;
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, null, label, featureIndex, featureValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new Error("not implemented");
|
||||
throw new UnsupportedOperationException("DataBatch.BatchIterator.remove");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user