[jvm-packages] fix spark-rapids compatibility issue (#8240)
* [jvm-packages] fix spark-rapids compatibility issue spark-rapids (from 22.10) has shimmed GpuColumnVector, which means we can't call it directly. So this PR call the UnshimmedGpuColumnVector
This commit is contained in:
parent
ab342af242
commit
8d247f0d64
@ -16,10 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||
|
||||
import scala.collection.Iterator
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.nvidia.spark.rapids.{GpuColumnVector}
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
|
||||
@ -331,7 +329,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
} else {
|
||||
try {
|
||||
currentBatch = new ColumnarBatch(
|
||||
GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()),
|
||||
GpuUtils.extractBatchToHost(table, dataTypes),
|
||||
table.getRowCount().toInt)
|
||||
val rowIterator = currentBatch.rowIterator().asScala
|
||||
.map(toUnsafe)
|
||||
|
||||
@ -17,16 +17,32 @@
|
||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import com.nvidia.spark.rapids.ColumnarRdd
|
||||
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVector}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.ml.param.{Param, Params}
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.types.{FloatType, NumericType, StructType}
|
||||
import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
|
||||
import org.apache.spark.sql.vectorized.ColumnVector
|
||||
|
||||
private[spark] object GpuUtils {
|
||||
|
||||
def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
|
||||
// spark-rapids has shimmed the GpuColumnVector from 22.10
|
||||
try {
|
||||
val clazz = Utils.classForName("com.nvidia.spark.rapids.GpuColumnVectorUtils")
|
||||
clazz.getDeclaredMethod("extractHostColumns", classOf[Table], classOf[Array[DataType]])
|
||||
.invoke(null, table, types).asInstanceOf[Array[ColumnVector]]
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
// If it's older version, use the GpuColumnVector
|
||||
GpuColumnVector.extractColumns(table, types).map(_.copyToHost())
|
||||
case e: Throwable => throw e
|
||||
}
|
||||
}
|
||||
|
||||
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
|
||||
|
||||
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
|
||||
|
||||
@ -20,7 +20,6 @@ import java.nio.file.{Files, Path}
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.{Locale, TimeZone}
|
||||
|
||||
import com.nvidia.spark.rapids.RapidsConf
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
|
||||
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark.util
|
||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||
|
||||
// based on org.apache.spark.util copy /paste
|
||||
private[spark] object Utils {
|
||||
object Utils {
|
||||
|
||||
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user