[jvm-packages] force use per-group weights in spark layer (#4118)

This commit is contained in:
Nan Zhu
2019-02-09 13:38:03 -08:00
committed by Jiaming Yuan
parent ba584e5e9f
commit 3320a52192
2 changed files with 56 additions and 15 deletions

View File

@@ -293,12 +293,11 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 5)
val (booster, metrics) = XGBoost.trainDistributed(
val (booster, _) = XGBoost.trainDistributed(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
"missing" -> Float.NaN, "use_external_memory" -> false).toMap,
hasGroup = true)
assert(booster != null)
@@ -337,8 +336,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb1 = new XGBoostClassifier(paramMap1)
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
@@ -367,8 +365,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
val xgb1 = new XGBoostRegressor(paramMap1)
xgb1.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1 != null)
assert(model1.summary.validationObjectiveHistory.length === 2)