[jvm-packages] Add Rapids plugin support (#7491)

* Add GPU pre-processing pipeline.
This commit is contained in:
Bobby Wang 2021-12-17 13:11:12 +08:00 committed by GitHub
parent 5b1161bb64
commit 24e25802a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 2035 additions and 37 deletions

View File

@ -34,13 +34,16 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.7.2</flink.version>
<spark.version>3.0.0</spark.version>
<spark.version>3.0.1</spark.version>
<scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>21.08.2</cudf.version>
<spark.rapids.version>21.08.0</spark.rapids.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<repositories>
<repository>

View File

@ -12,11 +12,6 @@
<version>1.6.0-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<cudf.version>21.08.2</cudf.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<dependencies>
<dependency>
<groupId>ai.rapids</groupId>

View File

@ -44,5 +44,18 @@
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
<version>${spark.rapids.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,68 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.nvidia.spark;
import java.util.List;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Table;
import org.apache.spark.sql.types.*;
/**
* Wrapper of CudfTable with schema for scala
*/
public class GpuColumnBatch implements AutoCloseable {
private final StructType schema;
private Table table; // the original Table
public GpuColumnBatch(Table table, StructType schema) {
this.table = table;
this.schema = schema;
}
@Override
public void close() {
if (table != null) {
table.close();
table = null;
}
}
/** Slice the columns indicated by indices into a Table*/
public Table slice(List<Integer> indices) {
if (indices == null || indices.size() == 0) {
return null;
}
int len = indices.size();
ColumnVector[] cv = new ColumnVector[len];
for (int i = 0; i < len; i++) {
int index = indices.get(i);
if (index >= table.getNumberOfColumns()) {
throw new RuntimeException("Wrong index");
}
cv[i] = table.getColumn(index);
}
return new Table(cv);
}
public StructType getSchema() {
return schema;
}
}

View File

@ -0,0 +1 @@
ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost

View File

@ -1 +0,0 @@
../../../xgboost4j-spark/src/main/scala

View File

@ -0,0 +1,572 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import scala.collection.Iterator
import scala.collection.JavaConverters._
import com.nvidia.spark.rapids.{GpuColumnVector}
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.commons.logging.LogFactory
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* GpuPreXGBoost brings Rapids-Plugin to XGBoost4j-Spark to accelerate XGBoost4j
* training and transform process
*/
class GpuPreXGBoost extends PreXGBoostProvider {
/**
* Whether the provider is enabled or not
*
* @param dataset the input dataset
* @return Boolean
*/
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
GpuPreXGBoost.providerEnabled(dataset)
}
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
*
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
}
/**
* Transform Dataset
*
* @param model [[XGBoostClassificationModel]] or [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
GpuPreXGBoost.transformDataset(model, dataset)
}
override def transformSchema(
xgboostEstimator: XGBoostEstimatorCommon,
schema: StructType): StructType = {
GpuPreXGBoost.transformSchema(xgboostEstimator, schema)
}
}
object GpuPreXGBoost extends PreXGBoostProvider {
private val logger = LogFactory.getLog("XGBoostSpark")
private val FEATURES_COLS = "features_cols"
private val TRAIN_NAME = "train"
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
// RuntimeConfig
val optionConf = dataset.map(ds => Some(ds.sparkSession.conf))
.getOrElse(SparkSession.getActiveSession.map(ss => ss.conf))
if (optionConf.isDefined) {
val conf = optionConf.get
val rapidsEnabled = try {
conf.get("spark.rapids.sql.enabled").toBoolean
} catch {
// Rapids plugin has default "spark.rapids.sql.enabled" to true
case _: NoSuchElementException => true
case _: Throwable => false // Any exception will return false
}
rapidsEnabled && conf.get("spark.sql.extensions", "")
.split(",")
.contains("com.nvidia.spark.rapids.SQLExecPlugin")
} else false
}
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
case est: XGBoostEstimatorCommon =>
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires tree_method set to gpu_hist")
val groupName = estimator match {
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
regressor.getGroupCol } else ""
case _: XGBoostClassifier => ""
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
}
// Check schema and cast columns' type
(GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol, est.baseMarginCol),
est.getFeaturesCols, groupName, est.getEvalSets(params))
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
}
val castedDF = GpuUtils.prepareColumnType(dataset, feturesCols, labelName, weightName,
marginName)
// Check columns and build column data batch
val trainingData = GpuUtils.buildColumnDataBatch(feturesCols,
labelName, weightName, marginName, "", castedDF)
// eval map
val evalDataMap = evalSets.map {
case (name, df) =>
val castDF = GpuUtils.prepareColumnType(df, feturesCols, labelName,
weightName, marginName)
(name, GpuUtils.buildColumnDataBatch(feturesCols, labelName, weightName,
marginName, groupName, castDF))
}
xgbExecParams: XGBoostExecutionParams =>
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
xgbExecParams.cacheTrainingSet)
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
}
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
val (booster, predictFunc, schema, featureColNames, missing) = model match {
case m: XGBoostClassificationModel =>
Seq(XGBoostClassificationModel._rawPredictionCol,
XGBoostClassificationModel._probabilityCol, m.leafPredictionCol, m.contribPredictionCol)
// predict and turn to Row
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
case m: XGBoostRegressionModel =>
Seq(XGBoostRegressionModel._originalPredictionCol, m.leafPredictionCol,
m.contribPredictionCol)
// predict and turn to Row
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(broadcastBooster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr,
predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
}
val sc = dataset.sparkSession.sparkContext
// Prepare some vars will be passed to executors.
val bOrigSchema = sc.broadcast(dataset.schema)
val bRowSchema = sc.broadcast(schema)
val bBooster = sc.broadcast(booster)
// Small vars so don't need to broadcast them
val isLocal = sc.isLocal
val featureIds = featureColNames.distinct.map(dataset.schema.fieldIndex)
// start transform by df->rd->mapPartition
val rowRDD: RDD[Row] = GpuUtils.toColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions {
tableIters =>
// UnsafeProjection is not serializable so do it on the executor side
val toUnsafe = UnsafeProjection.create(bOrigSchema.value)
// Iterator on Row
new Iterator[Row] {
// Convert InternalRow to Row
private val converter: InternalRow => Row = CatalystTypeConverters
.createToScalaConverter(bOrigSchema.value)
.asInstanceOf[InternalRow => Row]
// GPU batches read in must be closed by the receiver (us)
@transient var currentBatch: ColumnarBatch = null
// Iterator on Row
var iter: Iterator[Row] = null
// set some params of gpu related to booster
// - gpu id
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
bBooster.value.setParam("gpu_id", gpuId.toString)
bBooster.value.setParam("predictor", "gpu_predictor")
logger.info("GPU transform on device: " + gpuId)
TaskContext.get().addTaskCompletionListener[Unit](_ => {
closeCurrentBatch() // close the last ColumnarBatch
})
private def closeCurrentBatch(): Unit = {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
}
def loadNextBatch(): Unit = {
closeCurrentBatch()
if (tableIters.hasNext) {
val dataTypes = bOrigSchema.value.fields.map(x => x.dataType)
iter = withResource(tableIters.next()) { table =>
val gpuColumnBatch = new GpuColumnBatch(table, bOrigSchema.value)
// Create DMatrix
val feaTable = gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(featureIds).asJava)
if (feaTable == null) {
throw new RuntimeException("Something wrong for feature indices")
}
try {
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null)
val dm = new DMatrix(cudfColumnBatch, missing, 1)
if (dm == null) {
Iterator.empty
} else {
try {
currentBatch = new ColumnarBatch(
GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()),
table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala
.map(toUnsafe)
.map(converter(_))
predictFunc(bBooster, dm, rowIterator)
} finally {
dm.delete()
}
}
} finally {
feaTable.close()
}
}
} else {
iter = null
}
}
override def hasNext: Boolean = {
val itHasNext = iter != null && iter.hasNext
if (!itHasNext) { // Don't have extra Row for current ColumnarBatch
loadNextBatch()
iter != null && iter.hasNext
} else {
itHasNext
}
}
override def next(): Row = {
if (iter == null || !iter.hasNext) {
loadNextBatch()
}
if (iter == null) {
throw new NoSuchElementException()
}
iter.next()
}
}
}
bOrigSchema.unpersist(blocking = false)
bRowSchema.unpersist(blocking = false)
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rowRDD, schema)
}
/**
* Transform schema
*
* @param est supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
override def transformSchema(
est: XGBoostEstimatorCommon,
schema: StructType): StructType = {
val fit = est match {
case _: XGBoostClassifier | _: XGBoostRegressor => true
case _ => false
}
val Seq(label, weight, margin) = GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol,
est.baseMarginCol)
GpuUtils.validateSchema(schema, est.getFeaturesCols, label, weight, margin, fit)
}
/**
* Repartition all the Columnar Dataset (training and evaluation) to nWorkers,
* and assemble them into a map
*/
private def prepareInputData(
trainingData: ColumnDataBatch,
evalSetsMap: Map[String, ColumnDataBatch],
nWorkers: Int,
isCacheData: Boolean): Map[String, ColumnDataBatch] = {
// Cache is not supported
if (isCacheData) {
logger.warn("the cache param will be ignored by GPU pipeline!")
}
(Map(TRAIN_NAME -> trainingData) ++ evalSetsMap).map {
case (name, colData) =>
// No light cost way to get number of partitions from DataFrame, so always repartition
val newDF = colData.groupColName
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
.getOrElse(colData.rawDF.repartition(nWorkers))
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
}
}
private def repartitionForGroup(
groupName: String,
dataFrame: DataFrame,
nWorkers: Int): DataFrame = {
// Group the data first
logger.info("Start groupBy for LTR")
val schema = dataFrame.schema
val groupedDF = dataFrame
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
implicit val encoder = RowEncoder(schema)
// Expand the grouped rows after repartition
groupedDF.repartition(nWorkers).mapPartitions(iter => {
new Iterator[Row] {
var iterInRow: Iterator[Any] = Iterator.empty
override def hasNext: Boolean = {
if (iter.hasNext && !iterInRow.hasNext) {
// the first is groupId, second is list
iterInRow = iter.next.getSeq(1).iterator
}
iterInRow.hasNext
}
override def next(): Row = {
iterInRow.next.asInstanceOf[Row]
}
}
})
}
private def buildRDDWatches(
dataMap: Map[String, ColumnDataBatch],
xgbExeParams: XGBoostExecutionParams,
noEvalSet: Boolean): RDD[Watches] = {
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
// Start training
if (noEvalSet) {
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
val colIndicesForTrain = dataMap(TRAIN_NAME).colIndices
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
iter =>
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
Iterator(buildWatches(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
colIndicesForTrain, iterColBatch, maxBin))
})
} else {
// Train with evaluation sets
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
nameAndColumnBatchIter =>
Iterator(buildWatchesWithEval(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
nameAndColIndices, nameAndColumnBatchIter, maxBin))
}
}
}
private def buildWatches(
cachedDirName: Option[String],
missing: Float,
indices: ColumnIndices,
iter: Iterator[GpuColumnBatch],
maxBin: Int): Watches = {
val (dm, time) = GpuUtils.time {
buildDMatrix(iter, indices, missing, maxBin)
}
logger.debug("Benchmark[Train: Build DMatrix incrementally] " + time)
val (aDMatrix, aName) = if (dm == null) {
(Array.empty[DMatrix], Array.empty[String])
} else {
(Array(dm), Array("train"))
}
new Watches(aDMatrix, aName, cachedDirName)
}
private def buildWatchesWithEval(
cachedDirName: Option[String],
missing: Float,
indices: Map[String, ColumnIndices],
nameAndColumns: Iterator[(String, Iterator[GpuColumnBatch])],
maxBin: Int): Watches = {
val dms = nameAndColumns.map {
case (name, iter) => (name, {
val (dm, time) = GpuUtils.time {
buildDMatrix(iter, indices(name), missing, maxBin)
}
logger.debug(s"Benchmark[Train build $name DMatrix] " + time)
dm
})
}.filter(_._2 != null).toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
}
/**
* Build DeviceQuantileDMatrix based on GpuColumnBatches
*
* @param iter a sequence of GpuColumnBatch
* @param indices indicate the feature, label, weight, base margin column ids.
* @param missing the missing value
* @param maxBin the maxBin
* @return DMatrix
*/
private def buildDMatrix(
iter: Iterator[GpuColumnBatch],
indices: ColumnIndices,
missing: Float,
maxBin: Int): DMatrix = {
val rapidsIterator = new RapidsIterator(iter, indices)
new DeviceQuantileDMatrix(rapidsIterator, missing, maxBin, 1)
}
// zip all the Columnar RDDs into one RDD containing named column data batch.
private def coPartitionForGpu(
dataMap: Map[String, ColumnDataBatch],
sc: SparkContext,
nWorkers: Int): RDD[(String, Iterator[GpuColumnBatch])] = {
val emptyDataRdd = sc.parallelize(
Array.fill[(String, Iterator[GpuColumnBatch])](nWorkers)(null), nWorkers)
dataMap.foldLeft(emptyDataRdd) {
case (zippedRdd, (name, gdfColData)) =>
zippedRdd.zipPartitions(GpuUtils.toColumnarRdd(gdfColData.rawDF)) {
(itWrapper, iterCol) =>
val itCol = iterCol.map(table => new GpuColumnBatch(table, null))
(itWrapper.toArray :+ (name -> itCol)).filter(x => x != null).toIterator
}
}
}
private[this] class RapidsIterator(
base: Iterator[GpuColumnBatch],
indices: ColumnIndices) extends Iterator[CudfColumnBatch] {
override def hasNext: Boolean = base.hasNext
override def next(): CudfColumnBatch = {
// Since we have sliced original Table into different tables. Needs to close the original one.
withResource(base.next()) { gpuColumnBatch =>
val weights = indices.weightId.map(Seq(_)).getOrElse(Seq.empty)
val margins = indices.marginId.map(Seq(_)).getOrElse(Seq.empty)
new CudfColumnBatch(
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava));
}
}
}
/** Executes the provided code block and then closes the resource */
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@ -0,0 +1,167 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.ColumnarRdd
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, NumericType, StructType}
private[spark] object GpuUtils {
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
/** APIs for gpu column data related */
def buildColumnDataBatch(featureNames: Seq[String],
labelName: String,
weightName: String,
marginName: String,
groupName: String,
dataFrame: DataFrame): ColumnDataBatch = {
// Some check first
val schema = dataFrame.schema
val featureNameSet = featureNames.distinct
GpuUtils.validateSchema(schema, featureNameSet, labelName, weightName, marginName)
// group column
val (opGroup, groupId) = if (groupName.isEmpty) {
(None, None)
} else {
GpuUtils.checkNumericType(schema, groupName)
(Some(groupName), Some(schema.fieldIndex(groupName)))
}
// weight and base margin columns
val Seq(weightId, marginId) = Seq(weightName, marginName).map {
name =>
if (name.isEmpty) None else Some(schema.fieldIndex(name))
}
val colsIndices = ColumnIndices(featureNameSet.map(schema.fieldIndex),
schema.fieldIndex(labelName), weightId, marginId, groupId)
ColumnDataBatch(dataFrame, colsIndices, opGroup)
}
def checkNumericType(schema: StructType, colName: String,
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of NumericType but found: " +
s"${actualDataType.catalogString}.$message")
}
/** Check and Cast the columns to FloatType */
def prepareColumnType(
dataset: Dataset[_],
featureNames: Seq[String],
labelName: String = "",
weightName: String = "",
marginName: String = "",
fitting: Boolean = true): DataFrame = {
// check first
val featureNameSet = featureNames.distinct
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
val castToFloat = (ds: Dataset[_], colName: String) => {
val colMeta = ds.schema(colName).metadata
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
}
val colNames = if (fitting) {
var names = featureNameSet :+ labelName
if (weightName.nonEmpty) {
names = names :+ weightName
}
if (marginName.nonEmpty) {
names = names :+ marginName
}
names
} else {
featureNameSet
}
colNames.foldLeft(dataset.asInstanceOf[DataFrame])(
(ds, colName) => castToFloat(ds, colName))
}
/** Validate input schema */
def validateSchema(schema: StructType,
featureNames: Seq[String],
labelName: String = "",
weightName: String = "",
marginName: String = "",
fitting: Boolean = true): StructType = {
val msg = if (fitting) "train" else "transform"
// feature columns
require(featureNames.nonEmpty, s"Gpu $msg requires features columns. " +
"please refer to setFeaturesCols!")
featureNames.foreach(fn => checkNumericType(schema, fn))
if (fitting) {
require(labelName.nonEmpty, "label column is not set.")
checkNumericType(schema, labelName)
if (weightName.nonEmpty) {
checkNumericType(schema, weightName)
}
if (marginName.nonEmpty) {
checkNumericType(schema, marginName)
}
}
schema
}
def time[R](block: => R): (R, Float) = {
val t0 = System.currentTimeMillis
val result = block // call-by-name
val t1 = System.currentTimeMillis
(result, (t1 - t0).toFloat / 1000)
}
/** Get column names from Parameter */
def getColumnNames(params: Params)(cols: Param[String]*): Seq[String] = {
// get column name, null | undefined will be casted to ""
def getColumnName(params: Params)(param: Param[String]): String = {
if (params.isDefined(param)) {
val colName = params.getOrDefault(param)
if (colName != null) colName else ""
} else ""
}
val getName = getColumnName(params)(_)
cols.map(getName)
}
}
/**
* A container to contain the column ids
*/
private[spark] case class ColumnIndices(
featureIds: Seq[Int],
labelId: Int,
weightId: Option[Int],
marginId: Option[Int],
groupId: Option[Int])
private[spark] case class ColumnDataBatch(
rawDF: DataFrame,
colIndices: ColumnIndices,
groupColName: Option[String])

View File

@ -0,0 +1 @@
../../../../../../../../xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark

View File

@ -0,0 +1 @@
../../../../xgboost4j-spark/src/main/scala/org

View File

@ -1 +0,0 @@
../../xgboost4j-spark/src/test

View File

@ -0,0 +1 @@
../../../xgboost4j-spark/src/test/resources

View File

@ -0,0 +1,293 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import com.nvidia.spark.rapids.RapidsConf
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.rapids.execution.TrampolineUtil
trait GpuTestSuite extends FunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
require(resource.startsWith("/"), "resource must start with /")
getClass.getResource(resource).getPath
}
def enableCsvConf(): SparkConf = {
new SparkConf()
.set(RapidsConf.ENABLE_READ_CSV_DATES.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_BYTES.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_SHORTS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_INTEGERS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_LONGS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_FLOATS.key, "true")
.set(RapidsConf.ENABLE_READ_CSV_DOUBLES.key, "true")
}
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
// set "spark.rapids.sql.explain" to "ALL" to check if the operators
// can be replaced by GPU
val c = conf.clone()
.set("spark.rapids.sql.enabled", "true")
withSparkSession(c, f)
}
def withCpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
val c = conf.clone()
.set("spark.rapids.sql.enabled", "false") // Just to be sure
withSparkSession(c, f)
}
def compareResults(
sort: Boolean,
floatEpsilon: Double,
fromLeft: Array[Row],
fromRight: Array[Row]): Boolean = {
if (sort) {
val left = fromLeft.map(_.toSeq).sortWith(seqLt)
val right = fromRight.map(_.toSeq).sortWith(seqLt)
compare(left, right, floatEpsilon)
} else {
compare(fromLeft, fromRight, floatEpsilon)
}
}
// we guarantee that the types will be the same
private def seqLt(a: Seq[Any], b: Seq[Any]): Boolean = {
if (a.length < b.length) {
return true
}
// lengths are the same
for (i <- a.indices) {
val v1 = a(i)
val v2 = b(i)
if (v1 != v2) {
// null is always < anything but null
if (v1 == null) {
return true
}
if (v2 == null) {
return false
}
(v1, v2) match {
case (i1: Int, i2: Int) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
}// else equal go on
case (i1: Long, i2: Long) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Float, i2: Float) => if (i1.isNaN() && !i2.isNaN()) return false
else if (!i1.isNaN() && i2.isNaN()) return true
else if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Date, i2: Date) => if (i1.before(i2)) {
return true
} else if (i1.after(i2)) {
return false
} // else equal go on
case (i1: Double, i2: Double) => if (i1.isNaN() && !i2.isNaN()) return false
else if (!i1.isNaN() && i2.isNaN()) return true
else if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Short, i2: Short) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Timestamp, i2: Timestamp) => if (i1.before(i2)) {
return true
} else if (i1.after(i2)) {
return false
} // else equal go on
case (s1: String, s2: String) =>
val cmp = s1.compareTo(s2)
if (cmp < 0) {
return true
} else if (cmp > 0) {
return false
} // else equal go on
case (o1, _) =>
throw new UnsupportedOperationException(o1.getClass + " is not supported yet")
}
}
}
// They are equal...
false
}
private def compare(expected: Any, actual: Any, epsilon: Double = 0.0): Boolean = {
def doublesAreEqualWithinPercentage(expected: Double, actual: Double): (String, Boolean) = {
if (!compare(expected, actual)) {
if (expected != 0) {
val v = Math.abs((expected - actual) / expected)
(s"\n\nABS($expected - $actual) / ABS($actual) == $v is not <= $epsilon ", v <= epsilon)
} else {
val v = Math.abs(expected - actual)
(s"\n\nABS($expected - $actual) == $v is not <= $epsilon ", v <= epsilon)
}
} else {
("SUCCESS", true)
}
}
(expected, actual) match {
case (a: Float, b: Float) if a.isNaN && b.isNaN => true
case (a: Double, b: Double) if a.isNaN && b.isNaN => true
case (null, null) => true
case (null, _) => false
case (_, null) => false
case (a: Array[_], b: Array[_]) =>
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
case (a: Map[_, _], b: Map[_, _]) =>
a.size == b.size && a.keys.forall { aKey =>
b.keys.find(bKey => compare(aKey, bKey))
.exists(bKey => compare(a(aKey), b(bKey), epsilon))
}
case (a: Iterable[_], b: Iterable[_]) =>
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
case (a: Product, b: Product) =>
compare(a.productIterator.toSeq, b.productIterator.toSeq, epsilon)
case (a: Row, b: Row) =>
compare(a.toSeq, b.toSeq, epsilon)
// 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
case (a: Double, b: Double) if epsilon <= 0 =>
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
case (a: Double, b: Double) if epsilon > 0 =>
val ret = doublesAreEqualWithinPercentage(a, b)
if (!ret._2) {
System.err.println(ret._1 + " (double)")
}
ret._2
case (a: Float, b: Float) if epsilon <= 0 =>
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
case (a: Float, b: Float) if epsilon > 0 =>
val ret = doublesAreEqualWithinPercentage(a, b)
if (!ret._2) {
System.err.println(ret._1 + " (float)")
}
ret._2
case (a, b) => a == b
}
}
}
trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory(getClass.getName)
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir.toFile)
super.afterAll()
}
protected def createTmpFolder(prefix: String): Path = {
Files.createTempDirectory(tempDir, prefix)
}
}
object SparkSessionHolder extends Logging {
private var spark = createSparkSession()
private var origConf = spark.conf.getAll
private var origConfKeys = origConf.keys.toSet
private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
case (key, value) if spark.conf.get(key, null) != value =>
spark.conf.set(key, value)
case _ => // No need to modify it
}
private def createSparkSession(): SparkSession = {
TrampolineUtil.cleanupAnyExistingSession()
// Timezone is fixed to UTC to allow timestamps to work by default
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
// Add Locale setting
Locale.setDefault(Locale.US)
val builder = SparkSession.builder()
.master("local[1]")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.rapids.sql.enabled", "false")
.config("spark.rapids.sql.test.enabled", "false")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
.appName("XGBoost4j-Spark-Gpu unit test")
builder.getOrCreate()
}
private def reinitSession(): Unit = {
spark = createSparkSession()
origConf = spark.conf.getAll
origConfKeys = origConf.keys.toSet
}
def sparkSession: SparkSession = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
}
spark
}
def resetSparkSessionConf(): Unit = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
} else {
setAllConfs(origConf.toArray)
val currentKeys = spark.conf.getAll.keys.toSet
val toRemove = currentKeys -- origConfKeys
toRemove.foreach(spark.conf.unset)
}
logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
}
def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
resetSparkSessionConf
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
setAllConfs(conf.getAll)
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
f(spark)
}
}

View File

@ -0,0 +1,226 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{FloatType, StructField, StructType}
class GpuXGBoostClassifierSuite extends GpuTestSuite {
private val dataPath = if (new java.io.File("../../demo/data/veterans_lung_cancer.csv").isFile) {
"../../demo/data/veterans_lung_cancer.csv"
} else {
"../demo/data/veterans_lung_cancer.csv"
}
val labelName = "label_col"
val schema = StructType(Seq(
StructField("f1", FloatType), StructField("f2", FloatType), StructField("f3", FloatType),
StructField("f4", FloatType), StructField("f5", FloatType), StructField("f6", FloatType),
StructField("f7", FloatType), StructField("f8", FloatType), StructField("f9", FloatType),
StructField("f10", FloatType), StructField("f11", FloatType), StructField("f12", FloatType),
StructField(labelName, FloatType)
))
val featureNames = schema.fieldNames.filter(s => !s.equals(labelName)).toSeq
test("The transform result should be same for several runs on same model") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostClassifier(xgbParam)
.fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("use weight") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
val model = new XGBoostClassifier(xgbParam)
.fit(originalDf)
val model2 = new XGBoostClassifier(xgbParam)
.setWeightCol("weight")
.fit(dfWithWeight)
val left = model.transform(testDf).collect()
val right = model2.transform(testDf).collect()
// left should be different with right
assert(!compareResults(true, 0.000001, left, right))
}
}
test("Save model and transform GPU dataset") {
// Train a model on GPU
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
(classifier.fit(rawInput), testDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val left = gpuModel
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
val right = gpuModelFromFile
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on CPU can transform GPU dataset") {
// Train a model on CPU
val cpuModel = withCpuSparkSession() { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol("features")
.setLabelCol(labelName)
.setTreeMethod("auto")
classifier.fit(trainingDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
cpuModel.write.overwrite().save(xgbrModel)
val cpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCols in the model manually
val thrown = intercept[IllegalArgumentException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
"please refer to setFeaturesCols"))
val left = cpuModel
.setFeaturesCols(featureNames)
.transform(testDf)
.collect()
val right = cpuModelFromFile
.setFeaturesCols(featureNames)
.transform(testDf)
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on GPU can transform CPU dataset") {
// Train a model on GPU
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
classifier.fit(rawInput)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on CPU
withCpuSparkSession() { spark =>
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
val thrown = intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("features does not exist"))
val left = gpuModel
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
val right = gpuModelFromFile
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
}

View File

@ -0,0 +1,182 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StringType
class GpuXGBoostGeneralSuite extends GpuTestSuite {
private val labelName = "label_col"
private val weightName = "weight_col"
private val baseMarginName = "margin_col"
private val featureNames = Seq("f1", "f2", "f3")
private val allColumnNames = featureNames :+ weightName :+ baseMarginName :+ labelName
private val trainingData = Seq(
// f1, f2, f3, weight, margin, label
(1.0f, 2.0f, 3.0f, 1.0f, 0.5f, 0),
(2.0f, 3.0f, 4.0f, 2.0f, 0.6f, 0),
(1.2f, 2.1f, 3.1f, 1.1f, 0.51f, 0),
(2.3f, 3.1f, 4.1f, 2.1f, 0.61f, 0),
(3.0f, 4.0f, 5.0f, 1.5f, 0.3f, 1),
(4.0f, 5.0f, 6.0f, 2.5f, 0.4f, 1),
(3.1f, 4.1f, 5.1f, 1.6f, 0.4f, 1),
(4.1f, 5.1f, 6.1f, 2.6f, 0.5f, 1),
(5.0f, 6.0f, 7.0f, 1.0f, 0.2f, 2),
(6.0f, 7.0f, 8.0f, 1.3f, 0.6f, 2),
(5.1f, 6.1f, 7.1f, 1.2f, 0.1f, 2),
(6.1f, 7.1f, 8.1f, 1.4f, 0.7f, 2),
(6.2f, 7.2f, 8.2f, 1.5f, 0.8f, 2))
test("MLlib way setting features_cols should work") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
new XGBoostClassifier(xgbParam)
.fit(trainingDf)
}
}
test("disorder feature columns should work") {
withGpuSparkSession() { spark =>
import spark.implicits._
var trainingDf = trainingData.toDF(allColumnNames: _*)
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
}
test("Throw exception when feature/label columns are not numeric type") {
withGpuSparkSession() { spark =>
import spark.implicits._
val originalDf = trainingData.toDF(allColumnNames: _*)
var trainingDf = originalDf.withColumn("f2", col("f2").cast(StringType))
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val thrown1 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown1.getMessage.contains("Column f2 must be of NumericType but found: string."))
trainingDf = originalDf.withColumn(labelName, col(labelName).cast(StringType))
val thrown2 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown2.getMessage.contains(
s"Column $labelName must be of NumericType but found: string."))
}
}
test("Throw exception when features_cols or label_col is not set") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown.getMessage.contains("Gpu train requires features columns."))
val thrown1 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.fit(trainingDf)
}
assert(thrown1.getMessage.contains("label does not exist."))
}
}
test("Throw exception when tree method is not set to gpu_hist") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "hist")
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist"))
}
}
test("Train with eval") {
withGpuSparkSession() { spark =>
import spark.implicits._
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
.randomSplit(Array(0.6, 0.2, 0.2), seed = 1)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val model1 = new XGBoostClassifier(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
.fit(trainingDf)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
}
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
withGpuSparkSession() { spark =>
import spark.implicits._
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val xgbc = new XGBoostClassifier(xgbParam)
xgbc.write.overwrite().save(xgbcPath)
val paramMap2 = XGBoostClassifier.load(xgbcPath).MLlib2XGBoostParams
xgbParam.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
}
}
}

View File

@ -0,0 +1,239 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}
class GpuXGBoostRegressorSuite extends GpuTestSuite {
val labelName = "label_col"
val groupName = "group_col"
val schema = StructType(Seq(
StructField(labelName, FloatType),
StructField("f1", FloatType),
StructField("f2", FloatType),
StructField("f3", FloatType),
StructField(groupName, IntegerType)))
val featureNames = schema.fieldNames.filter(s =>
!(s.equals(labelName) || s.equals(groupName))).toSeq
test("The transform result should be same for several runs on same model") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostRegressor(xgbParam)
.fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("use weight") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
val model = new XGBoostRegressor(xgbParam)
.fit(originalDf)
val model2 = new XGBoostRegressor(xgbParam)
.setWeightCol("weight")
.fit(dfWithWeight)
val left = model.transform(testDf).collect()
val right = model2.transform(testDf).collect()
// left should be different with right
assert(!compareResults(true, 0.000001, left, right))
}
}
test("Save model and transform GPU dataset") {
// Train a model on GPU
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
(classifier.fit(rawInput), testDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val left = gpuModel
.transform(testDf)
.select(labelName, "prediction")
.collect()
val right = gpuModelFromFile
.transform(testDf)
.select(labelName, "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on CPU can transform GPU dataset") {
// Train a model on CPU
val cpuModel = withCpuSparkSession() { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol("features")
.setLabelCol(labelName)
.setTreeMethod("auto")
classifier.fit(trainingDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
cpuModel.write.overwrite().save(xgbrModel)
val cpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCols in the model manually
val thrown = intercept[IllegalArgumentException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
"please refer to setFeaturesCols"))
val left = cpuModel
.setFeaturesCols(featureNames)
.transform(testDf)
.collect()
val right = cpuModelFromFile
.setFeaturesCols(featureNames)
.transform(testDf)
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on GPU can transform CPU dataset") {
// Train a model on GPU
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCols(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
classifier.fit(rawInput)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on CPU
withCpuSparkSession() { spark =>
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
val thrown = intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("features does not exist"))
val left = gpuModel
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "prediction")
.collect()
val right = gpuModelFromFile
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Ranking: train with Group") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:pairwise",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val model = new XGBoostRegressor(xgbParam)
.setGroupCol(groupName)
.fit(trainingDf)
val ret = model.transform(testDf).collect()
assert(testDf.count() === ret.length)
}
}
}

View File

@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
@ -24,7 +25,6 @@ import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel._originalPredictionCol
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import org.apache.spark.rdd.RDD
@ -35,7 +35,7 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@ -43,7 +43,7 @@ import org.apache.spark.storage.StorageLevel
/**
* PreXGBoost serves preparing data before training and transform
*/
object PreXGBoost {
object PreXGBoost extends PreXGBoostProvider {
private val logger = LogFactory.getLog("XGBoostSpark")
@ -51,6 +51,48 @@ object PreXGBoost {
private lazy val defaultWeightColumn = lit(1.0)
private lazy val defaultGroupColumn = lit(-1)
// Find the correct PreXGBoostProvider by ServiceLoader
private val optionProvider: Option[PreXGBoostProvider] = {
val classLoader = Option(Thread.currentThread().getContextClassLoader)
.getOrElse(getClass.getClassLoader)
val serviceLoader = ServiceLoader.load(classOf[PreXGBoostProvider], classLoader)
// For now, we only trust GpuPreXGBoost.
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
"ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost")).toList match {
case Nil => None
case head::Nil =>
Some(head)
case _ => None
}
}
/**
* Transform schema
*
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
override def transformSchema(
xgboostEstimator: XGBoostEstimatorCommon,
schema: StructType): StructType = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(None)) {
return optionProvider.get.transformSchema(xgboostEstimator, schema)
}
xgboostEstimator match {
case est: XGBoostClassifier => est.transformSchemaInternal(schema)
case model: XGBoostClassificationModel => model.transformSchemaInternal(schema)
case reg: XGBoostRegressor => reg.transformSchemaInternal(schema)
case model: XGBoostRegressionModel => model.transformSchemaInternal(schema)
case _ => throw new RuntimeException("Unsupporting " + xgboostEstimator)
}
}
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
*
@ -61,11 +103,15 @@ object PreXGBoost {
* RDD[Watches] will be used as the training input
* Option[RDD[_]\] is the optional cached RDD
*/
def buildDatasetToRDD(
override def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]]) = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
}
val (packedParams, evalSet) = estimator match {
case est: XGBoostEstimatorCommon =>
// get weight column, if weight is not defined, default to lit(1.0)
@ -131,7 +177,11 @@ object PreXGBoost {
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
def transformDataFrame(model: Model[_], dataset: Dataset[_]): DataFrame = {
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
return optionProvider.get.transformDataset(model, dataset)
}
/** get the necessary parameters */
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
@ -467,7 +517,7 @@ object PreXGBoost {
}
}
private def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
private[scala] def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
val taskId = TaskContext.getPartitionId().toString
if (useExternalMemory) {
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")

View File

@ -0,0 +1,71 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset}
/**
* PreXGBoost implementation provider
*/
private[scala] trait PreXGBoostProvider {
/**
* Whether the provider is enabled or not
* @param dataset the input dataset
* @return Boolean
*/
def providerEnabled(dataset: Option[Dataset[_]]): Boolean = false
/**
* Transform schema
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
/**
* Convert the Dataset[_] to RDD[Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[Watches]], Option[ RDD[_] ])
* RDD[Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams => (RDD[Watches], Option[RDD[_]])
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame
}

View File

@ -53,12 +53,12 @@ object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
}
private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
maximizeEvalMetrics: Boolean)
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
private[spark] case class XGBoostExecutionParams(
private[scala] case class XGBoostExecutionParams(
numWorkers: Int,
numRounds: Int,
useExternalMemory: Boolean,
@ -257,7 +257,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private def getGPUAddrFromResources: Int = {
def getGPUAddrFromResources: Int = {
val tc = TaskContext.get()
if (tc == null) {
throw new RuntimeException("Something wrong for task context")
@ -473,7 +473,7 @@ object XGBoost extends Serializable {
}
class Watches private(
class Watches private[scala] (
val datasets: Array[DMatrix],
val names: Array[String],
val cacheDirName: Option[String]) {

View File

@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._
@ -27,9 +28,10 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}
import org.apache.spark.sql.types.StructType
class XGBoostClassifier (
override val uid: String,
private[spark] val xgboostParams: Map[String, Any])
@ -142,6 +144,13 @@ class XGBoostClassifier (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCols(value: Seq[String]): this.type =
set(featuresCols, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
@ -154,6 +163,15 @@ class XGBoostClassifier (
}
}
// Callback from PreXGBoost
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
@ -196,7 +214,7 @@ object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
class XGBoostClassificationModel private[ml](
override val uid: String,
override val numClasses: Int,
private[spark] val _booster: Booster)
private[scala] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with InferenceParams
with MLWritable with Serializable {
@ -242,6 +260,13 @@ class XGBoostClassificationModel private[ml](
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCols(value: Seq[String]): this.type =
set(featuresCols, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
@ -271,7 +296,7 @@ class XGBoostClassificationModel private[ml](
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
}
private[spark] def produceResultIterator(
private[scala] def produceResultIterator(
originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row],
@ -306,7 +331,7 @@ class XGBoostClassificationModel private[ml](
}
}
private[spark] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
private[scala] def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = {
val rawPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
@ -333,6 +358,14 @@ class XGBoostClassificationModel private[ml](
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
@ -343,7 +376,7 @@ class XGBoostClassificationModel private[ml](
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = PreXGBoost.transformDataFrame(this, dataset)
var outputData = PreXGBoost.transformDataset(this, dataset)
var numColsOutput = 0
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
@ -404,8 +437,8 @@ class XGBoostClassificationModel private[ml](
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
private[spark] val _rawPredictionCol = "_rawPrediction"
private[spark] val _probabilityCol = "_probability"
private[scala] val _rawPredictionCol = "_rawPrediction"
private[scala] val _probabilityCol = "_probability"
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader

View File

@ -32,6 +32,7 @@ import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.types.StructType
class XGBoostRegressor (
override val uid: String,
@ -145,6 +146,13 @@ class XGBoostRegressor (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCols(value: Seq[String]): this.type =
set(featuresCols, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
@ -155,6 +163,14 @@ class XGBoostRegressor (
}
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
@ -191,7 +207,7 @@ object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
class XGBoostRegressionModel private[ml] (
override val uid: String,
private[spark] val _booster: Booster)
private[scala] val _booster: Booster)
extends PredictionModel[Vector, XGBoostRegressionModel]
with XGBoostRegressorParams with InferenceParams
with MLWritable with Serializable {
@ -237,6 +253,13 @@ class XGBoostRegressionModel private[ml] (
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCols(value: Seq[String]): this.type =
set(featuresCols, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
@ -251,7 +274,7 @@ class XGBoostRegressionModel private[ml] (
_booster.predict(data = dm)(0)(0)
}
private[spark] def produceResultIterator(
private[scala] def produceResultIterator(
originalRowItr: Iterator[Row],
predictionItr: Iterator[Row],
predLeafItr: Iterator[Row],
@ -283,7 +306,7 @@ class XGBoostRegressionModel private[ml] (
}
}
private[spark] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
private[scala] def producePredictionItrs(booster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = {
val originalPredictionItr = {
booster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
@ -307,11 +330,19 @@ class XGBoostRegressionModel private[ml] (
Array(originalPredictionItr, predLeafItr, predContribItr)
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = PreXGBoost.transformDataFrame(this, dataset)
var outputData = PreXGBoost.transformDataset(this, dataset)
var numColsOutput = 0
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
@ -342,7 +373,7 @@ class XGBoostRegressionModel private[ml] (
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
private[spark] val _originalPredictionCol = "_originalPrediction"
private[scala] val _originalPredictionCol = "_originalPrediction"
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014,2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -289,7 +289,7 @@ private[spark] trait BoosterParams extends Params {
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
}
private[spark] object BoosterParams {
private[scala] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")

View File

@ -0,0 +1,53 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.apache.spark.ml.param.{BooleanParam, Param, Params}
trait GpuParams extends Params {
/**
* Param for the names of feature columns.
* @group param
*/
final val featuresCols: StringSeqParam = new StringSeqParam(this, "featuresCols",
"a sequence of feature column names.")
setDefault(featuresCols, Seq.empty[String])
/** @group getParam */
final def getFeaturesCols: Seq[String] = $(featuresCols)
}
class StringSeqParam(
parent: Params,
name: String,
doc: String) extends Param[Seq[String]](parent, name, doc) {
override def jsonEncode(value: Seq[String]): String = {
import org.json4s.JsonDSL._
compact(render(value))
}
override def jsonDecode(json: String): Seq[String] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[String]]
}
}

View File

@ -18,16 +18,16 @@ package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol {
with HasLabelCol with GpuParams {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
}
private[spark] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
private[spark] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
private[scala] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol