[jvm-packages] force use per-group weights in spark layer (#4118)
This commit is contained in:
parent
ba584e5e9f
commit
3320a52192
@ -618,18 +618,34 @@ private object Watches {
|
|||||||
val dms = nameAndlabeledPointGroupSets.map {
|
val dms = nameAndlabeledPointGroupSets.map {
|
||||||
case (name, labeledPointsGroups) =>
|
case (name, labeledPointsGroups) =>
|
||||||
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
val baseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
val duplicatedItr = labeledPointsGroups.map(labeledPoints => {
|
val groupsInfo = new mutable.ArrayBuilder.ofInt
|
||||||
labeledPoints.map { labeledPoint =>
|
val weights = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val iter = labeledPointsGroups.filter(labeledPointGroup => {
|
||||||
|
var groupWeight = -1.0f
|
||||||
|
var groupSize = 0
|
||||||
|
labeledPointGroup.map { labeledPoint => {
|
||||||
|
if (groupWeight < 0) {
|
||||||
|
groupWeight = labeledPoint.weight
|
||||||
|
} else if (groupWeight != labeledPoint.weight) {
|
||||||
|
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||||
|
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||||
|
}
|
||||||
baseMargins += labeledPoint.baseMargin
|
baseMargins += labeledPoint.baseMargin
|
||||||
|
groupSize += 1
|
||||||
labeledPoint
|
labeledPoint
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
weights += groupWeight
|
||||||
|
groupsInfo += groupSize
|
||||||
|
true
|
||||||
})
|
})
|
||||||
val dMatrix = new DMatrix(duplicatedItr.flatMap(_.iterator),
|
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
|
||||||
cachedDirName.map(_ + s"/$name").orNull)
|
|
||||||
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
|
||||||
if (baseMargin.isDefined) {
|
if (baseMargin.isDefined) {
|
||||||
dMatrix.setBaseMargin(baseMargin.get)
|
dMatrix.setBaseMargin(baseMargin.get)
|
||||||
}
|
}
|
||||||
|
dMatrix.setGroup(groupsInfo.result())
|
||||||
|
dMatrix.setWeight(weights.result())
|
||||||
(name, dMatrix)
|
(name, dMatrix)
|
||||||
}.toArray
|
}.toArray
|
||||||
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
|
||||||
@ -645,20 +661,46 @@ private object Watches {
|
|||||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
|
||||||
val trainGroups = new mutable.ArrayBuilder.ofInt
|
val trainGroups = new mutable.ArrayBuilder.ofInt
|
||||||
val testGroups = new mutable.ArrayBuilder.ofInt
|
val testGroups = new mutable.ArrayBuilder.ofInt
|
||||||
|
|
||||||
|
val trainWeights = new mutable.ArrayBuilder.ofFloat
|
||||||
|
val testWeights = new mutable.ArrayBuilder.ofFloat
|
||||||
|
|
||||||
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
|
||||||
val accepted = r.nextDouble() <= trainTestRatio
|
val accepted = r.nextDouble() <= trainTestRatio
|
||||||
if (!accepted) {
|
if (!accepted) {
|
||||||
|
var groupWeight = -1.0f
|
||||||
|
var groupSize = 0
|
||||||
labeledPointGroup.foreach(labeledPoint => {
|
labeledPointGroup.foreach(labeledPoint => {
|
||||||
testPoints += labeledPoint
|
testPoints += labeledPoint
|
||||||
testBaseMargins += labeledPoint.baseMargin
|
testBaseMargins += labeledPoint.baseMargin
|
||||||
|
if (groupWeight < 0) {
|
||||||
|
groupWeight = labeledPoint.weight
|
||||||
|
} else if (labeledPoint.weight != groupWeight) {
|
||||||
|
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||||
|
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||||
|
}
|
||||||
|
groupSize += 1
|
||||||
})
|
})
|
||||||
testGroups += labeledPointGroup.length
|
testWeights += groupWeight
|
||||||
|
testGroups += groupSize
|
||||||
} else {
|
} else {
|
||||||
labeledPointGroup.foreach(trainBaseMargins += _.baseMargin)
|
var groupWeight = -1.0f
|
||||||
trainGroups += labeledPointGroup.length
|
var groupSize = 0
|
||||||
|
labeledPointGroup.foreach { labeledPoint => {
|
||||||
|
if (groupWeight < 0) {
|
||||||
|
groupWeight = labeledPoint.weight
|
||||||
|
} else if (labeledPoint.weight != groupWeight) {
|
||||||
|
throw new IllegalArgumentException("the instances in the same group have to be" +
|
||||||
|
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
|
||||||
|
}
|
||||||
|
trainBaseMargins += labeledPoint.baseMargin
|
||||||
|
groupSize += 1
|
||||||
|
}}
|
||||||
|
trainWeights += groupWeight
|
||||||
|
trainGroups += groupSize
|
||||||
}
|
}
|
||||||
accepted
|
accepted
|
||||||
}
|
}
|
||||||
@ -666,10 +708,12 @@ private object Watches {
|
|||||||
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
|
||||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||||
trainMatrix.setGroup(trainGroups.result())
|
trainMatrix.setGroup(trainGroups.result())
|
||||||
|
trainMatrix.setWeight(trainWeights.result())
|
||||||
|
|
||||||
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
|
||||||
if (trainTestRatio < 1.0) {
|
if (trainTestRatio < 1.0) {
|
||||||
testMatrix.setGroup(testGroups.result())
|
testMatrix.setGroup(testGroups.result())
|
||||||
|
testMatrix.setWeight(testWeights.result())
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
|
||||||
|
|||||||
@ -293,12 +293,11 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("distributed training with group data") {
|
test("distributed training with group data") {
|
||||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||||
val (booster, metrics) = XGBoost.trainDistributed(
|
val (booster, _) = XGBoost.trainDistributed(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
"missing" -> Float.NaN, "use_external_memory" -> false).toMap,
|
||||||
"missing" -> Float.NaN).toMap,
|
|
||||||
hasGroup = true)
|
hasGroup = true)
|
||||||
|
|
||||||
assert(booster != null)
|
assert(booster != null)
|
||||||
@ -337,8 +336,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
|
||||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val model1 = xgb1.fit(train)
|
val model1 = xgb1.fit(train)
|
||||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||||
@ -367,8 +365,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "rank:pairwise",
|
"objective" -> "rank:pairwise",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
||||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
|
||||||
val model1 = xgb1.fit(train)
|
val model1 = xgb1.fit(train)
|
||||||
assert(model1 != null)
|
assert(model1 != null)
|
||||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user