enable ROCm on latest XGBoost
This commit is contained in:
@@ -42,7 +42,7 @@ class AFTObj : public ObjFunction {
|
||||
|
||||
template <typename Distribution>
|
||||
void GetGradientImpl(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
|
||||
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, int device,
|
||||
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, DeviceOrd device,
|
||||
bool is_null_weight, float aft_loss_distribution_scale) {
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
@@ -75,7 +75,7 @@ class AFTObj : public ObjFunction {
|
||||
CHECK_EQ(info.labels_upper_bound_.Size(), ndata);
|
||||
out_gpair->SetDevice(ctx_->Device());
|
||||
out_gpair->Reshape(ndata, 1);
|
||||
const int device = ctx_->gpu_id;
|
||||
const auto device = ctx_->Device();
|
||||
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
@@ -108,7 +108,7 @@ class AFTObj : public ObjFunction {
|
||||
_preds[_idx] = exp(_preds[_idx]);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
|
||||
io_preds->DeviceIdx())
|
||||
io_preds->Device())
|
||||
.Eval(io_preds);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user