[jvm-packages] call setGroup for ranking task (#2066)
* [jvm-packages] call setGroup for ranking task * passing groupData through xgBoostConfMap * fix original comment position * make groupData param * remove groupData variable, use xgBoostConfMap directly * set default groupData value * add use groupData tests * reduce rank-demo size * use TaskContext.getPartitionId() instead of mapPartitionsWithIndex * add DF use groupData test * remove unused varable
This commit is contained in:
@@ -123,6 +123,12 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||
if (xgBoostConfMap.isDefinedAt("groupData")
|
||||
&& xgBoostConfMap.get("groupData").get != null) {
|
||||
trainingSet.setGroup(
|
||||
xgBoostConfMap.get("groupData").get.asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||
watches = new mutable.HashMap[String, DMatrix] {
|
||||
put("train", trainingSet)
|
||||
|
||||
@@ -53,7 +53,14 @@ trait LearningTaskParams extends Params {
|
||||
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2)
|
||||
/**
|
||||
* group data specify each group sizes for ranking task. To correspond to partition of
|
||||
* training data, it is nested.
|
||||
*/
|
||||
val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" +
|
||||
" for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
Reference in New Issue
Block a user