xgboost/src/data/proxy_dmatrix.cc
Jiaming Yuan 3fde9361d7
[backport] Fix inplace predict with fallback when base margin is used. (#9536) (#9548)
- Copy meta info from proxy DMatrix.
- Use `std::call_once` to emit less warnings.
2023-09-05 23:38:06 +08:00

62 lines
2.2 KiB
C++

/**
* Copyright 2021-2023, XGBoost Contributors
* \file proxy_dmatrix.cc
*/
#include "proxy_dmatrix.h"
namespace xgboost::data {
void DMatrixProxy::SetArrayData(StringView interface_str) {
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
this->ctx_.gpu_id = Context::kCpuId;
}
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
char const *c_values, bst_feature_t n_features, bool on_host) {
CHECK(on_host) << "Not implemented on device.";
std::shared_ptr<CSRArrayAdapter> adapter{new CSRArrayAdapter(
StringView{c_indptr}, StringView{c_indices}, StringView{c_values}, n_features)};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
this->ctx_.gpu_id = Context::kCpuId;
}
namespace cuda_impl {
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
std::shared_ptr<DMatrixProxy> proxy, float missing);
#if !defined(XGBOOST_USE_CUDA)
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *, std::shared_ptr<DMatrixProxy>,
float) {
return nullptr;
}
#endif // XGBOOST_USE_CUDA
} // namespace cuda_impl
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
std::shared_ptr<DMatrixProxy> proxy,
float missing) {
bool type_error{false};
std::shared_ptr<DMatrix> p_fmat{nullptr};
if (proxy->Ctx()->IsCPU()) {
p_fmat = data::HostAdapterDispatch<false>(
proxy.get(),
[&](auto const &adapter) {
auto p_fmat =
std::shared_ptr<DMatrix>(DMatrix::Create(adapter.get(), missing, ctx->Threads()));
return p_fmat;
},
&type_error);
} else {
p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing);
}
CHECK(p_fmat) << "Failed to fallback.";
p_fmat->Info() = proxy->Info().Copy();
return p_fmat;
}
} // namespace xgboost::data