Cleanup set info. (#10139)
- Use the array interface internally. - Deprecate `XGDMatrixSetDenseInfo`. - Deprecate `XGDMatrixSetUIntInfo`. - Move the handling of `DataType` into the deprecated C function. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2024 by XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
@@ -614,8 +614,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -634,8 +634,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -679,19 +680,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
||||
xgboost::bst_ulong size, int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
Context ctx;
|
||||
auto dtype = static_cast<DataType>(type);
|
||||
std::string str;
|
||||
auto proc = [&](auto cast_d_ptr) {
|
||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||
auto t = linalg::TensorView<T, 1>(
|
||||
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
|
||||
{size}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json interface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{interface}.is_contiguous);
|
||||
str = Json::Dump(interface);
|
||||
return str;
|
||||
};
|
||||
|
||||
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
|
||||
switch (dtype) {
|
||||
case xgboost::DataType::kFloat32: {
|
||||
auto cast_ptr = reinterpret_cast<const float *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kDouble: {
|
||||
auto cast_ptr = reinterpret_cast<const double *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt32: {
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt64: {
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
|
||||
}
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -987,7 +1021,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs
|
||||
bst_float *hess, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
auto ctx = learner->Ctx()->MakeCPU();
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <tuple> // for make_tuple
|
||||
#include <utility> // for move
|
||||
#include <vector>
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <functional> // for multiplies
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <string> // for string
|
||||
#include <tuple> // for make_tuple
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/json_utils.h" // for TypeCheck
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
Reference in New Issue
Block a user