Support instance weights for xgboost4j-spark (#2642)
* Support instance weights for xgboost4j-spark * Use 0.001 instead of 0 for weights * Address CR comments
This commit is contained in:
parent
ba16475c3a
commit
a00157543d
@ -108,11 +108,14 @@ class XGBoostEstimator private[spark](
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
||||||
if (trainingSet.columns.contains($(baseMarginCol))) {
|
var newTrainingSet = trainingSet
|
||||||
trainingSet
|
if (!trainingSet.columns.contains($(baseMarginCol))) {
|
||||||
} else {
|
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||||
trainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
|
||||||
}
|
}
|
||||||
|
if (!trainingSet.columns.contains($(weightCol))) {
|
||||||
|
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
|
||||||
|
}
|
||||||
|
newTrainingSet
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -122,13 +125,14 @@ class XGBoostEstimator private[spark](
|
|||||||
val instances = ensureColumns(trainingSet).select(
|
val instances = ensureColumns(trainingSet).select(
|
||||||
col($(featuresCol)),
|
col($(featuresCol)),
|
||||||
col($(labelCol)).cast(FloatType),
|
col($(labelCol)).cast(FloatType),
|
||||||
col($(baseMarginCol)).cast(FloatType)
|
col($(baseMarginCol)).cast(FloatType),
|
||||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) =>
|
col($(weightCol)).cast(FloatType)
|
||||||
|
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||||
val (indices, values) = features match {
|
val (indices, values) = features match {
|
||||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||||
}
|
}
|
||||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin)
|
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
|
||||||
}
|
}
|
||||||
transformSchema(trainingSet.schema, logging = true)
|
transformSchema(trainingSet.schema, logging = true)
|
||||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||||
|
|||||||
@ -65,8 +65,13 @@ trait LearningTaskParams extends Params {
|
|||||||
*/
|
*/
|
||||||
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instance weights column name.
|
||||||
|
*/
|
||||||
|
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||||
|
|
||||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||||
baseMarginCol -> "baseMargin")
|
baseMarginCol -> "baseMargin", weightCol -> "weight")
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
|||||||
@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types.DataTypes
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||||
@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
|||||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test use weight") {
|
||||||
|
import DataUtils._
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||||
|
|
||||||
|
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||||
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
|
.withColumn("weight", getWeightFromId(col("id")))
|
||||||
|
|
||||||
|
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers, useExternalMemory = true)
|
||||||
|
.setPredictionCol("final_prediction")
|
||||||
|
.setExternalMemory(true)
|
||||||
|
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||||
|
val predictions = model.predict(testRDD).collect().flatten
|
||||||
|
|
||||||
|
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||||
|
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,12 +63,14 @@ class DataBatch {
|
|||||||
float[] label = new float[numRows];
|
float[] label = new float[numRows];
|
||||||
int[] featureIndex = new int[numElem];
|
int[] featureIndex = new int[numElem];
|
||||||
float[] featureValue = new float[numElem];
|
float[] featureValue = new float[numElem];
|
||||||
|
float[] weight = new float[numRows];
|
||||||
|
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
for (int i = 0; i < batch.size(); i++) {
|
for (int i = 0; i < batch.size(); i++) {
|
||||||
LabeledPoint labeledPoint = batch.get(i);
|
LabeledPoint labeledPoint = batch.get(i);
|
||||||
rowOffset[i] = offset;
|
rowOffset[i] = offset;
|
||||||
label[i] = labeledPoint.label();
|
label[i] = labeledPoint.label();
|
||||||
|
weight[i] = labeledPoint.weight();
|
||||||
if (labeledPoint.indices() != null) {
|
if (labeledPoint.indices() != null) {
|
||||||
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
|
||||||
labeledPoint.indices().length);
|
labeledPoint.indices().length);
|
||||||
@ -84,7 +86,7 @@ class DataBatch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rowOffset[batch.size()] = offset;
|
rowOffset[batch.size()] = offset;
|
||||||
return new DataBatch(rowOffset, null, label, featureIndex, featureValue);
|
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user