[jvm-packages] Comply with scala style convention + fix broken unit test (#5134)

* Fix scala style check

* fix messed unit test
This commit is contained in:
Philip Hyunsu Cho 2019-12-18 17:26:58 -08:00 committed by GitHub
parent bc9d88259f
commit 37fdfa03f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 23 deletions

View File

@ -163,7 +163,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
val allowNonZeroForMissing = overridedParams.getOrElse("allow_non_zero_for_missing", false).asInstanceOf[Boolean] val allowNonZeroForMissing = overridedParams
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
validateSparkSslConf validateSparkSslConf
if (overridedParams.contains("tree_method")) { if (overridedParams.contains("tree_method")) {
@ -260,15 +262,15 @@ object XGBoost extends Serializable {
private def verifyMissingSetting( private def verifyMissingSetting(
xgbLabelPoints: Iterator[XGBLabeledPoint], xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float, missing: Float,
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = { allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (missing != 0.0f && !allowNonZeroMissingValue) { if (missing != 0.0f && !allowNonZeroMissing) {
xgbLabelPoints.map(labeledPoint => { xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) { if (labeledPoint.indices != null) {
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" + throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
s" set value $missing) when you have SparseVector or Empty vector as your feature" + s" set value $missing) when you have SparseVector or Empty vector as your feature" +
s" format. If you didn't use Spark's VectorAssembler class to build your feature " + s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
s"vector but instead did so in a way that preserves zeros in your feature vector " + s"vector but instead did so in a way that preserves zeros in your feature vector " +
s"you can avoid this check by using the 'allow_non_zero_missing_value parameter'" + s"you can avoid this check by using the 'allow_non_zero_missing parameter'" +
s" (only use if you know what you are doing)") s" (only use if you know what you are doing)")
} }
labeledPoint labeledPoint
@ -296,12 +298,12 @@ object XGBoost extends Serializable {
private[spark] def processMissingValues( private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint], xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float, missing: Float,
allowNonZeroMissingValue: Boolean): Iterator[XGBLabeledPoint] = { allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) { if (!missing.isNaN) {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue), removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => v != missing) missing, (v: Float) => v != missing)
} else { } else {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissingValue), removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
missing, (v: Float) => !v.isNaN) missing, (v: Float) => !v.isNaN)
} }
} }
@ -309,13 +311,13 @@ object XGBoost extends Serializable {
private def processMissingValuesWithGroup( private def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]], xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float, missing: Float,
allowNonZeroMissingValue: Boolean): Iterator[Array[XGBLabeledPoint]] = { allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) { if (!missing.isNaN) {
xgbLabelPointGroups.map { xgbLabelPointGroups.map {
labeledPoints => XGBoost.processMissingValues( labeledPoints => XGBoost.processMissingValues(
labeledPoints.iterator, labeledPoints.iterator,
missing, missing,
allowNonZeroMissingValue allowNonZeroMissing
).toArray ).toArray
} }
} else { } else {
@ -441,7 +443,8 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPoints => { trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(xgbExecutionParams, val watches = Watches.buildWatches(xgbExecutionParams,
processMissingValues(labeledPoints, xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing), processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory)) getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
@ -472,7 +475,8 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => { trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam, val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing), processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory)) getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)

View File

@ -246,7 +246,7 @@ class XGBoostClassificationModel private[ml](
def setMissing(value: Float): this.type = set(missing, value) def setMissing(value: Float): this.type = set(missing, value)
def setAllowZeroForMissingValue(value: Boolean): this.type = set( def setAllowZeroForMissingValue(value: Boolean): this.type = set(
allowNonZeroForMissingValue, allowNonZeroForMissing,
value value
) )
@ -261,7 +261,7 @@ class XGBoostClassificationModel private[ml](
val dm = new DMatrix(XGBoost.processMissingValues( val dm = new DMatrix(XGBoost.processMissingValues(
Iterator(features.asXGB), Iterator(features.asXGB),
$(missing), $(missing),
$(allowNonZeroForMissingValue) $(allowNonZeroForMissing)
)) ))
val probability = _booster.predict(data = dm)(0).map(_.toDouble) val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) { if (numClasses == 2) {
@ -321,7 +321,7 @@ class XGBoostClassificationModel private[ml](
XGBoost.processMissingValues( XGBoost.processMissingValues(
features.map(_.asXGB), features.map(_.asXGB),
$(missing), $(missing),
$(allowNonZeroForMissingValue) $(allowNonZeroForMissing)
), ),
cacheInfo) cacheInfo)
try { try {

View File

@ -242,7 +242,7 @@ class XGBoostRegressionModel private[ml] (
def setMissing(value: Float): this.type = set(missing, value) def setMissing(value: Float): this.type = set(missing, value)
def setAllowZeroForMissingValue(value: Boolean): this.type = set( def setAllowZeroForMissingValue(value: Boolean): this.type = set(
allowNonZeroForMissingValue, allowNonZeroForMissing,
value value
) )
@ -257,7 +257,7 @@ class XGBoostRegressionModel private[ml] (
val dm = new DMatrix(XGBoost.processMissingValues( val dm = new DMatrix(XGBoost.processMissingValues(
Iterator(features.asXGB), Iterator(features.asXGB),
$(missing), $(missing),
$(allowNonZeroForMissingValue) $(allowNonZeroForMissing)
)) ))
_booster.predict(data = dm)(0)(0) _booster.predict(data = dm)(0)(0)
} }
@ -299,7 +299,7 @@ class XGBoostRegressionModel private[ml] (
XGBoost.processMissingValues( XGBoost.processMissingValues(
features.map(_.asXGB), features.map(_.asXGB),
$(missing), $(missing),
$(allowNonZeroForMissingValue) $(allowNonZeroForMissing)
), ),
cacheInfo) cacheInfo)
try { try {

View File

@ -109,16 +109,16 @@ private[spark] trait GeneralParams extends Params {
* Allows for having a non-zero value for missing when training on prediction * Allows for having a non-zero value for missing when training on prediction
* on a Sparse or Empty vector. * on a Sparse or Empty vector.
*/ */
final val allowNonZeroForMissingValue = new BooleanParam( final val allowNonZeroForMissing = new BooleanParam(
this, this,
"allowNonZeroForMissingValue", "allowNonZeroForMissing",
"Allow to have a non-zero value for missing when training or " + "Allow to have a non-zero value for missing when training or " +
"predicting on a Sparse or Empty vector. Should only be used if did " + "predicting on a Sparse or Empty vector. Should only be used if did " +
"not use Spark's VectorAssembler class to construct the feature vector " + "not use Spark's VectorAssembler class to construct the feature vector " +
"but instead used a method that preserves zeros in your vector." "but instead used a method that preserves zeros in your vector."
) )
final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissingValue) final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
/** /**
* the maximum time to wait for the job requesting new workers. default: 30 minutes * the maximum time to wait for the job requesting new workers. default: 30 minutes
@ -191,7 +191,7 @@ private[spark] trait GeneralParams extends Params {
customObj -> null, customEval -> null, missing -> Float.NaN, customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", checkpointInterval -> -1, checkpointPath -> "", checkpointInterval -> -1,
allowNonZeroForMissingValue -> false) allowNonZeroForMissing -> false)
} }
trait HasLeafPredictionCol extends Params { trait HasLeafPredictionCol extends Params {

View File

@ -151,7 +151,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
} }
} }
test("specify a non-zero missing value but set allow_non_zero_missing_value " + test("specify a non-zero missing value but set allow_non_zero_missing " +
"does not stop application") { "does not stop application") {
val spark = ss val spark = ss
import spark.implicits._ import spark.implicits._
@ -174,7 +174,7 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
inputDF.show() inputDF.show()
val paramMap = List("eta" -> "1", "max_depth" -> "2", val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -1.0f, "objective" -> "binary:logistic", "missing" -> -1.0f,
"num_workers" -> 1, "allow_non_zero_for_missing_value" -> "true").toMap "num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF) val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect() model.transform(inputDF).collect()
} }