[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)

- Introduce an abstract XGBoost Estimator
- Update to the latest XGBoost parameters
  - Add all XGBoost parameters supported in XGBoost4j-spark.
  - Add setter and getter for these parameters.
  - Remove the deprecated parameters
- Address the missing value handling
- Remove any ETL operations in XGBoost
- Rework the GPU plugin
- Expand sanity tests for CPU and GPU consistency
This commit is contained in:
Bobby Wang 2024-09-11 15:54:19 +08:00 committed by GitHub
parent d94f6679fc
commit 67c8c96784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 4537 additions and 7556 deletions

View File

@ -38,6 +38,7 @@ Contents
XGBoost4J-Spark-GPU Tutorial <xgboost4j_spark_gpu_tutorial>
Code Examples <https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example>
API docs <api>
How to migrate to XGBoost-Spark jvm 3.x <xgboost_spark_migration>
.. note::

View File

@ -0,0 +1,162 @@
########################################################
Migration Guide: How to migrate to XGBoost-Spark jvm 3.x
########################################################
XGBoost-Spark jvm packages underwent significant modifications in version 3.0,
which may cause compatibility issues with existing user code.
This guide will walk you through the process of updating your code to ensure
it's compatible with XGBoost-Spark 3.0 and later versions.
**********************
XGBoost Spark Packages
**********************
XGBoost-Spark 3.0 introduced a single uber package named xgboost-spark_2.12-3.0.0.jar, which bundles
both xgboost4j and xgboost4j-spark. This means you can now simply use `xgboost-spark`` for your application.
* For CPU
.. code-block:: xml
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-spark_${scala.binary.version}</artifactId>
<version>3.0.0</version>
</dependency>
* For GPU
.. code-block:: xml
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-spark-gpu_${scala.binary.version}</artifactId>
<version>3.0.0</version>
</dependency>
When submitting the XGBoost application to the Spark cluster, you only need to specify the single `xgboost-spark` package.
* For CPU
.. code-block:: bash
spark-submit \
--jars xgboost-spark_2.12-3.0.0.jar \
... \
* For GPU
.. code-block:: bash
spark-submit \
--jars xgboost-spark_2.12-3.0.0.jar \
... \
**************
XGBoost Ranking
**************
Learning to rank using XGBoostRegressor has been replaced by a dedicated `XGBoostRanker`, which is specifically designed
to support ranking algorithms.
.. code-block:: scala
// before 3.0
val regressor = new XGBoostRegressor().setObjective("rank:ndcg")
// after 3.0
val ranker = new XGBoostRanker()
******************************
XGBoost Constructor Parameters
******************************
XGBoost Spark now categorizes parameters into two groups: XGBoost-Spark parameters and XGBoost parameters.
When constructing an XGBoost estimator, only XGBoost-specific parameters are permitted. XGBoost-Spark specific
parameters must be configured using the estimator's setter methods. It's worth noting that
`XGBoost Parameters <https://xgboost.readthedocs.io/en/stable/parameter.html>`_
can be set both during construction and through the estimator's setter methods.
.. code-block:: scala
// before 3.0
val xgboost_paras = Map(
"eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> 1,
"features" -> "feature_column",
"label" -> "label_column",
)
val classifier = new XGBoostClassifier(xgboost_paras)
// after 3.0
val xgboost_paras = Map(
"eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic",
)
val classifier = new XGBoostClassifier(xgboost_paras)
.setNumRound(5)
.setNumWorkers(1)
.setFeaturesCol("feature_column")
.setLabelCol("label_column")
// Or you can use setter to set all parameters
val classifier = new XGBoostClassifier()
.setNumRound(5)
.setNumWorkers(1)
.setFeaturesCol("feature_column")
.setLabelCol("label_column")
.setEta(1)
.setMaxDepth(6)
.setObjective("binary:logistic")
******************
Removed Parameters
******************
Starting from 3.0, below parameters are removed.
- cacheTrainingSet
If you wish to cache the training dataset, you have the option to implement caching
in your code prior to fitting the data to an estimator.
.. code-block:: scala
val df = input.cache()
val model = new XGBoostClassifier().fit(df)
- trainTestRatio
The following method can be employed to do the evaluation.
.. code-block:: scala
val Array(train, eval) = trainDf.randomSplit(Array(0.7, 0.3))
val classifier = new XGBoostClassifer().setEvalDataset(eval)
val model = classifier.fit(train)
- tracker_conf
The following method can be used to configure RabitTracker.
.. code-block:: scala
val classifier = new XGBoostClassifer()
.setRabitTrackerTimeout(100)
.setRabitTrackerHostIp("192.168.0.2")
.setRabitTrackerPort(19203)
- rabitRingReduceThreshold
- rabitTimeout
- rabitConnectRetry
- singlePrecisionHistogram
- lambdaBias
- objectiveType

View File

@ -46,7 +46,7 @@
<use.cuda>OFF</use.cuda>
<cudf.version>24.06.0</cudf.version>
<spark.rapids.version>24.06.0</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier>
<spark.rapids.classifier>cuda12</spark.rapids.classifier>
<scalatest.version>3.2.19</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
<skip.native.build>false</skip.native.build>

View File

@ -54,6 +54,7 @@
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
<version>${spark.rapids.version}</version>
<classifier>${spark.rapids.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>

View File

@ -35,11 +35,39 @@ public class QuantileDMatrix extends DMatrix {
float missing,
int maxBin,
int nthread) throws XGBoostError {
this(iter, null, missing, maxBin, nthread);
}
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array
* interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation
* applied to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
QuantileDMatrix refDMatrix,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
long[] ref = null;
if (refDMatrix != null) {
ref = new long[1];
ref[0] = refDMatrix.getHandle();
}
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, null, conf, out));
iter, ref, conf, out));
handle = out[0];
}
@ -87,4 +115,5 @@ public class QuantileDMatrix extends DMatrix {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}

View File

@ -1,68 +0,0 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.nvidia.spark;
import java.util.List;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Table;
import org.apache.spark.sql.types.*;
/**
* Wrapper of CudfTable with schema for scala
*/
public class GpuColumnBatch implements AutoCloseable {
private final StructType schema;
private Table table; // the original Table
public GpuColumnBatch(Table table, StructType schema) {
this.table = table;
this.schema = schema;
}
@Override
public void close() {
if (table != null) {
table.close();
table = null;
}
}
/** Slice the columns indicated by indices into a Table*/
public Table slice(List<Integer> indices) {
if (indices == null || indices.size() == 0) {
return null;
}
int len = indices.size();
ColumnVector[] cv = new ColumnVector[len];
for (int i = 0; i < len; i++) {
int index = indices.get(i);
if (index >= table.getNumberOfColumns()) {
throw new RuntimeException("Wrong index");
}
cv[i] = table.getColumn(index);
}
return new Table(cv);
}
public StructType getSchema() {
return schema;
}
}

View File

@ -1 +0,0 @@
ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost

View File

@ -0,0 +1 @@
ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,17 +16,17 @@
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
import scala.collection.JavaConverters._
class QuantileDMatrix private[scala](
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
@ -36,6 +36,27 @@ class QuantileDMatrix private[scala](
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
}
/**
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch],
ref: QuantileDMatrix,
missing: Float,
maxBin: Int,
nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
}
/**
* set label of dmatrix
*
@ -84,7 +105,7 @@ class QuantileDMatrix private[scala](
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
/**
* Set label of DMatrix from cuda array interface
* Set label of DMatrix from array interface
*/
@throws(classOf[XGBoostError])
override def setLabel(column: Column): Unit =
@ -104,4 +125,9 @@ class QuantileDMatrix private[scala](
override def setBaseMargin(column: Column): Unit =
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
@throws(classOf[XGBoostError])
override def setQueryId(column: Column): Unit = {
throw new XGBoostError("QuantileDMatrix does not support setQueryId.")
}
}

View File

@ -1,602 +0,0 @@
/*
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.commons.logging.LogFactory
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* GpuPreXGBoost brings Rapids-Plugin to XGBoost4j-Spark to accelerate XGBoost4j
* training and transform process
*/
class GpuPreXGBoost extends PreXGBoostProvider {
/**
* Whether the provider is enabled or not
*
* @param dataset the input dataset
* @return Boolean
*/
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
GpuPreXGBoost.providerEnabled(dataset)
}
/**
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator [[XGBoostClassifier]] or [[XGBoostRegressor]]
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]):
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
GpuPreXGBoost.buildDatasetToRDD(estimator, dataset, params)
}
/**
* Transform Dataset
*
* @param model [[XGBoostClassificationModel]] or [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
GpuPreXGBoost.transformDataset(model, dataset)
}
override def transformSchema(
xgboostEstimator: XGBoostEstimatorCommon,
schema: StructType): StructType = {
GpuPreXGBoost.transformSchema(xgboostEstimator, schema)
}
}
class BoosterFlag extends Serializable {
// indicate if the GPU parameters are set.
var isGpuParamsSet = false
}
object GpuPreXGBoost extends PreXGBoostProvider {
private val logger = LogFactory.getLog("XGBoostSpark")
private val FEATURES_COLS = "features_cols"
private val TRAIN_NAME = "train"
override def providerEnabled(dataset: Option[Dataset[_]]): Boolean = {
// RuntimeConfig
val optionConf = dataset.map(ds => Some(ds.sparkSession.conf))
.getOrElse(SparkSession.getActiveSession.map(ss => ss.conf))
if (optionConf.isDefined) {
val conf = optionConf.get
val rapidsEnabled = try {
conf.get("spark.rapids.sql.enabled").toBoolean
} catch {
// Rapids plugin has default "spark.rapids.sql.enabled" to true
case _: NoSuchElementException => true
case _: Throwable => false // Any exception will return false
}
rapidsEnabled && conf.get("spark.sql.extensions", "")
.split(",")
.contains("com.nvidia.spark.rapids.SQLExecPlugin")
} else false
}
/**
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input to build DMatrix
* Option[ RDD[_] ] is the optional cached RDD
*/
override def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]):
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
case est: XGBoostEstimatorCommon =>
require(
est.isDefined(est.device) &&
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires `device` set to `cuda` or `gpu`."
)
val groupName = estimator match {
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
regressor.getGroupCol } else ""
case _: XGBoostClassifier => ""
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
}
// Check schema and cast columns' type
(GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol, est.baseMarginCol),
est.getFeaturesCols, groupName, est.getEvalSets(params))
case _ => throw new RuntimeException("Unsupported estimator: " + estimator)
}
val castedDF = GpuUtils.prepareColumnType(dataset, feturesCols, labelName, weightName,
marginName)
// Check columns and build column data batch
val trainingData = GpuUtils.buildColumnDataBatch(feturesCols,
labelName, weightName, marginName, groupName, castedDF)
// eval map
val evalDataMap = evalSets.map {
case (name, df) =>
val castDF = GpuUtils.prepareColumnType(df, feturesCols, labelName,
weightName, marginName)
(name, GpuUtils.buildColumnDataBatch(feturesCols, labelName, weightName,
marginName, groupName, castDF))
}
xgbExecParams: XGBoostExecutionParams =>
val dataMap = prepareInputData(trainingData, evalDataMap, xgbExecParams.numWorkers,
xgbExecParams.cacheTrainingSet)
(buildRDDWatches(dataMap, xgbExecParams, evalDataMap.isEmpty), None)
}
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
val (booster, predictFunc, schema, featureColNames, missing) = model match {
case m: XGBoostClassificationModel =>
Seq(XGBoostClassificationModel._rawPredictionCol,
XGBoostClassificationModel._probabilityCol, m.leafPredictionCol, m.contribPredictionCol)
// predict and turn to Row
val predictFunc =
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
case m: XGBoostRegressionModel =>
Seq(XGBoostRegressionModel._originalPredictionCol, m.leafPredictionCol,
m.contribPredictionCol)
// predict and turn to Row
val predictFunc =
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr,
predContribItr)
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, predictFunc, schema, m.getFeaturesCols, m.getMissing)
}
val sc = dataset.sparkSession.sparkContext
// Prepare some vars will be passed to executors.
val bOrigSchema = sc.broadcast(dataset.schema)
val bRowSchema = sc.broadcast(schema)
val bBooster = sc.broadcast(booster)
val bBoosterFlag = sc.broadcast(new BoosterFlag)
// Small vars so don't need to broadcast them
val isLocal = sc.isLocal
val featureIds = featureColNames.distinct.map(dataset.schema.fieldIndex)
// start transform by df->rd->mapPartition
val rowRDD: RDD[Row] = GpuUtils.toColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions {
tableIters =>
// UnsafeProjection is not serializable so do it on the executor side
val toUnsafe = UnsafeProjection.create(bOrigSchema.value)
// booster is visible for all spark tasks in the same executor
val booster = bBooster.value
val boosterFlag = bBoosterFlag.value
synchronized {
// there are two kind of race conditions,
// 1. multi-taskes set parameters at a time
// 2. one task sets parameter and another task reads the parameter
// both of them can cause potential un-expected behavior, moreover,
// it may cause executor crash
// So add synchronized to allow only one task to set parameter if it is not set.
// and rely on BlockManager to ensure the same booster only be called once to
// set parameter.
if (!boosterFlag.isGpuParamsSet) {
// set some params of gpu related to booster
// - gpu id
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("device", s"cuda:$gpuId")
logger.info("GPU transform on device: " + gpuId)
boosterFlag.isGpuParamsSet = true;
}
}
// Iterator on Row
new Iterator[Row] {
// Convert InternalRow to Row
private val converter: InternalRow => Row = CatalystTypeConverters
.createToScalaConverter(bOrigSchema.value)
.asInstanceOf[InternalRow => Row]
// GPU batches read in must be closed by the receiver (us)
@transient var currentBatch: ColumnarBatch = null
// Iterator on Row
var iter: Iterator[Row] = null
TaskContext.get().addTaskCompletionListener[Unit](_ => {
closeCurrentBatch() // close the last ColumnarBatch
})
private def closeCurrentBatch(): Unit = {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
}
def loadNextBatch(): Unit = {
closeCurrentBatch()
if (tableIters.hasNext) {
val dataTypes = bOrigSchema.value.fields.map(x => x.dataType)
iter = withResource(tableIters.next()) { table =>
val gpuColumnBatch = new GpuColumnBatch(table, bOrigSchema.value)
// Create DMatrix
val feaTable = gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(featureIds).asJava)
if (feaTable == null) {
throw new RuntimeException("Something wrong for feature indices")
}
try {
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
val dm = new DMatrix(cudfColumnBatch, missing, 1)
if (dm == null) {
Iterator.empty
} else {
try {
currentBatch = new ColumnarBatch(
GpuUtils.extractBatchToHost(table, dataTypes),
table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala
.map(toUnsafe)
.map(converter(_))
predictFunc(booster, dm, rowIterator)
} finally {
dm.delete()
}
}
} finally {
feaTable.close()
}
}
} else {
iter = null
}
}
override def hasNext: Boolean = {
val itHasNext = iter != null && iter.hasNext
if (!itHasNext) { // Don't have extra Row for current ColumnarBatch
loadNextBatch()
iter != null && iter.hasNext
} else {
itHasNext
}
}
override def next(): Row = {
if (iter == null || !iter.hasNext) {
loadNextBatch()
}
if (iter == null) {
throw new NoSuchElementException()
}
iter.next()
}
}
}
bOrigSchema.unpersist(blocking = false)
bRowSchema.unpersist(blocking = false)
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rowRDD, schema)
}
/**
* Transform schema
*
* @param est supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
override def transformSchema(
est: XGBoostEstimatorCommon,
schema: StructType): StructType = {
val fit = est match {
case _: XGBoostClassifier | _: XGBoostRegressor => true
case _ => false
}
val Seq(label, weight, margin) = GpuUtils.getColumnNames(est)(est.labelCol, est.weightCol,
est.baseMarginCol)
GpuUtils.validateSchema(schema, est.getFeaturesCols, label, weight, margin, fit)
}
/**
* Repartition all the Columnar Dataset (training and evaluation) to nWorkers,
* and assemble them into a map
*/
private def prepareInputData(
trainingData: ColumnDataBatch,
evalSetsMap: Map[String, ColumnDataBatch],
nWorkers: Int,
isCacheData: Boolean): Map[String, ColumnDataBatch] = {
// Cache is not supported
if (isCacheData) {
logger.warn("the cache param will be ignored by GPU pipeline!")
}
(Map(TRAIN_NAME -> trainingData) ++ evalSetsMap).map {
case (name, colData) =>
// No light cost way to get number of partitions from DataFrame, so always repartition
val newDF = colData.groupColName
.map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers))
.getOrElse(repartitionInputData(colData.rawDF, nWorkers))
name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName)
}
}
private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = {
// we can't involve any coalesce operation here, since Barrier mode will check
// the RDD patterns which does not allow coalesce.
dataFrame.repartition(nWorkers)
}
private def repartitionForGroup(
groupName: String,
dataFrame: DataFrame,
nWorkers: Int): DataFrame = {
// Group the data first
logger.info("Start groupBy for LTR")
val schema = dataFrame.schema
val groupedDF = dataFrame
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {
var iterInRow: Iterator[Any] = Iterator.empty
override def hasNext: Boolean = {
if (iter.hasNext && !iterInRow.hasNext) {
// the first is groupId, second is list
iterInRow = iter.next.getSeq(1).iterator
}
iterInRow.hasNext
}
override def next(): Row = {
iterInRow.next.asInstanceOf[Row]
}
}
})
}
private def buildRDDWatches(
dataMap: Map[String, ColumnDataBatch],
xgbExeParams: XGBoostExecutionParams,
noEvalSet: Boolean): RDD[() => Watches] = {
val sc = dataMap(TRAIN_NAME).rawDF.sparkSession.sparkContext
val maxBin = xgbExeParams.toMap.getOrElse("max_bin", 256).asInstanceOf[Int]
// Start training
if (noEvalSet) {
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
val colIndicesForTrain = dataMap(TRAIN_NAME).colIndices
GpuUtils.toColumnarRdd(dataMap(TRAIN_NAME).rawDF).mapPartitions({
iter =>
val iterColBatch = iter.map(table => new GpuColumnBatch(table, null))
Iterator(() => buildWatches(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
colIndicesForTrain, iterColBatch, maxBin))
})
} else {
// Train with evaluation sets
// Get the indices here at driver side to avoid passing the whole Map to executor(s)
val nameAndColIndices = dataMap.map(nc => (nc._1, nc._2.colIndices))
coPartitionForGpu(dataMap, sc, xgbExeParams.numWorkers).mapPartitions {
nameAndColumnBatchIter =>
Iterator(() => buildWatchesWithEval(
PreXGBoost.getCacheDirName(xgbExeParams.useExternalMemory), xgbExeParams.missing,
nameAndColIndices, nameAndColumnBatchIter, maxBin))
}
}
}
private def buildWatches(
cachedDirName: Option[String],
missing: Float,
indices: ColumnIndices,
iter: Iterator[GpuColumnBatch],
maxBin: Int): Watches = {
val (dm, time) = GpuUtils.time {
buildDMatrix(iter, indices, missing, maxBin)
}
logger.debug("Benchmark[Train: Build DMatrix incrementally] " + time)
val (aDMatrix, aName) = if (dm == null) {
(Array.empty[DMatrix], Array.empty[String])
} else {
(Array(dm), Array("train"))
}
new Watches(aDMatrix, aName, cachedDirName)
}
private def buildWatchesWithEval(
cachedDirName: Option[String],
missing: Float,
indices: Map[String, ColumnIndices],
nameAndColumns: Iterator[(String, Iterator[GpuColumnBatch])],
maxBin: Int): Watches = {
val dms = nameAndColumns.map {
case (name, iter) => (name, {
val (dm, time) = GpuUtils.time {
buildDMatrix(iter, indices(name), missing, maxBin)
}
logger.debug(s"Benchmark[Train build $name DMatrix] " + time)
dm
})
}.filter(_._2 != null).toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
}
/**
* Build QuantileDMatrix based on GpuColumnBatches
*
* @param iter a sequence of GpuColumnBatch
* @param indices indicate the feature, label, weight, base margin column ids.
* @param missing the missing value
* @param maxBin the maxBin
* @return DMatrix
*/
private def buildDMatrix(
iter: Iterator[GpuColumnBatch],
indices: ColumnIndices,
missing: Float,
maxBin: Int): DMatrix = {
val rapidsIterator = new RapidsIterator(iter, indices)
new QuantileDMatrix(rapidsIterator, missing, maxBin, 1)
}
// zip all the Columnar RDDs into one RDD containing named column data batch.
private def coPartitionForGpu(
dataMap: Map[String, ColumnDataBatch],
sc: SparkContext,
nWorkers: Int): RDD[(String, Iterator[GpuColumnBatch])] = {
val emptyDataRdd = sc.parallelize(
Array.fill[(String, Iterator[GpuColumnBatch])](nWorkers)(null), nWorkers)
dataMap.foldLeft(emptyDataRdd) {
case (zippedRdd, (name, gdfColData)) =>
zippedRdd.zipPartitions(GpuUtils.toColumnarRdd(gdfColData.rawDF)) {
(itWrapper, iterCol) =>
val itCol = iterCol.map(table => new GpuColumnBatch(table, null))
(itWrapper.toArray :+ (name -> itCol)).filter(x => x != null).toIterator
}
}
}
private[this] class RapidsIterator(
base: Iterator[GpuColumnBatch],
indices: ColumnIndices) extends Iterator[CudfColumnBatch] {
override def hasNext: Boolean = base.hasNext
override def next(): CudfColumnBatch = {
// Since we have sliced original Table into different tables. Needs to close the original one.
withResource(base.next()) { gpuColumnBatch =>
val weights = indices.weightId.map(Seq(_)).getOrElse(Seq.empty)
val margins = indices.marginId.map(Seq(_)).getOrElse(Seq.empty)
new CudfColumnBatch(
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
null);
}
}
}
/** Executes the provided code block and then closes the resource */
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@ -1,178 +0,0 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
import org.apache.spark.sql.vectorized.ColumnVector
private[spark] object GpuUtils {
def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
// spark-rapids has shimmed the GpuColumnVector from 22.10
GpuColumnVectorUtils.extractHostColumns(table, types)
}
def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)
def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
/** APIs for gpu column data related */
def buildColumnDataBatch(featureNames: Seq[String],
labelName: String,
weightName: String,
marginName: String,
groupName: String,
dataFrame: DataFrame): ColumnDataBatch = {
// Some check first
val schema = dataFrame.schema
val featureNameSet = featureNames.distinct
GpuUtils.validateSchema(schema, featureNameSet, labelName, weightName, marginName)
// group column
val (opGroup, groupId) = if (groupName.isEmpty) {
(None, None)
} else {
GpuUtils.checkNumericType(schema, groupName)
(Some(groupName), Some(schema.fieldIndex(groupName)))
}
// weight and base margin columns
val Seq(weightId, marginId) = Seq(weightName, marginName).map {
name =>
if (name.isEmpty) None else Some(schema.fieldIndex(name))
}
val colsIndices = ColumnIndices(featureNameSet.map(schema.fieldIndex),
schema.fieldIndex(labelName), weightId, marginId, groupId)
ColumnDataBatch(dataFrame, colsIndices, opGroup)
}
def checkNumericType(schema: StructType, colName: String,
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of NumericType but found: " +
s"${actualDataType.catalogString}.$message")
}
/** Check and Cast the columns to FloatType */
def prepareColumnType(
dataset: Dataset[_],
featureNames: Seq[String],
labelName: String = "",
weightName: String = "",
marginName: String = "",
fitting: Boolean = true): DataFrame = {
// check first
val featureNameSet = featureNames.distinct
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)
val castToFloat = (df: DataFrame, colName: String) => {
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
df
} else {
val colMeta = df.schema(colName).metadata
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
}
}
val colNames = if (fitting) {
var names = featureNameSet :+ labelName
if (weightName.nonEmpty) {
names = names :+ weightName
}
if (marginName.nonEmpty) {
names = names :+ marginName
}
names
} else {
featureNameSet
}
colNames.foldLeft(dataset.asInstanceOf[DataFrame])(
(ds, colName) => castToFloat(ds, colName))
}
/** Validate input schema */
def validateSchema(schema: StructType,
featureNames: Seq[String],
labelName: String = "",
weightName: String = "",
marginName: String = "",
fitting: Boolean = true): StructType = {
val msg = if (fitting) "train" else "transform"
// feature columns
require(featureNames.nonEmpty, s"Gpu $msg requires features columns. " +
"please refer to `setFeaturesCol(value: Array[String])`!")
featureNames.foreach(fn => checkNumericType(schema, fn))
if (fitting) {
require(labelName.nonEmpty, "label column is not set.")
checkNumericType(schema, labelName)
if (weightName.nonEmpty) {
checkNumericType(schema, weightName)
}
if (marginName.nonEmpty) {
checkNumericType(schema, marginName)
}
}
schema
}
def time[R](block: => R): (R, Float) = {
val t0 = System.currentTimeMillis
val result = block // call-by-name
val t1 = System.currentTimeMillis
(result, (t1 - t0).toFloat / 1000)
}
/** Get column names from Parameter */
def getColumnNames(params: Params)(cols: Param[String]*): Seq[String] = {
// get column name, null | undefined will be casted to ""
def getColumnName(params: Params)(param: Param[String]): String = {
if (params.isDefined(param)) {
val colName = params.getOrDefault(param)
if (colName != null) colName else ""
} else ""
}
val getName = getColumnName(params)(_)
cols.map(getName)
}
}
/**
* A container to contain the column ids
*/
private[spark] case class ColumnIndices(
featureIds: Seq[Int],
labelId: Int,
weightId: Option[Int],
marginId: Option[Int],
groupId: Option[Int])
private[spark] case class ColumnDataBatch(
rawDF: DataFrame,
colIndices: ColumnIndices,
groupColName: Option[String])

View File

@ -0,0 +1,315 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types.{DataType, FloatType, IntegerType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
/**
* GpuXGBoostPlugin is the XGBoost plugin which leverages spark-rapids
* to accelerate the XGBoost from ETL to train.
*/
class GpuXGBoostPlugin extends XGBoostPlugin {
private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")
/**
* Whether the plugin is enabled or not, if not enabled, fallback
* to the regular CPU pipeline
*
* @param dataset the input dataset
* @return Boolean
*/
override def isEnabled(dataset: Dataset[_]): Boolean = {
val conf = dataset.sparkSession.conf
val hasRapidsPlugin = conf.get("spark.plugins", "").split(",").contains(
"com.nvidia.spark.SQLPlugin")
val rapidsEnabled = try {
conf.get("spark.rapids.sql.enabled").toBoolean
} catch {
// Rapids plugin has default "spark.rapids.sql.enabled" to true
case _: NoSuchElementException => true
case _: Throwable => false // Any exception will return false
}
hasRapidsPlugin && rapidsEnabled
}
// TODO, support numeric type
private[spark] def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): Dataset[_] = {
// Columns to be selected for XGBoost training
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema
def selectCol(c: Param[String], targetType: DataType = FloatType) = {
// TODO support numeric types
if (estimator.isDefinedNonEmpty(c)) {
selectedCols.append(estimator.castIfNeeded(schema, estimator.getOrDefault(c), targetType))
}
}
Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol)
.foreach(p => selectCol(p))
estimator match {
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
case _ =>
}
// TODO support array/vector feature
estimator.getFeaturesCols.foreach { name =>
val col = estimator.castIfNeeded(dataset.schema, name)
selectedCols.append(col)
}
val input = dataset.select(selectedCols.toArray: _*)
estimator.repartitionIfNeeded(input)
}
// visible for testing
private[spark] def validate[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): Unit = {
require(estimator.getTreeMethod == "gpu_hist" || estimator.getDevice != "cpu",
"Using Spark-Rapids to accelerate XGBoost must set device=cuda")
}
/**
* Convert Dataset to RDD[Watches] which will be fed into XGBoost
*
* @param estimator which estimator to be handled.
* @param dataset to be converted.
* @return RDD[Watches]
*/
override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): RDD[Watches] = {
validate(estimator, dataset)
val train = preprocess(estimator, dataset)
val schema = train.schema
val indices = estimator.buildColumnIndices(schema)
val maxBin = estimator.getMaxBins
val nthread = estimator.getNthread
val missing = estimator.getMissing
/** build QuantileDMatrix on the executor side */
def buildQuantileDMatrix(iter: Iterator[Table],
ref: Option[QuantileDMatrix] = None): QuantileDMatrix = {
val colBatchIter = iter.map { table =>
withResource(new GpuColumnBatch(table)) { batch =>
new CudfColumnBatch(
batch.select(indices.featureIds.get),
batch.select(indices.labelId),
batch.select(indices.weightId.getOrElse(-1)),
batch.select(indices.marginId.getOrElse(-1)),
batch.select(indices.groupId.getOrElse(-1)));
}
}
ref.map(r => new QuantileDMatrix(colBatchIter, r, missing, maxBin, nthread)).getOrElse(
new QuantileDMatrix(colBatchIter, missing, maxBin, nthread)
)
}
estimator.getEvalDataset().map { evalDs =>
val evalProcessed = preprocess(estimator, evalDs)
ColumnarRdd(train.toDF()).zipPartitions(ColumnarRdd(evalProcessed.toDF())) {
(trainIter, evalIter) =>
new Iterator[Watches] {
override def hasNext: Boolean = trainIter.hasNext
override def next(): Watches = {
val trainDM = buildQuantileDMatrix(trainIter)
val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM))
new Watches(Array(trainDM, evalDM),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
}
}
}
}.getOrElse(
ColumnarRdd(train.toDF()).mapPartitions { iter =>
new Iterator[Watches] {
override def hasNext: Boolean = iter.hasNext
override def next(): Watches = {
val dm = buildQuantileDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
}
}
}
)
}
override def transform[M <: XGBoostModel[M]](model: XGBoostModel[M],
dataset: Dataset[_]): DataFrame = {
val sc = dataset.sparkSession.sparkContext
val (transformedSchema, pred) = model.preprocess(dataset)
val bBooster = sc.broadcast(model.nativeBooster)
val bOriginalSchema = sc.broadcast(dataset.schema)
val featureIds = model.getFeaturesCols.distinct.map(dataset.schema.fieldIndex).toList
val isLocal = sc.isLocal
val missing = model.getMissing
val nThread = model.getNthread
val rdd = ColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions { tableIters =>
// booster is visible for all spark tasks in the same executor
val booster = bBooster.value
val originalSchema = bOriginalSchema.value
// UnsafeProjection is not serializable so do it on the executor side
val toUnsafe = UnsafeProjection.create(originalSchema)
if (!booster.deviceIsSet) {
booster.deviceIsSet.synchronized {
if (!booster.deviceIsSet) {
booster.deviceIsSet = true
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("device", s"cuda:$gpuId")
logger.info("GPU transform on GPU device: cuda:" + gpuId)
}
}
}
// Iterator on Row
new Iterator[Row] {
// Convert InternalRow to Row
private val converter: InternalRow => Row = CatalystTypeConverters
.createToScalaConverter(originalSchema)
.asInstanceOf[InternalRow => Row]
// GPU batches read in must be closed by the receiver
@transient var currentBatch: ColumnarBatch = null
// Iterator on Row
var iter: Iterator[Row] = null
TaskContext.get().addTaskCompletionListener[Unit](_ => {
closeCurrentBatch() // close the last ColumnarBatch
})
private def closeCurrentBatch(): Unit = {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
}
def loadNextBatch(): Unit = {
closeCurrentBatch()
if (tableIters.hasNext) {
val dataTypes = originalSchema.fields.map(x => x.dataType)
iter = withResource(tableIters.next()) { table =>
// Create DMatrix
val featureTable = new GpuColumnBatch(table).select(featureIds)
if (featureTable == null) {
val msg = featureIds.mkString(",")
throw new RuntimeException(s"Couldn't create feature table for the " +
s"feature indices $msg")
}
try {
val cudfColumnBatch = new CudfColumnBatch(featureTable, null, null, null, null)
val dm = new DMatrix(cudfColumnBatch, missing, nThread)
if (dm == null) {
Iterator.empty
} else {
try {
currentBatch = new ColumnarBatch(
GpuColumnVectorUtils.extractHostColumns(table, dataTypes),
table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala.map(toUnsafe)
.map(converter(_))
model.predictInternal(booster, dm, pred, rowIterator).toIterator
} finally {
dm.delete()
}
}
} finally {
featureTable.close()
}
}
} else {
iter = null
}
}
override def hasNext: Boolean = {
val itHasNext = iter != null && iter.hasNext
if (!itHasNext) { // Don't have extra Row for current ColumnarBatch
loadNextBatch()
iter != null && iter.hasNext
} else {
itHasNext
}
}
override def next(): Row = {
if (iter == null || !iter.hasNext) {
loadNextBatch()
}
if (iter == null) {
throw new NoSuchElementException()
}
iter.next()
}
}
}
bBooster.unpersist(false)
bOriginalSchema.unpersist(false)
val output = dataset.sparkSession.createDataFrame(rdd, transformedSchema)
model.postTransform(output, pred).toDF()
}
}
private class GpuColumnBatch(table: Table) extends AutoCloseable {
def select(index: Int): Table = {
select(Seq(index))
}
def select(indices: Seq[Int]): Table = {
if (!indices.forall(index => index < table.getNumberOfColumns && index >= 0)) {
return null;
}
new Table(indices.map(table.getColumn): _*)
}
override def close(): Unit = {
if (Option(table).isDefined) {
table.close()
}
}
}

View File

@ -16,9 +16,7 @@
package ml.dmlc.xgboost4j.java;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.*;
import ai.rapids.cudf.Table;
import junit.framework.TestCase;
@ -122,8 +120,7 @@ public class DMatrixTest {
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
QuantileDMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
float[] anchorLabel = convertFloatTofloat(label1, label2);
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
@ -135,6 +132,57 @@ public class DMatrixTest {
}
}
private Float[] generateFloatArray(int size, long seed) {
Float[] array = new Float[size];
Random random = new Random(seed);
for (int i = 0; i < size; i++) {
array[i] = random.nextFloat();
}
return array;
}
@Test
public void testGetQuantileCut() throws XGBoostError {
int rows = 100;
try (
Table X_0 = new Table.TestBuilder()
.column(generateFloatArray(rows, 1l))
.column(generateFloatArray(rows, 2l))
.column(generateFloatArray(rows, 3l))
.column(generateFloatArray(rows, 4l))
.column(generateFloatArray(rows, 5l))
.build();
Table y_0 = new Table.TestBuilder().column(generateFloatArray(rows, 6l)).build();
Table X_1 = new Table.TestBuilder()
.column(generateFloatArray(rows, 11l))
.column(generateFloatArray(rows, 12l))
.column(generateFloatArray(rows, 13l))
.column(generateFloatArray(rows, 14l))
.column(generateFloatArray(rows, 15l))
.build();
Table y_1 = new Table.TestBuilder().column(generateFloatArray(rows, 16l)).build();
) {
List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X_0, y_0, null, null, null));
QuantileDMatrix train = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
tables.clear();
tables.add(new CudfColumnBatch(X_1, y_1, null, null, null));
QuantileDMatrix eval = new QuantileDMatrix(tables.iterator(), train, 0.0f, 256, 1);
DMatrix.QuantileCut trainCut = train.getQuantileCut();
DMatrix.QuantileCut evalCut = eval.getQuantileCut();
TestCase.assertTrue(trainCut.getIndptr().length == evalCut.getIndptr().length);
TestCase.assertTrue(Arrays.equals(trainCut.getIndptr(), evalCut.getIndptr()));
TestCase.assertTrue(trainCut.getValues().length == evalCut.getValues().length);
TestCase.assertTrue(Arrays.equals(trainCut.getValues(), evalCut.getValues()));
}
}
private float[] convertFloatTofloat(Float[]... datas) {
int totalLength = 0;
for (Float[] data : datas) {

View File

@ -16,11 +16,13 @@
package ml.dmlc.xgboost4j.scala
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.Table
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import org.scalatest.funsuite.AnyFunSuite
import scala.collection.mutable.ArrayBuffer
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
class QuantileDMatrixSuite extends AnyFunSuite {
@ -73,13 +75,4 @@ class QuantileDMatrixSuite extends AnyFunSuite {
}
}
}
/** Executes the provided code block and then closes the resource */
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@ -1,288 +0,0 @@
/*
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.{GpuTestUtils, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
require(resource.startsWith("/"), "resource must start with /")
getClass.getResource(resource).getPath
}
def enableCsvConf(): SparkConf = {
new SparkConf()
.set("spark.rapids.sql.csv.read.float.enabled", "true")
.set("spark.rapids.sql.csv.read.double.enabled", "true")
}
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
// set "spark.rapids.sql.explain" to "ALL" to check if the operators
// can be replaced by GPU
val c = conf.clone()
.set("spark.rapids.sql.enabled", "true")
withSparkSession(c, f)
}
def withCpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
val c = conf.clone()
.set("spark.rapids.sql.enabled", "false") // Just to be sure
withSparkSession(c, f)
}
def compareResults(
sort: Boolean,
floatEpsilon: Double,
fromLeft: Array[Row],
fromRight: Array[Row]): Boolean = {
if (sort) {
val left = fromLeft.map(_.toSeq).sortWith(seqLt)
val right = fromRight.map(_.toSeq).sortWith(seqLt)
compare(left, right, floatEpsilon)
} else {
compare(fromLeft, fromRight, floatEpsilon)
}
}
// we guarantee that the types will be the same
private def seqLt(a: Seq[Any], b: Seq[Any]): Boolean = {
if (a.length < b.length) {
return true
}
// lengths are the same
for (i <- a.indices) {
val v1 = a(i)
val v2 = b(i)
if (v1 != v2) {
// null is always < anything but null
if (v1 == null) {
return true
}
if (v2 == null) {
return false
}
(v1, v2) match {
case (i1: Int, i2: Int) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
}// else equal go on
case (i1: Long, i2: Long) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Float, i2: Float) => if (i1.isNaN() && !i2.isNaN()) return false
else if (!i1.isNaN() && i2.isNaN()) return true
else if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Date, i2: Date) => if (i1.before(i2)) {
return true
} else if (i1.after(i2)) {
return false
} // else equal go on
case (i1: Double, i2: Double) => if (i1.isNaN() && !i2.isNaN()) return false
else if (!i1.isNaN() && i2.isNaN()) return true
else if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Short, i2: Short) => if (i1 < i2) {
return true
} else if (i1 > i2) {
return false
} // else equal go on
case (i1: Timestamp, i2: Timestamp) => if (i1.before(i2)) {
return true
} else if (i1.after(i2)) {
return false
} // else equal go on
case (s1: String, s2: String) =>
val cmp = s1.compareTo(s2)
if (cmp < 0) {
return true
} else if (cmp > 0) {
return false
} // else equal go on
case (o1, _) =>
throw new UnsupportedOperationException(o1.getClass + " is not supported yet")
}
}
}
// They are equal...
false
}
private def compare(expected: Any, actual: Any, epsilon: Double = 0.0): Boolean = {
def doublesAreEqualWithinPercentage(expected: Double, actual: Double): (String, Boolean) = {
if (!compare(expected, actual)) {
if (expected != 0) {
val v = Math.abs((expected - actual) / expected)
(s"\n\nABS($expected - $actual) / ABS($actual) == $v is not <= $epsilon ", v <= epsilon)
} else {
val v = Math.abs(expected - actual)
(s"\n\nABS($expected - $actual) == $v is not <= $epsilon ", v <= epsilon)
}
} else {
("SUCCESS", true)
}
}
(expected, actual) match {
case (a: Float, b: Float) if a.isNaN && b.isNaN => true
case (a: Double, b: Double) if a.isNaN && b.isNaN => true
case (null, null) => true
case (null, _) => false
case (_, null) => false
case (a: Array[_], b: Array[_]) =>
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
case (a: Map[_, _], b: Map[_, _]) =>
a.size == b.size && a.keys.forall { aKey =>
b.keys.find(bKey => compare(aKey, bKey))
.exists(bKey => compare(a(aKey), b(bKey), epsilon))
}
case (a: Iterable[_], b: Iterable[_]) =>
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r, epsilon) }
case (a: Product, b: Product) =>
compare(a.productIterator.toSeq, b.productIterator.toSeq, epsilon)
case (a: Row, b: Row) =>
compare(a.toSeq, b.toSeq, epsilon)
// 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
case (a: Double, b: Double) if epsilon <= 0 =>
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
case (a: Double, b: Double) if epsilon > 0 =>
val ret = doublesAreEqualWithinPercentage(a, b)
if (!ret._2) {
System.err.println(ret._1 + " (double)")
}
ret._2
case (a: Float, b: Float) if epsilon <= 0 =>
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
case (a: Float, b: Float) if epsilon > 0 =>
val ret = doublesAreEqualWithinPercentage(a, b)
if (!ret._2) {
System.err.println(ret._1 + " (float)")
}
ret._2
case (a, b) => a == b
}
}
}
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory(getClass.getName)
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir.toFile)
super.afterAll()
}
protected def createTmpFolder(prefix: String): Path = {
Files.createTempDirectory(tempDir, prefix)
}
}
object SparkSessionHolder extends Logging {
private var spark = createSparkSession()
private var origConf = spark.conf.getAll
private var origConfKeys = origConf.keys.toSet
private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
case (key, value) if spark.conf.get(key, null) != value =>
spark.conf.set(key, value)
case _ => // No need to modify it
}
private def createSparkSession(): SparkSession = {
GpuTestUtils.cleanupAnyExistingSession()
// Timezone is fixed to UTC to allow timestamps to work by default
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
// Add Locale setting
Locale.setDefault(Locale.US)
val builder = SparkSession.builder()
.master("local[2]")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.rapids.sql.enabled", "false")
.config("spark.rapids.sql.test.enabled", "false")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
.config("spark.sql.files.maxPartitionBytes", "1000")
.appName("XGBoost4j-Spark-Gpu unit test")
builder.getOrCreate()
}
private def reinitSession(): Unit = {
spark = createSparkSession()
origConf = spark.conf.getAll
origConfKeys = origConf.keys.toSet
}
def sparkSession: SparkSession = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
}
spark
}
def resetSparkSessionConf(): Unit = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
} else {
setAllConfs(origConf.toArray)
val currentKeys = spark.conf.getAll.keys.toSet
val toRemove = currentKeys -- origConfKeys
toRemove.foreach(spark.conf.unset)
}
logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
}
def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
resetSparkSessionConf
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
setAllConfs(conf.getAll)
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
spark.sparkContext.setLogLevel("WARN")
f(spark)
}
}

View File

@ -1,232 +0,0 @@
/*
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{col, udf, when}
import org.apache.spark.sql.types.{FloatType, StructField, StructType}
class GpuXGBoostClassifierSuite extends GpuTestSuite {
private val dataPath = if (new java.io.File("../../demo/data/veterans_lung_cancer.csv").isFile) {
"../../demo/data/veterans_lung_cancer.csv"
} else {
"../demo/data/veterans_lung_cancer.csv"
}
val labelName = "label_col"
val schema = StructType(Seq(
StructField("f1", FloatType), StructField("f2", FloatType), StructField("f3", FloatType),
StructField("f4", FloatType), StructField("f5", FloatType), StructField("f6", FloatType),
StructField("f7", FloatType), StructField("f8", FloatType), StructField("f9", FloatType),
StructField("f10", FloatType), StructField("f11", FloatType), StructField("f12", FloatType),
StructField(labelName, FloatType)
))
val featureNames = schema.fieldNames.filter(s => !s.equals(labelName))
test("The transform result should be same for several runs on same model") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostClassifier(xgbParam)
.fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("use weight") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
val model = new XGBoostClassifier(xgbParam)
.fit(originalDf)
val model2 = new XGBoostClassifier(xgbParam)
.setWeightCol("weight")
.fit(dfWithWeight)
val left = model.transform(testDf).collect()
val right = model2.transform(testDf).collect()
// left should be different with right
assert(!compareResults(true, 0.000001, left, right))
}
}
test("Save model and transform GPU dataset") {
// Train a model on GPU
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
(classifier.fit(rawInput), testDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val left = gpuModel
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
val right = gpuModelFromFile
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on CPU can transform GPU dataset") {
// Train a model on CPU
val cpuModel = withCpuSparkSession() { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol("features")
.setLabelCol(labelName)
.setTreeMethod("auto")
classifier.fit(trainingDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
cpuModel.write.overwrite().save(xgbrModel)
val cpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
// manually
val thrown = intercept[NoSuchElementException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val left = cpuModel
.setFeaturesCol(featureNames)
.transform(testDf)
.collect()
val right = cpuModelFromFile
.setFeaturesCol(featureNames)
.transform(testDf)
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on GPU can transform CPU dataset") {
// Train a model on GPU
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setTreeMethod("gpu_hist")
classifier.fit(rawInput)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostClassificationModel.load(xgbrModel)
// transform on CPU
withCpuSparkSession() { spark =>
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
val left = gpuModel
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
val right = gpuModelFromFile
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "rawPrediction", "probability", "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
}

View File

@ -1,212 +0,0 @@
/*
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StringType
class GpuXGBoostGeneralSuite extends GpuTestSuite {
private val labelName = "label_col"
private val weightName = "weight_col"
private val baseMarginName = "margin_col"
private val featureNames = Array("f1", "f2", "f3")
private val allColumnNames = featureNames :+ weightName :+ baseMarginName :+ labelName
private val trainingData = Seq(
// f1, f2, f3, weight, margin, label
(1.0f, 2.0f, 3.0f, 1.0f, 0.5f, 0),
(2.0f, 3.0f, 4.0f, 2.0f, 0.6f, 0),
(1.2f, 2.1f, 3.1f, 1.1f, 0.51f, 0),
(2.3f, 3.1f, 4.1f, 2.1f, 0.61f, 0),
(3.0f, 4.0f, 5.0f, 1.5f, 0.3f, 1),
(4.0f, 5.0f, 6.0f, 2.5f, 0.4f, 1),
(3.1f, 4.1f, 5.1f, 1.6f, 0.4f, 1),
(4.1f, 5.1f, 6.1f, 2.6f, 0.5f, 1),
(5.0f, 6.0f, 7.0f, 1.0f, 0.2f, 2),
(6.0f, 7.0f, 8.0f, 1.3f, 0.6f, 2),
(5.1f, 6.1f, 7.1f, 1.2f, 0.1f, 2),
(6.1f, 7.1f, 8.1f, 1.4f, 0.7f, 2),
(6.2f, 7.2f, 8.2f, 1.5f, 0.8f, 2))
test("MLlib way setting features_cols should work") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map(
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
"tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName
)
new XGBoostClassifier(xgbParam)
.fit(trainingDf)
}
}
test("disorder feature columns should work") {
withGpuSparkSession() { spark =>
import spark.implicits._
var trainingDf = trainingData.toDF(allColumnNames: _*)
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
val xgbParam = Map(
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
"tree_method" -> "hist", "device" -> "cuda"
)
new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
}
test("Throw exception when feature/label columns are not numeric type") {
withGpuSparkSession() { spark =>
import spark.implicits._
val originalDf = trainingData.toDF(allColumnNames: _*)
var trainingDf = originalDf.withColumn("f2", col("f2").cast(StringType))
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val thrown1 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown1.getMessage.contains("Column f2 must be of NumericType but found: string."))
trainingDf = originalDf.withColumn(labelName, col(labelName).cast(StringType))
val thrown2 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown2.getMessage.contains(
s"Column $labelName must be of NumericType but found: string."))
}
}
test("Throw exception when features_cols or label_col is not set") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
// GPU train requires featuresCols. If not specified,
// then NoSuchElementException will be thrown
val thrown = intercept[NoSuchElementException] {
new XGBoostClassifier(xgbParam)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val thrown1 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.fit(trainingDf)
}
assert(thrown1.getMessage.contains("label does not exist."))
}
}
test("Throw exception when device is not set to cuda") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "hist")
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
}
}
test("Train with eval") {
withGpuSparkSession() { spark =>
import spark.implicits._
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
.randomSplit(Array(0.6, 0.2, 0.2), seed = 1)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val model1 = new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
.fit(trainingDf)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
}
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
withGpuSparkSession() { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val xgbc = new XGBoostClassifier(xgbParam)
xgbc.write.overwrite().save(xgbcPath)
val paramMap2 = XGBoostClassifier.load(xgbcPath).MLlib2XGBoostParams
xgbParam.foreach {
case (k, v: Array[String]) =>
assert(v.sameElements(paramMap2(k).asInstanceOf[Array[String]]))
case (k, v) =>
assert(v.toString == paramMap2(k).toString)
}
}
}
test("device ordinal should not be specified") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val params = Map(
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 5,
"num_workers" -> 1
)
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(params)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setDevice("cuda:1")
.fit(trainingDf)
}
assert(thrown.getMessage.contains("device given invalid value cuda:1"))
}
}
}

View File

@ -1,258 +0,0 @@
/*
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}
class GpuXGBoostRegressorSuite extends GpuTestSuite {
val labelName = "label_col"
val groupName = "group_col"
val schema = StructType(Seq(
StructField(labelName, FloatType),
StructField("f1", FloatType),
StructField("f2", FloatType),
StructField("f3", FloatType),
StructField(groupName, IntegerType)))
val featureNames = schema.fieldNames.filter(s =>
!(s.equals(labelName) || s.equals(groupName)))
test("The transform result should be same for several runs on same model") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostRegressor(xgbParam)
.fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("Tree method gpu_hist still works") {
withGpuSparkSession(enableCsvConf()) { spark =>
val params = Map(
"tree_method" -> "gpu_hist",
"features_cols" -> featureNames,
"label_col" -> labelName,
"num_round" -> 10,
"num_workers" -> 1
)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostRegressor(params).fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("use weight") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
val model = new XGBoostRegressor(xgbParam)
.fit(originalDf)
val model2 = new XGBoostRegressor(xgbParam)
.setWeightCol("weight")
.fit(dfWithWeight)
val left = model.transform(testDf).collect()
val right = model2.transform(testDf).collect()
// left should be different with right
assert(!compareResults(true, 0.000001, left, right))
}
}
test("Save model and transform GPU dataset") {
// Train a model on GPU
val (gpuModel, testDf) = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setTreeMethod("hist")
.setDevice("cuda")
(classifier.fit(rawInput), testDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val left = gpuModel
.transform(testDf)
.select(labelName, "prediction")
.collect()
val right = gpuModelFromFile
.transform(testDf)
.select(labelName, "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on CPU can transform GPU dataset") {
// Train a model on CPU
val cpuModel = withCpuSparkSession() { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol("features")
.setLabelCol(labelName)
.setTreeMethod("auto")
classifier.fit(trainingDf)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
cpuModel.write.overwrite().save(xgbrModel)
val cpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on GPU
withGpuSparkSession() { spark =>
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
// manually
val thrown = intercept[NoSuchElementException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val left = cpuModel
.setFeaturesCol(featureNames)
.transform(testDf)
.collect()
val right = cpuModelFromFile
.setFeaturesCol(featureNames)
.transform(testDf)
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Model trained on GPU can transform CPU dataset") {
// Train a model on GPU
val gpuModel = withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setDevice("cuda")
classifier.fit(rawInput)
}
val xgbrModel = new File(tempDir.toFile, "xgbrModel").getPath
gpuModel.write.overwrite().save(xgbrModel)
val gpuModelFromFile = XGBoostRegressionModel.load(xgbrModel)
// transform on CPU
withCpuSparkSession() { spark =>
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
val left = gpuModel
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "prediction")
.collect()
val right = gpuModelFromFile
.setFeaturesCol(featureColName)
.transform(testDf)
.select(labelName, "prediction")
.collect()
assert(compareResults(true, 0.000001, left, right))
}
}
test("Ranking: train with Group") {
withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
val model = new XGBoostRegressor(xgbParam)
.setGroupCol(groupName)
.fit(trainingDf)
val ret = model.transform(testDf).collect()
assert(testDf.count() === ret.length)
}
}
}

View File

@ -0,0 +1,145 @@
/*
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rapids.spark
import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import org.apache.spark.{GpuTestUtils, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
require(resource.startsWith("/"), "resource must start with /")
getClass.getResource(resource).getPath
}
def enableCsvConf(): SparkConf = {
new SparkConf()
.set("spark.rapids.sql.csv.read.float.enabled", "true")
.set("spark.rapids.sql.csv.read.double.enabled", "true")
}
def withGpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
// set "spark.rapids.sql.explain" to "ALL" to check if the operators
// can be replaced by GPU
val c = conf.clone()
.set("spark.rapids.sql.enabled", "true")
withSparkSession(c, f)
}
def withCpuSparkSession[U](conf: SparkConf = new SparkConf())(f: SparkSession => U): U = {
val c = conf.clone()
.set("spark.rapids.sql.enabled", "false") // Just to be sure
withSparkSession(c, f)
}
}
trait TmpFolderSuite extends BeforeAndAfterAll {
self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory(getClass.getName)
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir.toFile)
super.afterAll()
}
protected def createTmpFolder(prefix: String): Path = {
Files.createTempDirectory(tempDir, prefix)
}
}
object SparkSessionHolder extends Logging {
private var spark = createSparkSession()
private var origConf = spark.conf.getAll
private var origConfKeys = origConf.keys.toSet
private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
case (key, value) if spark.conf.get(key, null) != value =>
spark.conf.set(key, value)
case _ => // No need to modify it
}
private def createSparkSession(): SparkSession = {
GpuTestUtils.cleanupAnyExistingSession()
// Timezone is fixed to UTC to allow timestamps to work by default
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
// Add Locale setting
Locale.setDefault(Locale.US)
val builder = SparkSession.builder()
.master("local[2]")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.rapids.sql.test.enabled", "false")
.config("spark.stage.maxConsecutiveAttempts", "1")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
.config("spark.sql.files.maxPartitionBytes", "1000")
.appName("XGBoost4j-Spark-Gpu unit test")
builder.getOrCreate()
}
private def reinitSession(): Unit = {
spark = createSparkSession()
origConf = spark.conf.getAll
origConfKeys = origConf.keys.toSet
}
def sparkSession: SparkSession = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
}
spark
}
def resetSparkSessionConf(): Unit = {
if (SparkSession.getActiveSession.isEmpty) {
reinitSession()
} else {
setAllConfs(origConf.toArray)
val currentKeys = spark.conf.getAll.keys.toSet
val toRemove = currentKeys -- origConfKeys
toRemove.foreach(spark.conf.unset)
}
logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
}
def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
resetSparkSessionConf
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
setAllConfs(conf.getAll)
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
spark.sparkContext.setLogLevel("WARN")
f(spark)
}
}

View File

@ -0,0 +1,523 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ai.rapids.cudf.Table
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.SparkConf
import java.io.File
import scala.collection.mutable.ArrayBuffer
class GpuXGBoostPluginSuite extends GpuTestSuite {
test("params") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
val xgbParams: Map[String, Any] = Map(
"max_depth" -> 5,
"eta" -> 0.2,
"objective" -> "binary:logistic"
)
val features = Array("c1", "c2")
val estimator = new XGBoostClassifier(xgbParams)
.setFeaturesCol(features)
.setMissing(0.2f)
.setAlpha(0.97)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
.setNumRound(3)
.setDevice("cuda")
assert(estimator.getMaxDepth === 5)
assert(estimator.getEta === 0.2)
assert(estimator.getObjective === "binary:logistic")
assert(estimator.getFeaturesCols === features)
assert(estimator.getMissing === 0.2f)
assert(estimator.getAlpha === 0.97)
assert(estimator.getDevice === "cuda")
assert(estimator.getNumRound === 3)
estimator.setEta(0.66).setMaxDepth(7)
assert(estimator.getMaxDepth === 7)
assert(estimator.getEta === 0.66)
val model = estimator.fit(df)
assert(model.getMaxDepth === 7)
assert(model.getEta === 0.66)
assert(model.getObjective === "binary:logistic")
assert(model.getFeaturesCols === features)
assert(model.getMissing === 0.2f)
assert(model.getAlpha === 0.97)
assert(model.getLeafPredictionCol === "leaf")
assert(model.getContribPredictionCol === "contrib")
assert(model.getDevice === "cuda")
assert(model.getNumRound === 3)
}
}
test("isEnabled") {
def checkIsEnabled(spark: SparkSession, expected: Boolean): Unit = {
import spark.implicits._
val df = Seq((1.0f, 2.0f, 0.0f),
(2.0f, 3.0f, 1.0f)
).toDF("c1", "c2", "label")
val classifier = new XGBoostClassifier()
assert(classifier.getPlugin.isDefined)
assert(classifier.getPlugin.get.isEnabled(df) === expected)
}
// spark.rapids.sql.enabled is not set explicitly, default to true
withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)})
// set spark.rapids.sql.enabled to false
withCpuSparkSession() { spark =>
checkIsEnabled(spark, false)
}
// set spark.rapids.sql.enabled to true
withGpuSparkSession() { spark =>
checkIsEnabled(spark, true)
}
}
test("parameter validation") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
val classifier = new XGBoostClassifier()
val plugin = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
intercept[IllegalArgumentException] {
plugin.validate(classifier, df)
}
classifier.setDevice("cuda")
plugin.validate(classifier, df)
classifier.setDevice("gpu")
plugin.validate(classifier, df)
classifier.setDevice("cpu")
classifier.setTreeMethod("gpu_hist")
plugin.validate(classifier, df)
}
}
test("preprocess") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
.repartition(5)
assert(df.schema.names.contains("other"))
assert(df.rdd.getNumPartitions === 5)
val features = Array("c1", "c2")
var classifier = new XGBoostClassifier()
.setNumWorkers(3)
.setFeaturesCol(features)
assert(classifier.getPlugin.isDefined)
assert(classifier.getPlugin.get.isInstanceOf[GpuXGBoostPlugin])
var out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
.preprocess(classifier, df)
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
assert(out.schema.names.contains(classifier.getLabelCol))
assert(!out.schema.names.contains("weight") && !out.schema.names.contains("margin"))
assert(out.rdd.getNumPartitions === 3)
classifier = new XGBoostClassifier()
.setNumWorkers(4)
.setFeaturesCol(features)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setDevice("cuda")
out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
.preprocess(classifier, df)
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
assert(out.schema.names.contains(classifier.getLabelCol))
assert(out.schema.names.contains("weight") && out.schema.names.contains("margin"))
assert(out.rdd.getNumPartitions === 4)
}
}
// test distributed
test("build RDD Watches") {
withGpuSparkSession() { spark =>
import spark.implicits._
// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val df = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
val features = Array("c1", "c2")
val classifier = new XGBoostClassifier()
.setNumWorkers(2)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setFeaturesCol(features)
.setDevice("cuda")
.setMissing(missing)
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(0).getLabel
val weight = watches.datasets(0).getWeight
val margins = watches.datasets(0).getBaseMargin
val rowNumber = watches.datasets(0).rowNum
val nonMissing = watches.datasets(0).nonMissingNum
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
}.collect()
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty
for (row <- result) {
assert(row._1 === 1)
rowNumber.append(row._2)
nonMissing.append(row._3)
labels.append(row._4: _*)
weight.append(row._5: _*)
margins.append(row._6: _*)
}
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
}
test("build RDD Watches with Eval") {
withGpuSparkSession() { spark =>
import spark.implicits._
val train = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val eval = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
val features = Array("c1", "c2")
val classifier = new XGBoostClassifier()
.setNumWorkers(2)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setFeaturesCol(features)
.setDevice("cuda")
.setMissing(missing)
.setEvalDataset(eval)
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(1).getLabel
val weight = watches.datasets(1).getWeight
val margins = watches.datasets(1).getBaseMargin
val rowNumber = watches.datasets(1).rowNum
val nonMissing = watches.datasets(1).nonMissingNum
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
}.collect()
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty
for (row <- result) {
assert(row._1 === 2)
rowNumber.append(row._2)
nonMissing.append(row._3)
labels.append(row._4: _*)
weight.append(row._5: _*)
margins.append(row._6: _*)
}
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
}
test("transformed schema") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f)
).toDF("c1", "c2", "weight", "margin", "label", "other")
val estimator = new XGBoostClassifier()
.setNumWorkers(1)
.setNumRound(2)
.setFeaturesCol(Array("c1", "c2"))
.setLabelCol("label")
.setDevice("cuda")
assert(estimator.getPlugin.isDefined && estimator.getPlugin.get.isEnabled(df))
val out = estimator.fit(df).transform(df)
// Transform should not discard the other columns of the transforming dataframe
Seq("c1", "c2", "weight", "margin", "label", "other").foreach { v =>
assert(out.schema.names.contains(v))
}
// Transform for XGBoostClassifier needs to add extra columns
Seq("rawPrediction", "probability", "prediction").foreach { v =>
assert(out.schema.names.contains(v))
}
assert(out.schema.names.length === 9)
val out1 = estimator.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
.fit(df)
.transform(df)
Seq("leaf", "contrib").foreach { v =>
assert(out1.schema.names.contains(v))
}
}
}
private def checkEqual(left: Array[Array[Float]],
right: Array[Array[Float]],
epsilon: Float = 1e-4f): Unit = {
assert(left.size === right.size)
left.zip(right).foreach { case (leftValue, rightValue) =>
leftValue.zip(rightValue).foreach { case (l, r) =>
assert(math.abs(l - r) < epsilon)
}
}
}
Seq("binary:logistic", "multi:softprob").foreach { case objective =>
test(s"$objective: XGBoost-Spark should match xgboost4j") {
withGpuSparkSession() { spark =>
import spark.implicits._
val numRound = 100
var xgboostParams: Map[String, Any] = Map(
"objective" -> objective,
"device" -> "cuda"
)
val (trainPath, testPath) = if (objective == "binary:logistic") {
(writeFile(Classification.train.toDF("label", "weight", "c1", "c2", "c3")),
writeFile(Classification.test.toDF("label", "weight", "c1", "c2", "c3")))
} else {
xgboostParams = xgboostParams ++ Map("num_class" -> 6)
(writeFile(MultiClassification.train.toDF("label", "weight", "c1", "c2", "c3")),
writeFile(MultiClassification.test.toDF("label", "weight", "c1", "c2", "c3")))
}
val df = spark.read.parquet(trainPath)
val testdf = spark.read.parquet(testPath)
val features = Array("c1", "c2", "c3")
val featuresIndices = features.map(df.schema.fieldIndex)
val label = "label"
val classifier = new XGBoostClassifier(xgboostParams)
.setFeaturesCol(features)
.setLabelCol(label)
.setNumRound(numRound)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
.setDevice("cuda")
val xgb4jModel = withResource(new GpuColumnBatch(
Table.readParquet(new File(trainPath)))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices),
batch.select(df.schema.fieldIndex(label)), null, null, null
)
val qdm = new QuantileDMatrix(Seq(cb).iterator, classifier.getMissing,
classifier.getMaxBins, classifier.getNthread)
ScalaXGBoost.train(qdm, xgboostParams, numRound)
}
val (xgb4jLeaf, xgb4jContrib, xgb4jProb, xgb4jRaw) = withResource(new GpuColumnBatch(
Table.readParquet(new File(testPath)))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
)
val qdm = new DMatrix(cb, classifier.getMissing, classifier.getNthread)
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
xgb4jModel.predict(qdm), xgb4jModel.predict(qdm, outPutMargin = true))
}
val rows = classifier.fit(df).transform(testdf).collect()
// Check Leaf
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
checkEqual(xgb4jLeaf, xgbSparkLeaf)
// Check contrib
val xgbSparkContrib = rows.map(row =>
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
checkEqual(xgb4jContrib, xgbSparkContrib)
// Check probability
var xgbSparkProb = rows.map(row =>
row.getAs[DenseVector]("probability").toArray.map(_.toFloat))
if (objective == "binary:logistic") {
xgbSparkProb = xgbSparkProb.map(v => Array(v(1)))
}
checkEqual(xgb4jProb, xgbSparkProb)
// Check raw
var xgbSparkRaw = rows.map(row =>
row.getAs[DenseVector]("rawPrediction").toArray.map(_.toFloat))
if (objective == "binary:logistic") {
xgbSparkRaw = xgbSparkRaw.map(v => Array(v(1)))
}
checkEqual(xgb4jRaw, xgbSparkRaw)
}
}
}
test(s"Regression: XGBoost-Spark should match xgboost4j") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainPath = writeFile(Regression.train.toDF("label", "weight", "c1", "c2", "c3"))
val testPath = writeFile(Regression.test.toDF("label", "weight", "c1", "c2", "c3"))
val df = spark.read.parquet(trainPath)
val testdf = spark.read.parquet(testPath)
val features = Array("c1", "c2", "c3")
val featuresIndices = features.map(df.schema.fieldIndex)
val label = "label"
val numRound = 100
val xgboostParams: Map[String, Any] = Map(
"device" -> "cuda"
)
val regressor = new XGBoostRegressor(xgboostParams)
.setFeaturesCol(features)
.setLabelCol(label)
.setNumRound(numRound)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
.setDevice("cuda")
val xgb4jModel = withResource(new GpuColumnBatch(
Table.readParquet(new File(trainPath)))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices),
batch.select(df.schema.fieldIndex(label)), null, null, null
)
val qdm = new QuantileDMatrix(Seq(cb).iterator, regressor.getMissing,
regressor.getMaxBins, regressor.getNthread)
ScalaXGBoost.train(qdm, xgboostParams, numRound)
}
val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch(
Table.readParquet(new File(testPath)))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
)
val qdm = new DMatrix(cb, regressor.getMissing, regressor.getNthread)
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
xgb4jModel.predict(qdm))
}
val rows = regressor.fit(df).transform(testdf).collect()
// Check Leaf
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
checkEqual(xgb4jLeaf, xgbSparkLeaf)
// Check contrib
val xgbSparkContrib = rows.map(row =>
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
checkEqual(xgb4jContrib, xgbSparkContrib)
// Check prediction
val xgbSparkPred = rows.map(row =>
Array(row.getAs[Double]("prediction").toFloat))
checkEqual(xgb4jPred, xgbSparkPred)
}
}
def writeFile(df: Dataset[_]): String = {
def listFiles(directory: String): Array[String] = {
val dir = new File(directory)
if (dir.exists && dir.isDirectory) {
dir.listFiles.filter(f => f.isFile && f.getName.startsWith("part-")).map(_.getName)
} else {
Array.empty[String]
}
}
val dir = createTmpFolder("gpu_").toAbsolutePath.toString
df.coalesce(1).write.parquet(s"$dir/data")
val file = listFiles(s"$dir/data")(0)
s"$dir/data/$file"
}
}

View File

@ -0,0 +1,86 @@
/*
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.util.Random
trait TrainTestData {
protected def generateClassificationDataset(
numRows: Int,
numClass: Int,
seed: Int = 1): Seq[(Int, Float, Float, Float, Float)] = {
val random = new Random()
random.setSeed(seed)
(1 to numRows).map { _ =>
val label = random.nextInt(numClass)
// label, weight, c1, c2, c3
(label, random.nextFloat().abs, random.nextGaussian().toFloat, random.nextGaussian().toFloat,
random.nextGaussian().toFloat)
}
}
protected def generateRegressionDataset(
numRows: Int,
seed: Int = 11): Seq[(Float, Float, Float, Float, Float)] = {
val random = new Random()
random.setSeed(seed)
(1 to numRows).map { _ =>
// label, weight, c1, c2, c3
(random.nextFloat(), random.nextFloat().abs, random.nextGaussian().toFloat,
random.nextGaussian().toFloat,
random.nextGaussian().toFloat)
}
}
protected def generateRankDataset(
numRows: Int,
numClass: Int,
maxGroup: Int = 12,
seed: Int = 99): Seq[(Int, Float, Int, Float, Float, Float)] = {
val random = new Random()
random.setSeed(seed)
(1 to numRows).map { _ =>
val group = random.nextInt(maxGroup)
// label, weight, group, c1, c2, c3
(random.nextInt(numClass), group.toFloat, group,
random.nextGaussian().toFloat,
random.nextGaussian().toFloat,
random.nextGaussian().toFloat)
}
}
}
object Classification extends TrainTestData {
val train = generateClassificationDataset(300, 2, 3)
val test = generateClassificationDataset(150, 2, 5)
}
object MultiClassification extends TrainTestData {
val train = generateClassificationDataset(300, 4, 11)
val test = generateClassificationDataset(150, 4, 12)
}
object Regression extends TrainTestData {
val train = generateRegressionDataset(300, 222)
val test = generateRegressionDataset(150, 223)
}
object Ranking extends TrainTestData {
val train = generateRankDataset(300, 10, 555)
val test = generateRankDataset(150, 10, 556)
}

View File

@ -1,602 +0,0 @@
/*
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
/**
* PreXGBoost serves preparing data before training and transform
*/
object PreXGBoost extends PreXGBoostProvider {
private val logger = LogFactory.getLog("XGBoostSpark")
private lazy val defaultBaseMarginColumn = lit(Float.NaN)
private lazy val defaultWeightColumn = lit(1.0)
private lazy val defaultGroupColumn = lit(-1)
// Find the correct PreXGBoostProvider by ServiceLoader
private val optionProvider: Option[PreXGBoostProvider] = {
val classLoader = Option(Thread.currentThread().getContextClassLoader)
.getOrElse(getClass.getClassLoader)
val serviceLoader = ServiceLoader.load(classOf[PreXGBoostProvider], classLoader)
// For now, we only trust GpuPreXGBoost.
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
"ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost")).toList match {
case Nil => None
case head::Nil =>
Some(head)
case _ => None
}
}
/**
* Transform schema
*
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
override def transformSchema(
xgboostEstimator: XGBoostEstimatorCommon,
schema: StructType): StructType = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(None)) {
return optionProvider.get.transformSchema(xgboostEstimator, schema)
}
xgboostEstimator match {
case est: XGBoostClassifier => est.transformSchemaInternal(schema)
case model: XGBoostClassificationModel => model.transformSchemaInternal(schema)
case reg: XGBoostRegressor => reg.transformSchemaInternal(schema)
case model: XGBoostRegressionModel => model.transformSchemaInternal(schema)
case _ => throw new RuntimeException("Unsupporting " + xgboostEstimator)
}
}
/**
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input
* Option[RDD[_]\] is the optional cached RDD
*/
override def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]): XGBoostExecutionParams =>
(RDD[() => Watches], Option[RDD[_]]) = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
}
val (packedParams, evalSet, xgbInput) = estimator match {
case est: XGBoostEstimatorCommon =>
// get weight column, if weight is not defined, default to lit(1.0)
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
defaultWeightColumn
} else col(est.getWeightCol)
// get base-margin column, if base-margin is not defined, default to lit(Float.NaN)
val baseMargin = if (!est.isDefined(est.baseMarginCol) || est.getBaseMarginCol.isEmpty) {
defaultBaseMarginColumn
} else col(est.getBaseMarginCol)
val group = est match {
case regressor: XGBoostRegressor =>
// get group column, if group is not defined, default to lit(-1)
Some(
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
defaultGroupColumn
} else col(regressor.getGroupCol)
)
case _ => None
}
val (xgbInput, featuresName) = est.vectorize(dataset)
val evalSets = est.getEvalSets(params).transform((_, df) => {
val (dfTransformed, _) = est.vectorize(df)
dfTransformed
})
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
case _ => throw new RuntimeException("Unsupporting " + estimator)
}
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
packedParams, xgbInput.asInstanceOf[DataFrame]).head
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
val evalRDDMap = evalSet.map {
case (name, dataFrame) => (name,
DataUtils.convertDataFrameToXGBLabeledPointRDDs(packedParams,
dataFrame.asInstanceOf[DataFrame]).head)
}
val hasGroup = packedParams.group.map(_ != defaultGroupColumn).getOrElse(false)
xgbExecParams: XGBoostExecutionParams =>
composeInputData(trainingSet, hasGroup, packedParams.numWorkers) match {
case Left(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
case Right(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
}
}
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
override def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame = {
if (optionProvider.isDefined && optionProvider.get.providerEnabled(Some(dataset))) {
return optionProvider.get.transformDataset(model, dataset)
}
/** get the necessary parameters */
val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing,
allowNonZeroForMissing, predictFunc, schema) =
model match {
case m: XGBoostClassificationModel =>
val (xgbInput, featuresName) = m.vectorize(dataset)
// predict and turn to Row
val predictFunc =
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, probabilityItr,
predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(xgbInput.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
case m: XGBoostRegressionModel =>
// predict and turn to Row
val (xgbInput, featuresName) = m.vectorize(dataset)
val predictFunc =
(booster: Booster, dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
m.producePredictionItrs(booster, dm)
m.produceResultIterator(originalRowItr, rawPredictionItr, predLeafItr, predContribItr)
}
// prepare the final Schema
var schema = StructType(xgbInput.schema.fields ++
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
if (m.isDefined(m.leafPredictionCol)) {
schema = schema.add(StructField(name = m.getLeafPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (m.isDefined(m.contribPredictionCol)) {
schema = schema.add(StructField(name = m.getContribPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
}
val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster)
val appName = xgbInput.sparkSession.sparkContext.appName
val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
new AbstractIterator[Row] {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val cacheInfo = {
if (useExternalMemory) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
s"${TaskContext.getPartitionId()}-batch-$batchCnt"
} else {
null
}
}
val dm = new DMatrix(
processMissingValues(features.map(_.asXGB), missing, allowNonZeroForMissing),
cacheInfo)
try {
predictFunc(bBooster.value, dm, batchRow.iterator)
} finally {
batchCnt += 1
dm.delete()
}
}
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = batchIterImpl.next()
}
}
bBooster.unpersist(blocking = false)
xgbInput.sparkSession.createDataFrame(resultRDD, schema)
}
/**
* Converting the RDD[XGBLabeledPoint] to the function to build RDD[() => Watches]
*
* @param trainingSet the input training RDD[XGBLabeledPoint]
* @param evalRDDMap the eval set
* @param hasGroup if has group
* @return function to build (RDD[() => Watches], the cached RDD)
*/
private[spark] def buildRDDLabeledPointToRDDWatches(
trainingSet: RDD[XGBLabeledPoint],
evalRDDMap: Map[String, RDD[XGBLabeledPoint]] = Map(),
hasGroup: Boolean = false):
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]) = {
xgbExecParams: XGBoostExecutionParams =>
composeInputData(trainingSet, hasGroup, xgbExecParams.numWorkers) match {
case Left(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
case Right(trainingData) =>
val cachedRDD = if (xgbExecParams.cacheTrainingSet) {
Some(trainingData.persist(StorageLevel.MEMORY_AND_DISK))
} else None
(trainForNonRanking(trainingData, xgbExecParams, evalRDDMap), cachedRDD)
}
}
/**
* Transform RDD according to group column
*
* @param trainingData the input XGBLabeledPoint RDD
* @param hasGroup if has group column
* @param nWorkers total xgboost number workers to run xgboost tasks
* @return Either: the left is RDD with group, and the right is RDD without group
*/
private def composeInputData(
trainingData: RDD[XGBLabeledPoint],
hasGroup: Boolean,
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
if (hasGroup) {
Left(repartitionForTrainingGroup(trainingData, nWorkers))
} else {
Right(trainingData)
}
}
/**
* Repartition trainingData with group directly may cause data chaos, since the same group data
* may be split into different partitions.
*
* The first step is to aggregate the same group into same partition
* The second step is to repartition to nWorkers
*
* TODO, Could we repartition trainingData on group?
*/
private[spark] def repartitionForTrainingGroup(trainingData: RDD[XGBLabeledPoint],
nWorkers: Int): RDD[Array[XGBLabeledPoint]] = {
val allGroups = aggByGroupInfo(trainingData)
logger.info(s"repartitioning training group set to $nWorkers partitions")
allGroups.repartition(nWorkers)
}
/**
* Build RDD[() => Watches] for Ranking
* @param trainingData the training data RDD
* @param xgbExecutionParams xgboost execution params
* @param evalSetsMap the eval RDD
* @return RDD[() => Watches]
*/
private def trainForRanking(
trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
val buildWatches = () => Watches.buildWatchesWithGroup(xgbExecutionParam,
DataUtils.processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(buildWatches)
}).cache()
} else {
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
labeledPointGroupSets => {
val buildWatches = () => Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, DataUtils.processMissingValuesWithGroup(iter,
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(buildWatches)
}).cache()
}
}
private def coPartitionGroupSets(
aggedTrainingSet: RDD[Array[XGBLabeledPoint]],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int): RDD[(String, Iterator[Array[XGBLabeledPoint]])] = {
val repartitionedDatasets = Map("train" -> aggedTrainingSet) ++ evalSets.map {
case (name, rdd) => {
val aggedRdd = aggByGroupInfo(rdd)
if (aggedRdd.getNumPartitions != nWorkers) {
name -> aggedRdd.repartition(nWorkers)
} else {
name -> aggedRdd
}
}
}
repartitionedDatasets.foldLeft(aggedTrainingSet.sparkContext.parallelize(
Array.fill[(String, Iterator[Array[XGBLabeledPoint]])](nWorkers)(null), nWorkers)) {
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter) {
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private def aggByGroupInfo(trainingData: RDD[XGBLabeledPoint]) = {
val normalGroups: RDD[Array[XGBLabeledPoint]] = trainingData.mapPartitions(
// LabeledPointGroupIterator returns (Boolean, Array[XGBLabeledPoint])
new LabeledPointGroupIterator(_)).filter(!_.isEdgeGroup).map(_.points)
// edge groups with partition id.
val edgeGroups: RDD[(Int, XGBLabeledPointGroup)] = trainingData.mapPartitions(
new LabeledPointGroupIterator(_)).filter(_.isEdgeGroup).map(
group => (TaskContext.getPartitionId(), group))
// group chunks from different partitions together by group id in XGBLabeledPoint.
// use groupBy instead of aggregateBy since all groups within a partition have unique group ids.
val stitchedGroups: RDD[Array[XGBLabeledPoint]] = edgeGroups.groupBy(_._2.groupId).map(
groups => {
val it: Iterable[(Int, XGBLabeledPointGroup)] = groups._2
// sorted by partition id and merge list of Array[XGBLabeledPoint] into one array
it.toArray.sortBy(_._1).flatMap(_._2.points)
})
normalGroups.union(stitchedGroups)
}
/**
* Build RDD[() => Watches] for Non-Ranking
* @param trainingData the training data RDD
* @param xgbExecutionParams xgboost execution params
* @param evalSetsMap the eval RDD
* @return RDD[() => Watches]
*/
private def trainForNonRanking(
trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[() => Watches] = {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions { labeledPoints => {
val buildWatches = () => Watches.buildWatches(xgbExecutionParams,
DataUtils.processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(buildWatches)
}}.cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
mapPartitions {
nameAndLabeledPointSets =>
val buildWatches = () => Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, DataUtils.processMissingValues(iter,
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(buildWatches)
}.cache()
}
}
private def coPartitionNoGroupSets(
trainingData: RDD[XGBLabeledPoint],
evalSets: Map[String, RDD[XGBLabeledPoint]],
nWorkers: Int) = {
// eval_sets is supposed to be set by the caller of [[trainDistributed]]
val allDatasets = Map("train" -> trainingData) ++ evalSets
val repartitionedDatasets = allDatasets.map { case (name, rdd) =>
if (rdd.getNumPartitions != nWorkers) {
(name, rdd.repartition(nWorkers))
} else {
(name, rdd)
}
}
repartitionedDatasets.foldLeft(trainingData.sparkContext.parallelize(
Array.fill[(String, Iterator[XGBLabeledPoint])](nWorkers)(null), nWorkers)) {
case (rddOfIterWrapper, (name, rddOfIter)) =>
rddOfIterWrapper.zipPartitions(rddOfIter) {
(itrWrapper, itr) =>
if (!itr.hasNext) {
logger.error("when specifying eval sets as dataframes, you have to ensure that " +
"the number of elements in each dataframe is larger than the number of workers")
throw new Exception("too few elements in evaluation sets")
}
val itrArray = itrWrapper.toArray
if (itrArray.head != null) {
new IteratorWrapper(itrArray :+ (name -> itr))
} else {
new IteratorWrapper(Array(name -> itr))
}
}
}
}
private[scala] def getCacheDirName(useExternalMemory: Boolean): Option[String] = {
val taskId = TaskContext.getPartitionId().toString
if (useExternalMemory) {
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
Some(dir.toAbsolutePath.toString)
} else {
None
}
}
}
class IteratorWrapper[T](arrayOfXGBLabeledPoints: Array[(String, Iterator[T])])
extends Iterator[(String, Iterator[T])] {
private var currentIndex = 0
override def hasNext: Boolean = currentIndex <= arrayOfXGBLabeledPoints.length - 1
override def next(): (String, Iterator[T]) = {
currentIndex += 1
arrayOfXGBLabeledPoints(currentIndex - 1)
}
}
/**
* Training data group in a RDD partition.
*
* @param groupId The group id
* @param points Array of XGBLabeledPoint within the same group.
* @param isEdgeGroup whether it is a first or last group in a RDD partition.
*/
private[spark] case class XGBLabeledPointGroup(
groupId: Int,
points: Array[XGBLabeledPoint],
isEdgeGroup: Boolean)
/**
* Within each RDD partition, group the <code>XGBLabeledPoint</code> by group id.</p>
* And the first and the last groups may not have all the items due to the data partition.
* <code>LabeledPointGroupIterator</code> organizes data in a tuple format:
* (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).</p>
* The edge groups across partitions can be stitched together later.
* @param base collection of <code>XGBLabeledPoint</code>
*/
private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
extends AbstractIterator[XGBLabeledPointGroup] {
private var firstPointOfNextGroup: XGBLabeledPoint = null
private var isNewGroup = false
override def hasNext: Boolean = {
base.hasNext || isNewGroup
}
override def next(): XGBLabeledPointGroup = {
val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
var isFirstGroup = true
if (firstPointOfNextGroup != null) {
builder += firstPointOfNextGroup
isFirstGroup = false
}
isNewGroup = false
while (!isNewGroup && base.hasNext) {
val point = base.next()
val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
firstPointOfNextGroup = point
if (point.group == groupId) {
// add to current group
builder += point
} else {
// start a new group
isNewGroup = true
}
}
val isLastGroup = !isNewGroup
val result = builder.result()
val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
group
}
}

View File

@ -1,72 +0,0 @@
/*
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset}
/**
* PreXGBoost implementation provider
*/
private[scala] trait PreXGBoostProvider {
/**
* Whether the provider is enabled or not
* @param dataset the input dataset
* @return Boolean
*/
def providerEnabled(dataset: Option[Dataset[_]]): Boolean = false
/**
* Transform schema
* @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
* XGBoostRegressor/XGBoostRegressionModel
* @param schema the input schema
* @return the transformed schema
*/
def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
/**
* Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
*
* @param estimator supports XGBoostClassifier and XGBoostRegressor
* @param dataset the training data
* @param params all user defined and defaulted params
* @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
* RDD[() => Watches] will be used as the training input to build DMatrix
* Option[ RDD[_] ] is the optional cached RDD
*/
def buildDatasetToRDD(
estimator: Estimator[_],
dataset: Dataset[_],
params: Map[String, Any]):
XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]])
/**
* Transform Dataset
*
* @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
* @param dataset the input Dataset to transform
* @return the transformed DataFrame
*/
def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,12 +14,49 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.util
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
// based on org.apache.spark.util copy /paste
object Utils {
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
private[scala] object Utils {
private[spark] implicit class XGBLabeledPointFeatures(
val labeledPoint: XGBLabeledPoint
) extends AnyVal {
/** Converts the point to [[MLLabeledPoint]]. */
private[spark] def asML: MLLabeledPoint = {
MLLabeledPoint(labeledPoint.label, labeledPoint.features)
}
/**
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
*/
def features: Vector = if (labeledPoint.indices == null) {
Vectors.dense(labeledPoint.values.map(_.toDouble))
} else {
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
}
}
private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
/**
* Converts a [[Vector]] to a data point with a dummy label.
*
* This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
* for prediction.
*/
// TODO support sparsevector
def asXGB: XGBLabeledPoint = v match {
case v: DenseVector =>
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
case v: SparseVector =>
XGBLabeledPoint(0.0f, v.size, v.indices, v.toDense.values.map(_.toFloat))
}
}
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
@ -27,6 +64,7 @@ object Utils {
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
// scalastyle:off classforname
/** Preferred alternative to Class.forName(className) */
def classForName(className: String): Class[_] = {
Class.forName(className, true, getContextOrSparkClassLoader)
@ -35,6 +73,7 @@ object Utils {
/**
* Get the TypeHints according to the value
*
* @param value the instance of class to be serialized
* @return if value is null,
* return NoTypeHints
@ -53,6 +92,7 @@ object Utils {
/**
* Get the TypeHints according to the saved jsonClass field
*
* @param json
* @return TypeHints
*/
@ -68,4 +108,17 @@ object Utils {
FullTypeHints(List(Utils.classForName(className)))
}.getOrElse(NoTypeHints)
}
val TRAIN_NAME = "train"
val VALIDATION_NAME = "eval"
/** Executes the provided code block and then closes the resource */
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}

View File

@ -18,227 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.sql.SparkSession
/**
* Rabit tracker configurations.
*
* @param timeout The number of seconds before timeout waiting for workers to connect. and
* for the tracker to shutdown.
* @param hostIp The Rabit Tracker host IP address.
* This is only needed if the host IP cannot be automatically guessed.
* @param port The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0)
}
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
private[scala] case class XGBoostExecutionParams(
private[spark] case class RuntimeParams(
numWorkers: Int,
numRounds: Int,
useExternalMemory: Boolean,
obj: ObjectiveTrait,
eval: EvalTrait,
missing: Float,
allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf,
checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingRounds: Int,
cacheTrainingSet: Boolean,
device: Option[String],
device: String,
isLocal: Boolean,
featureNames: Option[Array[String]],
featureTypes: Option[Array[String]],
runOnGpu: Boolean) {
private var rawParamMap: Map[String, Any] = _
def setRawParamMap(inputMap: Map[String, Any]): Unit = {
rawParamMap = inputMap
}
def toMap: Map[String, Any] = {
rawParamMap
}
}
private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){
private val logger = LogFactory.getLog("XGBoostSpark")
private val isLocal = sc.isLocal
private val overridedParams = overrideParams(rawParams, sc)
validateSparkSslConf()
/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
* If so, throw an exception unless this safety measure has been explicitly overridden
* via conf `xgboost.spark.ignoreSsl`.
*/
private def validateSparkSslConf(): Unit = {
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
SparkSession.getActiveSession match {
case Some(ss) =>
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
case None =>
(sc.getConf.getBoolean("spark.ssl.enabled", false),
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
}
if (sparkSslEnabled) {
if (xgboostSparkIgnoreSsl) {
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
} else {
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
"To override this protection and still use xgboost-spark at your own risk, " +
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
}
}
}
/**
* we should not include any nested structure in the output of this function as the map is
* eventually to be feed to xgboost4j layer
*/
private def overrideParams(
params: Map[String, Any],
sc: SparkContext): Map[String, Any] = {
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
var overridedParams = params
if (overridedParams.contains("nthread")) {
val nThread = overridedParams("nthread").toString.toInt
require(nThread <= coresPerTask,
s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)")
} else {
overridedParams = overridedParams + ("nthread" -> coresPerTask)
}
val numEarlyStoppingRounds = overridedParams.getOrElse(
"num_early_stopping_rounds", 0).asInstanceOf[Int]
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
throw new IllegalArgumentException("custom_eval does not support early stopping")
}
overridedParams
}
/**
* The Map parameters accepted by estimator's constructor may have string type,
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
* kind of parameters into the correct type in the function.
*
* @return XGBoostExecutionParams
*/
def buildXGBRuntimeParams: XGBoostExecutionParams = {
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
if (obj != null) {
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
"is not defined, you have to specify the objective type as classification or regression" +
" with a customized objective function")
}
var trainTestRatio = 1.0
if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
"'eval_set_names'")
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
}
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
val round = overridedParams("num_round").asInstanceOf[Int]
val useExternalMemory = overridedParams
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
val allowNonZeroForMissing = overridedParams
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
val device: Option[String] = overridedParams.get("device").map(_.toString)
val deviceIsGpu = device.exists(_ == "cuda")
require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
// back-compatible with "gpu_hist"
val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu
val trackerConf = overridedParams.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
"instance of TrackerConf.")
}
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
val earlyStoppingRounds = overridedParams.getOrElse(
"num_early_stopping_rounds", 0).asInstanceOf[Int]
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]
val featureNames = if (overridedParams.contains("feature_names")) {
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
} else None
val featureTypes = if (overridedParams.contains("feature_types")){
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
} else None
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
checkpointParam,
inputParams,
earlyStoppingRounds,
cacheTrainingSet,
device,
isLocal,
featureNames,
featureTypes,
runOnGpu
)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
}
runOnGpu: Boolean,
obj: Option[ObjectiveTrait] = None,
eval: Option[EvalTrait] = None)
/**
* A trait to manage stage-level scheduling
*/
private[spark] trait XGBoostStageLevel extends Serializable {
private[spark] trait StageLevelScheduling extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
@ -255,8 +58,7 @@ private[spark] trait XGBoostStageLevel extends Serializable {
* @param conf spark configurations
* @return Boolean to skip stage-level scheduling or not
*/
private[spark] def skipStageLevelScheduling(
sparkVersion: String,
private[spark] def skipStageLevelScheduling(sparkVersion: String,
runOnGpu: Boolean,
conf: SparkConf): Boolean = {
if (runOnGpu) {
@ -313,14 +115,13 @@ private[spark] trait XGBoostStageLevel extends Serializable {
* on a single executor simultaneously.
*
* @param sc the spark context
* @param rdd which rdd to be applied with new resource profile
* @return the original rdd or the changed rdd
* @param rdd the rdd to be applied with new resource profile
* @return the original rdd or the modified rdd
*/
private[spark] def tryStageLevelScheduling(
sc: SparkContext,
xgbExecParams: XGBoostExecutionParams,
rdd: RDD[(Booster, Map[String, Array[Float]])]
): RDD[(Booster, Map[String, Array[Float]])] = {
private[spark] def tryStageLevelScheduling[T](sc: SparkContext,
xgbExecParams: RuntimeParams,
rdd: RDD[T]
): RDD[T] = {
val conf = sc.getConf
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
@ -360,7 +161,7 @@ private[spark] trait XGBoostStageLevel extends Serializable {
}
}
object XGBoost extends XGBoostStageLevel {
private[spark] object XGBoost extends StageLevelScheduling {
private val logger = LogFactory.getLog("XGBoostSpark")
def getGPUAddrFromResources: Int = {
@ -383,172 +184,118 @@ object XGBoost extends XGBoostStageLevel {
}
}
private def buildWatchesAndCheck(buildWatchesFun: () => Watches): Watches = {
val watches = buildWatchesFun()
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
if (!watches.toMap.contains("train")) {
throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" +
s" ${TaskContext.getPartitionId()}")
}
watches
}
private def buildDistributedBooster(
buildWatches: () => Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, Object],
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
/**
* Train a XGBoost Boost on the dataset in the Watches
*
* @param watches holds the dataset to be trained
* @param runtimeParams XGBoost runtime parameters
* @param xgboostParams XGBoost library paramters
* @return a booster and the metrics
*/
private def trainBooster(watches: Watches,
runtimeParams: RuntimeParams,
xgboostParams: Map[String, Any]
): (Booster, Array[Array[Float]]) = {
var watches: Watches = null
val taskId = TaskContext.getPartitionId().toString
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
val numEarlyStoppingRounds = runtimeParams.earlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ =>
Array.ofDim[Float](runtimeParams.numRounds))
try {
Communicator.init(rabitEnv)
watches = buildWatchesAndCheck(buildWatches)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
val externalCheckpointParams = xgbExecutionParam.checkpointParam
var params = xgbExecutionParam.toMap
if (xgbExecutionParam.runOnGpu) {
val gpuId = if (xgbExecutionParam.isLocal) {
// For local mode, force gpu id to primary device
0
var params = xgboostParams
if (runtimeParams.runOnGpu) {
val gpuId = if (runtimeParams.isLocal) {
TaskContext.get().partitionId() % runtimeParams.numWorkers
} else {
getGPUAddrFromResources
}
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("device" -> s"cuda:$gpuId")
}
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
val booster = SXGBoost.train(watches.toMap("train"), params, runtimeParams.numRounds,
watches.toMap, metrics, runtimeParams.obj.getOrElse(null),
runtimeParams.eval.getOrElse(null), earlyStoppingRound = numEarlyStoppingRounds)
(booster, metrics)
}
if (TaskContext.get().partitionId() == 0) {
/**
* Train a XGBoost booster with parameters on the dataset
*
* @param input the input dataset for training
* @param runtimeParams the runtime parameters for jvm
* @param xgboostParams the xgboost parameters to pass to xgboost library
* @return the booster and the metrics
*/
def train(input: RDD[Watches],
runtimeParams: RuntimeParams,
xgboostParams: Map[String, Any]): (Booster, Map[String, Array[Float]]) = {
val sc = input.sparkContext
logger.info(s"Running XGBoost ${spark.VERSION} with parameters: $xgboostParams")
// TODO Rabit tracker exception handling.
val trackerConf = runtimeParams.trackerConf
val tracker = new RabitTracker(runtimeParams.numWorkers,
trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
try {
val rabitEnv = tracker.getWorkerArgs()
val boostersAndMetrics = input.barrier().mapPartitions { iter =>
val partitionId = TaskContext.getPartitionId()
rabitEnv.put("DMLC_TASK_ID", partitionId.toString)
try {
Communicator.init(rabitEnv)
require(iter.hasNext, "Failed to create DMatrix")
val watches = iter.next()
try {
val (booster, metrics) = trainBooster(watches, runtimeParams, xgboostParams)
if (partitionId == 0) {
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} else {
Iterator.empty
}
} catch {
case xgbException: XGBoostError =>
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
throw xgbException
} finally {
if (watches != null) {
watches.delete()
}
}
} finally {
// If shutdown throws exception, then the real exception for
// training will be swallowed,
try {
Communicator.shutdown()
if (watches != null) watches.delete()
} catch {
case e: Throwable =>
logger.error("Communicator.shutdown error: ", e)
}
}
}
// Executes the provided code block inside a tracker and then stops the tracker
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
try {
block(tracker)
} finally {
tracker.stop()
}
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
sc: SparkContext,
buildTrainingData: XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]),
params: Map[String, Any]):
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
// Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
try {
val (booster, metrics) = withTracker(
runtimeParams.numWorkers,
runtimeParams.trackerConf
) { tracker =>
val rabitEnv = tracker.getWorkerArgs()
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None
// take the first Watches to train
if (iter.hasNext) {
optionWatches = Some(iter.next())
}
optionWatches.map { buildWatches =>
buildDistributedBooster(buildWatches,
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
}.getOrElse(throw new RuntimeException("No Watches to train"))
}
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
boostersAndMetrics)
val rdd = tryStageLevelScheduling(sc, runtimeParams, boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics)
}
// we should delete the checkpoint directory after a successful training
runtimeParams.checkpointParam.foreach {
cpParam =>
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanPath()
}
}
val (booster, metrics) = rdd.repartition(1).collect()(0)
(booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
logger.error("XGBoost job was aborted due to ", t)
throw t
} finally {
optionalCachedRDD.foreach(_.unpersist())
try {
tracker.stop()
} catch {
case t: Throwable => logger.error(t)
}
}
}
}
class Watches private[scala] (
val datasets: Array[DMatrix],
class Watches private[scala](val datasets: Array[DMatrix],
val names: Array[String],
val cacheDirName: Option[String]) {
@ -568,211 +315,14 @@ class Watches private[scala] (
override def toString: String = toMap.toString
}
private object Watches {
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
val builder = new mutable.ArrayBuilder.ofFloat()
var nTotal = 0
var nUndefined = 0
while (baseMargins.hasNext) {
nTotal += 1
val baseMargin = baseMargins.next()
if (baseMargin.isNaN) {
nUndefined += 1 // don't waste space for all-NaNs.
} else {
builder += baseMargin
}
}
if (nUndefined == nTotal) {
None
} else if (nUndefined == 0) {
Some(builder.result())
} else {
throw new IllegalArgumentException(
s"Encountered a partition with $nUndefined NaN base margin values. " +
s"If you want to specify base margin, ensure all values are non-NaN.")
}
}
def buildWatches(
nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
cachedDirName: Option[String]): Watches = {
val dms = nameAndLabeledPointSets.map {
case (name, labeledPoints) =>
val baseMargins = new mutable.ArrayBuilder.ofFloat
val duplicatedItr = labeledPoints.map(labeledPoint => {
baseMargins += labeledPoint.baseMargin
labeledPoint
})
val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
if (baseMargin.isDefined) {
dMatrix.setBaseMargin(baseMargin.get)
}
(name, dMatrix)
}.toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
}
def buildWatches(
xgbExecutionParams: XGBoostExecutionParams,
labeledPoints: Iterator[XGBLabeledPoint],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
val seed = xgbExecutionParams.xgbInputParams.seed
val r = new Random(seed)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
testPoints += labeledPoint
testBaseMargins += labeledPoint.baseMargin
} else {
trainBaseMargins += labeledPoint.baseMargin
}
accepted
}
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}
if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
def buildWatchesWithGroup(
nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
cachedDirName: Option[String]): Watches = {
val dms = nameAndlabeledPointGroupSets.map {
case (name, labeledPointsGroups) =>
val baseMargins = new mutable.ArrayBuilder.ofFloat
val groupsInfo = new mutable.ArrayBuilder.ofInt
val weights = new mutable.ArrayBuilder.ofFloat
val iter = labeledPointsGroups.filter(labeledPointGroup => {
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.map { labeledPoint => {
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (groupWeight != labeledPoint.weight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
baseMargins += labeledPoint.baseMargin
groupSize += 1
labeledPoint
}
}
weights += groupWeight
groupsInfo += groupSize
true
})
val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
if (baseMargin.isDefined) {
dMatrix.setBaseMargin(baseMargin.get)
}
dMatrix.setGroup(groupsInfo.result())
dMatrix.setWeight(weights.result())
(name, dMatrix)
}.toArray
new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
}
def buildWatchesWithGroup(
xgbExecutionParams: XGBoostExecutionParams,
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
val seed = xgbExecutionParams.xgbInputParams.seed
val r = new Random(seed)
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
val testBaseMargins = new mutable.ArrayBuilder.ofFloat
val trainGroups = new mutable.ArrayBuilder.ofInt
val testGroups = new mutable.ArrayBuilder.ofInt
val trainWeights = new mutable.ArrayBuilder.ofFloat
val testWeights = new mutable.ArrayBuilder.ofFloat
val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.foreach(labeledPoint => {
testPoints += labeledPoint
testBaseMargins += labeledPoint.baseMargin
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (labeledPoint.weight != groupWeight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
groupSize += 1
})
testWeights += groupWeight
testGroups += groupSize
} else {
var groupWeight = -1.0f
var groupSize = 0
labeledPointGroup.foreach { labeledPoint => {
if (groupWeight < 0) {
groupWeight = labeledPoint.weight
} else if (labeledPoint.weight != groupWeight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
trainBaseMargins += labeledPoint.baseMargin
groupSize += 1
}}
trainWeights += groupWeight
trainGroups += groupSize
}
accepted
}
val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
trainMatrix.setGroup(trainGroups.result())
trainMatrix.setWeight(trainWeights.result())
val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
if (trainTestRatio < 1.0) {
testMatrix.setGroup(testGroups.result())
testMatrix.setWeight(testWeights.result())
}
val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}
if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
}
/**
* Rabit tracker configurations.
*
* @param timeout The number of seconds before timeout waiting for workers to connect. and
* for the tracker to shutdown.
* @param hostIp The Rabit Tracker host IP address.
* This is only needed if the host IP cannot be automatically guessed.
* @param port The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
private[spark] case class TrackerConf(timeout: Int = 0, hostIp: String = "", port: Int = 0)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,490 +16,190 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import scala.collection.{Iterator, mutable}
import scala.collection.mutable
import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.{col, udf}
import org.json4s.DefaultFormats
class XGBoostClassifier (
override val uid: String,
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
class XGBoostClassifier(override val uid: String,
private[spark] val xgboostParams: Map[String, Any])
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
with XGBoostClassifierParams with DefaultParamsWritable {
with XGBoostEstimator[XGBoostClassifier, XGBoostClassificationModel]
with XGBProbabilisticClassifierParams[XGBoostClassifier] {
def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
def this() = this(XGBoostClassifier._uid, Map.empty)
def this(uid: String) = this(uid, Map[String, Any]())
def this(uid: String) = this(uid, Map.empty)
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbc"), xgboostParams)
def this(xgboostParams: Map[String, Any]) = this(XGBoostClassifier._uid, xgboostParams)
XGBoost2MLlibParams(xgboostParams)
xgboost2SparkParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)
private var numberClasses = 0
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
def setNumClass(value: Int): this.type = set(numClass, value)
// setters for general params
def setNumRound(value: Int): this.type = set(numRound, value)
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
def setNthread(value: Int): this.type = set(nthread, value)
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
def setSilent(value: Int): this.type = set(silent, value)
def setMissing(value: Float): this.type = set(missing, value)
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setSeed(value: Long): this.type = set(seed, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
def setSubsample(value: Double): this.type = set(subsample, value)
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
def setLambda(value: Double): this.type = set(lambda, value)
def setAlpha(value: Double): this.type = set(alpha, value)
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
def setSampleType(value: String): this.type = set(sampleType, value)
def setNormalizeType(value: String): this.type = set(normalizeType, value)
def setRateDrop(value: Double): this.type = set(rateDrop, value)
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setObjectiveType(value: String): this.type = set(objectiveType, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
set(maximizeEvaluationMetrics, value)
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
allowNonZeroForMissing,
value
)
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)
def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
if ($(objective).startsWith("multi")) {
// multi
"mlogloss"
private def validateObjective(dataset: Dataset[_]): Unit = {
// If the objective is set explicitly, it must be in BINARY_CLASSIFICATION_OBJS and
// MULTICLASSIFICATION_OBJS
val obj = if (isSet(objective)) {
val tmpObj = getObjective
val supportedObjs = BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq
require(supportedObjs.contains(tmpObj),
s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
Some(tmpObj)
} else {
// binary
"logloss"
}
None
}
// Callback from PreXGBoost
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
def inferNumClasses: Int = {
var num = getNumClass
// Infer num class if num class is not set explicitly.
// Note that user sets the num classes explicitly, we're not checking that.
if (num == 0) {
num = SparkUtils.getNumClasses(dataset, getLabelCol)
}
require(num > 0)
num
}
// objective is set explicitly.
if (obj.isDefined) {
if (MULTICLASSIFICATION_OBJS.contains(getObjective)) {
numberClasses = inferNumClasses
setNumClass(numberClasses)
} else {
transformSchemaWithFeaturesCols(true, schema)
numberClasses = 2
// binary classification doesn't require num_class be set
require(!isSet(numClass), "num_class is not allowed for binary classification")
}
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
val _numClasses = getNumClasses(dataset)
if (isDefined(numClass) && $(numClass) != _numClasses) {
throw new Exception("The number of classes in dataset doesn't match " +
"\'num_class\' in xgboost params.")
}
if (_numClasses == 2) {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to binary:logistic
} else {
// infer the objective according to the num_class
numberClasses = inferNumClasses
if (numberClasses <= 2) {
setObjective("binary:logistic")
}
} else if (_numClasses > 2) {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to multi:softprob
logger.warn("Inferred for binary classification, set the objective to binary:logistic")
require(!isSet(numClass), "num_class is not allowed for binary classification")
} else {
logger.warn("Inferred for multi classification, set the objective to multi:softprob")
setObjective("multi:softprob")
setNumClass(numberClasses)
}
}
}
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
/**
* Validate the parameters before training, throw exception if possible
*/
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)
validateObjective(dataset)
}
if (isDefined(customObj) && $(customObj) != null) {
set(objectiveType, "classification")
override protected def createModel(booster: Booster, summary: XGBoostTrainingSummary):
XGBoostClassificationModel = {
new XGBoostClassificationModel(uid, numberClasses, booster, Option(summary))
}
// Packing with all params plus params user defined
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
transformSchema(dataset.schema, logging = true)
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
buildTrainingData, derivedXGBParamMap)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
model
}
override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
}
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
override def load(path: String): XGBoostClassifier = super.load(path)
private val _uid = Identifiable.randomUID("xgbc")
}
class XGBoostClassificationModel private[ml](
override val uid: String,
override val numClasses: Int,
private[scala] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with InferenceParams
with MLWritable with Serializable {
val uid: String,
val numClasses: Int,
val nativeBooster: Booster,
val summary: Option[XGBoostTrainingSummary] = None
) extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostModel[XGBoostClassificationModel]
with XGBProbabilisticClassifierParams[XGBoostClassificationModel] {
import XGBoostClassificationModel._
def this(uid: String) = this(uid, 0, null)
// only called in copy()
def this(uid: String) = this(uid, 2, null)
override protected[spark] def postTransform(dataset: Dataset[_],
pred: PredictedColumns): Dataset[_] = {
var output = super.postTransform(dataset, pred)
/**
* Get the native booster instance of this model.
* This is used to call low-level APIs on native booster, such as "getFeatureScore".
*/
def nativeBooster: Booster = _booster
// Always use probability col to get the prediction
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
def setMissing(value: Float): this.type = set(missing, value)
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
allowNonZeroForMissing,
value
)
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
$(allowNonZeroForMissing)
))
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) {
math.round(probability(0))
} else {
probability2prediction(Vectors.dense(probability))
}
}
// Actually we don't use this function at all, to make it pass compiler check.
override def predictRaw(features: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
}
// Actually we don't use this function at all, to make it pass compiler check.
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
}
private[scala] def produceResultIterator(
originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
}
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
}
} else {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
case ((originals: Row, rawPrediction: Row), probability: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
}
}
}
private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
Array[Iterator[Row]] = {
val rawPredictionItr = {
booster.predict(dm, outPutMargin = true, $(treeLimit)).
map(Row(_)).iterator
}
val probabilityItr = {
booster.predict(dm, outPutMargin = false, $(treeLimit)).
map(Row(_)).iterator
}
val predLeafItr = {
if (isDefined(leafPredictionCol)) {
booster.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
val predContribItr = {
if (isDefined(contribPredictionCol)) {
booster.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = PreXGBoost.transformDataset(this, dataset)
var numColsOutput = 0
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
val raw = rawPrediction.map(_.toDouble).toArray
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
Vectors.dense(rawPredictions)
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
numColsOutput += 1
}
if (getObjective.equals("multi:softmax")) {
if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
if (getObjective == "multi:softmax") {
// For objective=multi:softmax scenario, there is no probability predicted from xgboost.
// Instead, the probability column will be filled with real prediction
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
probability(0)
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
output = output.withColumn(getPredictionCol, predictUDF(col(TMP_TRANSFORMED_COL)))
} else {
val predCol = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
output = output.withColumn(getPredictionCol, predCol(col(TMP_TRANSFORMED_COL)))
}
}
} else {
if (isDefinedNonEmpty(probabilityCol) && pred.predTmp) {
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
output = output.withColumn(TMP_TRANSFORMED_COL,
probabilityUDF(output.col(TMP_TRANSFORMED_COL)))
.withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol)
}
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// From XGBoost probability to MLlib prediction
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
if (pred.predRaw) {
val rawPredictionUDF = udf { raw: mutable.WrappedArray[Float] =>
val rawF = raw.map(_.toDouble).toArray
val rawPredictions = if (numClasses == 2) Array(-rawF(0), rawF(0)) else rawF
Vectors.dense(rawPredictions)
}
output = output.withColumn(getRawPredictionCol,
rawPredictionUDF(output.col(getRawPredictionCol)))
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData
.toDF
.drop(col(_rawPredictionCol))
.drop(col(_probabilityCol))
output.drop(TMP_TRANSFORMED_COL)
}
override def copy(extra: ParamMap): XGBoostClassificationModel = {
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
newModel.setSummary(summary).setParent(parent)
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses,
nativeBooster, summary), extra)
newModel.setParent(parent)
}
override def write: MLWriter =
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
}
override def predictRaw(features: Vector): Vector =
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
}
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
private[scala] val _rawPredictionCol = "_rawPrediction"
private[scala] val _probabilityCol = "_probability"
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
override def load(path: String): XGBoostClassificationModel = super.load(path)
private[XGBoostClassificationModel]
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}
private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
/** Checked against metadata when loading model */
private val className = classOf[XGBoostClassificationModel].getName
override def read: MLReader[XGBoostClassificationModel] = new ModelReader
private class ModelReader extends XGBoostModelReader[XGBoostClassificationModel] {
override def load(path: String): XGBoostClassificationModel = {
implicit val sc = super.sparkSession.sparkContext
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
val xgbModel = loadBooster(path)
val meta = SparkUtils.loadMetadata(path, sc)
implicit val format = DefaultFormats
val numClasses = (meta.params \ "numClass").extractOpt[Int].getOrElse(2)
val model = new XGBoostClassificationModel(meta.uid, numClasses, xgbModel)
meta.getAndSetParams(model)
model
}
}

View File

@ -0,0 +1,641 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.util.ServiceLoader
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.functions.array_to_vector
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.spark.params._
/**
* Hold the column index
*/
private[spark] case class ColumnIndices(
labelId: Int,
featureId: Option[Int], // the feature type is VectorUDT or Array
featureIds: Option[Seq[Int]], // the feature type is columnar
weightId: Option[Int],
marginId: Option[Int],
groupId: Option[Int])
private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]] {
private var dataset: Option[Dataset[_]] = None
def setEvalDataset(ds: Dataset[_]): T = {
this.dataset = Some(ds)
this.asInstanceOf[T]
}
def getEvalDataset(): Option[Dataset[_]] = {
this.dataset
}
}
private[spark] trait PluginMixin {
// Find the XGBoostPlugin by ServiceLoader
private val plugin: Option[XGBoostPlugin] = {
val classLoader = Option(Thread.currentThread().getContextClassLoader)
.getOrElse(getClass.getClassLoader)
val serviceLoader = ServiceLoader.load(classOf[XGBoostPlugin], classLoader)
// For now, we only trust GpuXGBoostPlugin.
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
"ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin")).toList match {
case Nil => None
case head :: Nil =>
Some(head)
case _ => None
}
}
/** Visible for testing */
protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin
protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
plugin.map(_.isEnabled(dataset)).getOrElse(false)
}
}
private[spark] trait XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable
with PluginMixin {
protected val logger = LogFactory.getLog("XGBoostSpark")
/**
* Cast the field in schema to the desired data type.
*
* @param dataset the input dataset
* @param name which column will be casted to float if possible.
* @param targetType the targetd data type
* @return Dataset
*/
private[spark] def castIfNeeded(schema: StructType,
name: String,
targetType: DataType = FloatType): Column = {
if (!(schema(name).dataType == targetType)) {
val meta = schema(name).metadata
col(name).as(name, meta).cast(targetType)
} else {
col(name)
}
}
/**
* Repartition the dataset to the numWorkers if needed.
*
* @param dataset to be repartition
* @return the repartitioned dataset
*/
private[spark] def repartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
val numPartitions = dataset.rdd.getNumPartitions
if (getForceRepartition || getNumWorkers != numPartitions) {
dataset.repartition(getNumWorkers)
} else {
dataset
}
}
/**
* Build the columns indices.
*/
private[spark] def buildColumnIndices(schema: StructType): ColumnIndices = {
// Get feature id(s)
val (featureIds: Option[Seq[Int]], featureId: Option[Int]) =
if (getFeaturesCols.length != 0) {
(Some(getFeaturesCols.map(schema.fieldIndex).toSeq), None)
} else {
(None, Some(schema.fieldIndex(getFeaturesCol)))
}
// function to get the column id according to the parameter
def columnId(param: Param[String]): Option[Int] = {
if (isDefinedNonEmpty(param)) {
Some(schema.fieldIndex($(param)))
} else {
None
}
}
// Special handle for group
val groupId: Option[Int] = this match {
case p: HasGroupCol => columnId(p.groupCol)
case _ => None
}
ColumnIndices(
labelId = columnId(labelCol).get,
featureId = featureId,
featureIds = featureIds,
columnId(weightCol),
columnId(baseMarginCol),
groupId)
}
/**
* Preprocess the dataset to meet the xgboost input requirement
*
* @param dataset
* @return
*/
private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
// Columns to be selected for XGBoost training
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema
def selectCol(c: Param[String], targetType: DataType) = {
if (isDefinedNonEmpty(c)) {
// Validation col should be a boolean column.
if (c == featuresCol) {
selectedCols.append(col($(c)))
} else {
selectedCols.append(castIfNeeded(schema, $(c), targetType))
}
}
}
Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(p => selectCol(p, FloatType))
this match {
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
case _ =>
}
val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
val columnIndices = buildColumnIndices(input.schema)
(input, columnIndices)
}
/** visible for testing */
private[spark] def toXGBLabeledPoint(dataset: Dataset[_],
columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
val isSetMissing = isSet(missing)
dataset.toDF().rdd.map { row =>
val features = row.getAs[Vector](columnIndexes.featureId.get)
val label = row.getFloat(columnIndexes.labelId)
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
val group = columnIndexes.groupId.map(row.getInt).getOrElse(-1)
// To make "0" meaningful, we convert sparse vector if possible to dense to create DMatrix.
features match {
case _: SparseVector => if (!isSetMissing) {
throw new IllegalArgumentException("We've detected sparse vectors in the dataset that " +
"need conversion to dense format. However, we can't assume 0 for missing values as " +
"it may be meaningful. Please specify the missing value explicitly to ensure " +
"accurate data representation for analysis.")
}
case _ =>
}
val values = features.toArray.map(_.toFloat)
XGBLabeledPoint(label, values.length, null, values, weight, group, baseMargin)
}
}
/**
* Convert the dataframe to RDD, visible to testing
*
* @param dataset
* @param columnsOrder the order of columns including weight/group/base margin ...
* @return RDD[Watches]
*/
private[spark] def toRdd(dataset: Dataset[_],
columnIndices: ColumnIndices): RDD[Watches] = {
val trainRDD = toXGBLabeledPoint(dataset, columnIndices)
val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames)
val featureTypes = if (getFeatureTypes.isEmpty) None else Some(getFeatureTypes)
val missing = getMissing
// Transform the labeledpoint to get margins/groups and build DMatrix
// TODO support basemargin for multiclassification
// TODO and optimization, move it into JNI.
def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = {
val dmatrix = if (columnIndices.marginId.isDefined || columnIndices.groupId.isDefined) {
val margins = new mutable.ArrayBuilder.ofFloat
val groups = new mutable.ArrayBuilder.ofInt
val groupWeights = new mutable.ArrayBuilder.ofFloat
var prevGroup = -101010
var prevWeight = -1.0f
var groupSize = 0
val transformedIter = iter.map { labeledPoint =>
if (columnIndices.marginId.isDefined) {
margins += labeledPoint.baseMargin
}
if (columnIndices.groupId.isDefined) {
if (prevGroup != labeledPoint.group) {
// starting with new group
if (prevGroup != -101010) {
// write the previous group
groups += groupSize
groupWeights += prevWeight
}
groupSize = 1
prevWeight = labeledPoint.weight
prevGroup = labeledPoint.group
} else {
// for the same group
if (prevWeight != labeledPoint.weight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
groupSize = groupSize + 1
}
}
labeledPoint
}
val dm = new DMatrix(transformedIter, null, missing)
columnIndices.marginId.foreach(_ => dm.setBaseMargin(margins.result()))
if (columnIndices.groupId.isDefined) {
if (prevGroup != -101011) {
// write the last group
groups += groupSize
groupWeights += prevWeight
}
dm.setGroup(groups.result())
// The new DMatrix() will set the weights for each instance. But ranking requires
// 1 weight for each group, so need to reset the weight.
// This is definitely optimized by moving setting group/base margin into JNI.
dm.setWeight(groupWeights.result())
}
dm
} else {
new DMatrix(iter, null, missing)
}
featureTypes.foreach(dmatrix.setFeatureTypes)
featureNames.foreach(dmatrix.setFeatureNames)
dmatrix
}
getEvalDataset().map { eval =>
val (evalDf, _) = preprocess(eval)
val evalRDD = toXGBLabeledPoint(evalDf, columnIndices)
trainRDD.zipPartitions(evalRDD) { (left, right) =>
new Iterator[Watches] {
override def hasNext: Boolean = left.hasNext
override def next(): Watches = {
val trainDMatrix = buildDMatrix(left)
val evalDMatrix = buildDMatrix(right)
new Watches(Array(trainDMatrix, evalDMatrix),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
}
}
}
}.getOrElse(
trainRDD.mapPartitions { iter =>
new Iterator[Watches] {
override def hasNext: Boolean = iter.hasNext
override def next(): Watches = {
val dm = buildDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
}
}
}
)
}
protected def createModel(booster: Booster, summary: XGBoostTrainingSummary): M
private[spark] def getRuntimeParameters(isLocal: Boolean): RuntimeParams = {
val runOnGpu = if (getDevice != "cpu" || getTreeMethod == "gpu_hist") true else false
RuntimeParams(
getNumWorkers,
getNumRound,
TrackerConf(getRabitTrackerTimeout, getRabitTrackerHostIp, getRabitTrackerPort),
getNumEarlyStoppingRounds,
getDevice,
isLocal,
runOnGpu,
Option(getCustomObj),
Option(getCustomEval)
)
}
/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
* If so, throw an exception unless this safety measure has been explicitly overridden
* via conf `xgboost.spark.ignoreSsl`.
*/
private def validateSparkSslConf(spark: SparkSession): Unit = {
val sparkSslEnabled = spark.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean
val xgbIgnoreSsl = spark.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean
if (sparkSslEnabled) {
if (xgbIgnoreSsl) {
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
} else {
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
"To override this protection and still use xgboost-spark at your own risk, " +
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
}
}
}
/**
* Validate the parameters before training, throw exception if possible
*/
protected[spark] def validate(dataset: Dataset[_]): Unit = {
validateSparkSslConf(dataset.sparkSession)
val schema = dataset.schema
SparkUtils.checkNumericType(schema, $(labelCol))
if (isDefinedNonEmpty(weightCol)) {
SparkUtils.checkNumericType(schema, $(weightCol))
}
if (isDefinedNonEmpty(baseMarginCol)) {
SparkUtils.checkNumericType(schema, $(baseMarginCol))
}
val taskCpus = dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", 1)
if (isDefined(nthread)) {
require(getNthread <= taskCpus,
s"the nthread configuration ($getNthread) must be no larger than " +
s"spark.task.cpus ($taskCpus)")
} else {
setNthread(taskCpus)
}
}
protected def train(dataset: Dataset[_]): M = {
validate(dataset)
val rdd = if (isPluginEnabled(dataset)) {
getPlugin.get.buildRddWatches(this, dataset)
} else {
val (input, columnIndexes) = preprocess(dataset)
toRdd(input, columnIndexes)
}
val xgbParams = getXGBoostParams
val runtimeParams = getRuntimeParameters(dataset.sparkSession.sparkContext.isLocal)
val (booster, metrics) = XGBoost.train(rdd, runtimeParams, xgbParams)
val summary = XGBoostTrainingSummary(metrics)
copyValues(createModel(booster, summary))
}
override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner]
}
/**
* Indicate what to be predicted
*
* @param predLeaf predicate leaf
* @param predContrib predicate contribution
* @param predRaw predicate raw
* @param predTmp predicate probability for classification, and raw for regression
*/
private[spark] case class PredictedColumns(
predLeaf: Boolean,
predContrib: Boolean,
predRaw: Boolean,
predTmp: Boolean)
/**
* XGBoost base model
*/
private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin {
protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"
override def copy(extra: ParamMap): M = defaultCopy(extra).asInstanceOf[M]
/**
* Get the native XGBoost Booster
*
* @return
*/
def nativeBooster: Booster
def summary: Option[XGBoostTrainingSummary]
protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = {
var output = dataset
// Convert leaf/contrib to the vector from array
if (pred.predLeaf) {
output = output.withColumn(getLeafPredictionCol,
array_to_vector(output.col(getLeafPredictionCol)))
}
if (pred.predContrib) {
output = output.withColumn(getContribPredictionCol,
array_to_vector(output.col(getContribPredictionCol)))
}
output
}
/**
* Preprocess the schema before transforming.
*
* @return the transformed schema and the
*/
private[spark] def preprocess(dataset: Dataset[_]): (StructType, PredictedColumns) = {
// Be careful about the order of columns
var schema = dataset.schema
/** If the parameter is defined, add it to schema and turn true */
def addToSchema(param: Param[String], colName: Option[String] = None): Boolean = {
if (isDefinedNonEmpty(param)) {
val name = colName.getOrElse($(param))
schema = schema.add(StructField(name, ArrayType(FloatType)))
true
} else {
false
}
}
val predLeaf = addToSchema(leafPredictionCol)
val predContrib = addToSchema(contribPredictionCol)
var predRaw = false
// For classification case, the transformed col is probability,
// while for others, it's the prediction value.
var predTmp = false
this match {
case p: XGBProbabilisticClassifierParams[_] => // classification case
predRaw = addToSchema(p.rawPredictionCol)
predTmp = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL))
if (isDefinedNonEmpty(predictionCol)) {
// Let's use transformed col to calculate the prediction
if (!predTmp) {
// Add the transformed col for prediction
schema = schema.add(
StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType)))
predTmp = true
}
}
case _ =>
// Rename TMP_TRANSFORMED_COL to prediction in the postTransform.
predTmp = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL))
}
(schema, PredictedColumns(predLeaf, predContrib, predRaw, predTmp))
}
/** Predict */
private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns,
batchRow: Iterator[Row]): Seq[Row] = {
var tmpOut = batchRow.toSeq.map(_.toSeq)
val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
case (a, b) => a ++ Seq(b)
}
if (pred.predLeaf) {
tmpOut = zip(tmpOut, booster.predictLeaf(dm))
}
if (pred.predContrib) {
tmpOut = zip(tmpOut, booster.predictContrib(dm))
}
if (pred.predRaw) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true))
}
if (pred.predTmp) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false))
}
tmpOut.map(Row.fromSeq)
}
override def transform(dataset: Dataset[_]): DataFrame = {
if (getPlugin.isDefined) {
return getPlugin.get.transform(this, dataset)
}
val (schema, pred) = preprocess(dataset)
val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)
// TODO configurable
val inferBatchSize = 32 << 10
// Broadcast the booster to each executor.
val featureName = getFeaturesCol
val missing = getMissing
val output = dataset.toDF().mapPartitions { rowIter =>
rowIter.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](
row.fieldIndex(featureName)))
// DMatrix used to prediction
val dm = new DMatrix(features.map(_.asXGB), null, missing)
try {
predictInternal(bBooster.value, dm, pred, batchRow.toIterator)
} finally {
dm.delete()
}
}
}(Encoders.row(schema))
bBooster.unpersist(blocking = false)
postTransform(output, pred).toDF()
}
override def write: MLWriter = new XGBoostModelWriter(this)
protected def predictSingleInstance(features: Vector): Array[Float] = {
if (nativeBooster == null) {
throw new IllegalArgumentException("The model has not been trained")
}
val dm = new DMatrix(Iterator(features.asXGB), null, getMissing)
nativeBooster.predict(data = dm)(0)
}
}
/**
* Class to write the model
*
* @param instance model to be written
*/
private[spark] class XGBoostModelWriter(instance: XGBoostModel[_]) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
if (Option(instance.nativeBooster).isEmpty) {
throw new RuntimeException("The XGBoost model has not been trained")
}
SparkUtils.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "model")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
val format = optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
try {
instance.nativeBooster.saveModel(outputStream, format)
} finally {
outputStream.close()
}
}
}
private[spark] abstract class XGBoostModelReader[M <: XGBoostModel[M]] extends MLReader[M] {
protected def loadBooster(path: String): Booster = {
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "model")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
try {
SXGBoost.loadModel(dataInStream)
} finally {
dataInStream.close()
}
}
}
// Trait for Ranker and Regressor Model
private[spark] trait RankerRegressorBaseModel[M <: XGBoostModel[M]] extends XGBoostModel[M] {
override protected[spark] def postTransform(dataset: Dataset[_],
pred: PredictedColumns): Dataset[_] = {
var output = super.postTransform(dataset, pred)
if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
originalPrediction(0).toDouble
}
output = output
.withColumn($(predictionCol), predictUDF(col(TMP_TRANSFORMED_COL)))
.drop(TMP_TRANSFORMED_COL)
}
output
}
}

View File

@ -0,0 +1,49 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.Serializable
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
trait XGBoostPlugin extends Serializable {
/**
* Whether the plugin is enabled or not, if not enabled, fallback
* to the regular CPU pipeline
*
* @param dataset the input dataset
* @return Boolean
*/
def isEnabled(dataset: Dataset[_]): Boolean
/**
* Convert Dataset to RDD[Watches] which will be fed into XGBoost
*
* @param estimator which estimator to be handled.
* @param dataset to be converted.
* @return RDD[Watches]
*/
def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): RDD[Watches]
/**
* Transform the dataset
*/
def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], dataset: Dataset[_]): DataFrame
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,405 +16,90 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{Iterator, mutable}
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
import org.apache.spark.ml.xgboost.SparkUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor._uid
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
class XGBoostRegressor (
override val uid: String,
class XGBoostRegressor(override val uid: String,
private val xgboostParams: Map[String, Any])
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
with XGBoostRegressorParams with DefaultParamsWritable {
with XGBoostEstimator[XGBoostRegressor, XGBoostRegressionModel] {
def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
def this() = this(_uid, Map[String, Any]())
def this(uid: String) = this(uid, Map[String, Any]())
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbr"), xgboostParams)
def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
XGBoost2MLlibParams(xgboostParams)
xgboost2SparkParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)
/**
* Validate the parameters before training, throw exception if possible
*/
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
def setGroupCol(value: String): this.type = set(groupCol, value)
// setters for general params
def setNumRound(value: Int): this.type = set(numRound, value)
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
def setNthread(value: Int): this.type = set(nthread, value)
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
def setSilent(value: Int): this.type = set(silent, value)
def setMissing(value: Float): this.type = set(missing, value)
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setSeed(value: Long): this.type = set(seed, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
def setSubsample(value: Double): this.type = set(subsample, value)
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
def setLambda(value: Double): this.type = set(lambda, value)
def setAlpha(value: Double): this.type = set(alpha, value)
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
def setSampleType(value: String): this.type = set(sampleType, value)
def setNormalizeType(value: String): this.type = set(normalizeType, value)
def setRateDrop(value: Double): this.type = set(rateDrop, value)
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setObjectiveType(value: String): this.type = set(objectiveType, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
set(maximizeEvaluationMetrics, value)
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
allowNonZeroForMissing,
value
)
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)
def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
if ($(objective).startsWith("rank")) {
"map"
} else {
"rmse"
// If the objective is set explicitly, it must be in REGRESSION_OBJS
if (isSet(objective)) {
val tmpObj = getObjective
require(REGRESSION_OBJS.contains(tmpObj),
s"Wrong objective for XGBoostRegressor, supported objs: ${REGRESSION_OBJS.mkString(",")}")
}
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
override protected def createModel(
booster: Booster,
summary: XGBoostTrainingSummary): XGBoostRegressionModel = {
new XGBoostRegressionModel(uid, booster, Option(summary))
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
if (!isDefined(objective)) {
// If user doesn't set objective, force it to reg:squarederror
setObjective("reg:squarederror")
}
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
}
if (isDefined(customObj) && $(customObj) != null) {
set(objectiveType, "regression")
}
transformSchema(dataset.schema, logging = true)
// Packing with all params plus params user defined
val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
buildTrainingData, derivedXGBParamMap)
val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
model
}
override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType =
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
override def load(path: String): XGBoostRegressor = super.load(path)
private val _uid = Identifiable.randomUID("xgbr")
}
class XGBoostRegressionModel private[ml] (
override val uid: String,
private[scala] val _booster: Booster)
class XGBoostRegressionModel private[ml](val uid: String,
val nativeBooster: Booster,
val summary: Option[XGBoostTrainingSummary] = None)
extends PredictionModel[Vector, XGBoostRegressionModel]
with XGBoostRegressorParams with InferenceParams
with MLWritable with Serializable {
with RankerRegressorBaseModel[XGBoostRegressionModel] {
import XGBoostRegressionModel._
// only called in copy()
def this(uid: String) = this(uid, null)
/**
* Get the native booster instance of this model.
* This is used to call low-level APIs on native booster, such as "getFeatureScore".
*/
def nativeBooster: Booster = _booster
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
def setMissing(value: Float): this.type = set(missing, value)
def setAllowNonZeroForMissing(value: Boolean): this.type = set(
allowNonZeroForMissing,
value
)
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
$(allowNonZeroForMissing)
))
_booster.predict(data = dm)(0)(0)
}
private[scala] def produceResultIterator(
originalRowItr: Iterator[Row],
predictionItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
originalRowItr.zip(predictionItr).zip(predLeafItr).
map { case ((originals: Row, prediction: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
}
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(predictionItr).zip(predContribItr).
map { case ((originals: Row, prediction: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
}
} else {
originalRowItr.zip(predictionItr).map {
case (originals: Row, originalPrediction: Row) =>
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
}
}
}
private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
Array[Iterator[Row]] = {
val originalPredictionItr = {
booster.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
}
val predLeafItr = {
if (isDefined(leafPredictionCol)) {
booster.predictLeaf(dm, $(treeLimit)).
map(Row(_)).iterator
} else {
Iterator()
}
}
val predContribItr = {
if (isDefined(contribPredictionCol)) {
booster.predictContrib(dm, $(treeLimit)).
map(Row(_)).iterator
} else {
Iterator()
}
}
Array(originalPredictionItr, predLeafItr, predContribItr)
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
}
override def transformSchema(schema: StructType): StructType = {
PreXGBoost.transformSchema(this, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = PreXGBoost.transformDataset(this, dataset)
var numColsOutput = 0
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
originalPrediction(0).toDouble
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData.toDF.drop(col(_originalPredictionCol))
}
override def copy(extra: ParamMap): XGBoostRegressionModel = {
val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
newModel.setSummary(summary).setParent(parent)
val newModel = copyValues(new XGBoostRegressionModel(uid, nativeBooster, summary), extra)
newModel.setParent(parent)
}
override def write: MLWriter =
new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
override def predict(features: Vector): Double = {
val values = predictSingleInstance(features)
values(0)
}
}
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
override def read: MLReader[XGBoostRegressionModel] = new ModelReader
private[scala] val _originalPredictionCol = "_originalPrediction"
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
override def load(path: String): XGBoostRegressionModel = super.load(path)
private[XGBoostRegressionModel]
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}
private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
/** Checked against metadata when loading model */
private val className = classOf[XGBoostRegressionModel].getName
private class ModelReader extends XGBoostModelReader[XGBoostRegressionModel] {
override def load(path: String): XGBoostRegressionModel = {
implicit val sc = super.sparkSession.sparkContext
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostRegressionModel(metadata.uid, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
val xgbModel = loadBooster(path)
val meta = SparkUtils.loadMetadata(path, sc)
val model = new XGBoostRegressionModel(meta.uid, xgbModel, None)
meta.getAndSetParams(model)
model
}
}

View File

@ -22,17 +22,17 @@ class XGBoostTrainingSummary private(
override def toString: String = {
val train = trainObjectiveHistory.mkString(",")
val vaidationObjectiveHistoryString = {
val validationObjectiveHistoryString = {
validationObjectiveHistory.map {
case (name, metrics) =>
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
}.mkString(";")
}
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $vaidationObjectiveHistoryString)"
s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $validationObjectiveHistoryString)"
}
}
private[xgboost4j] object XGBoostTrainingSummary {
private[spark] object XGBoostTrainingSummary {
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
new XGBoostTrainingSummary(
trainObjectiveHistory = metrics("train"),

View File

@ -1,295 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.immutable.HashSet
import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
private[spark] trait BoosterParams extends Params {
/**
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
* can directly get the weights of new features and eta actually shrinks the feature weights
* to make the boosting process more conservative. [default=0.3] range: [0,1]
*/
final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
" overfitting. After each boosting step, we can directly get the weights of new features." +
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
(value: Double) => value >= 0 && value <= 1)
final def getEta: Double = $(eta)
/**
* minimum loss reduction required to make a further partition on a leaf node of the tree.
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
* Double.MaxValue]
*/
final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
"further partition on a leaf node of the tree. the larger, the more conservative the " +
"algorithm will be.", (value: Double) => value >= 0)
final def getGamma: Double = $(gamma)
/**
* maximum depth of a tree, increase this value will make model more complex / likely to be
* overfitting. [default=6] range: [1, Int.MaxValue]
*/
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
final def getMaxDepth: Int = $(maxDepth)
/**
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
*/
final val maxLeaves = new IntParam(this, "maxLeaves",
"Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
(value: Int) => value >= 0)
final def getMaxLeaves: Int = $(maxLeaves)
/**
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
* in a leaf node with the sum of instance weight less than min_child_weight, then the building
* process will give up further partitioning. In linear regression mode, this simply corresponds
* to minimum number of instances needed to be in each node. The larger, the more conservative
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
*/
final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
" the sum of instance weight less than min_child_weight, then the building process will" +
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
" number of instances needed to be in each node. The larger, the more conservative" +
" the algorithm will be.", (value: Double) => value >= 0)
final def getMinChildWeight: Double = $(minChildWeight)
/**
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
* means there is no constraint. If it is set to a positive value, it can help making the update
* step more conservative. Usually this parameter is not needed, but it might help in logistic
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
* update. [default=0] range: [0, Double.MaxValue]
*/
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
"each tree's weight" +
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
" to a positive value, it can help making the update step more conservative. Usually this" +
" parameter is not needed, but it might help in logistic regression when class is extremely" +
" imbalanced. Set it to value of 1-10 might help control the update",
(value: Double) => value >= 0)
final def getMaxDeltaStep: Double = $(maxDeltaStep)
/**
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
* collected half of the data instances to grow trees and this will prevent overfitting.
* [default=1] range:(0,1]
*/
final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
"instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
"instances to grow trees and this will prevent overfitting.",
(value: Double) => value <= 1 && value > 0)
final def getSubsample: Double = $(subsample)
/**
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
*/
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
"columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
final def getColsampleBytree: Double = $(colsampleBytree)
/**
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
*/
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
"columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
final def getColsampleBylevel: Double = $(colsampleBylevel)
/**
* L2 regularization term on weights, increase this value will make model more conservative.
* [default=1]
*/
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
"increase this value will make model more conservative.", (value: Double) => value >= 0)
final def getLambda: Double = $(lambda)
/**
* L1 regularization term on weights, increase this value will make model more conservative.
* [default=0]
*/
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
"this value will make model more conservative.", (value: Double) => value >= 0)
final def getAlpha: Double = $(alpha)
/**
* The tree construction algorithm used in XGBoost. options:
* {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
*/
final val treeMethod = new Param[String](this, "treeMethod",
"The tree construction algorithm used in XGBoost, options: " +
"{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
/**
* The device for running XGBoost algorithms, options: cpu, cuda
*/
final val device = new Param[String](
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
(value: String) => BoosterParams.supportedDevices.contains(value)
)
final def getDevice: String = $(device)
/**
* growth policy for fast histogram algorithm
*/
final val growPolicy = new Param[String](this, "growPolicy",
"Controls a way new nodes are added to the tree. Currently supported only if" +
" tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
" closest to the root. lossguide: split at nodes with highest loss change.",
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
final def getGrowPolicy: String = $(growPolicy)
/**
* maximum number of bins in histogram
*/
final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
(value: Int) => value > 0)
final def getMaxBins: Int = $(maxBins)
/**
* whether to build histograms using single precision floating point values
*/
final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram",
"whether to use single precision to build histograms")
final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram)
/**
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
* value to consider: sum(negative cases) / sum(positive cases). [default=1]
*/
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
"positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
" sum(negative cases) / sum(positive cases)")
final def getScalePosWeight: Double = $(scalePosWeight)
// Dart boosters
/**
* Parameter for Dart booster.
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
*/
final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
"options: {'uniform', 'weighted'}",
(value: String) => BoosterParams.supportedSampleType.contains(value))
final def getSampleType: String = $(sampleType)
/**
* Parameter of Dart booster.
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
*/
final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
" algorithm, options: {'tree', 'forest'}",
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
final def getNormalizeType: String = $(normalizeType)
/**
* Parameter of Dart booster.
* dropout rate. [default=0.0] range: [0.0, 1.0]
*/
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
value >= 0 && value <= 1)
final def getRateDrop: Double = $(rateDrop)
/**
* Parameter of Dart booster.
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
* as gbtree. [default=0.0] range: [0.0, 1.0]
*/
final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
" a dropout is skipped, new trees are added in the same manner as gbtree.",
(value: Double) => value >= 0 && value <= 1)
final def getSkipDrop: Double = $(skipDrop)
// linear booster
/**
* Parameter of linear booster
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
*/
final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
"default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
final def getLambdaBias: Double = $(lambdaBias)
final val treeLimit = new IntParam(this, name = "treeLimit",
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
setDefault(treeLimit, 0)
final def getTreeLimit: Int = $(treeLimit)
final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
"decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
"padded ")
final def getMonotoneConstraints: String = $(monotoneConstraints)
final val interactionConstraints = new Param[String](this,
name = "interactionConstraints",
doc = "Constraints for interaction representing permitted interactions. The constraints" +
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
" where each inner list is a group of indices of features that are allowed to interact" +
" with each other. See tutorial for more information")
final def getInteractionConstraints: String = $(interactionConstraints)
}
private[scala] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
val supportedSampleType = HashSet("uniform", "weighted")
val supportedNormalizeType = HashSet("tree", "forest")
val supportedDevices = HashSet("cpu", "cuda")
}

View File

@ -16,20 +16,18 @@
package ml.dmlc.xgboost4j.scala.spark.params
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.{DefaultFormats, Extraction}
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.json4s.jackson.Serialization
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.Utils
/**
* General spark parameter that includes TypeHints for (de)serialization using json4s.
*/
class CustomGeneralParam[T: Manifest](
parent: Params,
class CustomGeneralParam[T: Manifest](parent: Params,
name: String,
doc: String) extends Param[T](parent, name, doc) {
@ -52,33 +50,10 @@ class CustomGeneralParam[T: Manifest](
}
}
class CustomEvalParam(
parent: Params,
class CustomEvalParam(parent: Params,
name: String,
doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
class CustomObjParam(
parent: Params,
class CustomObjParam(parent: Params,
name: String,
doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
class TrackerConfParam(
parent: Params,
name: String,
doc: String) extends Param[TrackerConf](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
override def jsonEncode(value: TrackerConf): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}
override def jsonDecode(json: String): TrackerConf = {
implicit val formats = DefaultFormats
val parsedValue = parse(json)
parsedValue.extract[TrackerConf]
}
}

View File

@ -0,0 +1,61 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
/**
* Dart booster parameters, more details can be found at
* https://xgboost.readthedocs.io/en/stable/parameter.html#
* additional-parameters-for-dart-booster-booster-dart
*/
private[spark] trait DartBoosterParams extends Params {
final val sampleType = new Param[String](this, "sample_type", "Type of sampling algorithm, " +
"options: {'uniform', 'weighted'}", ParamValidators.inArray(Array("uniform", "weighted")))
final def getSampleType: String = $(sampleType)
final val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
" algorithm, options: {'tree', 'forest'}",
ParamValidators.inArray(Array("tree", "forest")))
final def getNormalizeType: String = $(normalizeType)
final val rateDrop = new DoubleParam(this, "rate_drop", "Dropout rate (a fraction of previous " +
"trees to drop during the dropout)",
ParamValidators.inRange(0, 1, true, true))
final def getRateDrop: Double = $(rateDrop)
final val oneDrop = new BooleanParam(this, "one_drop", "When this flag is enabled, at least " +
"one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout " +
"from the original DART paper)")
final def getOneDrop: Boolean = $(oneDrop)
final val skipDrop = new DoubleParam(this, "skip_drop", "Probability of skipping the dropout " +
"procedure during a boosting iteration.\nIf a dropout is skipped, new trees are added " +
"in the same manner as gbtree.\nNote that non-zero skip_drop has higher priority than " +
"rate_drop or one_drop.",
ParamValidators.inRange(0, 1, true, true))
final def getSkipDrop: Double = $(skipDrop)
setDefault(sampleType -> "uniform", normalizeType -> "tree", rateDrop -> 0, skipDrop -> 0)
}

View File

@ -16,303 +16,45 @@
package ml.dmlc.xgboost4j.scala.spark.params
import com.google.common.base.CaseFormat
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.apache.spark.ml.param._
import scala.collection.mutable
/**
* General xgboost parameters, more details can be found
* at https://xgboost.readthedocs.io/en/stable/parameter.html#general-parameters
*/
private[spark] trait GeneralParams extends Params {
/**
* The number of rounds for boosting
*/
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
ParamValidators.gtEq(1))
setDefault(numRound, 1)
final val booster = new Param[String](this, "booster", "Which booster to use. Can be gbtree, " +
"gblinear or dart; gbtree and dart use tree based models while gblinear uses linear " +
"functions.", ParamValidators.inArray(Array("gbtree", "dart")))
final def getNumRound: Int = $(numRound)
final def getBooster: String = $(booster)
/**
* number of workers used to train xgboost model. default: 1
*/
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
ParamValidators.gtEq(1))
setDefault(numWorkers, 1)
final val device = new Param[String](this, "device", "Device for XGBoost to run. User can " +
"set it to one of the following values: {cpu, cuda, gpu}",
ParamValidators.inArray(Array("cpu", "cuda", "gpu")))
final def getNumWorkers: Int = $(numWorkers)
final def getDevice: String = $(device)
/**
* number of threads used by per worker. default 1
*/
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
ParamValidators.gtEq(1))
setDefault(nthread, 1)
final def getNthread: Int = $(nthread)
/**
* whether to use external memory as cache. default: false
*/
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
"whether to use external memory as cache")
setDefault(useExternalMemory, false)
final def getUseExternalMemory: Boolean = $(useExternalMemory)
/**
* Deprecated. Please use verbosity instead.
* 0 means printing running messages, 1 means silent mode. default: 0
*/
final val silent = new IntParam(this, "silent",
"Deprecated. Please use verbosity instead. " +
"0 means printing running messages, 1 means silent mode.",
(value: Int) => value >= 0 && value <= 1)
final def getSilent: Int = $(silent)
/**
* Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), 3 (debug).
* default: 1
*/
final val verbosity = new IntParam(this, "verbosity",
"Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), " +
"3 (debug).",
(value: Int) => value >= 0 && value <= 3)
final val verbosity = new IntParam(this, "verbosity", "Verbosity of printing messages. Valid " +
"values are 0 (silent), 1 (warning), 2 (info), 3 (debug). Sometimes XGBoost tries to change " +
"configurations based on heuristics, which is displayed as warning message. If there's " +
"unexpected behaviour, please try to increase value of verbosity.",
ParamValidators.inRange(0, 3, true, true))
final def getVerbosity: Int = $(verbosity)
/**
* customized objective function provided by user. default: null
*/
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
"provided by user")
final val validateParameters = new BooleanParam(this, "validate_parameters", "When set to " +
"True, XGBoost will perform validation of input parameters to check whether a parameter " +
"is used or not. A warning is emitted when there's unknown parameter.")
/**
* customized evaluation function provided by user. default: null
*/
final val customEval = new CustomEvalParam(this, "customEval",
"customized evaluation function provided by user")
final def getValidateParameters: Boolean = $(validateParameters)
/**
* the value treated as missing. default: Float.NaN
*/
final val missing = new FloatParam(this, "missing", "the value treated as missing")
setDefault(missing, Float.NaN)
final val nthread = new IntParam(this, "nthread", "Number of threads used by per worker",
ParamValidators.gtEq(1))
final def getMissing: Float = $(missing)
final def getNthread: Int = $(nthread)
/**
* Allows for having a non-zero value for missing when training on prediction
* on a Sparse or Empty vector.
*/
final val allowNonZeroForMissing = new BooleanParam(
this,
"allowNonZeroForMissing",
"Allow to have a non-zero value for missing when training or " +
"predicting on a Sparse or Empty vector. Should only be used if did " +
"not use Spark's VectorAssembler class to construct the feature vector " +
"but instead used a method that preserves zeros in your vector."
)
setDefault(allowNonZeroForMissing, false)
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
/**
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/
final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
"and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
"load the checkpoint with highest version as the starting point for training. If " +
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
final def getCheckpointPath: String = $(checkpointPath)
/**
* Param for set checkpoint interval (&gt;= 1) or disable checkpoint (-1). E.g. 10 means that
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
* also be set if the checkpoint interval is greater than 0.
*/
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
"set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
"model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
"set if the checkpoint interval is greater than 0.",
(interval: Int) => interval == -1 || interval >= 1)
final def getCheckpointInterval: Int = $(checkpointInterval)
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
*
* See below for detailed explanations.
*
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
* default: 0 (no timeout)
*
* Timeout for constructing the communication group and waiting for the tracker to
* shutdown when it's instructed to, doesn't apply to communication when tracking
* is running.
* The timeout value should take the time of data loading and pre-processing into account,
* due to potential lazy execution. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
*
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
* cannot be automatically guessed.
*
* - port : The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
setDefault(trackerConf, TrackerConf())
/** Random seed for the C++ part of XGBoost and train/test splitting. */
final val seed = new LongParam(this, "seed", "random seed")
setDefault(seed, 0L)
final def getSeed: Long = $(seed)
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
* In native code, the parameter name is feature_name.
* */
final val featureNames = new StringArrayParam(this, "feature_names",
"an array of feature names")
final def getFeatureNames: Array[String] = $(featureNames)
/** Feature types, q is numeric and c is categorical.
* In native code, the parameter name is feature_type
* */
final val featureTypes = new StringArrayParam(this, "feature_types",
"an array of feature types")
final def getFeatureTypes: Array[String] = $(featureTypes)
}
trait HasLeafPredictionCol extends Params {
/**
* Param for leaf prediction column name.
* @group param
*/
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
"name of the predictLeaf results")
/** @group getParam */
final def getLeafPredictionCol: String = $(leafPredictionCol)
}
trait HasContribPredictionCol extends Params {
/**
* Param for contribution prediction column name.
* @group param
*/
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
"name of the predictContrib results")
/** @group getParam */
final def getContribPredictionCol: String = $(contribPredictionCol)
}
trait HasBaseMarginCol extends Params {
/**
* Param for initial prediction (aka base margin) column name.
* @group param
*/
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
"Initial prediction (aka base margin) column name.")
/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)
}
trait HasGroupCol extends Params {
/**
* Param for group column name.
* @group param
*/
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
/** @group getParam */
final def getGroupCol: String = $(groupCol)
}
trait HasNumClass extends Params {
/**
* number of classes
*/
final val numClass = new IntParam(this, "numClass", "number of classes")
/** @group getParam */
final def getNumClass: Int = $(numClass)
}
/**
* Trait for shared param featuresCols.
*/
trait HasFeaturesCols extends Params {
/**
* Param for the names of feature columns.
* @group param
*/
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
"an array of feature column names.")
/** @group getParam */
final def getFeaturesCols: Array[String] = $(featuresCols)
/** Check if featuresCols is valid */
def isFeaturesColsValid: Boolean = {
isDefined(featuresCols) && $(featuresCols) != Array.empty
}
}
private[spark] trait ParamMapFuncs extends Params {
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
}
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
params.find(_.name == name).foreach {
case _: DoubleParam =>
set(name, paramValue.toString.toDouble)
case _: BooleanParam =>
set(name, paramValue.toString.toBoolean)
case _: IntParam =>
set(name, paramValue.toString.toInt)
case _: FloatParam =>
set(name, paramValue.toString.toFloat)
case _: LongParam =>
set(name, paramValue.toString.toLong)
case _: Param[_] =>
set(name, paramValue)
}
}
}
def MLlib2XGBoostParams: Map[String, Any] = {
val xgboostParams = new mutable.HashMap[String, Any]()
for (param <- params) {
if (isDefined(param)) {
val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
xgboostParams += name -> $(param)
}
}
xgboostParams.toMap
}
setDefault(booster -> "gbtree", device -> "cpu", verbosity -> 1, validateParameters -> false,
nthread -> 1)
}

View File

@ -1,32 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.{IntParam, Params}
private[spark] trait InferenceParams extends Params {
/**
* batch size of inference iteration
*/
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
/** @group getParam */
final def getInferBatchSize: Int = $(inferBatchSize)
setDefault(inferBatchSize, 32 << 10)
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -20,98 +20,124 @@ import scala.collection.immutable.HashSet
import org.apache.spark.ml.param._
/**
* Specify the learning task and the corresponding learning objective.
* More details can be found at
* https://xgboost.readthedocs.io/en/stable/parameter.html#learning-task-parameters
*/
private[spark] trait LearningTaskParams extends Params {
/**
* Specify the learning task and the corresponding learning objective.
* options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw,
* count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma.
* default: reg:squarederror
*/
final val objective = new Param[String](this, "objective",
"objective function used for training")
"Objective function used for training",
ParamValidators.inArray(LearningTaskParams.SUPPORTED_OBJECTIVES.toArray))
final def getObjective: String = $(objective)
/**
* The learning objective type of the specified custom objective and eval.
* Corresponding type will be assigned if custom objective is defined
* options: regression, classification. default: null
*/
final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
(value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
final val numClass = new IntParam(this, "num_class", "Number of classes, used by " +
"multi:softmax and multi:softprob objectives", ParamValidators.gtEq(0))
final def getObjectiveType: String = $(objectiveType)
final def getNumClass: Int = $(numClass)
/**
* the initial prediction score of all instances, global bias. default=0.5
*/
final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
" instances, global bias")
final val baseScore = new DoubleParam(this, "base_score", "The initial prediction score of " +
"all instances, global bias. The parameter is automatically estimated for selected " +
"objectives before training. To disable the estimation, specify a real number argument. " +
"For sufficient number of iterations, changing this value will not have too much effect.")
final def getBaseScore: Double = $(baseScore)
/**
* evaluation metrics for validation data, a default metric will be assigned according to
* objective(rmse for regression, and error for classification, mean average precision for
* ranking). options: rmse, rmsle, mae, mape, logloss, error, merror, mlogloss, auc, aucpr, ndcg,
* map, gamma-deviance
*/
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
"validation data, a default metric will be assigned according to objective " +
"(rmse for regression, and error for classification, mean average precision for ranking)")
final val evalMetric = new Param[String](this, "eval_metric", "Evaluation metrics for " +
"validation data, a default metric will be assigned according to objective (rmse for " +
"regression, and logloss for classification, mean average precision for rank:map, etc.)" +
"User can add multiple evaluation metrics. Python users: remember to pass the metrics in " +
"as list of parameters pairs instead of map, so that latter eval_metric won't override " +
"previous ones", ParamValidators.inArray(LearningTaskParams.SUPPORTED_EVAL_METRICS.toArray))
final def getEvalMetric: String = $(evalMetric)
/**
* Fraction of training points to use for testing.
*/
@Deprecated
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
"fraction of training points to use for testing",
ParamValidators.inRange(0, 1))
setDefault(trainTestRatio, 1.0)
final val seed = new LongParam(this, "seed", "Random number seed.")
@Deprecated
final def getTrainTestRatio: Double = $(trainTestRatio)
final def getSeed: Long = $(seed)
/**
* whether caching training data
*/
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
"whether caching training data")
final val seedPerIteration = new BooleanParam(this, "seed_per_iteration", "Seed PRNG " +
"determnisticly via iterator number..")
/**
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
* testing
*/
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
"whether cleaning checkpoint data")
final def getSeedPerIteration: Boolean = $(seedPerIteration)
/**
* If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric.
*/
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
"number of rounds of decreasing eval metric to tolerate before " +
"stopping the training",
(value: Int) => value == 0 || value > 1)
// Parameters for Tweedie Regression (objective=reg:tweedie)
final val tweedieVariancePower = new DoubleParam(this, "tweedie_variance_power", "Parameter " +
"that controls the variance of the Tweedie distribution var(y) ~ E(y)^tweedie_variance_power.",
ParamValidators.inRange(1, 2, false, false))
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
final def getTweedieVariancePower: Double = $(tweedieVariancePower)
// Parameter for using Pseudo-Huber (reg:pseudohubererror)
final val huberSlope = new DoubleParam(this, "huber_slope", "A parameter used for Pseudo-Huber " +
"loss to define the (delta) term.")
final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
"define the expected optimization to the evaluation metrics, true to maximize otherwise" +
" minimize it")
final def getHuberSlope: Double = $(huberSlope)
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
// Parameter for using Quantile Loss (reg:quantileerror) TODO
// Parameter for using AFT Survival Loss (survival:aft) and Negative
// Log Likelihood of AFT metric (aft-nloglik)
final val aftLossDistribution = new Param[String](this, "aft_loss_distribution", "Probability " +
"Density Function",
ParamValidators.inArray(Array("normal", "logistic", "extreme")))
final def getAftLossDistribution: String = $(aftLossDistribution)
// Parameters for learning to rank (rank:ndcg, rank:map, rank:pairwise)
final val lambdarankPairMethod = new Param[String](this, "lambdarank_pair_method", "pairs for " +
"pair-wise learning",
ParamValidators.inArray(Array("mean", "topk")))
final def getLambdarankPairMethod: String = $(lambdarankPairMethod)
final val lambdarankNumPairPerSample = new IntParam(this, "lambdarank_num_pair_per_sample",
"It specifies the number of pairs sampled for each document when pair method is mean, or" +
" the truncation level for queries when the pair method is topk. For example, to train " +
"with ndcg@6, set lambdarank_num_pair_per_sample to 6 and lambdarank_pair_method to topk",
ParamValidators.gtEq(1))
final def getLambdarankNumPairPerSample: Int = $(lambdarankNumPairPerSample)
final val lambdarankUnbiased = new BooleanParam(this, "lambdarank_unbiased", "Specify " +
"whether do we need to debias input click data.")
final def getLambdarankUnbiased: Boolean = $(lambdarankUnbiased)
final val lambdarankBiasNorm = new DoubleParam(this, "lambdarank_bias_norm", "Lp " +
"normalization for position debiasing, default is L2. Only relevant when " +
"lambdarankUnbiased is set to true.")
final def getLambdarankBiasNorm: Double = $(lambdarankBiasNorm)
final val ndcgExpGain = new BooleanParam(this, "ndcg_exp_gain", "Whether we should " +
"use exponential gain function for NDCG.")
final def getNdcgExpGain: Boolean = $(ndcgExpGain)
setDefault(objective -> "reg:squarederror", numClass -> 0, seed -> 0, seedPerIteration -> false,
tweedieVariancePower -> 1.5, huberSlope -> 1, lambdarankPairMethod -> "mean",
lambdarankUnbiased -> false, lambdarankBiasNorm -> 2, ndcgExpGain -> true)
}
private[spark] object LearningTaskParams {
val SUPPORTED_OBJECTIVES = HashSet("reg:squarederror", "reg:squaredlogerror", "reg:logistic",
"reg:pseudohubererror", "reg:absoluteerror", "reg:quantileerror", "binary:logistic",
"binary:logitraw", "binary:hinge", "count:poisson", "survival:cox", "survival:aft",
"multi:softmax", "multi:softprob", "rank:ndcg", "rank:map", "rank:pairwise", "reg:gamma",
"reg:tweedie")
val supportedObjectiveType = HashSet("regression", "classification")
val BINARY_CLASSIFICATION_OBJS = HashSet("binary:logistic", "binary:hinge", "binary:logitraw")
val MULTICLASSIFICATION_OBJS = HashSet("multi:softmax", "multi:softprob")
val RANKER_OBJS = HashSet("rank:ndcg", "rank:map", "rank:pairwise")
val REGRESSION_OBJS = SUPPORTED_OBJECTIVES -- BINARY_CLASSIFICATION_OBJS --
MULTICLASSIFICATION_OBJS -- RANKER_OBJS
val SUPPORTED_EVAL_METRICS = HashSet("rmse", "rmsle", "mae", "mape", "mphe", "logloss", "error",
"error@t", "merror", "mlogloss", "auc", "aucpr", "pre", "ndcg", "map", "ndcg@n", "map@n",
"pre@n", "ndcg-", "map-", "ndcg@n-", "map@n-", "poisson-nloglik", "gamma-nloglik",
"cox-nloglik", "gamma-deviance", "tweedie-nloglik", "aft-nloglik",
"interval-regression-accuracy")
}

View File

@ -1,36 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.sql.DataFrame
trait NonParamVariables {
protected var evalSetsMap: Map[String, DataFrame] = Map.empty
def setEvalSets(evalSets: Map[String, DataFrame]): this.type = {
evalSetsMap = evalSets
this
}
def getEvalSets(params: Map[String, Any]): Map[String, DataFrame] = {
if (params.contains("eval_sets")) {
params("eval_sets").asInstanceOf[Map[String, DataFrame]]
} else {
evalSetsMap
}
}
}

View File

@ -0,0 +1,65 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.mutable
import org.apache.spark.ml.param._
private[spark] trait ParamMapConversion extends NonXGBoostParams {
/**
* Convert XGBoost parameters to Spark Parameters
*
* @param xgboostParams XGBoost style parameters
*/
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
for ((name, paramValue) <- xgboostParams) {
params.find(_.name == name).foreach {
case _: DoubleParam =>
set(name, paramValue.toString.toDouble)
case _: BooleanParam =>
set(name, paramValue.toString.toBoolean)
case _: IntParam =>
set(name, paramValue.toString.toInt)
case _: FloatParam =>
set(name, paramValue.toString.toFloat)
case _: LongParam =>
set(name, paramValue.toString.toLong)
case _: Param[_] =>
set(name, paramValue)
}
}
}
/**
* Convert the user-supplied parameters to the XGBoost parameters.
*
* Note that this also contains jvm-specific parameters.
*/
def getXGBoostParams: Map[String, Any] = {
val xgboostParams = new mutable.HashMap[String, Any]()
// Only pass user-supplied parameters to xgboost.
for (param <- params) {
if (isSet(param) && !nonXGBoostParams.contains(param.name)) {
xgboostParams += param.name -> $(param)
}
}
xgboostParams.toMap
}
}

View File

@ -18,25 +18,27 @@ package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
private[spark] trait RabitParams extends Params {
/**
* Rabit parameters passed through Rabit.Init into native layer
* rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation
* rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable
* dmlc_worker_connect_retry - number of retrys to tracker
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
*/
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
"threshold count to enable allreduce/broadcast with ring based topology",
ParamValidators.gtEq(1))
setDefault(rabitRingReduceThreshold, (32 << 10))
private[spark] trait RabitParams extends Params with NonXGBoostParams {
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
"timeout threshold after rabit observed failures")
setDefault(rabitTimeout, -1)
final val rabitTrackerTimeout = new IntParam(this, "rabitTrackerTimeout", "The number of " +
"seconds before timeout waiting for workers to connect. and for the tracker to shutdown.",
ParamValidators.gtEq(0))
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
"number of retry worker do before fail", ParamValidators.gtEq(1))
setDefault(rabitConnectRetry, 5)
final def getRabitTrackerTimeout: Int = $(rabitTrackerTimeout)
final val rabitTrackerHostIp = new Param[String](this, "rabitTrackerHostIp", "The Rabit " +
"Tracker host IP address. This is only needed if the host IP cannot be automatically " +
"guessed.")
final def getRabitTrackerHostIp: String = $(rabitTrackerHostIp)
final val rabitTrackerPort = new IntParam(this, "rabitTrackerPort", "The port number for the " +
"tracker to listen to. Use a system allocated one by default.",
ParamValidators.gtEq(0))
final def getRabitTrackerPort: Int = $(rabitTrackerPort)
setDefault(rabitTrackerTimeout -> 0, rabitTrackerHostIp -> "", rabitTrackerPort -> 0)
addNonXGBoostParam(rabitTrackerPort, rabitTrackerHostIp, rabitTrackerPort)
}

View File

@ -0,0 +1,238 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.immutable.HashSet
import org.apache.spark.ml.param._
/**
* TreeBoosterParams defines the XGBoost TreeBooster parameters for Spark
*
* The details can be found at
* https://xgboost.readthedocs.io/en/stable/parameter.html#parameters-for-tree-booster
*/
private[spark] trait TreeBoosterParams extends Params {
final val eta = new DoubleParam(this, "eta", "Step size shrinkage used in update to prevents " +
"overfitting. After each boosting step, we can directly get the weights of new features, " +
"and eta shrinks the feature weights to make the boosting process more conservative.",
ParamValidators.inRange(0, 1, lowerInclusive = true, upperInclusive = true))
final def getEta: Double = $(eta)
final val gamma = new DoubleParam(this, "gamma", "Minimum loss reduction required to make a " +
"further partition on a leaf node of the tree. The larger gamma is, the more conservative " +
"the algorithm will be.",
ParamValidators.gtEq(0))
final def getGamma: Double = $(gamma)
final val maxDepth = new IntParam(this, "max_depth", "Maximum depth of a tree. Increasing this " +
"value will make the model more complex and more likely to overfit. 0 indicates no limit " +
"on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. " +
"exact tree method requires non-zero value.",
ParamValidators.gtEq(0))
final def getMaxDepth: Int = $(maxDepth)
final val minChildWeight = new DoubleParam(this, "min_child_weight", "Minimum sum of instance " +
"weight (hessian) needed in a child. If the tree partition step results in a leaf node " +
"with the sum of instance weight less than min_child_weight, then the building process " +
"will give up further partitioning. In linear regression task, this simply corresponds " +
"to minimum number of instances needed to be in each node. The larger min_child_weight " +
"is, the more conservative the algorithm will be.",
ParamValidators.gtEq(0))
final def getMinChildWeight: Double = $(minChildWeight)
final val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow " +
"each leaf output to be. If the value is set to 0, it means there is no constraint. If it " +
"is set to a positive value, it can help making the update step more conservative. Usually " +
"this parameter is not needed, but it might help in logistic regression when class is " +
"extremely imbalanced. Set it to value of 1-10 might help control the update.",
ParamValidators.gtEq(0))
final def getMaxDeltaStep: Double = $(maxDeltaStep)
final val subsample = new DoubleParam(this, "subsample", "Subsample ratio of the training " +
"instances. Setting it to 0.5 means that XGBoost would randomly sample half of the " +
"training data prior to growing trees. and this will prevent overfitting. Subsampling " +
"will occur once in every boosting iteration.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
final def getSubsample: Double = $(subsample)
final val samplingMethod = new Param[String](this, "sampling_method", "The method to use to " +
"sample the training instances. The supported sampling methods" +
"uniform: each training instance has an equal probability of being selected. Typically set " +
"subsample >= 0.5 for good results.\n" +
"gradient_based: the selection probability for each training instance is proportional to " +
"the regularized absolute value of gradients. subsample may be set to as low as 0.1 " +
"without loss of model accuracy. Note that this sampling method is only supported when " +
"tree_method is set to hist and the device is cuda; other tree methods only support " +
"uniform sampling.",
ParamValidators.inArray(Array("uniform", "gradient_based")))
final def getSamplingMethod: String = $(samplingMethod)
final val colsampleBytree = new DoubleParam(this, "colsample_bytree", "Subsample ratio of " +
"columns when constructing each tree. Subsampling occurs once for every tree constructed.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
final def getColsampleBytree: Double = $(colsampleBytree)
final val colsampleBylevel = new DoubleParam(this, "colsample_bylevel", "Subsample ratio of " +
"columns for each level. Subsampling occurs once for every new depth level reached in a " +
"tree. Columns are subsampled from the set of columns chosen for the current tree.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
final def getColsampleBylevel: Double = $(colsampleBylevel)
final val colsampleBynode = new DoubleParam(this, "colsample_bynode", "Subsample ratio of " +
"columns for each node (split). Subsampling occurs once every time a new split is " +
"evaluated. Columns are subsampled from the set of columns chosen for the current level.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
final def getColsampleBynode: Double = $(colsampleBynode)
/**
* L2 regularization term on weights, increase this value will make model more conservative.
* [default=1]
*/
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights. " +
"Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
final def getLambda: Double = $(lambda)
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights. " +
"Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
final def getAlpha: Double = $(alpha)
final val treeMethod = new Param[String](this, "tree_method", "The tree construction " +
"algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
ParamValidators.inArray(BoosterParams.supportedTreeMethods.toArray))
final def getTreeMethod: String = $(treeMethod)
final val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of " +
"positive and negative weights, useful for unbalanced classes. A typical value to consider: " +
"sum(negative instances) / sum(positive instances)")
final def getScalePosWeight: Double = $(scalePosWeight)
final val updater = new Param[String](this, "updater", "A comma separated string defining the " +
"sequence of tree updaters to run, providing a modular way to construct and to modify the " +
"trees. This is an advanced parameter that is usually set automatically, depending on some " +
"other parameters. However, it could be also set explicitly by a user. " +
"The following updaters exist:\n" +
"grow_colmaker: non-distributed column-based construction of trees.\n" +
"grow_histmaker: distributed tree construction with row-based data splitting based on " +
"global proposal of histogram counting.\n" +
"grow_quantile_histmaker: Grow tree using quantized histogram.\n" +
"grow_gpu_hist: Enabled when tree_method is set to hist along with device=cuda.\n" +
"grow_gpu_approx: Enabled when tree_method is set to approx along with device=cuda.\n" +
"sync: synchronizes trees in all distributed nodes.\n" +
"refresh: refreshes tree's statistics and or leaf values based on the current data. Note " +
"that no random subsampling of data rows is performed.\n" +
"prune: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth " +
"greater than max_depth.",
(value: String) => value.split(",").forall(
ParamValidators.inArray(BoosterParams.supportedUpdaters.toArray)))
final def getUpdater: String = $(updater)
final val refreshLeaf = new BooleanParam(this, "refresh_leaf", "This is a parameter of the " +
"refresh updater. When this flag is 1, tree leafs as well as tree nodes' stats are updated. " +
"When it is 0, only node stats are updated.")
final def getRefreshLeaf: Boolean = $(refreshLeaf)
// TODO set updater/refreshLeaf defaul value
final val processType = new Param[String](this, "process_type", "A type of boosting process to " +
"run. options: {default, update}",
ParamValidators.inArray(Array("default", "update")))
final def getProcessType: String = $(processType)
final val growPolicy = new Param[String](this, "grow_policy", "Controls a way new nodes are " +
"added to the tree. Currently supported only if tree_method is set to hist or approx. " +
"Choices: depthwise, lossguide. depthwise: split at nodes closest to the root. " +
"lossguide: split at nodes with highest loss change.",
ParamValidators.inArray(Array("depthwise", "lossguide")))
final def getGrowPolicy: String = $(growPolicy)
final val maxLeaves = new IntParam(this, "max_leaves", "Maximum number of nodes to be added. " +
"Not used by exact tree method", ParamValidators.gtEq(0))
final def getMaxLeaves: Int = $(maxLeaves)
final val maxBins = new IntParam(this, "max_bin", "Maximum number of discrete bins to bucket " +
"continuous features. Increasing this number improves the optimality of splits at the cost " +
"of higher computation time. Only used if tree_method is set to hist or approx.",
ParamValidators.gt(0))
final def getMaxBins: Int = $(maxBins)
final val numParallelTree = new IntParam(this, "num_parallel_tree", "Number of parallel trees " +
"constructed during each iteration. This option is used to support boosted random forest.",
ParamValidators.gt(0))
final def getNumParallelTree: Int = $(numParallelTree)
final val monotoneConstraints = new IntArrayParam(this, "monotone_constraints", "Constraint of " +
"variable monotonicity.")
final def getMonotoneConstraints: Array[Int] = $(monotoneConstraints)
final val interactionConstraints = new Param[String](this,
name = "interaction_constraints",
doc = "Constraints for interaction representing permitted interactions. The constraints" +
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
" where each inner list is a group of indices of features that are allowed to interact" +
" with each other. See tutorial for more information")
final def getInteractionConstraints: String = $(interactionConstraints)
final val maxCachedHistNode = new IntParam(this, "max_cached_hist_node", "Maximum number of " +
"cached nodes for CPU histogram.",
ParamValidators.gt(0))
final def getMaxCachedHistNode: Int = $(maxCachedHistNode)
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0,
subsample -> 1, samplingMethod -> "uniform", colsampleBytree -> 1, colsampleBylevel -> 1,
colsampleBynode -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", scalePosWeight -> 1,
processType -> "default", growPolicy -> "depthwise", maxLeaves -> 0, maxBins -> 256,
numParallelTree -> 1, maxCachedHistNode -> 65536)
}
private[spark] object BoosterParams {
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
val supportedUpdaters = HashSet("grow_colmaker", "grow_histmaker", "grow_quantile_histmaker",
"grow_gpu_hist", "grow_gpu_approx", "sync", "refresh", "prune")
}

View File

@ -1,119 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
import org.apache.spark.ml.util.XGBoostSchemaUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
def needDeterministicRepartitioning: Boolean = {
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
isDefined(checkpointInterval) && getCheckpointInterval > 0
}
/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
* Default: "error"
* @group param
*/
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(Array("skip", "error", "keep")))
setDefault(handleInvalid, "error")
/**
* Specify an array of feature column names which must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value)
/** Set the handleInvalid for VectorAssembler */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/**
* Check if schema has a field named with the value of "featuresCol" param and it's data type
* must be VectorUDT
*/
def isFeaturesColSet(schema: StructType): Boolean = {
schema.fieldNames.contains(getFeaturesCol) &&
XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType)
}
/** check the features columns type */
def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = {
if (isFeaturesColsValid) {
if (fit) {
XGBoostSchemaUtils.checkNumericType(schema, $(labelCol))
}
$(featuresCols).foreach(feature =>
XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType))
schema
} else {
throw new IllegalArgumentException("featuresCol or featuresCols must be specified")
}
}
/**
* Vectorize the features columns if necessary.
*
* @param input the input dataset
* @return (output dataset and the feature column name)
*/
def vectorize(input: Dataset[_]): (Dataset[_], String) = {
val schema = input.schema
if (isFeaturesColSet(schema)) {
// Dataset already has vectorized.
(input, getFeaturesCol)
} else if (isFeaturesColsValid) {
val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) {
getFeaturesCol
} else {
"features_" + uid
}
val vectorAssembler = new VectorAssembler()
.setHandleInvalid($(handleInvalid))
.setInputCols(getFeaturesCols)
.setOutputCol(featuresName)
(vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName)
} else {
// never reach here, since transformSchema will take care of the case
// that featuresCols is invalid
(input, getFeaturesCol)
}
}
}
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
private[scala] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol

View File

@ -0,0 +1,359 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.types.StructType
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
trait HasLeafPredictionCol extends Params {
/**
* Param for leaf prediction column name.
*
* @group param
*/
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
"name of the predictLeaf results")
/** @group getParam */
final def getLeafPredictionCol: String = $(leafPredictionCol)
}
trait HasContribPredictionCol extends Params {
/**
* Param for contribution prediction column name.
*
* @group param
*/
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
"name of the predictContrib results")
/** @group getParam */
final def getContribPredictionCol: String = $(contribPredictionCol)
}
trait HasBaseMarginCol extends Params {
/**
* Param for initial prediction (aka base margin) column name.
*
* @group param
*/
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
"Initial prediction (aka base margin) column name.")
/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)
}
trait HasGroupCol extends Params {
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
/** @group getParam */
final def getGroupCol: String = $(groupCol)
}
/**
* Trait for shared param featuresCols.
*/
trait HasFeaturesCols extends Params {
/**
* Param for the names of feature columns.
*
* @group param
*/
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
"An array of feature column names.")
/** @group getParam */
final def getFeaturesCols: Array[String] = $(featuresCols)
/** Check if featuresCols is valid */
def isFeaturesColsValid: Boolean = {
isDefined(featuresCols) && $(featuresCols) != Array.empty
}
}
/**
* A trait to hold non-xgboost parameters
*/
trait NonXGBoostParams extends Params {
private val paramNames: ArrayBuffer[String] = ArrayBuffer.empty
protected def addNonXGBoostParam(ps: Param[_]*): Unit = {
ps.foreach(p => paramNames.append(p.name))
}
protected lazy val nonXGBoostParams: Array[String] = paramNames.toSet.toArray
}
/**
* XGBoost spark-specific parameters which should not be passed
* into the xgboost library
*
* @tparam T should be the XGBoost estimators or models
*/
private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFeaturesCol
with HasLabelCol with HasBaseMarginCol with HasWeightCol with HasPredictionCol
with HasLeafPredictionCol with HasContribPredictionCol
with RabitParams with NonXGBoostParams with SchemaValidationTrait {
final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost",
ParamValidators.gtEq(1))
final def getNumRound: Int = $(numRound)
final val forceRepartition = new BooleanParam(this, "forceRepartition", "If the partition " +
"is equal to numWorkers, xgboost won't repartition the dataset. Set forceRepartition to " +
"true to force repartition.")
final def getForceRepartition: Boolean = $(forceRepartition)
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
ParamValidators.gtEq(1))
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", "Stop training " +
"Number of rounds of decreasing eval metric to tolerate before stopping training",
ParamValidators.gtEq(0))
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
final val inferBatchSize = new IntParam(this, "inferBatchSize", "batch size in rows " +
"to be grouped for inference",
ParamValidators.gtEq(1))
/** @group getParam */
final def getInferBatchSize: Int = $(inferBatchSize)
/**
* the value treated as missing. default: Float.NaN
*/
final val missing = new FloatParam(this, "missing", "The value treated as missing")
final def getMissing: Float = $(missing)
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
"provided by user")
final def getCustomObj: ObjectiveTrait = $(customObj)
final val customEval = new CustomEvalParam(this, "customEval",
"customized evaluation function provided by user")
final def getCustomEval: EvalTrait = $(customEval)
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
* In native code, the parameter name is feature_name.
* */
final val featureNames = new StringArrayParam(this, "feature_names",
"an array of feature names")
final def getFeatureNames: Array[String] = $(featureNames)
/** Feature types, q is numeric and c is categorical.
* In native code, the parameter name is feature_type
* */
final val featureTypes = new StringArrayParam(this, "feature_types",
"an array of feature types")
final def getFeatureTypes: Array[String] = $(featureTypes)
setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
featuresCols -> Array.empty, customObj -> null, customEval -> null,
featureNames -> Array.empty, featureTypes -> Array.empty)
addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
forceRepartition, featuresCols, customEval, customObj, featureTypes, featureNames)
final def getNumWorkers: Int = $(numWorkers)
def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]
def setForceRepartition(value: Boolean): T = set(forceRepartition, value).asInstanceOf[T]
def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T]
def setFeaturesCol(value: Array[String]): T = set(featuresCols, value).asInstanceOf[T]
def setBaseMarginCol(value: String): T = set(baseMarginCol, value).asInstanceOf[T]
def setWeightCol(value: String): T = set(weightCol, value).asInstanceOf[T]
def setLeafPredictionCol(value: String): T = set(leafPredictionCol, value).asInstanceOf[T]
def setContribPredictionCol(value: String): T = set(contribPredictionCol, value).asInstanceOf[T]
def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T]
def setMissing(value: Float): T = set(missing, value).asInstanceOf[T]
def setCustomObj(value: ObjectiveTrait): T = set(customObj, value).asInstanceOf[T]
def setCustomEval(value: EvalTrait): T = set(customEval, value).asInstanceOf[T]
def setRabitTrackerTimeout(value: Int): T = set(rabitTrackerTimeout, value).asInstanceOf[T]
def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T]
def setRabitTrackerPort(value: Int): T = set(rabitTrackerPort, value).asInstanceOf[T]
def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T]
def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]
}
private[spark] trait SchemaValidationTrait {
def validateAndTransformSchema(schema: StructType,
fitting: Boolean): StructType = schema
}
/**
* XGBoost ranking spark-specific parameters
*
* @tparam T should be XGBoostRanker or XGBoostRankingModel
*/
private[spark] trait RankerParams[T <: Params] extends HasGroupCol with NonXGBoostParams {
def setGroupCol(value: String): T = set(groupCol, value).asInstanceOf[T]
addNonXGBoostParam(groupCol)
}
/**
* XGBoost-specific parameters to pass into xgboost libraray
*
* @tparam T should be the XGBoost estimators or models
*/
private[spark] trait XGBoostParams[T <: Params] extends TreeBoosterParams
with LearningTaskParams with GeneralParams with DartBoosterParams {
// Setters for TreeBoosterParams
def setEta(value: Double): T = set(eta, value).asInstanceOf[T]
def setGamma(value: Double): T = set(gamma, value).asInstanceOf[T]
def setMaxDepth(value: Int): T = set(maxDepth, value).asInstanceOf[T]
def setMinChildWeight(value: Double): T = set(minChildWeight, value).asInstanceOf[T]
def setMaxDeltaStep(value: Double): T = set(maxDeltaStep, value).asInstanceOf[T]
def setSubsample(value: Double): T = set(subsample, value).asInstanceOf[T]
def setSamplingMethod(value: String): T = set(samplingMethod, value).asInstanceOf[T]
def setColsampleBytree(value: Double): T = set(colsampleBytree, value).asInstanceOf[T]
def setColsampleBylevel(value: Double): T = set(colsampleBylevel, value).asInstanceOf[T]
def setColsampleBynode(value: Double): T = set(colsampleBynode, value).asInstanceOf[T]
def setLambda(value: Double): T = set(lambda, value).asInstanceOf[T]
def setAlpha(value: Double): T = set(alpha, value).asInstanceOf[T]
def setTreeMethod(value: String): T = set(treeMethod, value).asInstanceOf[T]
def setScalePosWeight(value: Double): T = set(scalePosWeight, value).asInstanceOf[T]
def setUpdater(value: String): T = set(updater, value).asInstanceOf[T]
def setRefreshLeaf(value: Boolean): T = set(refreshLeaf, value).asInstanceOf[T]
def setProcessType(value: String): T = set(processType, value).asInstanceOf[T]
def setGrowPolicy(value: String): T = set(growPolicy, value).asInstanceOf[T]
def setMaxLeaves(value: Int): T = set(maxLeaves, value).asInstanceOf[T]
def setMaxBins(value: Int): T = set(maxBins, value).asInstanceOf[T]
def setNumParallelTree(value: Int): T = set(numParallelTree, value).asInstanceOf[T]
def setInteractionConstraints(value: String): T =
set(interactionConstraints, value).asInstanceOf[T]
def setMaxCachedHistNode(value: Int): T = set(maxCachedHistNode, value).asInstanceOf[T]
// Setters for LearningTaskParams
def setObjective(value: String): T = set(objective, value).asInstanceOf[T]
def setNumClass(value: Int): T = set(numClass, value).asInstanceOf[T]
def setBaseScore(value: Double): T = set(baseScore, value).asInstanceOf[T]
def setEvalMetric(value: String): T = set(evalMetric, value).asInstanceOf[T]
def setSeed(value: Long): T = set(seed, value).asInstanceOf[T]
def setSeedPerIteration(value: Boolean): T = set(seedPerIteration, value).asInstanceOf[T]
def setTweedieVariancePower(value: Double): T = set(tweedieVariancePower, value).asInstanceOf[T]
def setHuberSlope(value: Double): T = set(huberSlope, value).asInstanceOf[T]
def setAftLossDistribution(value: String): T = set(aftLossDistribution, value).asInstanceOf[T]
def setLambdarankPairMethod(value: String): T = set(lambdarankPairMethod, value).asInstanceOf[T]
def setLambdarankNumPairPerSample(value: Int): T =
set(lambdarankNumPairPerSample, value).asInstanceOf[T]
def setLambdarankUnbiased(value: Boolean): T = set(lambdarankUnbiased, value).asInstanceOf[T]
def setLambdarankBiasNorm(value: Double): T = set(lambdarankBiasNorm, value).asInstanceOf[T]
def setNdcgExpGain(value: Boolean): T = set(ndcgExpGain, value).asInstanceOf[T]
// Setters for Dart
def setSampleType(value: String): T = set(sampleType, value).asInstanceOf[T]
def setNormalizeType(value: String): T = set(normalizeType, value).asInstanceOf[T]
def setRateDrop(value: Double): T = set(rateDrop, value).asInstanceOf[T]
def setOneDrop(value: Boolean): T = set(oneDrop, value).asInstanceOf[T]
def setSkipDrop(value: Double): T = set(skipDrop, value).asInstanceOf[T]
// Setters for GeneralParams
def setBooster(value: String): T = set(booster, value).asInstanceOf[T]
def setDevice(value: String): T = set(device, value).asInstanceOf[T]
def setVerbosity(value: Int): T = set(verbosity, value).asInstanceOf[T]
def setValidateParameters(value: Boolean): T = set(validateParameters, value).asInstanceOf[T]
def setNthread(value: Int): T = set(nthread, value).asInstanceOf[T]
}
private[spark] trait ParamUtils[T <: Params] extends Params {
def isDefinedNonEmpty(param: Param[String]): Boolean = {
isDefined(param) && $(param).nonEmpty
}
}

View File

@ -1,229 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.util
import scala.collection.mutable
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.HashPartitioner
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{FloatType, IntegerType}
import org.apache.spark.sql.{Column, DataFrame, Row}
object DataUtils extends Serializable {
private[spark] implicit class XGBLabeledPointFeatures(
val labeledPoint: XGBLabeledPoint
) extends AnyVal {
/** Converts the point to [[MLLabeledPoint]]. */
private[spark] def asML: MLLabeledPoint = {
MLLabeledPoint(labeledPoint.label, labeledPoint.features)
}
/**
* Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
*/
def features: Vector = if (labeledPoint.indices == null) {
Vectors.dense(labeledPoint.values.map(_.toDouble))
} else {
Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
}
}
private[spark] implicit class MLLabeledPointToXGBLabeledPoint(
val labeledPoint: MLLabeledPoint
) extends AnyVal {
/** Converts an [[MLLabeledPoint]] to an [[XGBLabeledPoint]]. */
def asXGB: XGBLabeledPoint = {
labeledPoint.features.asXGB.copy(label = labeledPoint.label.toFloat)
}
}
private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
/**
* Converts a [[Vector]] to a data point with a dummy label.
*
* This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
* for prediction.
*/
def asXGB: XGBLabeledPoint = v match {
case v: DenseVector =>
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
case v: SparseVector =>
XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
}
}
private def attachPartitionKey(
row: Row,
deterministicPartition: Boolean,
numWorkers: Int,
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
if (deterministicPartition) {
(math.abs(row.hashCode() % numWorkers), xgbLp)
} else {
(1, xgbLp)
}
}
private def repartitionRDDs(
deterministicPartition: Boolean,
numWorkers: Int,
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
if (deterministicPartition) {
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
rdd => rdd.map(_._2)
}
} else {
arrayOfRDDs.map(rdd => {
if (rdd.getNumPartitions != numWorkers) {
rdd.map(_._2).repartition(numWorkers)
} else {
rdd.map(_._2)
}
})
}
}
/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
private[spark] case class PackedParams(labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean)
/**
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
*
* First, it serves converting each instance of input into XGBLabeledPoint
* Second, it repartition the RDD to the number workers.
*
*/
private[spark] def convertDataFrameToXGBLabeledPointRDDs(
packedParams: PackedParams,
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
packedParams match {
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
deterministicPartition) =>
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
groupCol.cast(IntegerType),
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
val arrayOfRDDs = dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
}
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
}
}
private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => v != missing)
} else {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => !v.isNaN)
}
}
private[spark] def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) {
xgbLabelPointGroups.map {
labeledPoints => processMissingValues(
labeledPoints.iterator,
missing,
allowNonZeroMissing
).toArray
}
} else {
xgbLabelPointGroups
}
}
private def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
}
private def verifyMissingSetting(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (missing != 0.0f && !allowNonZeroMissing) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
s"vector but instead did so in a way that preserves zeros in your feature vector " +
s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
s" (only use if you know what you are doing)")
}
labeledPoint
})
} else {
xgbLabelPoints
}
}
}

View File

@ -1,147 +0,0 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.apache.spark.ml.util
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.scala.spark
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FSDataInputStream
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JObject
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
abstract class XGBoostWriter extends MLWriter {
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
}
object DefaultXGBoostParamsWriter {
val XGBOOST_VERSION_TAG = "xgboostVersion"
/**
* Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext): Unit = {
// save xgboost version to distinguish the old model.
val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
}
}
object DefaultXGBoostParamsReader {
private val logger = LogFactory.getLog("XGBoostSpark")
/**
* Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
}
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
*
* And it will auto-skip the parameter not defined.
*
* This API is mainly copied from DefaultParamsReader
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
// XGBoost didn't set the default parameters since the save/load code is copied
// from spark 2.3.x, which means it just used the default values
// as the same with XGBoost version instead of them in model.
// For the compatibility, here we still don't set the default parameters.
// setParams(instance, metadata, isDefault = true)
setParams(instance, metadata, isDefault = false)
}
/** This API is only for XGBoostClassificationModel */
def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
implicit val format = DefaultFormats
// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
// or the new xgboost compatible.
val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
.extractOpt[String]
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
// or else, XGBoost will throw exception.
// So it's safe to get numClass from meta data.
xgbVerOpt
.map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
.getOrElse(dataInStream.readInt())
}
private def setParams(
instance: Params,
metadata: Metadata,
isDefault: Boolean): Unit = {
val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
paramsToSet match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val finalName = handleBrokenlyChangedName(paramName)
// For the deleted parameters, we'd better to remove it instead of throwing an exception.
// So we need to check if the parameter exists instead of blindly setting it.
if (instance.hasParam(finalName)) {
val param = instance.getParam(finalName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, handleBrokenlyChangedValue(paramName, value))
} else {
logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
}
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
/** This is really not good to do this transformation, but it is needed since there're
* some tests based on 0.82 saved model in which the objective is "reg:linear" */
private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
private def handleBrokenlyChangedName(paramName: String): String = {
paramNameCompatibilityMap.getOrElse(paramName, paramName)
}
private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
}
}

View File

@ -1,50 +0,0 @@
/*
Copyright (c) 2022-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
import org.apache.spark.ml.linalg.VectorUDT
object XGBoostSchemaUtils {
/** check if the dataType is VectorUDT */
def isVectorUDFType(dataType: DataType): Boolean = {
dataType match {
case _: VectorUDT => true
case _ => false
}
}
/** The feature columns will be vectorized by VectorAssembler first, which only
* supports Numeric, Boolean and VectorUDT types */
def checkFeatureColumnType(dataType: DataType): Unit = {
dataType match {
case _: NumericType | BooleanType =>
case _: VectorUDT =>
case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " +
s"boolean and VectorUDT types, found: ${d}")
}
}
def checkNumericType(
schema: StructType,
colName: String,
msg: String = ""): Unit = {
SchemaUtils.checkNumericType(schema, colName, msg)
}
}

View File

@ -0,0 +1,93 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.apache.spark.ml.xgboost
import org.apache.spark.SparkContext
import org.apache.spark.ml.classification.ProbabilisticClassifierParams
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{DatasetUtils, DefaultParamsReader, DefaultParamsWriter, SchemaUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.json4s.{JObject, JValue}
import ml.dmlc.xgboost4j.scala.spark.params.NonXGBoostParams
/**
* XGBoost classification spark-specific parameters which should not be passed
* into the xgboost library
*
* @tparam T should be XGBoostClassifier or XGBoostClassificationModel
*/
trait XGBProbabilisticClassifierParams[T <: Params]
extends ProbabilisticClassifierParams with NonXGBoostParams {
/**
* XGBoost doesn't use validateAndTransformSchema since spark validateAndTransformSchema
* needs to ensure the feature is vector type
*/
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
var outputSchema = SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(rawPredictionCol))
outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(probabilityCol))
outputSchema
}
addNonXGBoostParam(rawPredictionCol, probabilityCol, thresholds)
}
/** Utils to access the spark internal functions */
object SparkUtils {
def getNumClasses(dataset: Dataset[_], labelCol: String, maxNumClasses: Int = 100): Int = {
DatasetUtils.getNumClasses(dataset, labelCol, maxNumClasses)
}
def checkNumericType(schema: StructType, colName: String, msg: String = ""): Unit = {
SchemaUtils.checkNumericType(schema, colName, msg)
}
def saveMetadata(instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, paramMap)
}
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
}
def appendColumn(schema: StructType,
colName: String,
dataType: DataType,
nullable: Boolean = false): StructType = {
SchemaUtils.appendColumn(schema, colName, dataType, nullable)
}
def appendVectorUDTColumn(schema: StructType,
colName: String,
dataType: DataType = new VectorUDT,
nullable: Boolean = false): StructType = {
SchemaUtils.appendColumn(schema, colName, dataType, nullable)
}
}

View File

@ -16,21 +16,11 @@
package ml.dmlc.xgboost4j.scala.spark
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.funsuite.AnyFunSuite
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
xgbParamsFactory.buildXGBRuntimeParams
}
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
@ -113,9 +103,11 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
"num_workers" -> numWorkers)).fit(trainingDF)
val trainingDF = smallBinaryClassificationVector
val model = new XGBoostClassifier(paramMap)
.setNumWorkers(numWorkers)
.setNumRound(10)
.fit(trainingDF)
val prediction = model.transform(trainingDF)
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
// threads

View File

@ -16,10 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable.ListBuffer
import org.apache.commons.logging.LogFactory
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
import org.apache.commons.logging.LogFactory
import scala.collection.mutable.ListBuffer
/**

View File

@ -1,114 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import org.apache.spark.sql.functions._
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbClassifier = new XGBoostClassifier(paramMap)
assert(xgbClassifier.needDeterministicRepartitioning)
}
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Regressor)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbRegressor = new XGBoostRegressor(paramMap)
assert(xgbRegressor.needDeterministicRepartitioning)
}
test("deterministic partitioning takes effect with various parts of data") {
val trainingDF = buildDataFrame(Classification.train)
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
// have to produce the identical RDDs
val transformedDFs = (1 until 6).map(shuffleCount => {
var resultDF = trainingDF
for (i <- 0 until shuffleCount) {
resultDF = resultDF.repartition(numWorkers)
}
resultDF
})
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true),
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
case (partitionIndex, labelPoints) =>
Iterator((partitionIndex, labelPoints.toList))
}.collect().toMap)
resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
assert(map1.keys.toSet === map2.keys.toSet)
for ((parIdx, labeledPoints) <- map1) {
val sortedA = labeledPoints.sortBy(_.hashCode())
val sortedB = map2(parIdx).sortBy(_.hashCode())
assert(sortedA.length === sortedB.length)
assert(sortedA.indices.forall(idx =>
sortedA(idx).values.toSet === sortedB(idx).values.toSet))
}
map2
}
}
test("deterministic partitioning has a uniform repartition on dataset with missing values") {
val N = 10000
val dataset = (0 until N).map{ n =>
(n, n % 2, Vectors.sparse(3, Array(0, 1, 2), Array(Double.NaN, n, Double.NaN)))
}
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true), df
).head
val partitionsSizes = dfRepartitioned
.mapPartitions(iter => Array(iter.size.toDouble).iterator, true)
.collect()
val partitionMean = partitionsSizes.sum / partitionsSizes.length
val squaredDiffSum = partitionsSizes
.map(partitionSize => Math.pow(partitionSize - partitionMean, 2))
val standardDeviation = math.sqrt(squaredDiffSum.sum / squaredDiffSum.length)
assert(standardDeviation < math.sqrt(N.toDouble))
}
}

View File

@ -16,9 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.commons.logging.LogFactory
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory
class EvalError extends EvalTrait {

View File

@ -1,131 +0,0 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.funsuite.AnyFunSuite
import org.apache.hadoop.fs.{FileSystem, Path}
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model2, model4)
}
test("test update/load models") {
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "1.ubj")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "3.ubj")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}
test("test cleanUpHigherVersions") {
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.ubj").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.ubj").exists())
}
test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}
private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = produceParamMap(tmpPath, 2)
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) {
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.ubj")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
} else {
assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
}
}
test("training with checkpoint boosters") {
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
}
test("training with checkpoint boosters with cached training dataset") {
trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
}
test("the checkpoint file should be cleaned after a successful training") {
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
}
}

View File

@ -1,70 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.functions._
import scala.util.Random
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
test("transform throwing exception if feature size of dataset is greater than model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
for (x <- 1 to 252) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val thrown = intercept[Exception] {
model.transform(assembler.transform(df)).show()
}
assert(thrown.getMessage.contains(
"Number of columns does not match number of features in booster"))
}
test("train throwing exception if feature size of dataset is different on distributed train") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = ss
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
.map(lp => (lp.label, lp)).partitionBy(
new Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
}
).map(_._2).zipWithIndex().map {
case (lp, id) =>
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
}

View File

@ -1,235 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.scalatest.funsuite.AnyFunSuite
import scala.util.Random
import org.apache.spark.SparkException
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
val numCols = 5
val data = (0 until numRows).map { x =>
val label = Random.nextInt(2)
val values = Array.tabulate[Double](numCols) { c =>
if (c == numCols - 1) 0 else Random.nextDouble
}
(label, Vectors.dense(values))
}
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
}
val denseDF = buildDenseDataFrame().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
val model = new XGBoostClassifier(paramMap).fit(denseDF)
model.transform(denseDF).collect()
}
test("handle Float.NaN as missing value correctly") {
val spark = ss
import spark.implicits._
val testDF = Seq(
(1.0f, 0.0f, Float.NaN, 1.0),
(1.0f, 0.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0),
(1.0f, Float.NaN, 0.0f, 0.0),
(0.0f, 1.0f, 0.0f, 1.0),
(Float.NaN, 0.0f, 0.0f, 1.0)
).toDF("col1", "col2", "col3", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
.setHandleInvalid("keep")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
test("specify a non-zero missing value but with dense vector does not stop" +
" application") {
val spark = ss
import spark.implicits._
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(1.0f, 0.0f, -1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 0.0),
(0.0f, 1.0f, 0.0f, 1.0),
(-1.0f, 0.0f, 0.0f, 1.0)
).toDF("col1", "col2", "col3", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
test("specify a non-zero missing value and meet an empty vector we should" +
" stop the application") {
val spark = ss
import spark.implicits._
val testDF = Seq(
(1.0f, 0.0f, -1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0),// empty vector
(-1.0f, 0.0f, 0.0f, 1.0)
).toDF("col1", "col2", "col3", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
intercept[SparkException] {
new XGBoostClassifier(paramMap).fit(inputDF)
}
}
test("specify a non-zero missing value and meet a Sparse vector we should" +
" stop the application") {
val spark = ss
import spark.implicits._
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(1.0f, 0.0f, -1.0f, 1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
).toDF("col1", "col2", "col3", "col4", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3", "col4"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
intercept[SparkException] {
new XGBoostClassifier(paramMap).fit(inputDF)
}
}
test("specify a non-zero missing value but set allow_non_zero_for_missing " +
"does not stop application") {
val spark = ss
import spark.implicits._
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
).toDF("col1", "col2", "col3", "col4", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3", "col4"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f,
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
// https://github.com/dmlc/xgboost/pull/5929
test("handle the empty last row correctly with a missing value as 0") {
val spark = ss
import spark.implicits._
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
(0.0f, 0.0f, 0.0f, 0.0f, 0.0)
).toDF("col1", "col2", "col3", "col4", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3", "col4"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> 0.0f,
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
test("Getter and setter for AllowNonZeroForMissingValue works") {
{
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val classifier = new XGBoostClassifier(paramMap)
classifier.setAllowNonZeroForMissing(true)
assert(classifier.getAllowNonZeroForMissingValue)
classifier.setAllowNonZeroForMissing(false)
assert(!classifier.getAllowNonZeroForMissingValue)
val model = classifier.fit(training)
model.setAllowNonZeroForMissing(true)
assert(model.getAllowNonZeroForMissingValue)
model.setAllowNonZeroForMissing(false)
assert(!model.getAllowNonZeroForMissingValue)
}
{
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val regressor = new XGBoostRegressor(paramMap)
regressor.setAllowNonZeroForMissing(true)
assert(regressor.getAllowNonZeroForMissingValue)
regressor.setAllowNonZeroForMissing(false)
assert(!regressor.getAllowNonZeroForMissingValue)
val model = regressor.fit(training)
model.setAllowNonZeroForMissing(true)
assert(model.getAllowNonZeroForMissingValue)
model.setAllowNonZeroForMissing(false)
assert(!model.getAllowNonZeroForMissingValue)
}
}
}

View File

@ -1,104 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}
test("fail training elegantly with unsupported objective function") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
test("fail training elegantly with unsupported eval metrics") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
test("custom_eval does not support early stopping") {
val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(paramMap).fit(trainingDF)
}
assert(thrown.getMessage.contains("custom_eval does not support early stopping"))
}
test("early stopping should work without custom_eval setting") {
val paramMap = Map("eta" -> "0.1", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)
new XGBoostClassifier(paramMap).fit(trainingDF)
}
test("Default parameters") {
val classifier = new XGBoostClassifier()
intercept[NoSuchElementException] {
classifier.getBaseScore
}
}
test("approx can't be used for gpu train") {
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val thrown = intercept[IllegalArgumentException] {
xgb.fit(trainingDF)
}
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
"for Spark GPU cluster"))
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,24 +18,25 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileInputStream}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.IOUtils
import org.apache.spark.SparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.Utils.{withResource, XGBLabeledPointFeatures}
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
trait PerTest extends BeforeAndAfterEach {
self: AnyFunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
protected val numWorkers: Int = 4
@transient private var currentSession: SparkSession = _
def ss: SparkSession = getOrCreateSession
implicit def sc: SparkContext = ss.sparkContext
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
@ -45,10 +46,11 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
.config("spark.driver.memory", "512m")
.config("spark.barrier.sync.timeout", 10)
.config("spark.task.cpus", 1)
.config("spark.stage.maxConsecutiveAttempts", 1)
override def beforeEach(): Unit = getOrCreateSession
override def afterEach() {
override def afterEach(): Unit = {
if (currentSession != null) {
currentSession.stop()
cleanExternalCache(currentSession.sparkContext.appName)
@ -74,42 +76,25 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
(id, labeledPoint.label, labeledPoint.features, labeledPoint.weight)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
protected def buildDataFrameWithRandSort(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
val df = buildDataFrame(labeledPoints, numPartitions)
val rndSortedRDD = df.rdd.mapPartitions { iter =>
iter.map(_ -> Random.nextDouble()).toList
.sortBy(_._2)
.map(_._1).iterator
}
ss.createDataFrame(rndSortedRDD, df.schema)
.toDF("id", "label", "features", "weight")
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group, labeledPoint.weight)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
.toDF("id", "label", "features", "group", "weight")
}
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
@ -118,12 +103,32 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
}
}
/** Executes the provided code block and then closes the resource */
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
def smallBinaryClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
(0.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
(1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
))).toDF("label", "margin", "weight", "features")
def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
(2.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
(2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
))).toDF("label", "margin", "weight", "features")
def smallGroupVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 2.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 1, 0.4, 1.0, Vectors.dense(0.0, 0.0, 0.0)),
(0.0, 1, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 0, 1.2, 2.0, Vectors.dense(2.0, 0.0, 4.0)),
(1.0, 2, -0.5, 3.0, Vectors.dense(0.2, 1.2, 2.0)),
(0.0, 2, -0.4, 3.0, Vectors.dense(0.5, 2.2, 1.7))
))).toDF("label", "group", "margin", "weight", "features")
}

View File

@ -1,195 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.util.Arrays
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.functions._
import org.scalatest.funsuite.AnyFunSuite
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath)
val paramMap2 = xgbc2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbr = new XGBoostRegressor(paramMap)
val xgbrPath = new File(tempDir.toFile, "xgbr").getPath
xgbr.write.overwrite().save(xgbrPath)
val xgbr2 = XGBoostRegressor.load(xgbrPath)
val paramMap2 = xgbr2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbr.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath
model.write.overwrite.save(xgbrModelPath)
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getPredictionCol === model2.getPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
val xgb = new XGBoostClassifier(paramMap)
// Construct MLlib pipeline, save and load
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
val pipePath = new File(tempDir.toFile, "pipeline").getPath
pipeline.write.overwrite().save(pipePath)
val pipeline2 = Pipeline.read.load(pipePath)
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
val paramMap2 = xgb2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
// Model training, save and load
val pipeModel = pipeline.fit(df)
val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath
pipeModel.write.overwrite.save(pipeModelPath)
val pipeModel2 = PipelineModel.load(pipeModelPath)
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
assert(xgbModel.getEta === xgbModel2.getEta)
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
"using custom Eval and Obj") {
val trainingDF = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
"num_round" -> "10", "num_workers" -> numWorkers, "objective" -> "binary:logistic")
val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath)
val paramMap2 = xgbc2.MLlib2XGBoostParams
paramMap.foreach {
case ("custom_eval", v) => assert(v.isInstanceOf[EvalError])
case ("custom_obj", v) =>
assert(v.isInstanceOf[CustomObj])
assert(v.asInstanceOf[CustomObj].customParameter ==
paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter)
case (_, _) =>
}
val eval = new EvalError()
val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("cross-version model loading (0.82)") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
for (x <- 1 to 250) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
df = assembler.transform(df)
for (x <- 1 to 250) {
df = df.drop(s"feature_${x}")
}
model.transform(df).show()
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,8 +16,9 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import scala.io.Source
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
trait TrainTestData {
@ -31,8 +32,8 @@ trait TrainTestData {
Source.fromInputStream(is).getLines()
}
protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
Seq[XGBLabeledPoint] = {
protected def getLabeledPoints(resource: String, featureSize: Int,
zeroBased: Boolean): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures.head.toFloat
@ -65,10 +66,32 @@ trait TrainTestData {
object Classification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
Random.setSeed(10)
val randomWeights = Array.fill(train.length)(Random.nextFloat())
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
randomWeights(index), v.group, v.baseMargin)
}
}
object MultiClassification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
private def split(): (Seq[XGBLabeledPoint], Seq[XGBLabeledPoint]) = {
val tmp: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
Random.setSeed(100)
val randomizedTmp = Random.shuffle(tmp)
val splitIndex = (randomizedTmp.length * 0.8).toInt
(randomizedTmp.take(splitIndex), randomizedTmp.drop(splitIndex))
}
val (train, test) = split()
Random.setSeed(10)
val randomWeights = Array.fill(train.length)(Random.nextFloat())
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
randomWeights(index), v.group, v.baseMargin)
}
private def getLabeledPoints(resource: String): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
@ -92,31 +115,25 @@ object Regression extends TrainTestData {
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
}
object Ranking extends TrainTestData {
Random.setSeed(10)
val randomWeights = Array.fill(train.length)(Random.nextFloat())
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
randomWeights(index), v.group, v.baseMargin)
}
object Ranking extends TrainTestData {
val RANK_COL_NUM = 3
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
// use the group as the weight
val trainWithWeight = train.map { labelPoint =>
XGBLabeledPoint(labelPoint.label, labelPoint.size, labelPoint.indices, labelPoint.values,
labelPoint.group, labelPoint.group, labelPoint.baseMargin)
}
val trainGroups = train.map(_.group)
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
"/rank.test.txt", RANK_COL_NUM, zeroBased = false)
private def getGroups(resource: String): Seq[Int] = {
getResourceLines(resource).map(_.toInt).toList
}
}
object Synthetic extends {
val TRAIN_COL_NUM = 3
val TRAIN_WRONG_COL_NUM = 2
val train: Seq[XGBLabeledPoint] = Seq(
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f))
)
val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq(
XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f))
)
}

View File

@ -16,241 +16,212 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileInputStream}
import java.io.File
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.funsuite.AnyFunSuite
import org.apache.commons.io.IOUtils
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.json4s.{DefaultFormats, Formats}
import org.json4s.jackson.parseJson
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoostClassifier copy") {
val classifier = new XGBoostClassifier().setNthread(2).setNumWorkers(10)
val classifierCopied = classifier.copy(ParamMap.empty)
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> round,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostClassifier()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("binary:logistic")
.setNumRound(round)
.setNumWorkers(numWorkers)
.fit(trainingDF)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 === p2)
}
assert(classifier.uid === classifierCopied.uid)
assert(classifier.getNthread === classifierCopied.getNthread)
assert(classifier.getNumWorkers === classifier.getNumWorkers)
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
model.setRawPredictionCol("raw_prediction")
.setProbabilityCol("probability_prediction")
.setPredictionCol("final_prediction")
var predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("probability_prediction"))
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("").setPredictionCol("final_prediction")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction") === false)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.validationObjectiveHistory.isEmpty)
test("XGBoostClassification copy") {
val model = new XGBoostClassificationModel("hello").setNthread(2).setNumWorkers(10)
val modelCopied = model.copy(ParamMap.empty)
assert(model.uid === modelCopied.uid)
assert(model.getNthread === modelCopied.getNthread)
assert(model.getNumWorkers === modelCopied.getNumWorkers)
}
test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.getEta == 0.1)
assert(model.getMaxDepth == 6)
assert(model.numClasses == 6)
val transformedDf = model.transform(trainingDF)
assert(!transformedDf.columns.contains("probability"))
test("read/write") {
val trainDf = smallBinaryClassificationVector
val xgbParams: Map[String, Any] = Map(
"max_depth" -> 5,
"eta" -> 0.2,
"objective" -> "binary:logistic"
)
def check(xgboostParams: XGBoostParams[_]): Unit = {
assert(xgboostParams.getMaxDepth === 5)
assert(xgboostParams.getEta === 0.2)
assert(xgboostParams.getObjective === "binary:logistic")
}
test("objective will be set if not specifying it") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val xgb = new XGBoostClassifier(paramMap)
assert(!xgb.isDefined(xgb.objective))
xgb.fit(training)
assert(xgb.getObjective == "binary:logistic")
val classifierPath = new File(tempDir.toFile, "classifier").getPath
val classifier = new XGBoostClassifier(xgbParams).setNumRound(2)
check(classifier)
val trainingDF = buildDataFrame(MultiClassification.train)
val paramMap1 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val xgb1 = new XGBoostClassifier(paramMap1)
assert(!xgb1.isDefined(xgb1.objective))
xgb1.fit(trainingDF)
assert(xgb1.getObjective == "multi:softprob")
classifier.write.overwrite().save(classifierPath)
val loadedClassifier = XGBoostClassifier.load(classifierPath)
check(loadedClassifier)
// shouldn't change user's objective setting
val paramMap2 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod, "objective" -> "multi:softmax")
val xgb2 = new XGBoostClassifier(paramMap2)
assert(xgb2.getObjective == "multi:softmax")
xgb2.fit(trainingDF)
assert(xgb2.getObjective == "multi:softmax")
val model = loadedClassifier.fit(trainDf)
check(model)
assert(model.numClasses === 2)
val modelPath = new File(tempDir.toFile, "model").getPath
model.write.overwrite().save(modelPath)
val modelLoaded = XGBoostClassificationModel.load(modelPath)
assert(modelLoaded.numClasses === 2)
check(modelLoaded)
}
test("use base margin") {
val training1 = buildDataFrame(Classification.train)
val training2 = training1.withColumn("margin", functions.rand())
val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
test("XGBoostClassificationModel transformed schema") {
val trainDf = smallBinaryClassificationVector
val classifier = new XGBoostClassifier().setNumRound(1)
val model = classifier.fit(trainDf)
var out = model.transform(trainDf)
val xgb = new XGBoostClassifier(paramMap)
val model1 = xgb.fit(training1)
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
var count = 0
for ((r1, r2) <- prediction1.zip(prediction2)) {
if (!r1.equals(r2)) count = count + 1
}
assert(count != 0)
// Transform should not discard the other columns of the transforming dataframe
Seq("label", "margin", "weight", "features").foreach { v =>
assert(out.schema.names.contains(v))
}
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
val resultDF = model.transform(test)
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
// Transform needs to add extra columns
Seq("rawPrediction", "probability", "prediction").foreach { v =>
assert(out.schema.names.contains(v))
}
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("")
val resultDF = model.transform(test)
assert(!resultDF.columns.contains("predictLeaf"))
assert(out.schema.names.length === 7)
model.setRawPredictionCol("").setProbabilityCol("")
out = model.transform(trainDf)
// rawPrediction="", probability=""
Seq("rawPrediction", "probability").foreach { v =>
assert(!out.schema.names.contains(v))
}
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(buildDataFrame(Classification.test))
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictContrib"))
assert(out.schema.names.contains("prediction"))
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
out = model.transform(trainDf)
assert(out.schema.names.contains("leaf"))
assert(out.schema.names.contains("contrib"))
val out1 = classifier.setLeafPredictionCol("leaf1")
.setContribPredictionCol("contrib1")
.fit(trainDf).transform(trainDf)
assert(out1.schema.names.contains("leaf1"))
assert(out1.schema.names.contains("contrib1"))
}
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("")
val resultDF = model.transform(test)
assert(!resultDF.columns.contains("predictContrib"))
test("Supported objectives") {
val classifier = new XGBoostClassifier()
val df = smallMultiClassificationVector
(BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq).foreach { obj =>
classifier.setObjective(obj)
classifier.validate(df)
}
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(buildDataFrame(Classification.test))
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
assert(resultDF.columns.contains("predictContrib"))
classifier.setObjective("reg:squaredlogerror")
intercept[IllegalArgumentException](
classifier.validate(df)
)
}
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
test("BinaryClassification infer objective and num_class") {
val trainDf = smallBinaryClassificationVector
var classifier = new XGBoostClassifier()
assert(classifier.getObjective === "reg:squarederror")
assert(classifier.getNumClass === 0)
classifier.validate(trainDf)
assert(classifier.getObjective === "binary:logistic")
assert(!classifier.isSet(classifier.numClass))
// Infer objective according num class
classifier = new XGBoostClassifier()
classifier.setNumClass(2)
intercept[IllegalArgumentException](
classifier.validate(trainDf)
)
// Infer to num class according to num class
classifier = new XGBoostClassifier()
classifier.setObjective("binary:logistic")
classifier.validate(trainDf)
assert(classifier.getObjective === "binary:logistic")
assert(!classifier.isSet(classifier.numClass))
}
test("MultiClassification infer objective and num_class") {
val trainDf = smallMultiClassificationVector
var classifier = new XGBoostClassifier()
assert(classifier.getObjective === "reg:squarederror")
assert(classifier.getNumClass === 0)
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softprob")
assert(classifier.getNumClass === 3)
// Infer to objective according to num class
classifier = new XGBoostClassifier()
classifier.setNumClass(3)
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softprob")
assert(classifier.getNumClass === 3)
// Infer to num class according to objective
classifier = new XGBoostClassifier()
classifier.setObjective("multi:softmax")
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softmax")
assert(classifier.getNumClass === 3)
}
test("XGBoost-Spark binary classification output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
val paramMap = Map("objective" -> "binary:logistic")
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
}
test("XGBoostClassifier should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Classification.train.iterator)
test("XGBoost-Spark binary classification output with weight should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.trainWithWeight.iterator)
trainingDM.setWeight(Classification.randomWeights)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Classification.train)
val testDF = buildDataFrameWithRandSort(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
val trainingDF = buildDataFrame(Classification.trainWithWeight)
val testDF = buildDataFrame(Classification.test)
val paramMap = Map("objective" -> "binary:logistic")
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
5, paramMap, Some("weight"))
}
Seq("multi:softprob", "multi:softmax").foreach { objective =>
test(s"XGBoost-Spark multi classification with $objective output should match XGBoost4j") {
val trainingDM = new DMatrix(MultiClassification.train.iterator)
val testDM = new DMatrix(MultiClassification.test.iterator)
val trainingDF = buildDataFrame(MultiClassification.train)
val testDF = buildDataFrame(MultiClassification.test)
val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
}
}
test("XGBoost-Spark multi classification output with weight should match XGBoost4j") {
val trainingDM = new DMatrix(MultiClassification.trainWithWeight.iterator)
trainingDM.setWeight(MultiClassification.randomWeights)
val testDM = new DMatrix(MultiClassification.test.iterator)
val trainingDF = buildDataFrame(MultiClassification.trainWithWeight)
val testDF = buildDataFrame(MultiClassification.test)
val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap, Some("weight"))
}
private def checkResultsWithXGBoost4j(
@ -258,223 +229,73 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
round: Int = 5,
xgbParams: Map[String, Any] = Map.empty,
weightCol: Option[String] = None): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"base_score" -> 0.5,
"objective" -> "binary:logistic",
"tree_method" -> treeMethod,
"max_bin" -> 16)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
"max_bin" -> 16) ++ xgbParams
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val classifier = new XGBoostClassifier(paramMap)
.setNumRound(round)
.setNumWorkers(numWorkers)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
weightCol.foreach(weight => classifier.setWeightCol(weight))
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction1.indices) {
assert(prediction1(i).length === prediction2(i).values.length - 1)
for (j <- prediction1(i).indices) {
assert(prediction1(i)(j) === prediction2(i)(j + 1))
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
assert(left.size === right.size)
left.zipWithIndex.foreach { case (leftValue, index) =>
assert(leftValue.sameElements(right(index)))
}
}
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
val xgbSparkModel = classifier.fit(trainingDF)
val rows = xgbSparkModel.transform(testDF).collect()
assert(testDF.count() === prediction4.size)
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction3.indices) {
assert(prediction3(i).length === prediction4(i).values.length - 1)
for (j <- prediction3(i).indices) {
assert(prediction3(i)(j) === prediction4(i)(j + 1))
// Check Leaf
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
val xgbSparkLeaf = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
checkEqual(xgb4jLeaf, xgbSparkLeaf)
// Check contrib
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
val xgbSparkContrib = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
checkEqual(xgb4jContrib, xgbSparkContrib)
def checkEqualForBinary(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
assert(left.size === right.size)
left.zipWithIndex.foreach { case (leftValue, index) =>
assert(leftValue.length === 1)
assert(leftValue.length === right(index).length - 1)
assert(leftValue(0) === right(index)(1))
}
}
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
val prediction6 = model2.predict(firstOfDF)
assert(prediction5 === prediction6)
// Check probability
val xgb4jProb = xgb4jModel.predict(testDM)
val xgbSparkProb = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("probability").toArray.map(_.toFloat))).toMap
if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
checkEqualForBinary(xgb4jProb, xgbSparkProb)
} else {
checkEqual(xgb4jProb, xgbSparkProb)
}
test("infrequent features") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
new Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
// Check rawPrediction
val xgb4jRawPred = xgb4jModel.predict(testDM, outPutMargin = true)
val xgbSparkRawPred = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction").toArray.map(_.toFloat))).toMap
if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
checkEqualForBinary(xgb4jRawPred, xgbSparkRawPred)
} else {
checkEqual(xgb4jRawPred, xgbSparkRawPred)
}
).map(_._2).zipWithIndex().map {
case (lp, id) =>
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
test("infrequent features (use_external_memory)") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
new Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
}
).map(_._2).zipWithIndex().map {
case (lp, id) =>
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
test("featuresCols with features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "features", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "features")
val xgbClassifier = new XGBoostClassifier(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features_" + model.uid))
df.show()
val newFeatureName = "features_new"
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol(newFeatureName)
.transform(xgbInput)
.select(newFeatureName, "label")
val df1 = model
.setFeaturesCol(newFeatureName)
.transform(vectorizedInput)
assert(df1.schema.fieldNames.contains(newFeatureName))
df1.show()
}
test("featuresCols without features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "f4", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "f4")
val xgbClassifier = new XGBoostClassifier(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
.setEvalSets(Map("eval" -> xgbInput))
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
// transform should work for the dataset which includes the feature column names.
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features"))
df.show()
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol("features")
.transform(xgbInput)
.select("features", "label")
val df1 = model.transform(vectorizedInput)
df1.show()
}
test("XGBoostClassificationModel should be compatible") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
// test json
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeJsonModelPath))
// test ubj
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath1)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath1))
}
test("native json model file should store feature_name and feature_type") {
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
val featureTypes = (1 to 33).map(idx => "q").toArray
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod
)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
.setFeatureNames(featureNames)
.setFeatureTypes(featureTypes)
val model = xgb.fit(trainingDF)
val modelStr = new String(model._booster.toByteArray("json"))
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
assert(featureNamesInModel.length == 33)
assert(featureTypesInModel.length == 33)
}
}

View File

@ -1,75 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.Booster
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
.master(s"local[${numWorkers},${maxFailure}]")
test("test classification prediction parity w/o ring reduce") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model1 = new XGBoostClassifier(xgbSettings).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 == p2)
}
}
test("test regression prediction parity w/o ring reduce") {
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
val model1 = new XGBoostRegressor(xgbSettings).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) < predictionErrorMin)
}
}
}

View File

@ -1,81 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.sql._
import org.scalatest.funsuite.AnyFunSuite
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
test("nthread configuration must be no larger than spark.task.cpus") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
intercept[IllegalArgumentException] {
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
}
}
test("kryoSerializer test") {
// TODO write an isolated test for Booster.
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator, null)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val eval = new EvalError()
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("Check for Spark encryption over-the-wire") {
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
ss.conf.set("spark.ssl.enabled", true)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
"xgboost.spark.ignoreSsl != true") {
val thrown = intercept[Exception] {
new XGBoostClassifier(paramMap).fit(training)
}
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
thrown.getMessage.contains("spark.ssl.enabled"))
}
// Confirm that this check can be overridden.
ss.conf.set("xgboost.spark.ignoreSsl", true)
new XGBoostClassifier(paramMap).fit(training)
originalSslConfOpt match {
case None =>
ss.conf.unset("spark.ssl.enabled")
case Some(originalSslConf) =>
ss.conf.set("spark.ssl.enabled", originalSslConf)
}
ss.conf.unset("xgboost.spark.ignoreSsl")
}
}

View File

@ -0,0 +1,512 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.SparkException
import org.json4s.{DefaultFormats, Formats}
import org.json4s.jackson.parseJson
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.Utils.TRAIN_NAME
class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
test("params") {
val df = smallBinaryClassificationVector
val xgbParams: Map[String, Any] = Map(
"max_depth" -> 5,
"eta" -> 0.2,
"objective" -> "binary:logistic"
)
val estimator = new XGBoostClassifier(xgbParams)
.setFeaturesCol("features")
.setMissing(0.2f)
.setAlpha(0.97)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
.setNumRound(1)
assert(estimator.getMaxDepth === 5)
assert(estimator.getEta === 0.2)
assert(estimator.getObjective === "binary:logistic")
assert(estimator.getFeaturesCol === "features")
assert(estimator.getMissing === 0.2f)
assert(estimator.getAlpha === 0.97)
estimator.setEta(0.66).setMaxDepth(7)
assert(estimator.getMaxDepth === 7)
assert(estimator.getEta === 0.66)
val model = estimator.fit(df)
assert(model.getMaxDepth === 7)
assert(model.getEta === 0.66)
assert(model.getObjective === "binary:logistic")
assert(model.getFeaturesCol === "features")
assert(model.getMissing === 0.2f)
assert(model.getAlpha === 0.97)
assert(model.getLeafPredictionCol === "leaf")
assert(model.getContribPredictionCol === "contrib")
}
test("nthread") {
val classifier = new XGBoostClassifier().setNthread(100)
intercept[IllegalArgumentException](
classifier.validate(smallBinaryClassificationVector)
)
}
test("RuntimeParameter") {
var runtimeParams = new XGBoostClassifier(
Map("device" -> "cpu"))
.getRuntimeParameters(true)
assert(!runtimeParams.runOnGpu)
runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)
runtimeParams = new XGBoostClassifier(
Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)
runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)
}
test("missing value exception for sparse vector") {
val sparse1 = Vectors.dense(0.0, 0.0, 0.0).toSparse
assert(sparse1.isInstanceOf[SparseVector])
val sparse2 = Vectors.dense(0.5, 2.2, 1.7).toSparse
assert(sparse2.isInstanceOf[SparseVector])
val sparseInput = ss.createDataFrame(sc.parallelize(Seq(
(1.0, sparse1),
(2.0, sparse2)
))).toDF("label", "features")
val classifier = new XGBoostClassifier()
val (input, columnIndexes) = classifier.preprocess(sparseInput)
val rdd = classifier.toXGBLabeledPoint(input, columnIndexes)
val exception = intercept[SparkException] {
rdd.collect()
}
assert(exception.getMessage.contains("We've detected sparse vectors in the dataset " +
"that need conversion to dense format"))
// explicitly set missing value, no exception
classifier.setMissing(Float.NaN)
val rdd1 = classifier.toXGBLabeledPoint(input, columnIndexes)
rdd1.collect()
}
test("missing value for dense vector no need to set missing explicitly") {
val dense1 = Vectors.dense(0.0, 0.0, 0.0)
assert(dense1.isInstanceOf[DenseVector])
val dense2 = Vectors.dense(0.5, 2.2, 1.7)
assert(dense2.isInstanceOf[DenseVector])
val sparseInput = ss.createDataFrame(sc.parallelize(Seq(
(1.0, dense1),
(2.0, dense2)
))).toDF("label", "features")
val classifier = new XGBoostClassifier()
val (input, columnIndexes) = classifier.preprocess(sparseInput)
val rdd = classifier.toXGBLabeledPoint(input, columnIndexes)
rdd.collect()
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
"using custom Eval and Obj") {
val trainingDF = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6",
"verbosity" -> "1", "objective" -> "binary:logistic")
val xgbc = new XGBoostClassifier(paramMap)
.setCustomObj(new CustomObj(1))
.setCustomEval(new EvalError)
.setNumRound(10)
.setNumWorkers(numWorkers)
val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath)
assert(xgbc.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
assert(xgbc2.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
val eval = new EvalError()
val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model.nativeBooster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model.nativeBooster.toByteArray, model2.nativeBooster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
val evalResults2 = eval.eval(model2.nativeBooster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("Check for Spark encryption over-the-wire") {
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
ss.conf.set("spark.ssl.enabled", true)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic")
val training = smallBinaryClassificationVector
withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
"xgboost.spark.ignoreSsl != true") {
val thrown = intercept[Exception] {
new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
}
assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
thrown.getMessage.contains("spark.ssl.enabled"))
}
// Confirm that this check can be overridden.
ss.conf.set("xgboost.spark.ignoreSsl", true)
new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
originalSslConfOpt match {
case None =>
ss.conf.unset("spark.ssl.enabled")
case Some(originalSslConf) =>
ss.conf.set("spark.ssl.enabled", originalSslConf)
}
ss.conf.unset("xgboost.spark.ignoreSsl")
}
test("nthread configuration must be no larger than spark.task.cpus") {
val training = smallBinaryClassificationVector
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic")
intercept[IllegalArgumentException] {
new XGBoostClassifier(paramMap)
.setNumWorkers(numWorkers)
.setNumRound(2)
.setNthread(sc.getConf.getInt("spark.task.cpus", 1) + 1)
.fit(training)
}
}
test("preprocess dataset") {
val dataset = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0), "a"),
(0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0), "b"),
(2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7), "c")
))).toDF("label", "group", "margin", "weight", "features", "other")
val classifier = new XGBoostClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setBaseMarginCol("margin")
.setWeightCol("weight")
val (df, indices) = classifier.preprocess(dataset)
var schema = df.schema
assert(!schema.names.contains("group") && !schema.names.contains("other"))
assert(indices.labelId == schema.fieldIndex("label") &&
indices.groupId.isEmpty &&
indices.marginId.get == schema.fieldIndex("margin") &&
indices.weightId.get == schema.fieldIndex("weight") &&
indices.featureId.get == schema.fieldIndex("features") &&
indices.featureIds.isEmpty)
classifier.setWeightCol("")
val (df1, indices1) = classifier.preprocess(dataset)
schema = df1.schema
Seq("weight", "group", "other").foreach(v => assert(!schema.names.contains(v)))
assert(indices1.labelId == schema.fieldIndex("label") &&
indices1.groupId.isEmpty &&
indices1.marginId.get == schema.fieldIndex("margin") &&
indices1.weightId.isEmpty &&
indices1.featureId.get == schema.fieldIndex("features") &&
indices1.featureIds.isEmpty)
}
test("to XGBoostLabeledPoint RDD") {
val data = Array(
Array(1.0, 2.0, 3.0, 4.0, 5.0),
Array(0.0, 0.0, 0.0, 0.0, 2.0),
Array(12.0, 13.0, 14.0, 14.0, 15.0),
Array(20.5, 21.2, 0.0, 0.0, 2.0)
)
val dataset = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
(2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
(3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"),
(4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c")
))).toDF("label", "group", "margin", "weight", "features", "other")
val classifier = new XGBoostClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setWeightCol("weight")
.setNumWorkers(2)
.setMissing(Float.NaN)
val (df, indices) = classifier.preprocess(dataset)
val rdd = classifier.toXGBLabeledPoint(df, indices)
val result = rdd.collect().sortBy(x => x.label)
assert(result.length == data.length)
def toArray(index: Int): Array[Float] = {
val labelPoint = result(index)
if (labelPoint.indices != null) {
Vectors.sparse(labelPoint.size,
labelPoint.indices,
labelPoint.values.map(_.toDouble)).toArray.map(_.toFloat)
} else {
labelPoint.values
}
}
assert(result(0).label === 1.0f && result(0).baseMargin.isNaN &&
result(0).weight === 1.0f && toArray(0) === data(0).map(_.toFloat))
assert(result(1).label == 2.0f && result(1).baseMargin.isNaN &&
result(1).weight === 0.0f && toArray(1) === data(1).map(_.toFloat))
assert(result(2).label === 3.0f && result(2).baseMargin.isNaN &&
result(2).weight == 0.0f && toArray(2) === data(2).map(_.toFloat))
assert(result(3).label === 4.0f && result(3).baseMargin.isNaN &&
result(3).weight === -2.1f && toArray(3) === data(3).map(_.toFloat))
}
Seq((Float.NaN, 2), (0.0f, 7 + 2), (15.0f, 1 + 2), (10101011.0f, 0 + 2)).foreach {
case (missing, expectedMissingValue) =>
test(s"to RDD watches with missing $missing") {
val data = Array(
Array(1.0, 2.0, 3.0, 4.0, 5.0),
Array(1.0, Float.NaN, 0.0, 0.0, 2.0),
Array(12.0, 13.0, Float.NaN, 14.0, 15.0),
Array(0.0, 0.0, 0.0, 0.0, 0.0)
)
val dataset = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
(2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
(3.0, 3, -0.5, 0.0, Vectors.dense(data(2)), "b"),
(4.0, 4, -0.4, -2.1, Vectors.dense(data(3)), "c")
))).toDF("label", "group", "margin", "weight", "features", "other")
val classifier = new XGBoostClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setMissing(missing)
.setNumWorkers(2)
val (df, indices) = classifier.preprocess(dataset)
val rdd = classifier.toRdd(df, indices)
val result = rdd.mapPartitions { iter =>
if (iter.hasNext) {
val watches = iter.next()
val size = watches.size
val trainDM = watches.toMap(TRAIN_NAME)
val rowNum = trainDM.rowNum
val labels = trainDM.getLabel
val weight = trainDM.getWeight
val margins = trainDM.getBaseMargin
val nonMissing = trainDM.nonMissingNum
watches.delete()
Iterator.single((size, rowNum, labels, weight, margins, nonMissing))
} else {
Iterator.empty
}
}.collect()
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
var nonMissingValues = 0L
var totalRows = 0L
for (row <- result) {
assert(row._1 === 1)
totalRows = totalRows + row._2
labels.append(row._3: _*)
weight.append(row._4: _*)
margins.append(row._5: _*)
nonMissingValues = nonMissingValues + row._6
}
assert(totalRows === 4)
assert(nonMissingValues === data.size * data(0).length - expectedMissingValue)
assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
}
}
test("to RDD watches with eval") {
val trainData = Array(
Array(-1.0, -2.0, -3.0, -4.0, -5.0),
Array(2.0, 2.0, 2.0, 3.0, -2.0),
Array(-12.0, -13.0, -14.0, -14.0, -15.0),
Array(-20.5, -21.2, 0.0, 0.0, 2.0)
)
val trainDataset = ss.createDataFrame(sc.parallelize(Seq(
(11.0, 0, 0.15, 11.0, Vectors.dense(trainData(0)), "a"),
(12.0, 12, -0.15, 10.0, Vectors.dense(trainData(1)).toSparse, "b"),
(13.0, 12, -0.15, 10.0, Vectors.dense(trainData(2)), "b"),
(14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c")
))).toDF("label", "group", "margin", "weight", "features", "other")
val evalData = Array(
Array(1.0, 2.0, 3.0, 4.0, 5.0),
Array(0.0, 0.0, 0.0, 0.0, 2.0),
Array(12.0, 13.0, 14.0, 14.0, 15.0),
Array(20.5, 21.2, 0.0, 0.0, 2.0)
)
val evalDataset = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 1.0, Vectors.dense(evalData(0)), "a"),
(2.0, 2, -0.5, 0.0, Vectors.dense(evalData(1)).toSparse, "b"),
(3.0, 2, -0.5, 0.0, Vectors.dense(evalData(2)), "b"),
(4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c")
))).toDF("label", "group", "margin", "weight", "features", "other")
val classifier = new XGBoostClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setEvalDataset(evalDataset)
.setNumWorkers(2)
.setMissing(Float.NaN)
val (df, indices) = classifier.preprocess(trainDataset)
val rdd = classifier.toRdd(df, indices)
val result = rdd.mapPartitions { iter =>
if (iter.hasNext) {
val watches = iter.next()
val size = watches.size
val evalDM = watches.toMap(Utils.VALIDATION_NAME)
val rowNum = evalDM.rowNum
val labels = evalDM.getLabel
val weight = evalDM.getWeight
val margins = evalDM.getBaseMargin
watches.delete()
Iterator.single((size, rowNum, labels, weight, margins))
} else {
Iterator.empty
}
}.collect()
val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
var totalRows = 0L
for (row <- result) {
assert(row._1 === 2)
totalRows = totalRows + row._2
labels.append(row._3: _*)
weight.append(row._4: _*)
margins.append(row._5: _*)
}
assert(totalRows === 4)
assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
}
test("XGBoost-Spark model format should match xgboost4j") {
val trainingDF = buildDataFrame(MultiClassification.train)
Seq(new XGBoostClassifier()).foreach { est =>
est.setNumRound(5)
val model = est.fit(trainingDF)
// test json
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.overwrite().option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/model").getPath,
nativeJsonModelPath))
// test ubj
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.overwrite().save(modelUbjPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/model").getPath,
nativeUbjModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.overwrite().option("format", "json").save(modelJsonPath)
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath1)
assert(!compareTwoFiles(new File(modelJsonPath, "data/model").getPath,
nativeUbjModelPath1))
}
}
test("native json model file should store feature_name and feature_type") {
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
val featureTypes = (1 to 33).map(idx => "q").toArray
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier()
.setNumWorkers(numWorkers)
.setFeatureNames(featureNames)
.setFeatureTypes(featureTypes)
.setNumRound(2)
val model = xgb.fit(trainingDF)
val modelStr = new String(model.nativeBooster.toByteArray("json"))
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
assert(featureNamesInModel.length == 33)
assert(featureTypesInModel.length == 33)
assert(featureNames sameElements featureNamesInModel)
assert(featureTypes sameElements featureTypesInModel)
}
test("Exception with clear message") {
val df = smallMultiClassificationVector
val classifier = new XGBoostClassifier()
.setNumRound(2)
.setObjective("multi:softprob")
.setNumClass(2)
val exception = intercept[SparkException] {
classifier.fit(df)
}
exception.getMessage.contains("SoftmaxMultiClassObj: label must be in [0, num_class).")
}
}

View File

@ -1,376 +0,0 @@
/*
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkException, TaskContext}
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
val (booster, metrics) = XGBoost.trainDistributed(
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap)
assert(booster != null)
}
test("training with external memory cache") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"use_external_memory" -> true)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist with monotone_constraints (lossguide)") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
"num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist with interaction_constraints (lossguide)") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
"num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist with monotone_constraints (depthwise)") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
"num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist with interaction_constraints (depthwise)") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
"num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist depthwise") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1",
"max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
"num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test with quantile hist lossguide") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
"max_leaves" -> "8", "num_round" -> 5,
"num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("test with quantile hist lossguide with max bin") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("test with quantile hist depthwidth with max depth") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2",
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("test with quantile hist depthwidth with max depth and max bin") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {
val trainingRDD = sc.parallelize(Ranking.train, split)
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
// check the the order of the groups with group id.
// Ranking.train has 20 groups
assert(trainingGroups.length == 20)
// compare all points
val allPoints = trainingGroups.sortBy(_(0).group).flatten
assert(allPoints.length == Ranking.train.size)
for (i <- 0 to Ranking.train.size - 1) {
assert(allPoints(i).group == Ranking.train(i).group)
assert(allPoints(i).label == Ranking.train(i).label)
assert(allPoints(i).values.sameElements(Ranking.train(i).values))
}
}
}
test("repartitionForTrainingGroup with group data which has empty partition") {
val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
// make one partition empty for testing
it.filter(_ => TaskContext.getPartitionId() != 3)
})
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
}
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 5)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
val (booster, _) = XGBoost.trainDistributed(
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap)
assert(booster != null)
}
test("training summary") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.validationObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
assert(model.summary.validationObjectiveHistory.length === 1)
assert(model.summary.validationObjectiveHistory(0)._1 === "test")
assert(model.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model.summary.trainObjectiveHistory !== model.summary.validationObjectiveHistory(0))
}
test("train with multiple validation datasets (non-ranking)") {
val training = buildDataFrame(Classification.train)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic",
"num_round" -> 5, "num_workers" -> numWorkers,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostClassifier(paramMap2)
val model2 = xgb2.fit(train)
assert(model2.summary.validationObjectiveHistory.length === 2)
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
}
test("train with multiple validation datasets (ranking)") {
val training = buildDataFrameWithGroup(Ranking.train, 5)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0)
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
assert(model1 != null)
assert(model1.summary.validationObjectiveHistory.length === 2)
assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostRegressor(paramMap2)
val model2 = xgb2.fit(train)
assert(model2 != null)
assert(model2.summary.validationObjectiveHistory.length === 2)
assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
}
test("infer with different batch sizes") {
val regModel = new XGBoostRegressor(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Regression.train))
val regDF = buildDataFrame(Regression.test)
val regRet1 = regModel.transform(regDF).collect()
val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
assert(regRet1 sameElements regRet2)
assert(regRet1 sameElements regRet3)
assert(regRet1 sameElements regRet4)
val clsModel = new XGBoostClassifier(Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers))
.fit(buildDataFrame(Classification.train))
val clsDF = buildDataFrame(Classification.test)
val clsRet1 = clsModel.transform(clsDF).collect()
val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
assert(clsRet1 sameElements clsRet2)
assert(clsRet1 sameElements clsRet3)
assert(clsRet1 sameElements clsRet4)
}
test("chaining the prediction") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
toDF("feature", "label").repartition(5)
// 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251
for (x <- 1 to 250) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
df = assembler.transform(df)
for (x <- 1 to 250) {
df = df.drop(s"feature_${x}")
}
val df1 = model.transform(df).withColumnRenamed(
"prediction", "prediction1").withColumnRenamed(
"rawPrediction", "rawPrediction1").withColumnRenamed(
"probability", "probability1")
val df2 = model.transform(df1)
df1.collect()
df2.collect()
}
test("throw exception for empty partition in trainingset") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_class" -> "2", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> "auto", "allow_non_zero_for_missing" -> true)
// The Dmatrix will be empty
val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 4,
Array(0, 1, 2, 3), Array(0, 1, 2, 3))))
val xgb = new XGBoostClassifier(paramMap)
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
}

View File

@ -18,32 +18,116 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoostRegressor copy") {
val regressor = new XGBoostRegressor().setNthread(2).setNumWorkers(10)
val regressortCopied = regressor.copy(ParamMap.empty)
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
assert(regressor.uid === regressortCopied.uid)
assert(regressor.getNthread === regressortCopied.getNthread)
assert(regressor.getNumWorkers === regressor.getNumWorkers)
}
test("XGBoostRegressionModel copy") {
val model = new XGBoostRegressionModel("hello").setNthread(2).setNumWorkers(10)
val modelCopied = model.copy(ParamMap.empty)
assert(model.uid === modelCopied.uid)
assert(model.getNthread === modelCopied.getNthread)
assert(model.getNumWorkers === modelCopied.getNumWorkers)
}
test("read/write") {
val trainDf = smallBinaryClassificationVector
val xgbParams: Map[String, Any] = Map(
"max_depth" -> 5,
"eta" -> 0.2
)
def check(xgboostParams: XGBoostParams[_]): Unit = {
assert(xgboostParams.getMaxDepth === 5)
assert(xgboostParams.getEta === 0.2)
assert(xgboostParams.getObjective === "reg:squarederror")
}
val regressorPath = new File(tempDir.toFile, "regressor").getPath
val regressor = new XGBoostRegressor(xgbParams).setNumRound(1)
check(regressor)
regressor.write.overwrite().save(regressorPath)
val loadedRegressor = XGBoostRegressor.load(regressorPath)
check(loadedRegressor)
val model = loadedRegressor.fit(trainDf)
check(model)
val modelPath = new File(tempDir.toFile, "model").getPath
model.write.overwrite().save(modelPath)
val modelLoaded = XGBoostRegressionModel.load(modelPath)
check(modelLoaded)
}
test("XGBoostRegressionModel transformed schema") {
val trainDf = smallBinaryClassificationVector
val regressor = new XGBoostRegressor().setNumRound(1)
val model = regressor.fit(trainDf)
var out = model.transform(trainDf)
// Transform should not discard the other columns of the transforming dataframe
Seq("label", "margin", "weight", "features").foreach { v =>
assert(out.schema.names.contains(v))
}
// Regressor does not have extra columns
Seq("rawPrediction", "probability").foreach { v =>
assert(!out.schema.names.contains(v))
}
assert(out.schema.names.contains("prediction"))
assert(out.schema.names.length === 5)
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
out = model.transform(trainDf)
assert(out.schema.names.contains("leaf"))
assert(out.schema.names.contains("contrib"))
}
test("Supported objectives") {
val regressor = new XGBoostRegressor()
val df = smallMultiClassificationVector
REGRESSION_OBJS.foreach { obj =>
regressor.setObjective(obj)
regressor.validate(df)
}
regressor.setObjective("binary:logistic")
intercept[IllegalArgumentException](
regressor.validate(df)
)
}
test("XGBoost-Spark output should match XGBoost4j") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
val paramMap = Map("objective" -> "reg:squarederror")
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
}
test("XGBoostRegressor should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Regression.train.iterator)
test("XGBoost-Spark output with weight should match XGBoost4j") {
val trainingDM = new DMatrix(Regression.trainWithWeight.iterator)
trainingDM.setWeight(Regression.randomWeights)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Regression.train)
val testDF = buildDataFrameWithRandSort(Regression.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
val trainingDF = buildDataFrame(Regression.trainWithWeight)
val testDF = buildDataFrame(Regression.test)
val paramMap = Map("objective" -> "reg:squarederror")
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
5, paramMap, Some("weight"))
}
private def checkResultsWithXGBoost4j(
@ -51,306 +135,51 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
round: Int = 5,
xgbParams: Map[String, Any] = Map.empty,
weightCol: Option[String] = None): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"max_bin" -> 64,
"tree_method" -> treeMethod)
"base_score" -> 0.5,
"max_bin" -> 16) ++ xgbParams
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
assert(prediction1.indices.count { i =>
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
} < prediction1.length * 0.1)
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction3 = model1.predict(firstOfDM)(0)(0)
val prediction4 = model2.predict(firstOfDF)
assert(math.abs(prediction3 - prediction4) <= 0.01f)
}
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> round,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostRegressor()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("reg:squarederror")
val regressor = new XGBoostRegressor(paramMap)
.setNumRound(round)
.setTreeMethod(treeMethod)
.setNumWorkers(numWorkers)
.fit(trainingDF)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
weightCol.foreach(weight => regressor.setWeightCol(weight))
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) <= 0.01f)
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
assert(left.size === right.size)
left.zipWithIndex.foreach { case (leftValue, index) =>
assert(leftValue.sameElements(right(index)))
}
}
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group", "tree_method" -> treeMethod)
val xgbSparkModel = regressor.fit(trainingDF)
val rows = xgbSparkModel.transform(testDF).collect()
val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
// Check Leaf
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
val xgbSparkLeaf = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
checkEqual(xgb4jLeaf, xgbSparkLeaf)
val prediction = model.transform(testDF).collect()
assert(testDF.count() === prediction.length)
// Check contrib
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
val xgbSparkContrib = rows.map(row =>
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
checkEqual(xgb4jContrib, xgbSparkContrib)
// Check prediction
val xgb4jPred = xgb4jModel.predict(testDM)
val xgbSparkPred = rows.map(row => {
val pred = row.getAs[Double]("prediction").toFloat
(row.getAs[Int]("id"), Array(pred))}).toMap
checkEqual(xgb4jPred, xgbSparkPred)
}
test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val testDF = buildDataFrame(Regression.test)
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
val prediction = model.transform(testDF).collect()
val first = prediction.head.getAs[Double]("prediction")
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
}
test("objective will be set if not specifying it") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val xgb = new XGBoostRegressor(paramMap)
assert(!xgb.isDefined(xgb.objective))
xgb.fit(training)
assert(xgb.getObjective == "reg:squarederror")
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod,
"objective" -> "reg:squaredlogerror")
val xgb1 = new XGBoostRegressor(paramMap1)
assert(xgb1.getObjective == "reg:squaredlogerror")
xgb1.fit(training)
assert(xgb1.getObjective == "reg:squaredlogerror")
}
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
}
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("")
val resultDF = model.transform(testDF)
assert(!resultDF.columns.contains("predictLeaf"))
}
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictContrib"))
}
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("")
val resultDF = model.transform(testDF)
assert(!resultDF.columns.contains("predictContrib"))
}
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
assert(resultDF.columns.contains("predictContrib"))
}
test("featuresCols with features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "features", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "features")
val xgbClassifier = new XGBoostRegressor(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features_" + model.uid))
df.show()
val newFeatureName = "features_new"
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol(newFeatureName)
.transform(xgbInput)
.select(newFeatureName, "label")
val df1 = model
.setFeaturesCol(newFeatureName)
.transform(vectorizedInput)
assert(df1.schema.fieldNames.contains(newFeatureName))
df1.show()
}
test("featuresCols without features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "f4", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "f4")
val xgbClassifier = new XGBoostRegressor(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
.setEvalSets(Map("eval" -> xgbInput))
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
// transform should work for the dataset which includes the feature column names.
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features"))
df.show()
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol("features")
.transform(xgbInput)
.select("features", "label")
val df1 = model.transform(vectorizedInput)
df1.show()
}
test("XGBoostRegressionModel should be compatible") {
val trainingDF = buildDataFrame(Regression.train)
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))
// test default "ubj"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
// test the deprecated format
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
model.write.option("format", "deprecated").save(modelDeprecatedPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2023 by Contributors
Copyright (c) 2023-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,40 +16,18 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.Booster
class XGBoostSuite extends AnyFunSuite with PerTest {
// Do not create spark context
override def beforeEach(): Unit = {}
test("XGBoost execution parameters") {
var xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(!xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "tree_method" -> "gpu_hist",
"num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
}
test("skip stage-level scheduling") {
val conf = new SparkConf()
.setMaster("spark://foo")
@ -101,7 +79,7 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
}
object FakedXGBoost extends XGBoostStageLevel {
object FakedXGBoost extends StageLevelScheduling {
// Do not skip stage-level scheduling for testing purposes.
override private[spark] def skipStageLevelScheduling(
@ -129,12 +107,12 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
val df = ss.range(1, 10)
val rdd = df.rdd
val xgbExecutionParams = new XGBoostExecutionParamsFactory(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
.buildXGBRuntimeParams
assert(xgbExecutionParams.runOnGpu)
val runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, xgbExecutionParams,
val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, runtimeParams,
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
val taskResources = finalRDD.getResourceProfile().taskResources

View File

@ -519,4 +519,39 @@ public class DMatrix {
CSR,
CSC
}
/**
* A class to hold the quantile information
*/
public class QuantileCut {
// cut ptr
long[] indptr;
// cut values
float[] values;
QuantileCut(long[] indptr, float[] values) {
this.indptr = indptr;
this.values = values;
}
public long[] getIndptr() {
return indptr;
}
public float[] getValues() {
return values;
}
}
/**
* Get the Quantile Cut.
* @return QuantileCut
* @throws XGBoostError
*/
public QuantileCut getQuantileCut() throws XGBoostError {
long[][] indptr = new long[1][];
float[][] values = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetQuantileCut(this.handle, indptr, values));
return new QuantileCut(indptr[0], values[0]);
}
}

View File

@ -1,75 +0,0 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}
private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}

View File

@ -172,7 +172,7 @@ class XGBoostJNI {
long handle, String field, String json);
public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
@ -180,4 +180,7 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
public final static native int XGDMatrixGetQuantileCut(long handle, long[][] outIndptr, float[][] outValues);
}

View File

@ -365,4 +365,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
override def read(kryo: Kryo, input: Input): Unit = {
booster = kryo.readObject(input, classOf[JBooster])
}
// a flag to indicate if the device is set for the GPU transform
var deviceIsSet = false
}

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}

View File

@ -1,7 +1,6 @@
//
// Created by bobwang on 2021/9/8.
//
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_USE_CUDA
#include <jni.h>
@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();

View File

@ -1,10 +1,13 @@
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#include <jni.h>
#include <xgboost/c_api.h>
#include "../../../../src/common/device_helpers.cuh"
#include "../../../../src/common/cuda_pinned_allocator.h"
#include "../../../../src/common/device_vector.cuh" // for device_vector
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h"
#include <xgboost/c_api.h>
namespace xgboost {
namespace jni {
@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
template <typename T>
using Deleter = std::function<void(T *)>;
} // anonymous namespace
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
DMatrixHandle result;
DMatrixHandle ref{nullptr};
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
if (jref_iter) {
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
if (jref != NULL) {
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
[&](jlong *ptr) {
jenv->ReleaseLongArrayElements(jref, ptr, 0);
jenv->DeleteLocalRef(jref);
}};
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
}
auto ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
setHandle(jenv, jout, result);
return ret;
}

View File

@ -20,6 +20,7 @@
#include <xgboost/c_api.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/string_view.h> // for StringView
#include <algorithm> // for copy_n
#include <cstddef>
@ -30,8 +31,9 @@
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h"
#include "../../../../src/c_api/c_api_error.h"
#include "../../../../src/c_api/c_api_utils.h"
#include "../../../../src/data/array_interface.h" // for ArrayInterface
#define JVM_CHECK_CALL(__expr) \
{ \
@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
jlongArray jout) {
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
[&](char const *ptr) {
jenv->ReleaseStringUTFChars(jconf, ptr);
}};
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
conf.get(), jout);
}
@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
char const *str_indptr;
char const *str_data;
Json config{Object{}};
auto str_config = Json::Dump(config);
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
ArrayInterface<1> indptr{StringView{str_indptr}};
ArrayInterface<1> data{StringView{str_data}};
CHECK_GE(indptr.Shape(0), 2);
// Cut ptr
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
CHECK_LT(indptr(indptr.Shape(0) - 1),
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
static_cast<jlong const *>(indptr.data));
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
// Cut values
auto n_cuts = indptr(indptr.Shape(0) - 1);
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
return ret;
}

View File

@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
#ifdef __cplusplus
}
#endif

View File

@ -258,8 +258,7 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
@Test
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
private DMatrix createFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
@ -280,12 +279,17 @@ public class DMatrixTest {
label0[i] = random.nextFloat();
}
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
DMatrix dm = new DMatrix(data0, nrow, ncol, -0.1f);
dm.setLabel(label0);
return dm;
}
@Test
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
DMatrix dm = createFromDenseMatrix();
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
TestCase.assertTrue(dm.rowNum() == 10);
TestCase.assertTrue(dm.getLabel().length == 10);
}
@Test
@ -493,4 +497,28 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
}
@Test
public void getGetQuantileCut() throws XGBoostError {
DMatrix Xy = createFromDenseMatrix();
Map<String, Object> params = new HashMap<String, Object>();
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", Xy);
XGBoost.train(Xy, params, 1, watches, null, null); // Create the cuts
DMatrix.QuantileCut cuts = Xy.getQuantileCut();
TestCase.assertEquals(cuts.indptr.length, 6);
for (int i = 1; i < cuts.indptr.length; ++i) {
// Number of bins for each feature + min value.
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] >= 5);
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] <= Xy.rowNum() + 1);
}
TestCase.assertEquals(cuts.values.length, cuts.indptr[cuts.indptr.length - 1]);
for (int i = 1; i < cuts.indptr.length; ++i) {
long begin = cuts.indptr[i - 1];
long end = cuts.indptr[i];
for (long j = begin + 1; j < end; ++j) {
TestCase.assertTrue(cuts.values[(int) j] > cuts.values[(int) j - 1]);
}
}
}
}