diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 94c99d41b..19b93cbb1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -108,11 +108,14 @@ class XGBoostEstimator private[spark]( } private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = { - if (trainingSet.columns.contains($(baseMarginCol))) { - trainingSet - } else { - trainingSet.withColumn($(baseMarginCol), lit(Float.NaN)) + var newTrainingSet = trainingSet + if (!trainingSet.columns.contains($(baseMarginCol))) { + newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN)) } + if (!trainingSet.columns.contains($(weightCol))) { + newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0)) + } + newTrainingSet } /** @@ -122,13 +125,14 @@ class XGBoostEstimator private[spark]( val instances = ensureColumns(trainingSet).select( col($(featuresCol)), col($(labelCol)).cast(FloatType), - col($(baseMarginCol)).cast(FloatType) - ).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float) => + col($(baseMarginCol)).cast(FloatType), + col($(weightCol)).cast(FloatType) + ).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) => val (indices, values) = features match { case v: SparseVector => (v.indices, v.values.map(_.toFloat)) case v: DenseVector => (null, v.values.map(_.toFloat)) } - XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin) + XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight) } transformSchema(trainingSet.schema, logging = true) val derivedXGBoosterParamMap = fromParamsToXGBParamMap diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 5be9173ab..2981246f4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -65,8 +65,13 @@ trait LearningTaskParams extends Params { */ val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name") + /** + * Instance weights column name. + */ + val weightCol = new Param[String](this, "weightCol", "weight column name") + setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null, - baseMarginCol -> "baseMargin") + baseMarginCol -> "baseMargin", weightCol -> "weight") } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index 8a5813d3c..f7bfba7c6 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.ParamMap -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DataTypes import org.scalatest.FunSuite class XGBoostDFSuite extends FunSuite with PerTest { @@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest { val predWithMargin = trainPredict(trainingDfWithMargin) assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm }) } + + test("test use weight") { + import DataUtils._ + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "weightCol" -> "weight") + + val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType) + val trainingDF = buildDataFrame(Regression.train) + .withColumn("weight", getWeightFromId(col("id"))) + + val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = true) + .setPredictionCol("final_prediction") + .setExternalMemory(true) + val testRDD = sc.parallelize(Regression.test.map(_.features)) + val predictions = model.predict(testRDD).collect().flatten + + // The predictions heavily relies on the first training instance, and thus are very close. + predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f)) + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index 4a0ff2380..61aeabd98 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -63,12 +63,14 @@ class DataBatch { float[] label = new float[numRows]; int[] featureIndex = new int[numElem]; float[] featureValue = new float[numElem]; + float[] weight = new float[numRows]; int offset = 0; for (int i = 0; i < batch.size(); i++) { LabeledPoint labeledPoint = batch.get(i); rowOffset[i] = offset; label[i] = labeledPoint.label(); + weight[i] = labeledPoint.weight(); if (labeledPoint.indices() != null) { System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset, labeledPoint.indices().length); @@ -84,7 +86,7 @@ class DataBatch { } rowOffset[batch.size()] = offset; - return new DataBatch(rowOffset, null, label, featureIndex, featureValue); + return new DataBatch(rowOffset, weight, label, featureIndex, featureValue); } @Override