/*! * Copyright 2017-2024 by Contributors * \file updater_quantile_hist.cc */ #include #include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" #include "xgboost/tree_updater.h" #pragma GCC diagnostic pop #include "xgboost/logging.h" #include "updater_quantile_hist.h" #include "../data.h" namespace xgboost { namespace sycl { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); DMLC_REGISTER_PARAMETER(HistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { const DeviceOrd device_spec = ctx_->Device(); qu_ = device_manager.GetQueue(device_spec); param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64); if (hist_maker_param_.single_precision_histogram || !has_fp64_support) { if (!hist_maker_param_.single_precision_histogram) { LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; } hist_precision_ = HistPrecision::fp32; } else { hist_precision_ = HistPrecision::fp64; } } template void QuantileHistMaker::SetPimpl(std::unique_ptr>* pimpl, DMatrix *dmat) { pimpl->reset(new HistUpdater( ctx_, qu_, param_, int_constraint_, dmat)); if (collective::IsDistributed()) { (*pimpl)->SetHistSynchronizer(new DistributedHistSynchronizer()); (*pimpl)->SetHistRowsAdder(new DistributedHistRowsAdder()); } else { (*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer()); (*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder()); } } template void QuantileHistMaker::CallUpdate( const std::unique_ptr>& pimpl, xgboost::tree::TrainParam const *param, linalg::Matrix *gpair, DMatrix *dmat, xgboost::common::Span> out_position, const std::vector &trees) { const auto* gpair_h = gpair->Data(); gpair_device_.Resize(&qu_, gpair_h->Size()); qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair)); qu_.wait(); for (auto tree : trees) { pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree); } } void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, linalg::Matrix* gpair, DMatrix *dmat, xgboost::common::Span> out_position, const std::vector &trees) { if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("DeviceMatrixInitialization"); sycl::DeviceMatrix dmat_device; dmat_device.Init(qu_, dmat); updater_monitor_.Stop("DeviceMatrixInitialization"); updater_monitor_.Start("GmatInitialization"); gmat_.Init(qu_, ctx_, dmat_device, static_cast(param_.max_bin)); updater_monitor_.Stop("GmatInitialization"); is_gmat_initialized_ = true; } // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); int_constraint_.Configure(param_, dmat->Info().num_col_); // build tree if (hist_precision_ == HistPrecision::fp32) { if (!pimpl_fp32) { SetPimpl(&pimpl_fp32, dmat); } CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees); } else { if (!pimpl_fp64) { SetPimpl(&pimpl_fp64, dmat); } CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees); } param_.learning_rate = lr; p_last_dmat_ = dmat; } bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, linalg::MatrixView out_preds) { if (param_.subsample < 1.0f) return false; if (hist_precision_ == HistPrecision::fp32) { if (pimpl_fp32) { return pimpl_fp32->UpdatePredictionCache(data, out_preds); } else { return false; } } else { if (pimpl_fp64) { return pimpl_fp64->UpdatePredictionCache(data, out_preds); } else { return false; } } } XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") .describe("Grow tree using quantized histogram with SYCL.") .set_body( [](Context const* ctx, ObjInfo const * task) { return new QuantileHistMaker(ctx, task); }); } // namespace tree } // namespace sycl } // namespace xgboost