[jvm-packages] Don't cast to float if it's already float (#10386)

This commit is contained in:
Bobby Wang 2024-06-04 18:01:51 +08:00 committed by GitHub
parent 9b7633c01d
commit bc7643d35e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,9 +89,13 @@ private[spark] object GpuUtils {
val featureNameSet = featureNames.distinct
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
val castToFloat = (ds: Dataset[_], colName: String) => {
val colMeta = ds.schema(colName).metadata
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
val castToFloat = (df: DataFrame, colName: String) => {
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
df
} else {
val colMeta = df.schema(colName).metadata
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
}
}
val colNames = if (fitting) {
var names = featureNameSet :+ labelName