[jvm-packages] force use per-group weights in spark layer (#4118)
This commit is contained in:
@@ -618,18 +618,34 @@ private object Watches {
|
||||
val dms = nameAndlabeledPointGroupSets.map {
|
||||
case (name, labeledPointsGroups) =>
|
||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
|
||||
labeledPoints.map { labeledPoint =>
|
||||
val groupsInfo = new mutable.ArrayBuilder.ofInt
|
||||
val weights = new mutable.ArrayBuilder.ofFloat
|
||||
val iter = labeledPointsGroups.filter(labeledPointGroup => {
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.map { labeledPoint => {
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (groupWeight != labeledPoint.weight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
baseMargins += labeledPoint.baseMargin
|
||||
groupSize += 1
|
||||
labeledPoint
|
||||
}
|
||||
}
|
||||
weights += groupWeight
|
||||
groupsInfo += groupSize
|
||||
true
|
||||
})
|
||||
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
|
||||
cachedDirName.map(_ + s"/$name").orNull)
|
||||
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
|
||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
||||
if (baseMargin.isDefined) {
|
||||
dMatrix.setBaseMargin(baseMargin.get)
|
||||
}
|
||||
dMatrix.setGroup(groupsInfo.result())
|
||||
dMatrix.setWeight(weights.result())
|
||||
(name, dMatrix)
|
||||
}.toArray
|
||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
||||
@@ -645,20 +661,46 @@ private object Watches {
|
||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
|
||||
val trainGroups = new mutable.ArrayBuilder.ofInt
|
||||
val testGroups = new mutable.ArrayBuilder.ofInt
|
||||
|
||||
val trainWeights = new mutable.ArrayBuilder.ofFloat
|
||||
val testWeights = new mutable.ArrayBuilder.ofFloat
|
||||
|
||||
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
||||
val accepted = r.nextDouble() <= trainTestRatio
|
||||
if (!accepted) {
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.foreach(labeledPoint => {
|
||||
testPoints += labeledPoint
|
||||
testBaseMargins += labeledPoint.baseMargin
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (labeledPoint.weight != groupWeight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
groupSize += 1
|
||||
})
|
||||
testGroups += labeledPointGroup.length
|
||||
testWeights += groupWeight
|
||||
testGroups += groupSize
|
||||
} else {
|
||||
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
||||
trainGroups += labeledPointGroup.length
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.foreach { labeledPoint => {
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (labeledPoint.weight != groupWeight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
trainBaseMargins += labeledPoint.baseMargin
|
||||
groupSize += 1
|
||||
}}
|
||||
trainWeights += groupWeight
|
||||
trainGroups += groupSize
|
||||
}
|
||||
accepted
|
||||
}
|
||||
@@ -666,10 +708,12 @@ private object Watches {
|
||||
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
trainMatrix.setGroup(trainGroups.result())
|
||||
trainMatrix.setWeight(trainWeights.result())
|
||||
|
||||
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
if (trainTestRatio < 1.0) {
|
||||
testMatrix.setGroup(testGroups.result())
|
||||
testMatrix.setWeight(testWeights.result())
|
||||
}
|
||||
|
||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||
|
||||
Reference in New Issue
Block a user