From bc7643d35eead36aae89e0f99f48e3cd6d4fea88 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 4 Jun 2024 18:01:51 +0800 Subject: [PATCH] [jvm-packages] Don't cast to float if it's already float (#10386) --- .../dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala index c88aefa4e..79a8d5449 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala @@ -89,9 +89,13 @@ private[spark] object GpuUtils { val featureNameSet = featureNames.distinct validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting) - val castToFloat = (ds: Dataset[_], colName: String) => { - val colMeta = ds.schema(colName).metadata - ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType)) + val castToFloat = (df: DataFrame, colName: String) => { + if (df.schema(colName).dataType.isInstanceOf[FloatType]) { + df + } else { + val colMeta = df.schema(colName).metadata + df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType)) + } } val colNames = if (fitting) { var names = featureNameSet :+ labelName