[jvm-packages] Don't cast to float if it's already float (#10386)
This commit is contained in:
parent
9b7633c01d
commit
bc7643d35e
@ -89,9 +89,13 @@ private[spark] object GpuUtils {
|
||||
val featureNameSet = featureNames.distinct
|
||||
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
|
||||
|
||||
val castToFloat = (ds: Dataset[_], colName: String) => {
|
||||
val colMeta = ds.schema(colName).metadata
|
||||
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
|
||||
val castToFloat = (df: DataFrame, colName: String) => {
|
||||
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
|
||||
df
|
||||
} else {
|
||||
val colMeta = df.schema(colName).metadata
|
||||
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
|
||||
}
|
||||
}
|
||||
val colNames = if (fitting) {
|
||||
var names = featureNameSet :+ labelName
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user