Support instance weights for xgboost4j-spark (#2642)
* Support instance weights for xgboost4j-spark * Use 0.001 instead of 0 for weights * Address CR comments
This commit is contained in:
@@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
|
||||
}
|
||||
|
||||
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
||||
if (trainingSet.columns.contains($(baseMarginCol))) {
|
||||
trainingSet
|
||||
} else {
|
||||
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
var newTrainingSet = trainingSet
|
||||
if (!trainingSet.columns.contains($(baseMarginCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
}
|
||||
if (!trainingSet.columns.contains($(weightCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
|
||||
}
|
||||
newTrainingSet
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
|
||||
val instances = ensureColumns(trainingSet).select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(baseMarginCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) =>
|
||||
col($(baseMarginCol)).cast(FloatType),
|
||||
col($(weightCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin)
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||
|
||||
@@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
|
||||
*/
|
||||
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
||||
|
||||
/**
|
||||
* Instance weights column name.
|
||||
*/
|
||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||
baseMarginCol -> "baseMargin")
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight")
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
@@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
@@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||
}
|
||||
|
||||
test("test use weight") {
|
||||
import DataUtils._
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
.setPredictionCol("final_prediction")
|
||||
.setExternalMemory(true)
|
||||
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||
val predictions = model.predict(testRDD).collect().flatten
|
||||
|
||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user