diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index 12f18be88..6fbe6e129 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -16,10 +16,8 @@ package ml.dmlc.xgboost4j.scala.rapids.spark -import scala.collection.Iterator import scala.collection.JavaConverters._ -import com.nvidia.spark.rapids.{GpuColumnVector} import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix} @@ -331,7 +329,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { } else { try { currentBatch = new ColumnarBatch( - GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()), + GpuUtils.extractBatchToHost(table, dataTypes), table.getRowCount().toInt) val rowIterator = currentBatch.rowIterator().asScala .map(toUnsafe) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala index fdd1061a7..f5876fded 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuUtils.scala @@ -17,16 +17,32 @@ package ml.dmlc.xgboost4j.scala.rapids.spark import ai.rapids.cudf.Table -import com.nvidia.spark.rapids.ColumnarRdd +import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVector} +import ml.dmlc.xgboost4j.scala.spark.util.Utils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{FloatType, NumericType, StructType} +import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType} +import org.apache.spark.sql.vectorized.ColumnVector private[spark] object GpuUtils { + def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = { + // spark-rapids has shimmed the GpuColumnVector from 22.10 + try { + val clazz = Utils.classForName("com.nvidia.spark.rapids.GpuColumnVectorUtils") + clazz.getDeclaredMethod("extractHostColumns", classOf[Table], classOf[Array[DataType]]) + .invoke(null, table, types).asInstanceOf[Array[ColumnVector]] + } catch { + case _: ClassNotFoundException => + // If it's older version, use the GpuColumnVector + GpuColumnVector.extractColumns(table, types).map(_.copyToHost()) + case e: Throwable => throw e + } + } + def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df) def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_)) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala index 4d82459fa..a8d7a81ed 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala @@ -20,7 +20,6 @@ import java.nio.file.{Files, Path} import java.sql.{Date, Timestamp} import java.util.{Locale, TimeZone} -import com.nvidia.spark.rapids.RapidsConf import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala index d5e133b4c..710dd9adc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark.util import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints} // based on org.apache.spark.util copy /paste -private[spark] object Utils { +object Utils { def getSparkClassLoader: ClassLoader = getClass.getClassLoader