[jvm-packages] Comply with scala style convention + fix broken unit test (#5134)
* Fix scala style check * fix messed unit test
This commit is contained in:
parent
bc9d88259f
commit
37fdfa03f8
@ -163,7 +163,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
val allowNonZeroForMissing = overridedParams.getOrElse("allow_non_zero_for_missing", false).asInstanceOf[Boolean]
|
val allowNonZeroForMissing = overridedParams
|
||||||
|
.getOrElse("allow_non_zero_for_missing", false)
|
||||||
|
.asInstanceOf[Boolean]
|
||||||
validateSparkSslConf
|
validateSparkSslConf
|
||||||
|
|
||||||
if (overridedParams.contains("tree_method")) {
|
if (overridedParams.contains("tree_method")) {
|
||||||
@ -260,15 +262,15 @@ object XGBoost extends Serializable {
|
|||||||
private def verifyMissingSetting(
|
private def verifyMissingSetting(
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
missing: Float,
|
missing: Float,
|
||||||
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = {
|
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||||
if (missing != 0.0f && !allowNonZeroMissingValue) {
|
if (missing != 0.0f && !allowNonZeroMissing) {
|
||||||
xgbLabelPoints.map(labeledPoint => {
|
xgbLabelPoints.map(labeledPoint => {
|
||||||
if (labeledPoint.indices != null) {
|
if (labeledPoint.indices != null) {
|
||||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||||
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
|
||||||
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
s"vector but instead did so in a way that preserves zeros in your feature vector " +
|
||||||
s"you can avoid this check by using the 'allow_non_zero_missing_value parameter'" +
|
s"you can avoid this check by using the 'allow_non_zero_missing parameter'" +
|
||||||
s" (only use if you know what you are doing)")
|
s" (only use if you know what you are doing)")
|
||||||
}
|
}
|
||||||
labeledPoint
|
labeledPoint
|
||||||
@ -296,12 +298,12 @@ object XGBoost extends Serializable {
|
|||||||
private[spark] def processMissingValues(
|
private[spark] def processMissingValues(
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
missing: Float,
|
missing: Float,
|
||||||
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = {
|
allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue),
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||||
missing, (v: Float) => v != missing)
|
missing, (v: Float) => v != missing)
|
||||||
} else {
|
} else {
|
||||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue),
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
|
||||||
missing, (v: Float) => !v.isNaN)
|
missing, (v: Float) => !v.isNaN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -309,13 +311,13 @@ object XGBoost extends Serializable {
|
|||||||
private def processMissingValuesWithGroup(
|
private def processMissingValuesWithGroup(
|
||||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
missing: Float,
|
missing: Float,
|
||||||
allowNonZeroMissingValue: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
xgbLabelPointGroups.map {
|
xgbLabelPointGroups.map {
|
||||||
labeledPoints => XGBoost.processMissingValues(
|
labeledPoints => XGBoost.processMissingValues(
|
||||||
labeledPoints.iterator,
|
labeledPoints.iterator,
|
||||||
missing,
|
missing,
|
||||||
allowNonZeroMissingValue
|
allowNonZeroMissing
|
||||||
).toArray
|
).toArray
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -441,7 +443,8 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||||
processMissingValues(labeledPoints, xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing),
|
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||||
|
xgbExecutionParams.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||||
@ -472,7 +475,8 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing),
|
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||||
|
xgbExecutionParam.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||||
|
|||||||
@ -246,7 +246,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
def setMissing(value: Float): this.type = set(missing, value)
|
def setMissing(value: Float): this.type = set(missing, value)
|
||||||
|
|
||||||
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
||||||
allowNonZeroForMissingValue,
|
allowNonZeroForMissing,
|
||||||
value
|
value
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
val dm = new DMatrix(XGBoost.processMissingValues(
|
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||||
Iterator(features.asXGB),
|
Iterator(features.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
$(allowNonZeroForMissingValue)
|
$(allowNonZeroForMissing)
|
||||||
))
|
))
|
||||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
@ -321,7 +321,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
XGBoost.processMissingValues(
|
XGBoost.processMissingValues(
|
||||||
features.map(_.asXGB),
|
features.map(_.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
$(allowNonZeroForMissingValue)
|
$(allowNonZeroForMissing)
|
||||||
),
|
),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
|
|||||||
@ -242,7 +242,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
def setMissing(value: Float): this.type = set(missing, value)
|
def setMissing(value: Float): this.type = set(missing, value)
|
||||||
|
|
||||||
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
def setAllowZeroForMissingValue(value: Boolean): this.type = set(
|
||||||
allowNonZeroForMissingValue,
|
allowNonZeroForMissing,
|
||||||
value
|
value
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -257,7 +257,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
val dm = new DMatrix(XGBoost.processMissingValues(
|
val dm = new DMatrix(XGBoost.processMissingValues(
|
||||||
Iterator(features.asXGB),
|
Iterator(features.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
$(allowNonZeroForMissingValue)
|
$(allowNonZeroForMissing)
|
||||||
))
|
))
|
||||||
_booster.predict(data = dm)(0)(0)
|
_booster.predict(data = dm)(0)(0)
|
||||||
}
|
}
|
||||||
@ -299,7 +299,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
XGBoost.processMissingValues(
|
XGBoost.processMissingValues(
|
||||||
features.map(_.asXGB),
|
features.map(_.asXGB),
|
||||||
$(missing),
|
$(missing),
|
||||||
$(allowNonZeroForMissingValue)
|
$(allowNonZeroForMissing)
|
||||||
),
|
),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
|
|||||||
@ -109,16 +109,16 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
* Allows for having a non-zero value for missing when training on prediction
|
* Allows for having a non-zero value for missing when training on prediction
|
||||||
* on a Sparse or Empty vector.
|
* on a Sparse or Empty vector.
|
||||||
*/
|
*/
|
||||||
final val allowNonZeroForMissingValue = new BooleanParam(
|
final val allowNonZeroForMissing = new BooleanParam(
|
||||||
this,
|
this,
|
||||||
"allowNonZeroForMissingValue",
|
"allowNonZeroForMissing",
|
||||||
"Allow to have a non-zero value for missing when training or " +
|
"Allow to have a non-zero value for missing when training or " +
|
||||||
"predicting on a Sparse or Empty vector. Should only be used if did " +
|
"predicting on a Sparse or Empty vector. Should only be used if did " +
|
||||||
"not use Spark's VectorAssembler class to construct the feature vector " +
|
"not use Spark's VectorAssembler class to construct the feature vector " +
|
||||||
"but instead used a method that preserves zeros in your vector."
|
"but instead used a method that preserves zeros in your vector."
|
||||||
)
|
)
|
||||||
|
|
||||||
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissingValue)
|
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
||||||
@ -191,7 +191,7 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||||
checkpointPath -> "", checkpointInterval -> -1,
|
checkpointPath -> "", checkpointInterval -> -1,
|
||||||
allowNonZeroForMissingValue -> false)
|
allowNonZeroForMissing -> false)
|
||||||
}
|
}
|
||||||
|
|
||||||
trait HasLeafPredictionCol extends Params {
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("specify a non-zero missing value but set allow_non_zero_missing_value " +
|
test("specify a non-zero missing value but set allow_non_zero_missing " +
|
||||||
"does not stop application") {
|
"does not stop application") {
|
||||||
val spark = ss
|
val spark = ss
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
@ -174,7 +174,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
inputDF.show()
|
inputDF.show()
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
"objective" -> "binary:logistic", "missing" -> -1.0f,
|
"objective" -> "binary:logistic", "missing" -> -1.0f,
|
||||||
"num_workers" -> 1, "allow_non_zero_for_missing_value" -> "true").toMap
|
"num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
|
||||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
model.transform(inputDF).collect()
|
model.transform(inputDF).collect()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user