Align device id in predict transform with predictor. (#6662)

This commit is contained in:
Jiaming Yuan 2021-02-02 08:33:29 +08:00 committed by GitHub
parent d8ec7aad5a
commit a9ec0ea6da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 8 deletions

View File

@ -108,7 +108,7 @@ class AFTObj : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = exp(_preds[_idx]); _preds[_idx] = exp(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, }, common::Range{0, static_cast<int64_t>(io_preds->Size())},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -74,7 +74,7 @@ class HingeObj : public ObjFunction {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, common::Range{0, static_cast<int64_t>(io_preds->Size()), 1},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -136,7 +136,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass); const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata); max_preds_.Resize(ndata);
auto device = tparam_->gpu_id; auto device = io_preds->DeviceIdx();
if (prob) { if (prob) {
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {

View File

@ -113,7 +113,7 @@ class RegLossObj : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]); _preds[_idx] = Loss::PredTransform(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, }, common::Range{0, static_cast<int64_t>(io_preds->Size())},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
@ -238,7 +238,7 @@ class PoissonRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
@ -426,7 +426,7 @@ class GammaRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
@ -529,7 +529,7 @@ class TweedieRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
tparam_->gpu_id) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -17,3 +17,25 @@ TEST(Objective, UnknownFunction) {
delete obj; delete obj;
} }
} }
namespace xgboost {
TEST(Objective, PredTransform) {
// Test that show PredTransform uses the same device with predictor.
xgboost::GenericParameter tparam;
tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
size_t n = 100;
for (const auto &entry :
::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
std::unique_ptr<xgboost::ObjFunction> obj{
xgboost::ObjFunction::Create(entry->name, &tparam)};
obj->Configure(Args{{"num_class", "2"}});
HostDeviceVector<float> predts;
predts.Resize(n, 3.14f); // prediction is performed on host.
ASSERT_FALSE(predts.DeviceCanRead());
obj->PredTransform(&predts);
ASSERT_FALSE(predts.DeviceCanRead());
ASSERT_TRUE(predts.HostCanWrite());
}
}
} // namespace xgboost