[jvm-packages] force use per-group weights in spark layer (#4118)
This commit is contained in:
parent
ba584e5e9f
commit
3320a52192
@ -618,18 +618,34 @@ private object Watches {
|
||||
val dms = nameAndlabeledPointGroupSets.map {
|
||||
case (name, labeledPointsGroups) =>
|
||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
|
||||
labeledPoints.map { labeledPoint =>
|
||||
val groupsInfo = new mutable.ArrayBuilder.ofInt
|
||||
val weights = new mutable.ArrayBuilder.ofFloat
|
||||
val iter = labeledPointsGroups.filter(labeledPointGroup => {
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.map { labeledPoint => {
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (groupWeight != labeledPoint.weight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
baseMargins += labeledPoint.baseMargin
|
||||
groupSize += 1
|
||||
labeledPoint
|
||||
}
|
||||
}
|
||||
weights += groupWeight
|
||||
groupsInfo += groupSize
|
||||
true
|
||||
})
|
||||
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
|
||||
cachedDirName.map(_ + s"/$name").orNull)
|
||||
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
|
||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
||||
if (baseMargin.isDefined) {
|
||||
dMatrix.setBaseMargin(baseMargin.get)
|
||||
}
|
||||
dMatrix.setGroup(groupsInfo.result())
|
||||
dMatrix.setWeight(weights.result())
|
||||
(name, dMatrix)
|
||||
}.toArray
|
||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
||||
@ -645,20 +661,46 @@ private object Watches {
|
||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
|
||||
val trainGroups = new mutable.ArrayBuilder.ofInt
|
||||
val testGroups = new mutable.ArrayBuilder.ofInt
|
||||
|
||||
val trainWeights = new mutable.ArrayBuilder.ofFloat
|
||||
val testWeights = new mutable.ArrayBuilder.ofFloat
|
||||
|
||||
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
||||
val accepted = r.nextDouble() <= trainTestRatio
|
||||
if (!accepted) {
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.foreach(labeledPoint => {
|
||||
testPoints += labeledPoint
|
||||
testBaseMargins += labeledPoint.baseMargin
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (labeledPoint.weight != groupWeight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
groupSize += 1
|
||||
})
|
||||
testGroups += labeledPointGroup.length
|
||||
testWeights += groupWeight
|
||||
testGroups += groupSize
|
||||
} else {
|
||||
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
||||
trainGroups += labeledPointGroup.length
|
||||
var groupWeight = -1.0f
|
||||
var groupSize = 0
|
||||
labeledPointGroup.foreach { labeledPoint => {
|
||||
if (groupWeight < 0) {
|
||||
groupWeight = labeledPoint.weight
|
||||
} else if (labeledPoint.weight != groupWeight) {
|
||||
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||
}
|
||||
trainBaseMargins += labeledPoint.baseMargin
|
||||
groupSize += 1
|
||||
}}
|
||||
trainWeights += groupWeight
|
||||
trainGroups += groupSize
|
||||
}
|
||||
accepted
|
||||
}
|
||||
@ -666,10 +708,12 @@ private object Watches {
|
||||
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
trainMatrix.setGroup(trainGroups.result())
|
||||
trainMatrix.setWeight(trainWeights.result())
|
||||
|
||||
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
if (trainTestRatio < 1.0) {
|
||||
testMatrix.setGroup(testGroups.result())
|
||||
testMatrix.setWeight(testWeights.result())
|
||||
}
|
||||
|
||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||
|
||||
@ -293,12 +293,11 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
val (booster, _) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"missing" -> Float.NaN, "use_external_memory" -> false).toMap,
|
||||
hasGroup = true)
|
||||
|
||||
assert(booster != null)
|
||||
@ -337,8 +336,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
@ -367,8 +365,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1 != null)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user