Support instance weights for xgboost4j-spark (#2642)

* Support instance weights for xgboost4j-spark

* Use 0.001 instead of 0 for weights

* Address CR comments
This commit is contained in:
Yun Ni 2017-08-28 09:03:20 -07:00 committed by Nan Zhu
parent ba16475c3a
commit a00157543d
4 changed files with 42 additions and 11 deletions

View File

@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
}
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
if (trainingSet.columns.contains($(baseMarginCol))) {
trainingSet
} else {
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
var newTrainingSet = trainingSet
if (!trainingSet.columns.contains($(baseMarginCol))) {
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
}
if (!trainingSet.columns.contains($(weightCol))) {
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
}
newTrainingSet
}
/**
@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
val instances = ensureColumns(trainingSet).select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) =>
col($(baseMarginCol)).cast(FloatType),
col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin)
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
}
transformSchema(trainingSet.schema, logging = true)
val derivedXGBoosterParamMap = fromParamsToXGBParamMap

View File

@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
*/
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
/**
* Instance weights column name.
*/
val weightCol = new Param[String](this, "weightCol", "weight column name")
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
baseMarginCol -> "baseMargin")
baseMarginCol -> "baseMargin", weightCol -> "weight")
}
private[spark] object LearningTaskParams {

View File

@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite
class XGBoostDFSuite extends FunSuite with PerTest {
@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
val predWithMargin = trainPredict(trainingDfWithMargin)
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
}
test("test use weight") {
import DataUtils._
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "weightCol" -> "weight")
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
.setPredictionCol("final_prediction")
.setExternalMemory(true)
val testRDD = sc.parallelize(Regression.test.map(_.features))
val predictions = model.predict(testRDD).collect().flatten
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
}

View File

@ -63,12 +63,14 @@ class DataBatch {
float[] label = new float[numRows];
int[] featureIndex = new int[numElem];
float[] featureValue = new float[numElem];
float[] weight = new float[numRows];
int offset = 0;
for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset;
label[i] = labeledPoint.label();
weight[i] = labeledPoint.weight();
if (labeledPoint.indices() != null) {
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
labeledPoint.indices().length);
@ -84,7 +86,7 @@ class DataBatch {
}
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, null, label, featureIndex, featureValue);
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
}
@Override