[jvm-packages] Allow for bypassing spark missing value check (#4805)
* Allow for bypassing spark missing value check * Update documentation for dealing with missing values in spark xgboost
This commit is contained in:
parent
27b3646d29
commit
bc9d88259f
@ -156,24 +156,9 @@ labels. A DataFrame like this (containing vector-represented features and numeri
|
|||||||
Dealing with missing values
|
Dealing with missing values
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
Strategies to handle missing values (and therefore overcome issues as above):
|
XGBoost supports missing values by default (`as desribed here <https://xgboost.readthedocs.io/en/latest/faq.html#how-to-deal-with-missing-value>`_).
|
||||||
|
If given a SparseVector, XGBoost will treat any values absent from the SparseVector as missing. You are also able to
|
||||||
In the case that a feature column contains missing values for any reason (could be related to business logic / wrong data ingestion process / etc.), the user should decide on a strategy of how to handle it.
|
specify to XGBoost to treat a specific value in your Dataset as if it was a missing value. By default XGBoost will treat NaN as the value representing missing.
|
||||||
The choice of approach depends on the value representing 'missing' which fall into four different categories:
|
|
||||||
|
|
||||||
1. 0
|
|
||||||
2. NaN
|
|
||||||
3. Null
|
|
||||||
4. Any other value which is not mentioned in (1) / (2) / (3)
|
|
||||||
|
|
||||||
We introduce the following approaches dealing with missing value and their fitting scenarios:
|
|
||||||
|
|
||||||
1. Skip VectorAssembler (using setHandleInvalid = "skip") directly. Used in (2), (3).
|
|
||||||
2. Keep it (using setHandleInvalid = "keep"), and set the "missing" parameter in XGBClassifier/XGBRegressor as the value representing missing. Used in (2) and (4).
|
|
||||||
3. Keep it (using setHandleInvalid = "keep") and transform to other irregular values. Used in (3).
|
|
||||||
4. Nothing to be done, used in (1).
|
|
||||||
|
|
||||||
Then, XGBoost will automatically learn what's the ideal direction to go when a value is missing, based on that value and strategy.
|
|
||||||
|
|
||||||
Example of setting a missing value (e.g. -999) to the "missing" parameter in XGBoostClassifier:
|
Example of setting a missing value (e.g. -999) to the "missing" parameter in XGBoostClassifier:
|
||||||
|
|
||||||
@ -190,11 +175,37 @@ Example of setting a missing value (e.g. -999) to the "missing" parameter in XGB
|
|||||||
setFeaturesCol("features").
|
setFeaturesCol("features").
|
||||||
setLabelCol("classIndex")
|
setLabelCol("classIndex")
|
||||||
|
|
||||||
.. note:: Using 0 to represent meaningful value
|
.. note:: Missing values with Spark's VectorAssembler
|
||||||
|
|
||||||
Due to the fact that Spark's VectorAssembler transformer only accepts 0 as a missing values, this one creates a problem when the user has 0 as meaningful value plus there are enough 0's to use SparseVector (However, In case the dataset is represented by a DenseVector, the 0 is kept)
|
If given a Dataset with enough features having a value of 0 Spark's VectorAssembler transformer class will return a
|
||||||
|
SparseVector where the absent values are meant to indicate a value of 0. This conflicts with XGBoost's default to
|
||||||
|
treat values absent from the SparseVector as missing. The model would effectively be
|
||||||
|
treating 0 as missing but not declaring that to be so which can lead to confusion when using the trained model on
|
||||||
|
other platforms. To avoid this, XGBoost will raise an exception if it receives a SparseVector and the "missing"
|
||||||
|
parameter has not been explicitly set to 0. To workaround this issue the user has three options:
|
||||||
|
|
||||||
In this case, users are also supposed to transform 0 to some other values to avoid the issue.
|
1. Explicitly convert the Vector returned from VectorAssembler to a DenseVector to return the zeros to the dataset. If
|
||||||
|
doing this with missing values encoded as NaN, you will want to set ``setHandleInvalid = "keep"`` on VectorAssembler
|
||||||
|
in order to keep the NaN values in the dataset. You would then set the "missing" parameter to whatever you want to be
|
||||||
|
treated as missing. However this may cause a large amount of memory use if your dataset is very sparse.
|
||||||
|
2. Before calling VectorAssembler you can transform the values you want to represent missing into an irregular value
|
||||||
|
that is not 0, NaN, or Null and set the "missing" parameter to 0. The irregular value should ideally be chosen to be
|
||||||
|
outside the range of values that your features have.
|
||||||
|
3. Do not use the VectorAssembler class and instead use a custom way of constructing a SparseVector that allows for
|
||||||
|
specifying sparsity to indicate a non-zero value. You can then set the "missing" parameter to whatever sparsity
|
||||||
|
indicates in your Dataset. If this approach is taken you can pass the parameter
|
||||||
|
``"allow_non_zero_for_missing_value" -> true`` to bypass XGBoost's assertion that "missing" must be zero when given a
|
||||||
|
SparseVector.
|
||||||
|
|
||||||
|
Option 1 is recommended if memory constraints are not an issue. Option 3 requires more work to get set up but is
|
||||||
|
guaranteed to give you correct results while option 2 will be quicker to set up but may be difficult to find a good
|
||||||
|
irregular value that does not conflict with your feature values.
|
||||||
|
|
||||||
|
.. note:: Using a non-default missing value when using other bindings of XGBoost.
|
||||||
|
|
||||||
|
When XGBoost is saved in native format only the booster itself is saved, the value of the missing parameter is not
|
||||||
|
saved alongside the model. Thus, if a non-default missing parameter is used to train the model in Spark the user should
|
||||||
|
take care to use the same missing parameter when using the saved model in another binding.
|
||||||
|
|
||||||
Training
|
Training
|
||||||
========
|
========
|
||||||
|
|||||||
@ -69,6 +69,7 @@ private[this] case class XGBoostExecutionParams(
|
|||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
missing: Float,
|
missing: Float,
|
||||||
|
allowNonZeroForMissing: Boolean,
|
||||||
trackerConf: TrackerConf,
|
trackerConf: TrackerConf,
|
||||||
timeoutRequestWorkers: Long,
|
timeoutRequestWorkers: Long,
|
||||||
checkpointParam: CheckpointParam,
|
checkpointParam: CheckpointParam,
|
||||||
@ -162,6 +163,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
|
val allowNonZeroForMissing = overridedParams.getOrElse("allow_non_zero_for_missing", false).asInstanceOf[Boolean]
|
||||||
validateSparkSslConf
|
validateSparkSslConf
|
||||||
|
|
||||||
if (overridedParams.contains("tree_method")) {
|
if (overridedParams.contains("tree_method")) {
|
||||||
@ -212,7 +214,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||||
missing, trackerConf,
|
missing, allowNonZeroForMissing, trackerConf,
|
||||||
timeoutRequestWorkers,
|
timeoutRequestWorkers,
|
||||||
checkpointParam,
|
checkpointParam,
|
||||||
inputParams,
|
inputParams,
|
||||||
@ -255,14 +257,19 @@ private[spark] case class XGBLabeledPointGroup(
|
|||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float):
|
private def verifyMissingSetting(
|
||||||
Iterator[XGBLabeledPoint] = {
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
if (missing != 0.0f) {
|
missing: Float,
|
||||||
|
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = {
|
||||||
|
if (missing != 0.0f && !allowNonZeroMissingValue) {
|
||||||
xgbLabelPoints.map(labeledPoint => {
|
xgbLabelPoints.map(labeledPoint => {
|
||||||
if (labeledPoint.indices != null) {
|
if (labeledPoint.indices != null) {
|
||||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||||
" format")
|
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
||||||
|
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
||||||
|
s"you can avoid this check by using the 'allow_non_zero_missing_value parameter'" +
|
||||||
|
s" (only use if you know what you are doing)")
|
||||||
}
|
}
|
||||||
labeledPoint
|
labeledPoint
|
||||||
})
|
})
|
||||||
@ -288,22 +295,28 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
private[spark] def processMissingValues(
|
private[spark] def processMissingValues(
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
missing: Float,
|
||||||
|
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue),
|
||||||
missing, (v: Float) => v != missing)
|
missing, (v: Float) => v != missing)
|
||||||
} else {
|
} else {
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue),
|
||||||
missing, (v: Float) => !v.isNaN)
|
missing, (v: Float) => !v.isNaN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def processMissingValuesWithGroup(
|
private def processMissingValuesWithGroup(
|
||||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
missing: Float,
|
||||||
|
allowNonZeroMissingValue: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
xgbLabelPointGroups.map {
|
xgbLabelPointGroups.map {
|
||||||
labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray
|
labeledPoints => XGBoost.processMissingValues(
|
||||||
|
labeledPoints.iterator,
|
||||||
|
missing,
|
||||||
|
allowNonZeroMissingValue
|
||||||
|
).toArray
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
xgbLabelPointGroups
|
xgbLabelPointGroups
|
||||||
@ -428,7 +441,7 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||||
processMissingValues(labeledPoints, xgbExecutionParams.missing),
|
processMissingValues(labeledPoints, xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||||
@ -440,7 +453,7 @@ object XGBoost extends Serializable {
|
|||||||
val watches = Watches.buildWatches(
|
val watches = Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
case (name, iter) => (name, processMissingValues(iter,
|
case (name, iter) => (name, processMissingValues(iter,
|
||||||
xgbExecutionParams.missing))
|
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||||
@ -459,7 +472,7 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing),
|
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||||
@ -470,7 +483,7 @@ object XGBoost extends Serializable {
|
|||||||
val watches = Watches.buildWatchesWithGroup(
|
val watches = Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
case (name, iter) => (name, processMissingValuesWithGroup(iter,
|
case (name, iter) => (name, processMissingValuesWithGroup(iter,
|
||||||
xgbExecutionParam.missing))
|
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||||
|
|||||||
@ -245,6 +245,11 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
def setMissing(value: Float): this.type = set(missing, value)
|
||||||
|
|
||||||
|
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
||||||
|
allowNonZeroForMissingValue,
|
||||||
|
value
|
||||||
|
)
|
||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -253,7 +258,11 @@ class XGBoostClassificationModel private[ml](
|
|||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||||
|
Iterator(features.asXGB),
|
||||||
|
$(missing),
|
||||||
|
$(allowNonZeroForMissingValue)
|
||||||
|
))
|
||||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
math.round(probability(0))
|
math.round(probability(0))
|
||||||
@ -309,7 +318,11 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
|
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
|
XGBoost.processMissingValues(
|
||||||
|
features.map(_.asXGB),
|
||||||
|
$(missing),
|
||||||
|
$(allowNonZeroForMissingValue)
|
||||||
|
),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
|
|||||||
@ -241,6 +241,11 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
|
|
||||||
def setMissing(value: Float): this.type = set(missing, value)
|
def setMissing(value: Float): this.type = set(missing, value)
|
||||||
|
|
||||||
|
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
||||||
|
allowNonZeroForMissingValue,
|
||||||
|
value
|
||||||
|
)
|
||||||
|
|
||||||
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -249,7 +254,11 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||||
|
Iterator(features.asXGB),
|
||||||
|
$(missing),
|
||||||
|
$(allowNonZeroForMissingValue)
|
||||||
|
))
|
||||||
_booster.predict(data = dm)(0)(0)
|
_booster.predict(data = dm)(0)(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,7 +296,11 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
}
|
}
|
||||||
|
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.processMissingValues(features.map(_.asXGB), $(missing)),
|
XGBoost.processMissingValues(
|
||||||
|
features.map(_.asXGB),
|
||||||
|
$(missing),
|
||||||
|
$(allowNonZeroForMissingValue)
|
||||||
|
),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
val Array(rawPredictionItr, predLeafItr, predContribItr) =
|
||||||
|
|||||||
@ -105,6 +105,21 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
|
|
||||||
final def getMissing: Float = $(missing)
|
final def getMissing: Float = $(missing)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allows for having a non-zero value for missing when training on prediction
|
||||||
|
* on a Sparse or Empty vector.
|
||||||
|
*/
|
||||||
|
final val allowNonZeroForMissingValue = new BooleanParam(
|
||||||
|
this,
|
||||||
|
"allowNonZeroForMissingValue",
|
||||||
|
"Allow to have a non-zero value for missing when training or " +
|
||||||
|
"predicting on a Sparse or Empty vector. Should only be used if did " +
|
||||||
|
"not use Spark's VectorAssembler class to construct the feature vector " +
|
||||||
|
"but instead used a method that preserves zeros in your vector."
|
||||||
|
)
|
||||||
|
|
||||||
|
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissingValue)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
||||||
*/
|
*/
|
||||||
@ -175,7 +190,8 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
useExternalMemory -> false, silent -> 0, verbosity -> 1,
|
useExternalMemory -> false, silent -> 0, verbosity -> 1,
|
||||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||||
checkpointPath -> "", checkpointInterval -> -1)
|
checkpointPath -> "", checkpointInterval -> -1,
|
||||||
|
allowNonZeroForMissingValue -> false)
|
||||||
}
|
}
|
||||||
|
|
||||||
trait HasLeafPredictionCol extends Params {
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
|||||||
@ -150,4 +150,32 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
new XGBoostClassifier(paramMap).fit(inputDF)
|
new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("specify a non-zero missing value but set allow_non_zero_missing_value " +
|
||||||
|
"does not stop application") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
ss.sparkContext.setLogLevel("INFO")
|
||||||
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
|
// vector,
|
||||||
|
val testDF = Seq(
|
||||||
|
(7.0f, 0.0f, -1.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 1.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, -1.0f, 0.0f, 1.0f, 0.0),
|
||||||
|
(0.0f, 0.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "col4", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3", "col4"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
inputDF.show()
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "missing" -> -1.0f,
|
||||||
|
"num_workers" -> 1, "allow_non_zero_for_missing_value" -> "true").toMap
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
model.transform(inputDF).collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user