Support instance weights for xgboost4j-spark (#2642)
* Support instance weights for xgboost4j-spark * Use 0.001 instead of 0 for weights * Address CR comments
This commit is contained in:
parent
ba16475c3a
commit
a00157543d
@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
|
||||
}
|
||||
|
||||
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
||||
if (trainingSet.columns.contains($(baseMarginCol))) {
|
||||
trainingSet
|
||||
} else {
|
||||
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
var newTrainingSet = trainingSet
|
||||
if (!trainingSet.columns.contains($(baseMarginCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
}
|
||||
if (!trainingSet.columns.contains($(weightCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
|
||||
}
|
||||
newTrainingSet
|
||||
}
|
||||
|
||||
/**
|
||||
@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
|
||||
val instances = ensureColumns(trainingSet).select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(baseMarginCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) =>
|
||||
col($(baseMarginCol)).cast(FloatType),
|
||||
col($(weightCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin)
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||
|
||||
@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
|
||||
*/
|
||||
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
||||
|
||||
/**
|
||||
* Instance weights column name.
|
||||
*/
|
||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||
baseMarginCol -> "baseMargin")
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight")
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||
}
|
||||
|
||||
test("test use weight") {
|
||||
import DataUtils._
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
.setPredictionCol("final_prediction")
|
||||
.setExternalMemory(true)
|
||||
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||
val predictions = model.predict(testRDD).collect().flatten
|
||||
|
||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,12 +63,14 @@ class DataBatch {
|
||||
float[] label = new float[numRows];
|
||||
int[] featureIndex = new int[numElem];
|
||||
float[] featureValue = new float[numElem];
|
||||
float[] weight = new float[numRows];
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; i < batch.size(); i++) {
|
||||
LabeledPoint labeledPoint = batch.get(i);
|
||||
rowOffset[i] = offset;
|
||||
label[i] = labeledPoint.label();
|
||||
weight[i] = labeledPoint.weight();
|
||||
if (labeledPoint.indices() != null) {
|
||||
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
||||
labeledPoint.indices().length);
|
||||
@ -84,7 +86,7 @@ class DataBatch {
|
||||
}
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, null, label, featureIndex, featureValue);
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user