[jvm-packages] Don't cast to float if it's already float (#10386)

This commit is contained in:
Bobby Wang
2024-06-04 18:01:51 +08:00
committed by GitHub
parent 9b7633c01d
commit bc7643d35e

View File

@@ -89,9 +89,13 @@ private[spark] object GpuUtils {
val featureNameSet = featureNames.distinct
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
val castToFloat = (ds: Dataset[_], colName: String) => {
val colMeta = ds.schema(colName).metadata
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
val castToFloat = (df: DataFrame, colName: String) => {
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
df
} else {
val colMeta = df.schema(colName).metadata
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
}
}
val colNames = if (fitting) {
var names = featureNameSet :+ labelName