[jvm-packages] Fix "obj_type" error to enable custom objectives and evaluations (#3646)
credits to @mmui
This commit is contained in:
@@ -140,15 +140,18 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
assert(xgb.getObjectiveType === "classification")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user