[jvm-packages] force use per-group weights in spark layer (#4118)

This commit is contained in:
Nan Zhu
2019-02-09 13:38:03 -08:00
committed by Jiaming Yuan
parent ba584e5e9f
commit 3320a52192
2 changed files with 56 additions and 15 deletions

View File

@@ -618,18 +618,34 @@ private object Watches {
val dms = nameAndlabeledPointGroupSets.map {
case (name, labeledPointsGroups) =>
val baseMargins = new mutable.ArrayBuilder.ofFloat
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
labeledPoints.map { labeledPoint =>
val groupsInfo = new mutable.ArrayBuilder.ofInt
val weights = new mutable.ArrayBuilder.ofFloat
val iter = labeledPointsGroups.filter(labeledPointGroup => {
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.map { labeledPoint => {
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (groupWeight != labeledPoint.weight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
baseMargins += labeledPoint.baseMargin
groupSize += 1
labeledPoint
}
}
weights += groupWeight
groupsInfo += groupSize
true
})
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
cachedDirName.map(_ + s"/$name").orNull)
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
if (baseMargin.isDefined) {
dMatrix.setBaseMargin(baseMargin.get)
}
dMatrix.setGroup(groupsInfo.result())
dMatrix.setWeight(weights.result())
(name, dMatrix)
}.toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
@@ -645,20 +661,46 @@ private object Watches {
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainGroups = new mutable.ArrayBuilder.ofInt
val testGroups = new mutable.ArrayBuilder.ofInt
val trainWeights = new mutable.ArrayBuilder.ofFloat
val testWeights = new mutable.ArrayBuilder.ofFloat
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.foreach(labeledPoint => {
testPoints += labeledPoint
testBaseMargins += labeledPoint.baseMargin
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (labeledPoint.weight != groupWeight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
groupSize += 1
})
testGroups += labeledPointGroup.length
testWeights += groupWeight
testGroups += groupSize
} else {
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
trainGroups += labeledPointGroup.length
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.foreach { labeledPoint => {
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (labeledPoint.weight != groupWeight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
trainBaseMargins += labeledPoint.baseMargin
groupSize += 1
}}
trainWeights += groupWeight
trainGroups += groupSize
}
accepted
}
@@ -666,10 +708,12 @@ private object Watches {
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
trainMatrix.setGroup(trainGroups.result())
trainMatrix.setWeight(trainWeights.result())
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
if (trainTestRatio < 1.0) {
testMatrix.setGroup(testGroups.result())
testMatrix.setWeight(testWeights.result())
}
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)