[jvm-packages] fix evaluation when featuresCols is used (#7798)

This commit is contained in:
Bobby Wang 2022-04-13 12:52:50 +08:00 committed by GitHub
parent 4b00c64d96
commit 3f536b5308
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 3 deletions

View File

@ -140,9 +140,13 @@ object PreXGBoost extends PreXGBoostProvider {
val (xgbInput, featuresName) = est.vectorize(dataset)
val evalSets = est.getEvalSets(params).transform((_, df) => {
val (dfTransformed, _) = est.vectorize(df)
dfTransformed
})
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params),
xgbInput)
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
case _ => throw new RuntimeException("Unsupporting " + estimator)
}
@ -154,7 +158,8 @@ object PreXGBoost extends PreXGBoostProvider {
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
val evalRDDMap = evalSet.map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head)
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams,
dataFrame.asInstanceOf[DataFrame]).head)
}
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)

View File

@ -370,6 +370,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val xgbClassifier = new XGBoostClassifier(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
.setEvalSets(Map("eval" -> xgbInput))
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))

View File

@ -273,6 +273,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val xgbClassifier = new XGBoostRegressor(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
.setEvalSets(Map("eval" -> xgbInput))
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))