[jvm-packages] fix the prediction issue for multi:softmax (#7694)

This commit is contained in:
Bobby Wang
2022-02-24 01:09:45 +08:00
committed by GitHub
parent 6762c45494
commit 89aa8ddf52
2 changed files with 39 additions and 21 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,9 +17,11 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest {
@@ -102,6 +104,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(model.getEta == 0.1)
assert(model.getMaxDepth == 6)
assert(model.numClasses == 6)
val transformedDf = model.transform(trainingDF)
assert(!transformedDf.columns.contains("probability"))
}
test("use base margin") {