[jvm-packages] fix the prediction issue for multi:softmax (#7694)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -17,9 +17,11 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
@@ -102,6 +104,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(model.getEta == 0.1)
|
||||
assert(model.getMaxDepth == 6)
|
||||
assert(model.numClasses == 6)
|
||||
val transformedDf = model.transform(trainingDF)
|
||||
assert(!transformedDf.columns.contains("probability"))
|
||||
}
|
||||
|
||||
test("use base margin") {
|
||||
|
||||
Reference in New Issue
Block a user