Use the new DeviceOrd in the linalg module. (#9527)
This commit is contained in:
@@ -226,7 +226,7 @@ TEST(GPUPredictor, ShapStump) {
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
|
||||
auto& phis = predictions.HostVector();
|
||||
auto base_score = mparam.BaseScore(Context::kCpuId)(0);
|
||||
auto base_score = mparam.BaseScore(DeviceOrd::CPU())(0);
|
||||
EXPECT_EQ(phis[0], 0.0);
|
||||
EXPECT_EQ(phis[1], base_score);
|
||||
EXPECT_EQ(phis[2], 0.0);
|
||||
|
||||
@@ -287,7 +287,7 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
||||
|
||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
auto score = mparam.BaseScore(Context::kCpuId)(0);
|
||||
auto score = mparam.BaseScore(DeviceOrd::CPU())(0);
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
right_weight + score); // go to right for matching cat
|
||||
|
||||
Reference in New Issue
Block a user