Support instance weights for xgboost4j-spark (#2642)
* Support instance weights for xgboost4j-spark * Use 0.001 instead of 0 for weights * Address CR comments
This commit is contained in:
@@ -18,11 +18,11 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
@@ -213,4 +213,24 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||
}
|
||||
|
||||
test("test use weight") {
|
||||
import DataUtils._
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
.setPredictionCol("final_prediction")
|
||||
.setExternalMemory(true)
|
||||
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||
val predictions = model.predict(testRDD).collect().flatten
|
||||
|
||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user