[jvm-packages] fix spark-rapids compatibility issue (#8240)

* [jvm-packages] fix spark-rapids compatibility issue

spark-rapids (from 22.10) has shimmed GpuColumnVector, which means
we can't call it directly. So this PR call the UnshimmedGpuColumnVector
This commit is contained in:
Bobby Wang 2022-09-22 23:31:29 +08:00 committed by GitHub
parent ab342af242
commit 8d247f0d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 7 deletions

View File

@ -16,10 +16,8 @@
package ml.dmlc.xgboost4j.scala.rapids.spark package ml.dmlc.xgboost4j.scala.rapids.spark
import scala.collection.Iterator
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import com.nvidia.spark.rapids.{GpuColumnVector}
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
@ -331,7 +329,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
} else { } else {
try { try {
currentBatch = new ColumnarBatch( currentBatch = new ColumnarBatch(
GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()), GpuUtils.extractBatchToHost(table, dataTypes),
table.getRowCount().toInt) table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala val rowIterator = currentBatch.rowIterator().asScala
.map(toUnsafe) .map(toUnsafe)

View File

@ -17,16 +17,32 @@
package ml.dmlc.xgboost4j.scala.rapids.spark package ml.dmlc.xgboost4j.scala.rapids.spark
import ai.rapids.cudf.Table import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.ColumnarRdd import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVector}
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, NumericType, StructType} import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
import org.apache.spark.sql.vectorized.ColumnVector
private[spark] object GpuUtils { private[spark] object GpuUtils {
def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
// spark-rapids has shimmed the GpuColumnVector from 22.10
try {
val clazz = Utils.classForName("com.nvidia.spark.rapids.GpuColumnVectorUtils")
clazz.getDeclaredMethod("extractHostColumns", classOf[Table], classOf[Array[DataType]])
.invoke(null, table, types).asInstanceOf[Array[ColumnVector]]
} catch {
case _: ClassNotFoundException =>
// If it's older version, use the GpuColumnVector
GpuColumnVector.extractColumns(table, types).map(_.copyToHost())
case e: Throwable => throw e
}
}
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df) def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_)) def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))

View File

@ -20,7 +20,6 @@ import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp} import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone} import java.util.{Locale, TimeZone}
import com.nvidia.spark.rapids.RapidsConf
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.SparkConf import org.apache.spark.SparkConf

View File

@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark.util
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints} import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
// based on org.apache.spark.util copy /paste // based on org.apache.spark.util copy /paste
private[spark] object Utils { object Utils {
def getSparkClassLoader: ClassLoader = getClass.getClassLoader def getSparkClassLoader: ClassLoader = getClass.getClassLoader