Unify test helpers for creating ctx. (#9274)
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
TEST(Plugin, LinearRegressionGPairOneAPI) {
|
||||
Context tparam = CreateEmptyGenericParam(0);
|
||||
Context tparam = MakeCUDACtx(0);
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
|
||||
std::unique_ptr<ObjFunction> obj {
|
||||
@@ -33,7 +33,7 @@ TEST(Plugin, LinearRegressionGPairOneAPI) {
|
||||
}
|
||||
|
||||
TEST(Plugin, SquaredLogOneAPI) {
|
||||
Context tparam = CreateEmptyGenericParam(0);
|
||||
Context tparam = MakeCUDACtx(0);
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
|
||||
std::unique_ptr<ObjFunction> obj { ObjFunction::Create("reg:squaredlogerror_oneapi", &tparam) };
|
||||
@@ -56,7 +56,7 @@ TEST(Plugin, SquaredLogOneAPI) {
|
||||
}
|
||||
|
||||
TEST(Plugin, LogisticRegressionGPairOneAPI) {
|
||||
Context tparam = CreateEmptyGenericParam(0);
|
||||
Context tparam = MakeCUDACtx(0);
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
std::unique_ptr<ObjFunction> obj { ObjFunction::Create("reg:logistic_oneapi", &tparam) };
|
||||
|
||||
@@ -72,7 +72,7 @@ TEST(Plugin, LogisticRegressionGPairOneAPI) {
|
||||
}
|
||||
|
||||
TEST(Plugin, LogisticRegressionBasicOneAPI) {
|
||||
Context lparam = CreateEmptyGenericParam(0);
|
||||
Context lparam = MakeCUDACtx(0);
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
std::unique_ptr<ObjFunction> obj {
|
||||
ObjFunction::Create("reg:logistic_oneapi", &lparam)
|
||||
@@ -103,7 +103,7 @@ TEST(Plugin, LogisticRegressionBasicOneAPI) {
|
||||
}
|
||||
|
||||
TEST(Plugin, LogisticRawGPairOneAPI) {
|
||||
Context lparam = CreateEmptyGenericParam(0);
|
||||
Context lparam = MakeCUDACtx(0);
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
std::unique_ptr<ObjFunction> obj {
|
||||
ObjFunction::Create("binary:logitraw_oneapi", &lparam)
|
||||
@@ -120,7 +120,7 @@ TEST(Plugin, LogisticRawGPairOneAPI) {
|
||||
}
|
||||
|
||||
TEST(Plugin, CPUvsOneAPI) {
|
||||
Context ctx = CreateEmptyGenericParam(0);
|
||||
Context ctx = MakeCUDACtx(0);
|
||||
|
||||
ObjFunction * obj_cpu =
|
||||
ObjFunction::Create("reg:squarederror", &ctx);
|
||||
@@ -140,8 +140,8 @@ TEST(Plugin, CPUvsOneAPI) {
|
||||
}
|
||||
auto& info = pdmat->Info();
|
||||
|
||||
info.labels_.Resize(kRows);
|
||||
auto& h_labels = info.labels_.HostVector();
|
||||
info.labels.Reshape(kRows, 1);
|
||||
auto& h_labels = info.labels.Data()->HostVector();
|
||||
for (size_t i = 0; i < h_labels.size(); ++i) {
|
||||
h_labels[i] = 1 / static_cast<float>(i+1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user