From 3f536b5308d8145e42fb69fcfbe6b5564d5e3d2b Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 13 Apr 2022 12:52:50 +0800 Subject: [PATCH] [jvm-packages] fix evaluation when featuresCols is used (#7798) --- .../ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala | 11 ++++++++--- .../scala/spark/XGBoostClassifierSuite.scala | 1 + .../xgboost4j/scala/spark/XGBoostRegressorSuite.scala | 1 + 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 67deb6979..32fd6938e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -140,9 +140,13 @@ object PreXGBoost extends PreXGBoostProvider { val (xgbInput, featuresName) = est.vectorize(dataset) + val evalSets = est.getEvalSets(params).transform((_, df) => { + val (dfTransformed, _) = est.vectorize(df) + dfTransformed + }) + (PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group, - est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params), - xgbInput) + est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput) case _ => throw new RuntimeException("Unsupporting " + estimator) } @@ -154,7 +158,8 @@ object PreXGBoost extends PreXGBoostProvider { // transform the eval Dataset[_] to RDD[XGBLabeledPoint] val evalRDDMap = evalSet.map { case (name, dataFrame) => (name, - DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head) + DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, + dataFrame.asInstanceOf[DataFrame]).head) } val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 91f4a4cfa..0fa851f57 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -370,6 +370,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val xgbClassifier = new XGBoostClassifier(paramMap) .setFeaturesCol(featuresName) .setLabelCol("label") + .setEvalSets(Map("eval" -> xgbInput)) val model = xgbClassifier.fit(xgbInput) assert(model.getFeaturesCols.sameElements(featuresName)) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 04e510640..e427c17e3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -273,6 +273,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { val xgbClassifier = new XGBoostRegressor(paramMap) .setFeaturesCol(featuresName) .setLabelCol("label") + .setEvalSets(Map("eval" -> xgbInput)) val model = xgbClassifier.fit(xgbInput) assert(model.getFeaturesCols.sameElements(featuresName))