From 118192f11678aeaab510eaa725e815e4e706279e Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 8 Apr 2022 13:21:04 +0800 Subject: [PATCH] [jvm-packages] xgboost4j-spark should work when featuresCols is specified (#7789) --- .../spark/GpuXGBoostClassifierSuite.scala | 12 +-- .../rapids/spark/GpuXGBoostGeneralSuite.scala | 7 +- .../spark/GpuXGBoostRegressorSuite.scala | 14 ++- .../xgboost4j/scala/spark/PreXGBoost.scala | 42 +++++---- .../scala/spark/XGBoostClassifier.scala | 30 +++--- .../scala/spark/XGBoostRegressor.scala | 28 +++--- .../scala/spark/params/GeneralParams.scala | 23 ++++- .../scala/spark/params/GpuParams.scala | 34 ------- .../spark/params/XGBoostEstimatorCommon.scala | 91 ++++++++++++++++++- .../linalg/xgboost/XGBoostSchemaUtils.scala | 51 +++++++++++ .../scala/spark/XGBoostClassifierSuite.scala | 74 +++++++++++++++ .../scala/spark/XGBoostRegressorSuite.scala | 80 +++++++++++++++- 12 files changed, 377 insertions(+), 109 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala index 6ff1947b3..fc26b2985 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala @@ -126,7 +126,7 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol("features") val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName) @@ -149,11 +149,10 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { // Since CPU model does not know the information about the features cols that GPU transform // pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model // manually - val thrown = intercept[IllegalArgumentException](cpuModel + val thrown = intercept[NoSuchElementException](cpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("Gpu transform requires features columns. " + - "please refer to `setFeaturesCol(value: Array[String])`")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val left = cpuModel .setFeaturesCol(featureNames) @@ -196,17 +195,16 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { val featureColName = "feature_col" val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol(featureColName) val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName) // Since GPU model does not know the information about the features col name that CPU // transform pipeline requires. End user needs to setFeaturesCol in the model manually - val thrown = intercept[IllegalArgumentException]( + intercept[IllegalArgumentException]( gpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("features does not exist")) val left = gpuModel .setFeaturesCol(featureColName) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala index 53cdcb923..3d643761a 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala @@ -108,12 +108,15 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { val trainingDf = trainingData.toDF(allColumnNames: _*) val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist") - val thrown = intercept[IllegalArgumentException] { + + // GPU train requires featuresCols. If not specified, + // then NoSuchElementException will be thrown + val thrown = intercept[NoSuchElementException] { new XGBoostClassifier(xgbParam) .setLabelCol(labelName) .fit(trainingDf) } - assert(thrown.getMessage.contains("Gpu train requires features columns.")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val thrown1 = intercept[IllegalArgumentException] { new XGBoostClassifier(xgbParam) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala index 2777c2ea6..5342aa563 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -122,7 +122,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol("features") val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName) @@ -145,11 +145,10 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { // Since CPU model does not know the information about the features cols that GPU transform // pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model // manually - val thrown = intercept[IllegalArgumentException](cpuModel + val thrown = intercept[NoSuchElementException](cpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("Gpu transform requires features columns. " + - "please refer to `setFeaturesCol(value: Array[String])`")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val left = cpuModel .setFeaturesCol(featureNames) @@ -192,17 +191,16 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val featureColName = "feature_col" val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol(featureColName) val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName) // Since GPU model does not know the information about the features col name that CPU // transform pipeline requires. End user needs to setFeaturesCol in the model manually - val thrown = intercept[IllegalArgumentException]( + intercept[IllegalArgumentException]( gpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("features does not exist")) val left = gpuModel .setFeaturesCol(featureColName) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 8baaafba7..67deb6979 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,8 +35,10 @@ import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -112,7 +114,7 @@ object PreXGBoost extends PreXGBoostProvider { return optionProvider.get.buildDatasetToRDD(estimator, dataset, params) } - val (packedParams, evalSet) = estimator match { + val (packedParams, evalSet, xgbInput) = estimator match { case est: XGBoostEstimatorCommon => // get weight column, if weight is not defined, default to lit(1.0) val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) { @@ -136,15 +138,18 @@ object PreXGBoost extends PreXGBoostProvider { } - (PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group, - est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params)) + val (xgbInput, featuresName) = est.vectorize(dataset) + + (PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group, + est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params), + xgbInput) case _ => throw new RuntimeException("Unsupporting " + estimator) } // transform the training Dataset[_] to RDD[XGBLabeledPoint] val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs( - packedParams, dataset.asInstanceOf[DataFrame]).head + packedParams, xgbInput.asInstanceOf[DataFrame]).head // transform the eval Dataset[_] to RDD[XGBLabeledPoint] val evalRDDMap = evalSet.map { @@ -184,11 +189,11 @@ object PreXGBoost extends PreXGBoostProvider { } /** get the necessary parameters */ - val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing, - predictFunc, schema) = + val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing, + allowNonZeroForMissing, predictFunc, schema) = model match { case m: XGBoostClassificationModel => - + val (xgbInput, featuresName) = m.vectorize(dataset) // predict and turn to Row val predictFunc = (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { @@ -199,7 +204,7 @@ object PreXGBoost extends PreXGBoostProvider { } // prepare the final Schema - var schema = StructType(dataset.schema.fields ++ + var schema = StructType(xgbInput.schema.fields ++ Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType = ArrayType(FloatType, containsNull = false), nullable = false)) ++ Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType = @@ -214,11 +219,12 @@ object PreXGBoost extends PreXGBoostProvider { ArrayType(FloatType, containsNull = false), nullable = false)) } - (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, - m.getAllowNonZeroForMissingValue, predictFunc, schema) + (m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory, + m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema) case m: XGBoostRegressionModel => // predict and turn to Row + val (xgbInput, featuresName) = m.vectorize(dataset) val predictFunc = (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { val Array(rawPredictionItr, predLeafItr, predContribItr) = @@ -227,7 +233,7 @@ object PreXGBoost extends PreXGBoostProvider { } // prepare the final Schema - var schema = StructType(dataset.schema.fields ++ + var schema = StructType(xgbInput.schema.fields ++ Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType = ArrayType(FloatType, containsNull = false), nullable = false))) @@ -240,14 +246,14 @@ object PreXGBoost extends PreXGBoostProvider { ArrayType(FloatType, containsNull = false), nullable = false)) } - (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, - m.getAllowNonZeroForMissingValue, predictFunc, schema) + (m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory, + m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema) } - val bBooster = dataset.sparkSession.sparkContext.broadcast(booster) - val appName = dataset.sparkSession.sparkContext.appName + val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster) + val appName = xgbInput.sparkSession.sparkContext.appName - val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => + val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => new AbstractIterator[Row] { private var batchCnt = 0 @@ -295,7 +301,7 @@ object PreXGBoost extends PreXGBoostProvider { } bBooster.unpersist(blocking = false) - dataset.sparkSession.createDataFrame(resultRDD, schema) + xgbInput.sparkSession.createDataFrame(resultRDD, schema) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index c8635d93c..3e62e9946 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -144,13 +144,6 @@ class XGBoostClassifier ( def setSinglePrecisionHistogram(value: Boolean): this.type = set(singlePrecisionHistogram, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") @@ -165,7 +158,12 @@ class XGBoostClassifier ( // Callback from PreXGBoost private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(true, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -260,13 +258,6 @@ class XGBoostClassificationModel private[ml]( def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -359,7 +350,12 @@ class XGBoostClassificationModel private[ml]( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -385,8 +381,6 @@ class XGBoostClassificationModel private[ml]( Vectors.dense(rawPredictions) } - - if ($(rawPredictionCol).nonEmpty) { outputData = outputData .withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol))) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 617aedfad..9af52d165 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -146,13 +146,6 @@ class XGBoostRegressor ( def setSinglePrecisionHistogram(value: Boolean): this.type = set(singlePrecisionHistogram, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") @@ -164,7 +157,12 @@ class XGBoostRegressor ( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -253,13 +251,6 @@ class XGBoostRegressionModel private[ml] ( def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -331,7 +322,12 @@ class XGBoostRegressionModel private[ml] ( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index a75f64dd8..2416df0b3 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -247,6 +247,27 @@ trait HasNumClass extends Params { final def getNumClass: Int = $(numClass) } +/** + * Trait for shared param featuresCols. + */ +trait HasFeaturesCols extends Params { + /** + * Param for the names of feature columns. + * @group param + */ + final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", + "an array of feature column names.") + + /** @group getParam */ + final def getFeaturesCols: Array[String] = $(featuresCols) + + /** Check if featuresCols is valid */ + def isFeaturesColsValid: Boolean = { + isDefined(featuresCols) && $(featuresCols) != Array.empty + } + +} + private[spark] trait ParamMapFuncs extends Params { def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala deleted file mode 100644 index 9ab4c7357..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright (c) 2021-2022 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark.params - -import org.apache.spark.ml.param.{Params, StringArrayParam} - -trait GpuParams extends Params { - /** - * Param for the names of feature columns for GPU pipeline. - * @group param - */ - final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", - "an array of feature column names for GPU pipeline.") - - setDefault(featuresCols, Array.empty[String]) - - /** @group getParam */ - final def getFeaturesCols: Array[String] = $(featuresCols) - -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala index 025757021..5d2a1c04e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,101 @@ package ml.dmlc.xgboost4j.scala.spark.params -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils +import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.types.StructType private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol - with HasLabelCol with GpuParams { + with HasLabelCol with HasFeaturesCols with HasHandleInvalid { def needDeterministicRepartitioning: Boolean = { getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0 } + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(Array("skip", "error", "keep"))) + + setDefault(handleInvalid, "error") + + /** + * Specify an array of feature column names which must be numeric types. + */ + def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value) + + /** Set the handleInvalid for VectorAssembler */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Check if schema has a field named with the value of "featuresCol" param and it's data type + * must be VectorUDT + */ + def isFeaturesColSet(schema: StructType): Boolean = { + schema.fieldNames.contains(getFeaturesCol) && + XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType) + } + + /** check the features columns type */ + def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = { + if (isFeaturesColsValid) { + if (fit) { + XGBoostSchemaUtils.checkNumericType(schema, $(labelCol)) + } + $(featuresCols).foreach(feature => + XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType)) + schema + } else { + throw new IllegalArgumentException("featuresCol or featuresCols must be specified") + } + } + + /** + * Vectorize the features columns if necessary. + * + * @param input the input dataset + * @return (output dataset and the feature column name) + */ + def vectorize(input: Dataset[_]): (Dataset[_], String) = { + val schema = input.schema + if (isFeaturesColSet(schema)) { + // Dataset already has vectorized. + (input, getFeaturesCol) + } else if (isFeaturesColsValid) { + val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) { + getFeaturesCol + } else { + "features_" + uid + } + val vectorAssembler = new VectorAssembler() + .setHandleInvalid($(handleInvalid)) + .setInputCols(getFeaturesCols) + .setOutputCol(featuresName) + (vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName) + } else { + // never reach here, since transformSchema will take care of the case + // that featuresCols is invalid + (input, getFeaturesCol) + } + } } private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala new file mode 100644 index 000000000..0976067ec --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala @@ -0,0 +1,51 @@ +/* + Copyright (c) 2022 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.spark.ml.linalg.xgboost + +import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.util.SchemaUtils + +object XGBoostSchemaUtils { + + /** check if the dataType is VectorUDT */ + def isVectorUDFType(dataType: DataType): Boolean = { + dataType match { + case _: VectorUDT => true + case _ => false + } + } + + /** The feature columns will be vectorized by VectorAssembler first, which only + * supports Numeric, Boolean and VectorUDT types */ + def checkFeatureColumnType(dataType: DataType): Unit = { + dataType match { + case _: NumericType | BooleanType => + case _: VectorUDT => + case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " + + s"boolean and VectorUDT types, found: ${d}") + } + } + + def checkNumericType( + schema: StructType, + colName: String, + msg: String = ""): Unit = { + SchemaUtils.checkNumericType(schema, colName, msg) + } + +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 7940a51e5..91f4a4cfa 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.scalatest.FunSuite import org.apache.spark.Partitioner +import org.apache.spark.ml.feature.VectorAssembler class XGBoostClassifierSuite extends FunSuite with PerTest { @@ -316,4 +317,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { xgb.fit(repartitioned) } + test("featuresCols with features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "features", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "features") + val xgbClassifier = new XGBoostClassifier(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features_" + model.uid)) + df.show() + + val newFeatureName = "features_new" + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol(newFeatureName) + .transform(xgbInput) + .select(newFeatureName, "label") + + val df1 = model + .setFeaturesCol(newFeatureName) + .transform(vectorizedInput) + assert(df1.schema.fieldNames.contains(newFeatureName)) + df1.show() + } + + test("featuresCols without features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "f4", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "f4") + val xgbClassifier = new XGBoostClassifier(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + // transform should work for the dataset which includes the feature column names. + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features")) + df.show() + + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol("features") + .transform(xgbInput) + .select("features", "label") + + val df1 = model.transform(vectorizedInput) + df1.show() + } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index b06ffc939..04e510640 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,15 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} -import org.apache.spark.ml.linalg.Vector + +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types._ import org.scalatest.FunSuite +import org.apache.spark.ml.feature.VectorAssembler + class XGBoostRegressorSuite extends FunSuite with PerTest { protected val treeMethod: String = "auto" @@ -216,4 +219,77 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { assert(resultDF.columns.contains("predictLeaf")) assert(resultDF.columns.contains("predictContrib")) } + + test("featuresCols with features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "features", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "features") + val xgbClassifier = new XGBoostRegressor(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features_" + model.uid)) + df.show() + + val newFeatureName = "features_new" + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol(newFeatureName) + .transform(xgbInput) + .select(newFeatureName, "label") + + val df1 = model + .setFeaturesCol(newFeatureName) + .transform(vectorizedInput) + assert(df1.schema.fieldNames.contains(newFeatureName)) + df1.show() + } + + test("featuresCols without features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "f4", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "f4") + val xgbClassifier = new XGBoostRegressor(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + // transform should work for the dataset which includes the feature column names. + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features")) + df.show() + + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol("features") + .transform(xgbInput) + .select("features", "label") + + val df1 = model.transform(vectorizedInput) + df1.show() + } }