[jvm-package] Clean up the legacy gpu support tests (#7523)

This commit is contained in:
Bobby Wang
2021-12-21 09:15:51 +08:00
committed by GitHub
parent 59bd1ab17e
commit e8c1eb99e4
6 changed files with 10 additions and 101 deletions

View File

@@ -16,14 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.GpuTestSuite
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
class XGBoostClassifierSuite extends FunSuite with PerTest {
protected val treeMethod: String = "auto"
@@ -200,9 +199,6 @@ abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
assert(resultDF.columns.contains("predictContrib"))
}
}
class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
@@ -220,11 +216,11 @@ class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
@@ -315,10 +311,5 @@ class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
}
@GpuTestSuite
class XGBoostGpuClassifierSuite extends XGBoostClassifierSuiteBase {
override protected val treeMethod: String = "gpu_hist"
override protected val numWorkers: Int = 1
}

View File

@@ -16,7 +16,6 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.GpuTestSuite
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._
@@ -24,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
class XGBoostRegressorSuite extends FunSuite with PerTest {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
@@ -218,13 +217,3 @@ abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
assert(resultDF.columns.contains("predictContrib"))
}
}
class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase {
}
@GpuTestSuite
class XGBoostGpuRegressorSuite extends XGBoostRegressorSuiteBase {
override protected val treeMethod: String = "gpu_hist"
override protected val numWorkers: Int = 1
}