[jvm-package] Clean up the legacy gpu support tests (#7523)
This commit is contained in:
@@ -16,14 +16,13 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.GpuTestSuite
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.apache.spark.Partitioner
|
||||
|
||||
abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
@@ -200,9 +199,6 @@ abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
|
||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
@@ -220,11 +216,11 @@ class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||
}
|
||||
|
||||
private def checkResultsWithXGBoost4j(
|
||||
trainingDM: DMatrix,
|
||||
testDM: DMatrix,
|
||||
trainingDF: DataFrame,
|
||||
testDF: DataFrame,
|
||||
round: Int = 5): Unit = {
|
||||
trainingDM: DMatrix,
|
||||
testDM: DMatrix,
|
||||
trainingDF: DataFrame,
|
||||
testDF: DataFrame,
|
||||
round: Int = 5): Unit = {
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
@@ -315,10 +311,5 @@ class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
xgb.fit(repartitioned)
|
||||
}
|
||||
}
|
||||
|
||||
@GpuTestSuite
|
||||
class XGBoostGpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||
override protected val treeMethod: String = "gpu_hist"
|
||||
override protected val numWorkers: Int = 1
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.GpuTestSuite
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.functions._
|
||||
@@ -24,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||
@@ -218,13 +217,3 @@ abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
}
|
||||
|
||||
class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase {
|
||||
|
||||
}
|
||||
|
||||
@GpuTestSuite
|
||||
class XGBoostGpuRegressorSuite extends XGBoostRegressorSuiteBase {
|
||||
override protected val treeMethod: String = "gpu_hist"
|
||||
override protected val numWorkers: Int = 1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user