[jvm-packages] xgboost4j-spark should work when featuresCols is specified (#7789)
This commit is contained in:
parent
729d227b89
commit
118192f116
@ -126,7 +126,7 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
|||||||
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setHandleInvalid("keep")
|
.setHandleInvalid("keep")
|
||||||
.setInputCols(featureNames.toArray)
|
.setInputCols(featureNames)
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
||||||
|
|
||||||
@ -149,11 +149,10 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
|||||||
// Since CPU model does not know the information about the features cols that GPU transform
|
// Since CPU model does not know the information about the features cols that GPU transform
|
||||||
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
||||||
// manually
|
// manually
|
||||||
val thrown = intercept[IllegalArgumentException](cpuModel
|
val thrown = intercept[NoSuchElementException](cpuModel
|
||||||
.transform(testDf)
|
.transform(testDf)
|
||||||
.collect())
|
.collect())
|
||||||
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
|
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
||||||
"please refer to `setFeaturesCol(value: Array[String])`"))
|
|
||||||
|
|
||||||
val left = cpuModel
|
val left = cpuModel
|
||||||
.setFeaturesCol(featureNames)
|
.setFeaturesCol(featureNames)
|
||||||
@ -196,17 +195,16 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
|||||||
val featureColName = "feature_col"
|
val featureColName = "feature_col"
|
||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setHandleInvalid("keep")
|
.setHandleInvalid("keep")
|
||||||
.setInputCols(featureNames.toArray)
|
.setInputCols(featureNames)
|
||||||
.setOutputCol(featureColName)
|
.setOutputCol(featureColName)
|
||||||
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
||||||
|
|
||||||
// Since GPU model does not know the information about the features col name that CPU
|
// Since GPU model does not know the information about the features col name that CPU
|
||||||
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
||||||
val thrown = intercept[IllegalArgumentException](
|
intercept[IllegalArgumentException](
|
||||||
gpuModel
|
gpuModel
|
||||||
.transform(testDf)
|
.transform(testDf)
|
||||||
.collect())
|
.collect())
|
||||||
assert(thrown.getMessage.contains("features does not exist"))
|
|
||||||
|
|
||||||
val left = gpuModel
|
val left = gpuModel
|
||||||
.setFeaturesCol(featureColName)
|
.setFeaturesCol(featureColName)
|
||||||
|
|||||||
@ -108,12 +108,15 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
||||||
val thrown = intercept[IllegalArgumentException] {
|
|
||||||
|
// GPU train requires featuresCols. If not specified,
|
||||||
|
// then NoSuchElementException will be thrown
|
||||||
|
val thrown = intercept[NoSuchElementException] {
|
||||||
new XGBoostClassifier(xgbParam)
|
new XGBoostClassifier(xgbParam)
|
||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
.fit(trainingDf)
|
.fit(trainingDf)
|
||||||
}
|
}
|
||||||
assert(thrown.getMessage.contains("Gpu train requires features columns."))
|
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
||||||
|
|
||||||
val thrown1 = intercept[IllegalArgumentException] {
|
val thrown1 = intercept[IllegalArgumentException] {
|
||||||
new XGBoostClassifier(xgbParam)
|
new XGBoostClassifier(xgbParam)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -122,7 +122,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
|
|
||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setHandleInvalid("keep")
|
.setHandleInvalid("keep")
|
||||||
.setInputCols(featureNames.toArray)
|
.setInputCols(featureNames)
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName)
|
||||||
|
|
||||||
@ -145,11 +145,10 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
// Since CPU model does not know the information about the features cols that GPU transform
|
// Since CPU model does not know the information about the features cols that GPU transform
|
||||||
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
||||||
// manually
|
// manually
|
||||||
val thrown = intercept[IllegalArgumentException](cpuModel
|
val thrown = intercept[NoSuchElementException](cpuModel
|
||||||
.transform(testDf)
|
.transform(testDf)
|
||||||
.collect())
|
.collect())
|
||||||
assert(thrown.getMessage.contains("Gpu transform requires features columns. " +
|
assert(thrown.getMessage.contains("Failed to find a default value for featuresCols"))
|
||||||
"please refer to `setFeaturesCol(value: Array[String])`"))
|
|
||||||
|
|
||||||
val left = cpuModel
|
val left = cpuModel
|
||||||
.setFeaturesCol(featureNames)
|
.setFeaturesCol(featureNames)
|
||||||
@ -192,17 +191,16 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
val featureColName = "feature_col"
|
val featureColName = "feature_col"
|
||||||
val vectorAssembler = new VectorAssembler()
|
val vectorAssembler = new VectorAssembler()
|
||||||
.setHandleInvalid("keep")
|
.setHandleInvalid("keep")
|
||||||
.setInputCols(featureNames.toArray)
|
.setInputCols(featureNames)
|
||||||
.setOutputCol(featureColName)
|
.setOutputCol(featureColName)
|
||||||
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName)
|
||||||
|
|
||||||
// Since GPU model does not know the information about the features col name that CPU
|
// Since GPU model does not know the information about the features col name that CPU
|
||||||
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
// transform pipeline requires. End user needs to setFeaturesCol in the model manually
|
||||||
val thrown = intercept[IllegalArgumentException](
|
intercept[IllegalArgumentException](
|
||||||
gpuModel
|
gpuModel
|
||||||
.transform(testDf)
|
.transform(testDf)
|
||||||
.collect())
|
.collect())
|
||||||
assert(thrown.getMessage.contains("features does not exist"))
|
|
||||||
|
|
||||||
val left = gpuModel
|
val left = gpuModel
|
||||||
.setFeaturesCol(featureColName)
|
.setFeaturesCol(featureColName)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -35,8 +35,10 @@ import org.apache.commons.logging.LogFactory
|
|||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
|
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
@ -112,7 +114,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
return optionProvider.get.buildDatasetToRDD(estimator, dataset, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
val (packedParams, evalSet) = estimator match {
|
val (packedParams, evalSet, xgbInput) = estimator match {
|
||||||
case est: XGBoostEstimatorCommon =>
|
case est: XGBoostEstimatorCommon =>
|
||||||
// get weight column, if weight is not defined, default to lit(1.0)
|
// get weight column, if weight is not defined, default to lit(1.0)
|
||||||
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
|
val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) {
|
||||||
@ -136,15 +138,18 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group,
|
val (xgbInput, featuresName) = est.vectorize(dataset)
|
||||||
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params))
|
|
||||||
|
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
|
||||||
|
est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params),
|
||||||
|
xgbInput)
|
||||||
|
|
||||||
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
case _ => throw new RuntimeException("Unsupporting " + estimator)
|
||||||
}
|
}
|
||||||
|
|
||||||
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
|
// transform the training Dataset[_] to RDD[XGBLabeledPoint]
|
||||||
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||||
packedParams, dataset.asInstanceOf[DataFrame]).head
|
packedParams, xgbInput.asInstanceOf[DataFrame]).head
|
||||||
|
|
||||||
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
// transform the eval Dataset[_] to RDD[XGBLabeledPoint]
|
||||||
val evalRDDMap = evalSet.map {
|
val evalRDDMap = evalSet.map {
|
||||||
@ -184,11 +189,11 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** get the necessary parameters */
|
/** get the necessary parameters */
|
||||||
val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing,
|
val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing,
|
||||||
predictFunc, schema) =
|
allowNonZeroForMissing, predictFunc, schema) =
|
||||||
model match {
|
model match {
|
||||||
case m: XGBoostClassificationModel =>
|
case m: XGBoostClassificationModel =>
|
||||||
|
val (xgbInput, featuresName) = m.vectorize(dataset)
|
||||||
// predict and turn to Row
|
// predict and turn to Row
|
||||||
val predictFunc =
|
val predictFunc =
|
||||||
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
||||||
@ -199,7 +204,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare the final Schema
|
// prepare the final Schema
|
||||||
var schema = StructType(dataset.schema.fields ++
|
var schema = StructType(xgbInput.schema.fields ++
|
||||||
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
|
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
||||||
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
|
Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType =
|
||||||
@ -214,11 +219,12 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
}
|
}
|
||||||
|
|
||||||
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
|
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
|
||||||
m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
||||||
|
|
||||||
case m: XGBoostRegressionModel =>
|
case m: XGBoostRegressionModel =>
|
||||||
// predict and turn to Row
|
// predict and turn to Row
|
||||||
|
val (xgbInput, featuresName) = m.vectorize(dataset)
|
||||||
val predictFunc =
|
val predictFunc =
|
||||||
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
(broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => {
|
||||||
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
||||||
@ -227,7 +233,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare the final Schema
|
// prepare the final Schema
|
||||||
var schema = StructType(dataset.schema.fields ++
|
var schema = StructType(xgbInput.schema.fields ++
|
||||||
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
|
Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType =
|
||||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||||
|
|
||||||
@ -240,14 +246,14 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
}
|
}
|
||||||
|
|
||||||
(m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing,
|
(m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory,
|
||||||
m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(booster)
|
val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster)
|
||||||
val appName = dataset.sparkSession.sparkContext.appName
|
val appName = xgbInput.sparkSession.sparkContext.appName
|
||||||
|
|
||||||
val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||||
new AbstractIterator[Row] {
|
new AbstractIterator[Row] {
|
||||||
private var batchCnt = 0
|
private var batchCnt = 0
|
||||||
|
|
||||||
@ -295,7 +301,7 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
dataset.sparkSession.createDataFrame(resultRDD, schema)
|
xgbInput.sparkSession.createDataFrame(resultRDD, schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -144,13 +144,6 @@ class XGBoostClassifier (
|
|||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||||
set(singlePrecisionHistogram, value)
|
set(singlePrecisionHistogram, value)
|
||||||
|
|
||||||
/**
|
|
||||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
|
||||||
* all feature columns must be numeric types.
|
|
||||||
*/
|
|
||||||
def setFeaturesCol(value: Array[String]): this.type =
|
|
||||||
set(featuresCols, value)
|
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
@ -165,7 +158,12 @@ class XGBoostClassifier (
|
|||||||
|
|
||||||
// Callback from PreXGBoost
|
// Callback from PreXGBoost
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||||
super.transformSchema(schema)
|
if (isFeaturesColSet(schema)) {
|
||||||
|
// User has vectorized the features into VectorUDT.
|
||||||
|
super.transformSchema(schema)
|
||||||
|
} else {
|
||||||
|
transformSchemaWithFeaturesCols(true, schema)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
@ -260,13 +258,6 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
|
||||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
|
||||||
* all feature columns must be numeric types.
|
|
||||||
*/
|
|
||||||
def setFeaturesCol(value: Array[String]): this.type =
|
|
||||||
set(featuresCols, value)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -359,7 +350,12 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||||
super.transformSchema(schema)
|
if (isFeaturesColSet(schema)) {
|
||||||
|
// User has vectorized the features into VectorUDT.
|
||||||
|
super.transformSchema(schema)
|
||||||
|
} else {
|
||||||
|
transformSchemaWithFeaturesCols(false, schema)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
@ -385,8 +381,6 @@ class XGBoostClassificationModel private[ml](
|
|||||||
Vectors.dense(rawPredictions)
|
Vectors.dense(rawPredictions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if ($(rawPredictionCol).nonEmpty) {
|
if ($(rawPredictionCol).nonEmpty) {
|
||||||
outputData = outputData
|
outputData = outputData
|
||||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
||||||
|
|||||||
@ -146,13 +146,6 @@ class XGBoostRegressor (
|
|||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||||
set(singlePrecisionHistogram, value)
|
set(singlePrecisionHistogram, value)
|
||||||
|
|
||||||
/**
|
|
||||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
|
||||||
* all feature columns must be numeric types.
|
|
||||||
*/
|
|
||||||
def setFeaturesCol(value: Array[String]): this.type =
|
|
||||||
set(featuresCols, value)
|
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
@ -164,7 +157,12 @@ class XGBoostRegressor (
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||||
super.transformSchema(schema)
|
if (isFeaturesColSet(schema)) {
|
||||||
|
// User has vectorized the features into VectorUDT.
|
||||||
|
super.transformSchema(schema)
|
||||||
|
} else {
|
||||||
|
transformSchemaWithFeaturesCols(false, schema)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
@ -253,13 +251,6 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
|
||||||
* This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires
|
|
||||||
* all feature columns must be numeric types.
|
|
||||||
*/
|
|
||||||
def setFeaturesCol(value: Array[String]): this.type =
|
|
||||||
set(featuresCols, value)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -331,7 +322,12 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
private[spark] def transformSchemaInternal(schema: StructType): StructType = {
|
||||||
super.transformSchema(schema)
|
if (isFeaturesColSet(schema)) {
|
||||||
|
// User has vectorized the features into VectorUDT.
|
||||||
|
super.transformSchema(schema)
|
||||||
|
} else {
|
||||||
|
transformSchemaWithFeaturesCols(false, schema)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -247,6 +247,27 @@ trait HasNumClass extends Params {
|
|||||||
final def getNumClass: Int = $(numClass)
|
final def getNumClass: Int = $(numClass)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Trait for shared param featuresCols.
|
||||||
|
*/
|
||||||
|
trait HasFeaturesCols extends Params {
|
||||||
|
/**
|
||||||
|
* Param for the names of feature columns.
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
|
||||||
|
"an array of feature column names.")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getFeaturesCols: Array[String] = $(featuresCols)
|
||||||
|
|
||||||
|
/** Check if featuresCols is valid */
|
||||||
|
def isFeaturesColsValid: Boolean = {
|
||||||
|
isDefined(featuresCols) && $(featuresCols) != Array.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private[spark] trait ParamMapFuncs extends Params {
|
private[spark] trait ParamMapFuncs extends Params {
|
||||||
|
|
||||||
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||||
|
|||||||
@ -1,34 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{Params, StringArrayParam}
|
|
||||||
|
|
||||||
trait GpuParams extends Params {
|
|
||||||
/**
|
|
||||||
* Param for the names of feature columns for GPU pipeline.
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
|
|
||||||
"an array of feature column names for GPU pipeline.")
|
|
||||||
|
|
||||||
setDefault(featuresCols, Array.empty[String])
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getFeaturesCols: Array[String] = $(featuresCols)
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,16 +16,101 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
|
||||||
|
import org.apache.spark.ml.param.{Param, ParamValidators}
|
||||||
|
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
|
||||||
|
import org.apache.spark.sql.Dataset
|
||||||
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
|
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
|
||||||
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
|
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
|
||||||
with HasLabelCol with GpuParams {
|
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|
||||||
|
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|
||||||
|
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|
||||||
|
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|
||||||
|
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|
||||||
|
* Default: "error"
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
|
||||||
|
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
|
||||||
|
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
|
||||||
|
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
|
||||||
|
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
|
||||||
|
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
|
||||||
|
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
|
||||||
|
ParamValidators.inArray(Array("skip", "error", "keep")))
|
||||||
|
|
||||||
|
setDefault(handleInvalid, "error")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify an array of feature column names which must be numeric types.
|
||||||
|
*/
|
||||||
|
def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value)
|
||||||
|
|
||||||
|
/** Set the handleInvalid for VectorAssembler */
|
||||||
|
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if schema has a field named with the value of "featuresCol" param and it's data type
|
||||||
|
* must be VectorUDT
|
||||||
|
*/
|
||||||
|
def isFeaturesColSet(schema: StructType): Boolean = {
|
||||||
|
schema.fieldNames.contains(getFeaturesCol) &&
|
||||||
|
XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** check the features columns type */
|
||||||
|
def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = {
|
||||||
|
if (isFeaturesColsValid) {
|
||||||
|
if (fit) {
|
||||||
|
XGBoostSchemaUtils.checkNumericType(schema, $(labelCol))
|
||||||
|
}
|
||||||
|
$(featuresCols).foreach(feature =>
|
||||||
|
XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType))
|
||||||
|
schema
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("featuresCol or featuresCols must be specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Vectorize the features columns if necessary.
|
||||||
|
*
|
||||||
|
* @param input the input dataset
|
||||||
|
* @return (output dataset and the feature column name)
|
||||||
|
*/
|
||||||
|
def vectorize(input: Dataset[_]): (Dataset[_], String) = {
|
||||||
|
val schema = input.schema
|
||||||
|
if (isFeaturesColSet(schema)) {
|
||||||
|
// Dataset already has vectorized.
|
||||||
|
(input, getFeaturesCol)
|
||||||
|
} else if (isFeaturesColsValid) {
|
||||||
|
val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) {
|
||||||
|
getFeaturesCol
|
||||||
|
} else {
|
||||||
|
"features_" + uid
|
||||||
|
}
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setHandleInvalid($(handleInvalid))
|
||||||
|
.setInputCols(getFeaturesCols)
|
||||||
|
.setOutputCol(featuresName)
|
||||||
|
(vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName)
|
||||||
|
} else {
|
||||||
|
// never reach here, since transformSchema will take care of the case
|
||||||
|
// that featuresCols is invalid
|
||||||
|
(input, getFeaturesCol)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
|
||||||
|
|||||||
@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2022 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.linalg.xgboost
|
||||||
|
|
||||||
|
import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
|
||||||
|
import org.apache.spark.ml.linalg.VectorUDT
|
||||||
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
|
|
||||||
|
object XGBoostSchemaUtils {
|
||||||
|
|
||||||
|
/** check if the dataType is VectorUDT */
|
||||||
|
def isVectorUDFType(dataType: DataType): Boolean = {
|
||||||
|
dataType match {
|
||||||
|
case _: VectorUDT => true
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The feature columns will be vectorized by VectorAssembler first, which only
|
||||||
|
* supports Numeric, Boolean and VectorUDT types */
|
||||||
|
def checkFeatureColumnType(dataType: DataType): Unit = {
|
||||||
|
dataType match {
|
||||||
|
case _: NumericType | BooleanType =>
|
||||||
|
case _: VectorUDT =>
|
||||||
|
case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " +
|
||||||
|
s"boolean and VectorUDT types, found: ${d}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def checkNumericType(
|
||||||
|
schema: StructType,
|
||||||
|
colName: String,
|
||||||
|
msg: String = ""): Unit = {
|
||||||
|
SchemaUtils.checkNumericType(schema, colName, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -23,6 +23,7 @@ import org.apache.spark.sql._
|
|||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.Partitioner
|
import org.apache.spark.Partitioner
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
@ -316,4 +317,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
xgb.fit(repartitioned)
|
xgb.fit(repartitioned)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("featuresCols with features column can work") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val xgbInput = Seq(
|
||||||
|
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
||||||
|
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
||||||
|
.toDF("f1", "f2", "f3", "features", "label")
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
|
||||||
|
|
||||||
|
val featuresName = Array("f1", "f2", "f3", "features")
|
||||||
|
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||||
|
.setFeaturesCol(featuresName)
|
||||||
|
.setLabelCol("label")
|
||||||
|
|
||||||
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|
||||||
|
val df = model.transform(xgbInput)
|
||||||
|
assert(df.schema.fieldNames.contains("features_" + model.uid))
|
||||||
|
df.show()
|
||||||
|
|
||||||
|
val newFeatureName = "features_new"
|
||||||
|
// transform also can work for vectorized dataset
|
||||||
|
val vectorizedInput = new VectorAssembler()
|
||||||
|
.setInputCols(featuresName)
|
||||||
|
.setOutputCol(newFeatureName)
|
||||||
|
.transform(xgbInput)
|
||||||
|
.select(newFeatureName, "label")
|
||||||
|
|
||||||
|
val df1 = model
|
||||||
|
.setFeaturesCol(newFeatureName)
|
||||||
|
.transform(vectorizedInput)
|
||||||
|
assert(df1.schema.fieldNames.contains(newFeatureName))
|
||||||
|
df1.show()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("featuresCols without features column can work") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val xgbInput = Seq(
|
||||||
|
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
||||||
|
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
||||||
|
.toDF("f1", "f2", "f3", "f4", "label")
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
|
||||||
|
|
||||||
|
val featuresName = Array("f1", "f2", "f3", "f4")
|
||||||
|
val xgbClassifier = new XGBoostClassifier(paramMap)
|
||||||
|
.setFeaturesCol(featuresName)
|
||||||
|
.setLabelCol("label")
|
||||||
|
|
||||||
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|
||||||
|
// transform should work for the dataset which includes the feature column names.
|
||||||
|
val df = model.transform(xgbInput)
|
||||||
|
assert(df.schema.fieldNames.contains("features"))
|
||||||
|
df.show()
|
||||||
|
|
||||||
|
// transform also can work for vectorized dataset
|
||||||
|
val vectorizedInput = new VectorAssembler()
|
||||||
|
.setInputCols(featuresName)
|
||||||
|
.setOutputCol("features")
|
||||||
|
.transform(xgbInput)
|
||||||
|
.select("features", "label")
|
||||||
|
|
||||||
|
val df1 = model.transform(vectorizedInput)
|
||||||
|
df1.show()
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -17,12 +17,15 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
|
||||||
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||||
protected val treeMethod: String = "auto"
|
protected val treeMethod: String = "auto"
|
||||||
|
|
||||||
@ -216,4 +219,77 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
assert(resultDF.columns.contains("predictLeaf"))
|
assert(resultDF.columns.contains("predictLeaf"))
|
||||||
assert(resultDF.columns.contains("predictContrib"))
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("featuresCols with features column can work") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val xgbInput = Seq(
|
||||||
|
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
||||||
|
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
||||||
|
.toDF("f1", "f2", "f3", "features", "label")
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
|
||||||
|
|
||||||
|
val featuresName = Array("f1", "f2", "f3", "features")
|
||||||
|
val xgbClassifier = new XGBoostRegressor(paramMap)
|
||||||
|
.setFeaturesCol(featuresName)
|
||||||
|
.setLabelCol("label")
|
||||||
|
|
||||||
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|
||||||
|
val df = model.transform(xgbInput)
|
||||||
|
assert(df.schema.fieldNames.contains("features_" + model.uid))
|
||||||
|
df.show()
|
||||||
|
|
||||||
|
val newFeatureName = "features_new"
|
||||||
|
// transform also can work for vectorized dataset
|
||||||
|
val vectorizedInput = new VectorAssembler()
|
||||||
|
.setInputCols(featuresName)
|
||||||
|
.setOutputCol(newFeatureName)
|
||||||
|
.transform(xgbInput)
|
||||||
|
.select(newFeatureName, "label")
|
||||||
|
|
||||||
|
val df1 = model
|
||||||
|
.setFeaturesCol(newFeatureName)
|
||||||
|
.transform(vectorizedInput)
|
||||||
|
assert(df1.schema.fieldNames.contains(newFeatureName))
|
||||||
|
df1.show()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("featuresCols without features column can work") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val xgbInput = Seq(
|
||||||
|
(Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
|
||||||
|
(Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
|
||||||
|
.toDF("f1", "f2", "f3", "f4", "label")
|
||||||
|
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
|
||||||
|
|
||||||
|
val featuresName = Array("f1", "f2", "f3", "f4")
|
||||||
|
val xgbClassifier = new XGBoostRegressor(paramMap)
|
||||||
|
.setFeaturesCol(featuresName)
|
||||||
|
.setLabelCol("label")
|
||||||
|
|
||||||
|
val model = xgbClassifier.fit(xgbInput)
|
||||||
|
assert(model.getFeaturesCols.sameElements(featuresName))
|
||||||
|
|
||||||
|
// transform should work for the dataset which includes the feature column names.
|
||||||
|
val df = model.transform(xgbInput)
|
||||||
|
assert(df.schema.fieldNames.contains("features"))
|
||||||
|
df.show()
|
||||||
|
|
||||||
|
// transform also can work for vectorized dataset
|
||||||
|
val vectorizedInput = new VectorAssembler()
|
||||||
|
.setInputCols(featuresName)
|
||||||
|
.setOutputCol("features")
|
||||||
|
.transform(xgbInput)
|
||||||
|
.select("features", "label")
|
||||||
|
|
||||||
|
val df1 = model.transform(vectorizedInput)
|
||||||
|
df1.show()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user