Support instance weights for xgboost4j-spark (#2642)

* Support instance weights for xgboost4j-spark

* Use 0.001 instead of 0 for weights

* Address CR comments
This commit is contained in:
Yun Ni 2017-08-28 09:03:20 -07:00 committed by Nan Zhu
parent ba16475c3a
commit a00157543d
4 changed files with 42 additions and 11 deletions

View File

@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
} }
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = { private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
if (trainingSet.columns.contains($(baseMarginCol))) { var newTrainingSet = trainingSet
trainingSet if (!trainingSet.columns.contains($(baseMarginCol))) {
} else { newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
} }
if (!trainingSet.columns.contains($(weightCol))) {
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
}
newTrainingSet
} }
/** /**
@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
val instances = ensureColumns(trainingSet).select( val instances = ensureColumns(trainingSet).select(
col($(featuresCol)), col($(featuresCol)),
col($(labelCol)).cast(FloatType), col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType) col($(baseMarginCol)).cast(FloatType),
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) => col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match { val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat)) case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat)) case v: DenseVector => (null, v.values.map(_.toFloat))
} }
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin) XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
} }
transformSchema(trainingSet.schema, logging = true) transformSchema(trainingSet.schema, logging = true)
val derivedXGBoosterParamMap = fromParamsToXGBParamMap val derivedXGBoosterParamMap = fromParamsToXGBParamMap

View File

@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
*/ */
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name") val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
/**
* Instance weights column name.
*/
val weightCol = new Param[String](this, "weightCol", "weight column name")
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null, setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
baseMarginCol -> "baseMargin") baseMarginCol -> "baseMargin", weightCol -> "weight")
} }
private[spark] object LearningTaskParams { private[spark] object LearningTaskParams {

View File

@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite import org.scalatest.FunSuite
class XGBoostDFSuite extends FunSuite with PerTest { class XGBoostDFSuite extends FunSuite with PerTest {
@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
val predWithMargin = trainPredict(trainingDfWithMargin) val predWithMargin = trainPredict(trainingDfWithMargin)
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm }) assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
} }
test("test use weight") {
import DataUtils._
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "weightCol" -> "weight")
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
.setPredictionCol("final_prediction")
.setExternalMemory(true)
val testRDD = sc.parallelize(Regression.test.map(_.features))
val predictions = model.predict(testRDD).collect().flatten
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
} }

View File

@ -63,12 +63,14 @@ class DataBatch {
float[] label = new float[numRows]; float[] label = new float[numRows];
int[] featureIndex = new int[numElem]; int[] featureIndex = new int[numElem];
float[] featureValue = new float[numElem]; float[] featureValue = new float[numElem];
float[] weight = new float[numRows];
int offset = 0; int offset = 0;
for (int i = 0; i < batch.size(); i++) { for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i); LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset; rowOffset[i] = offset;
label[i] = labeledPoint.label(); label[i] = labeledPoint.label();
weight[i] = labeledPoint.weight();
if (labeledPoint.indices() != null) { if (labeledPoint.indices() != null) {
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset, System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
labeledPoint.indices().length); labeledPoint.indices().length);
@ -84,7 +86,7 @@ class DataBatch {
} }
rowOffset[batch.size()] = offset; rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, null, label, featureIndex, featureValue); return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
} }
@Override @Override