[jvm-packages] force use per-group weights in spark layer (#4118)
This commit is contained in:
@@ -293,12 +293,11 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
val (booster, _) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"missing" -> Float.NaN, "use_external_memory" -> false).toMap,
|
||||
hasGroup = true)
|
||||
|
||||
assert(booster != null)
|
||||
@@ -337,8 +336,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb1 = new XGBoostClassifier(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
|
||||
@@ -367,8 +365,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
|
||||
val xgb1 = new XGBoostRegressor(paramMap1)
|
||||
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val model1 = xgb1.fit(train)
|
||||
assert(model1 != null)
|
||||
assert(model1.summary.validationObjectiveHistory.length === 2)
|
||||
|
||||
Reference in New Issue
Block a user