[jvm-packages] Don't cast to float if it's already float (#10386)
This commit is contained in:
parent
9b7633c01d
commit
bc7643d35e
@ -89,9 +89,13 @@ private[spark] object GpuUtils {
|
|||||||
val featureNameSet = featureNames.distinct
|
val featureNameSet = featureNames.distinct
|
||||||
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
|
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
|
||||||
|
|
||||||
val castToFloat = (ds: Dataset[_], colName: String) => {
|
val castToFloat = (df: DataFrame, colName: String) => {
|
||||||
val colMeta = ds.schema(colName).metadata
|
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
|
||||||
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
|
df
|
||||||
|
} else {
|
||||||
|
val colMeta = df.schema(colName).metadata
|
||||||
|
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
val colNames = if (fitting) {
|
val colNames = if (fitting) {
|
||||||
var names = featureNameSet :+ labelName
|
var names = featureNameSet :+ labelName
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user