[jvm-packages] fix evaluation when featuresCols is used (#7798)
This commit is contained in:
parent
4b00c64d96
commit
3f536b5308
@ -140,9 +140,13 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
val (xgbInput, featuresName) = est.vectorize(dataset)
|
val (xgbInput, featuresName) = est.vectorize(dataset)
|
||||||
|
|
||||||
|
val evalSets = est.getEvalSets(params).transform((_, df) => {
|
||||||
|
val (dfTransformed, _) = est.vectorize(df)
|
||||||
|
dfTransformed
|
||||||
|
})
|
||||||
|
|
||||||
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
|
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
|
||||||
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params),
|
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
|
||||||
xgbInput)
|
|
||||||
|
|
||||||
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
||||||
}
|
}
|
||||||
@ -154,7 +158,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
||||||
val evalRDDMap = evalSet.map {
|
val evalRDDMap = evalSet.map {
|
||||||
case (name, dataFrame) => (name,
|
case (name, dataFrame) => (name,
|
||||||
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams, dataFrame).head)
|
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams,
|
||||||
|
dataFrame.asInstanceOf[DataFrame]).head)
|
||||||
}
|
}
|
||||||
|
|
||||||
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
|
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
|
||||||
|
|||||||
@ -370,6 +370,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
val xgbClassifier = new XGBoostClassifier(paramMap)
|
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||||
.setFeaturesCol(featuresName)
|
.setFeaturesCol(featuresName)
|
||||||
.setLabelCol("label")
|
.setLabelCol("label")
|
||||||
|
.setEvalSets(Map("eval" -> xgbInput))
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|||||||
@ -273,6 +273,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
val xgbClassifier = new XGBoostRegressor(paramMap)
|
val xgbClassifier = new XGBoostRegressor(paramMap)
|
||||||
.setFeaturesCol(featuresName)
|
.setFeaturesCol(featuresName)
|
||||||
.setLabelCol("label")
|
.setLabelCol("label")
|
||||||
|
.setEvalSets(Map("eval" -> xgbInput))
|
||||||
|
|
||||||
val model = xgbClassifier.fit(xgbInput)
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
assert(model.getFeaturesCols.sameElements(featuresName))
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user