Support instance weights for xgboost4j-spark (#2642)

* Support instance weights for xgboost4j-spark

* Use 0.001 instead of 0 for weights

* Address CR comments
This commit is contained in:
Yun Ni
2017-08-28 09:03:20 -07:00
committed by Nan Zhu
parent ba16475c3a
commit a00157543d
4 changed files with 42 additions and 11 deletions

View File

@@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite
class XGBoostDFSuite extends FunSuite with PerTest {
@@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
val predWithMargin = trainPredict(trainingDfWithMargin)
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
}
test("test use weight") {
import DataUtils._
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "weightCol" -> "weight")
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
.setPredictionCol("final_prediction")
.setExternalMemory(true)
val testRDD = sc.parallelize(Regression.test.map(_.features))
val predictions = model.predict(testRDD).collect().flatten
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
}