Support instance weights for xgboost4j-spark (#2642)

* Support instance weights for xgboost4j-spark

* Use 0.001 instead of 0 for weights

* Address CR comments
This commit is contained in:
Yun Ni
2017-08-28 09:03:20 -07:00
committed by Nan Zhu
parent ba16475c3a
commit a00157543d
4 changed files with 42 additions and 11 deletions

View File

@@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
}
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
if (trainingSet.columns.contains($(baseMarginCol))) {
trainingSet
} else {
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
var newTrainingSet = trainingSet
if (!trainingSet.columns.contains($(baseMarginCol))) {
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
}
if (!trainingSet.columns.contains($(weightCol))) {
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
}
newTrainingSet
}
/**
@@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
val instances = ensureColumns(trainingSet).select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) =>
col($(baseMarginCol)).cast(FloatType),
col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin)
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
}
transformSchema(trainingSet.schema, logging = true)
val derivedXGBoosterParamMap = fromParamsToXGBParamMap

View File

@@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
*/
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
/**
* Instance weights column name.
*/
val weightCol = new Param[String](this, "weightCol", "weight column name")
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
baseMarginCol -> "baseMargin")
baseMarginCol -> "baseMargin", weightCol -> "weight")
}
private[spark] object LearningTaskParams {