[jvm-packages] fix evaluation when featuresCols is used (#7798)
This commit is contained in:
parent
4b00c64d96
commit
3f536b5308
@ -140,9 +140,13 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
val (xgbInput, featuresName) = est.vectorize(dataset)
|
||||
|
||||
val evalSets = est.getEvalSets(params).transform((_, df) => {
|
||||
val (dfTransformed, _) = est.vectorize(df)
|
||||
dfTransformed
|
||||
})
|
||||
|
||||
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
|
||||
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params),
|
||||
xgbInput)
|
||||
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
|
||||
|
||||
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
||||
}
|
||||
@ -154,7 +158,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
||||
val evalRDDMap = evalSet.map {
|
||||
case (name, dataFrame) => (name,
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head)
|
||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams,
|
||||
dataFrame.asInstanceOf[DataFrame]).head)
|
||||
}
|
||||
|
||||
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
|
||||
|
||||
@ -370,6 +370,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||
.setFeaturesCol(featuresName)
|
||||
.setLabelCol("label")
|
||||
.setEvalSets(Map("eval" -> xgbInput))
|
||||
|
||||
val model = xgbClassifier.fit(xgbInput)
|
||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||
|
||||
@ -273,6 +273,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
val xgbClassifier = new XGBoostRegressor(paramMap)
|
||||
.setFeaturesCol(featuresName)
|
||||
.setLabelCol("label")
|
||||
.setEvalSets(Map("eval" -> xgbInput))
|
||||
|
||||
val model = xgbClassifier.fit(xgbInput)
|
||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user