[jvm-packages] fix spark-rapids compatibility issue (#8240)
* [jvm-packages] fix spark-rapids compatibility issue spark-rapids (from 22.10) has shimmed GpuColumnVector, which means we can't call it directly. So this PR call the UnshimmedGpuColumnVector
This commit is contained in:
parent
ab342af242
commit
8d247f0d64
@ -16,10 +16,8 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||||
|
|
||||||
import scala.collection.Iterator
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import com.nvidia.spark.rapids.{GpuColumnVector}
|
|
||||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
|
||||||
@ -331,7 +329,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
currentBatch = new ColumnarBatch(
|
currentBatch = new ColumnarBatch(
|
||||||
GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()),
|
GpuUtils.extractBatchToHost(table, dataTypes),
|
||||||
table.getRowCount().toInt)
|
table.getRowCount().toInt)
|
||||||
val rowIterator = currentBatch.rowIterator().asScala
|
val rowIterator = currentBatch.rowIterator().asScala
|
||||||
.map(toUnsafe)
|
.map(toUnsafe)
|
||||||
|
|||||||
@ -17,16 +17,32 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||||
|
|
||||||
import ai.rapids.cudf.Table
|
import ai.rapids.cudf.Table
|
||||||
import com.nvidia.spark.rapids.ColumnarRdd
|
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVector}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.util.Utils
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.ml.param.{Param, Params}
|
import org.apache.spark.ml.param.{Param, Params}
|
||||||
import org.apache.spark.sql.functions.col
|
import org.apache.spark.sql.functions.col
|
||||||
import org.apache.spark.sql.types.{FloatType, NumericType, StructType}
|
import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
|
||||||
|
import org.apache.spark.sql.vectorized.ColumnVector
|
||||||
|
|
||||||
private[spark] object GpuUtils {
|
private[spark] object GpuUtils {
|
||||||
|
|
||||||
|
def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
|
||||||
|
// spark-rapids has shimmed the GpuColumnVector from 22.10
|
||||||
|
try {
|
||||||
|
val clazz = Utils.classForName("com.nvidia.spark.rapids.GpuColumnVectorUtils")
|
||||||
|
clazz.getDeclaredMethod("extractHostColumns", classOf[Table], classOf[Array[DataType]])
|
||||||
|
.invoke(null, table, types).asInstanceOf[Array[ColumnVector]]
|
||||||
|
} catch {
|
||||||
|
case _: ClassNotFoundException =>
|
||||||
|
// If it's older version, use the GpuColumnVector
|
||||||
|
GpuColumnVector.extractColumns(table, types).map(_.copyToHost())
|
||||||
|
case e: Throwable => throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
|
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
|
||||||
|
|
||||||
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
|
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
|
||||||
|
|||||||
@ -20,7 +20,6 @@ import java.nio.file.{Files, Path}
|
|||||||
import java.sql.{Date, Timestamp}
|
import java.sql.{Date, Timestamp}
|
||||||
import java.util.{Locale, TimeZone}
|
import java.util.{Locale, TimeZone}
|
||||||
|
|
||||||
import com.nvidia.spark.rapids.RapidsConf
|
|
||||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||||
|
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
|
|||||||
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark.util
|
|||||||
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
|
||||||
|
|
||||||
// based on org.apache.spark.util copy /paste
|
// based on org.apache.spark.util copy /paste
|
||||||
private[spark] object Utils {
|
object Utils {
|
||||||
|
|
||||||
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user