[jvm-packages] xgboost4j-spark should work when featuresCols is specified (#7789)

This commit is contained in:
Bobby Wang 2022-04-08 13:21:04 +08:00 committed by GitHub
parent 729d227b89
commit 118192f116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 377 additions and 109 deletions

View File

@ -126,7 +126,7 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setInputCols(featureNames)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
@ -149,11 +149,10 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
// manually
val thrown = intercept[IllegalArgumentException](cpuModel
val thrown = intercept[NoSuchElementException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
"please refer to `setFeaturesCol(value: Array[String])`"))
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val left = cpuModel
.setFeaturesCol(featureNames)
@ -196,17 +195,16 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setInputCols(featureNames)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
val thrown = intercept[IllegalArgumentException](
intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("features does not exist"))
val left = gpuModel
.setFeaturesCol(featureColName)

View File

@ -108,12 +108,15 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
val thrown = intercept[IllegalArgumentException] {
// GPU train requires featuresCols. If not specified,
// then NoSuchElementException will be thrown
val thrown = intercept[NoSuchElementException] {
new XGBoostClassifier(xgbParam)
.setLabelCol(labelName)
.fit(trainingDf)
}
assert(thrown.getMessage.contains("Gpu train requires features columns."))
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val thrown1 = intercept[IllegalArgumentException] {
new XGBoostClassifier(xgbParam)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -122,7 +122,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setInputCols(featureNames)
.setOutputCol("features")
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
@ -145,11 +145,10 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
// Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
// manually
val thrown = intercept[IllegalArgumentException](cpuModel
val thrown = intercept[NoSuchElementException](cpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
"please refer to `setFeaturesCol(value: Array[String])`"))
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
val left = cpuModel
.setFeaturesCol(featureNames)
@ -192,17 +191,16 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep")
.setInputCols(featureNames.toArray)
.setInputCols(featureNames)
.setOutputCol(featureColName)
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
// Since GPU model does not know the information about the features col name that CPU
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
val thrown = intercept[IllegalArgumentException](
intercept[IllegalArgumentException](
gpuModel
.transform(testDf)
.collect())
assert(thrown.getMessage.contains("features does not exist"))
val left = gpuModel
.setFeaturesCol(featureColName)

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -35,8 +35,10 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@ -112,7 +114,7 @@ object PreXGBoost extends PreXGBoostProvider {
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
}
val (packedParams, evalSet) = estimator match {
val (packedParams, evalSet, xgbInput) = estimator match {
case est: XGBoostEstimatorCommon =>
// get weight column, if weight is not defined, default to lit(1.0)
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
@ -136,15 +138,18 @@ object PreXGBoost extends PreXGBoostProvider {
}
(PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params))
val (xgbInput, featuresName) = est.vectorize(dataset)
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params),
xgbInput)
case _ => throw new RuntimeException("Unsupporting " + estimator)
}
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
packedParams, dataset.asInstanceOf[DataFrame]).head
packedParams, xgbInput.asInstanceOf[DataFrame]).head
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
val evalRDDMap = evalSet.map {
@ -184,11 +189,11 @@ object PreXGBoost extends PreXGBoostProvider {
}
/** get the necessary parameters */
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
predictFunc, schema) =
val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing,
allowNonZeroForMissing, predictFunc, schema) =
model match {
case m: XGBoostClassificationModel =>
val (xgbInput, featuresName) = m.vectorize(dataset)
// predict and turn to Row
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
@ -199,7 +204,7 @@ object PreXGBoost extends PreXGBoostProvider {
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
var schema = StructType(xgbInput.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
@ -214,11 +219,12 @@ object PreXGBoost extends PreXGBoostProvider {
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
m.getAllowNonZeroForMissingValue, predictFunc, schema)
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
case m: XGBoostRegressionModel =>
// predict and turn to Row
val (xgbInput, featuresName) = m.vectorize(dataset)
val predictFunc =
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
val Array(rawPredictionItr, predLeafItr, predContribItr) =
@ -227,7 +233,7 @@ object PreXGBoost extends PreXGBoostProvider {
}
// prepare the final Schema
var schema = StructType(dataset.schema.fields ++
var schema = StructType(xgbInput.schema.fields ++
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
@ -240,14 +246,14 @@ object PreXGBoost extends PreXGBoostProvider {
ArrayType(FloatType, containsNull = false), nullable = false))
}
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
m.getAllowNonZeroForMissingValue, predictFunc, schema)
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
}
val bBooster = dataset.sparkSession.sparkContext.broadcast(booster)
val appName = dataset.sparkSession.sparkContext.appName
val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster)
val appName = xgbInput.sparkSession.sparkContext.appName
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
new AbstractIterator[Row] {
private var batchCnt = 0
@ -295,7 +301,7 @@ object PreXGBoost extends PreXGBoostProvider {
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(resultRDD, schema)
xgbInput.sparkSession.createDataFrame(resultRDD, schema)
}

View File

@ -144,13 +144,6 @@ class XGBoostClassifier (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type =
set(featuresCols, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
@ -165,7 +158,12 @@ class XGBoostClassifier (
// Callback from PreXGBoost
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(true, schema)
}
}
override def transformSchema(schema: StructType): StructType = {
@ -260,13 +258,6 @@ class XGBoostClassificationModel private[ml](
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type =
set(featuresCols, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
@ -359,7 +350,12 @@ class XGBoostClassificationModel private[ml](
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
}
override def transformSchema(schema: StructType): StructType = {
@ -385,8 +381,6 @@ class XGBoostClassificationModel private[ml](
Vectors.dense(rawPredictions)
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))

View File

@ -146,13 +146,6 @@ class XGBoostRegressor (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type =
set(featuresCols, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
@ -164,7 +157,12 @@ class XGBoostRegressor (
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
}
override def transformSchema(schema: StructType): StructType = {
@ -253,13 +251,6 @@ class XGBoostRegressionModel private[ml] (
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
/**
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
* all feature columns must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type =
set(featuresCols, value)
/**
* Single instance prediction.
* Note: The performance is not ideal, use it carefully!
@ -331,7 +322,12 @@ class XGBoostRegressionModel private[ml] (
}
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
super.transformSchema(schema)
if (isFeaturesColSet(schema)) {
// User has vectorized the features into VectorUDT.
super.transformSchema(schema)
} else {
transformSchemaWithFeaturesCols(false, schema)
}
}
override def transformSchema(schema: StructType): StructType = {

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -247,6 +247,27 @@ trait HasNumClass extends Params {
final def getNumClass: Int = $(numClass)
}
/**
* Trait for shared param featuresCols.
*/
trait HasFeaturesCols extends Params {
/**
* Param for the names of feature columns.
* @group param
*/
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
"an array of feature column names.")
/** @group getParam */
final def getFeaturesCols: Array[String] = $(featuresCols)
/** Check if featuresCols is valid */
def isFeaturesColsValid: Boolean = {
isDefined(featuresCols) && $(featuresCols) != Array.empty
}
}
private[spark] trait ParamMapFuncs extends Params {
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {

View File

@ -1,34 +0,0 @@
/*
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.{Params, StringArrayParam}
trait GpuParams extends Params {
/**
* Param for the names of feature columns for GPU pipeline.
* @group param
*/
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
"an array of feature column names for GPU pipeline.")
setDefault(featuresCols, Array.empty[String])
/** @group getParam */
final def getFeaturesCols: Array[String] = $(featuresCols)
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,16 +16,101 @@
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol with GpuParams {
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
* Default: "error"
* @group param
*/
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(Array("skip", "error", "keep")))
setDefault(handleInvalid, "error")
/**
* Specify an array of feature column names which must be numeric types.
*/
def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value)
/** Set the handleInvalid for VectorAssembler */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/**
* Check if schema has a field named with the value of "featuresCol" param and it's data type
* must be VectorUDT
*/
def isFeaturesColSet(schema: StructType): Boolean = {
schema.fieldNames.contains(getFeaturesCol) &&
XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType)
}
/** check the features columns type */
def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = {
if (isFeaturesColsValid) {
if (fit) {
XGBoostSchemaUtils.checkNumericType(schema, $(labelCol))
}
$(featuresCols).foreach(feature =>
XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType))
schema
} else {
throw new IllegalArgumentException("featuresCol or featuresCols must be specified")
}
}
/**
* Vectorize the features columns if necessary.
*
* @param input the input dataset
* @return (output dataset and the feature column name)
*/
def vectorize(input: Dataset[_]): (Dataset[_], String) = {
val schema = input.schema
if (isFeaturesColSet(schema)) {
// Dataset already has vectorized.
(input, getFeaturesCol)
} else if (isFeaturesColsValid) {
val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) {
getFeaturesCol
} else {
"features_" + uid
}
val vectorAssembler = new VectorAssembler()
.setHandleInvalid($(handleInvalid))
.setInputCols(getFeaturesCols)
.setOutputCol(featuresName)
(vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName)
} else {
// never reach here, since transformSchema will take care of the case
// that featuresCols is invalid
(input, getFeaturesCol)
}
}
}
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass

View File

@ -0,0 +1,51 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.apache.spark.ml.linalg.xgboost
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.util.SchemaUtils
object XGBoostSchemaUtils {
/** check if the dataType is VectorUDT */
def isVectorUDFType(dataType: DataType): Boolean = {
dataType match {
case _: VectorUDT => true
case _ => false
}
}
/** The feature columns will be vectorized by VectorAssembler first, which only
* supports Numeric, Boolean and VectorUDT types */
def checkFeatureColumnType(dataType: DataType): Unit = {
dataType match {
case _: NumericType | BooleanType =>
case _: VectorUDT =>
case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " +
s"boolean and VectorUDT types, found: ${d}")
}
}
def checkNumericType(
schema: StructType,
colName: String,
msg: String = ""): Unit = {
SchemaUtils.checkNumericType(schema, colName, msg)
}
}

View File

@ -23,6 +23,7 @@ import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostClassifierSuite extends FunSuite with PerTest {
@ -316,4 +317,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
xgb.fit(repartitioned)
}
test("featuresCols with features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "features", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "features")
val xgbClassifier = new XGBoostClassifier(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features_" + model.uid))
df.show()
val newFeatureName = "features_new"
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol(newFeatureName)
.transform(xgbInput)
.select(newFeatureName, "label")
val df1 = model
.setFeaturesCol(newFeatureName)
.transform(vectorizedInput)
assert(df1.schema.fieldNames.contains(newFeatureName))
df1.show()
}
test("featuresCols without features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "f4", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "f4")
val xgbClassifier = new XGBoostClassifier(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
// transform should work for the dataset which includes the feature column names.
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features"))
df.show()
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol("features")
.transform(xgbInput)
.select("features", "label")
val df1 = model.transform(vectorizedInput)
df1.show()
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,12 +17,15 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostRegressorSuite extends FunSuite with PerTest {
protected val treeMethod: String = "auto"
@ -216,4 +219,77 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
assert(resultDF.columns.contains("predictLeaf"))
assert(resultDF.columns.contains("predictContrib"))
}
test("featuresCols with features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "features", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "features")
val xgbClassifier = new XGBoostRegressor(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features_" + model.uid))
df.show()
val newFeatureName = "features_new"
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol(newFeatureName)
.transform(xgbInput)
.select(newFeatureName, "label")
val df1 = model
.setFeaturesCol(newFeatureName)
.transform(vectorizedInput)
assert(df1.schema.fieldNames.contains(newFeatureName))
df1.show()
}
test("featuresCols without features column can work") {
val spark = ss
import spark.implicits._
val xgbInput = Seq(
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
.toDF("f1", "f2", "f3", "f4", "label")
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
val featuresName = Array("f1", "f2", "f3", "f4")
val xgbClassifier = new XGBoostRegressor(paramMap)
.setFeaturesCol(featuresName)
.setLabelCol("label")
val model = xgbClassifier.fit(xgbInput)
assert(model.getFeaturesCols.sameElements(featuresName))
// transform should work for the dataset which includes the feature column names.
val df = model.transform(xgbInput)
assert(df.schema.fieldNames.contains("features"))
df.show()
// transform also can work for vectorized dataset
val vectorizedInput = new VectorAssembler()
.setInputCols(featuresName)
.setOutputCol("features")
.transform(xgbInput)
.select("features", "label")
val df1 = model.transform(vectorizedInput)
df1.show()
}
}