[R] Support multi-class custom objective. (#9526)
This commit is contained in:
parent
90ef250ea1
commit
be6a552956
@ -151,14 +151,30 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
|
|||||||
if (is.null(obj)) {
|
if (is.null(obj)) {
|
||||||
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
||||||
} else {
|
} else {
|
||||||
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE,
|
pred <- predict(
|
||||||
ntreelimit = 0)
|
booster_handle,
|
||||||
|
dtrain,
|
||||||
|
outputmargin = TRUE,
|
||||||
|
training = TRUE,
|
||||||
|
reshape = TRUE
|
||||||
|
)
|
||||||
gpair <- obj(pred, dtrain)
|
gpair <- obj(pred, dtrain)
|
||||||
n_samples <- dim(dtrain)[1]
|
n_samples <- dim(dtrain)[1]
|
||||||
# We still require row-major in R as I'm not quite sure sure how to get the stride of
|
|
||||||
# the matrix in C.
|
msg <- paste(
|
||||||
gpair$grad <- matrix(gpair$grad, nrow = n_samples, byrow = TRUE)
|
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
||||||
gpair$hess <- matrix(gpair$hess, nrow = n_samples, byrow = TRUE)
|
"(n_samples, n_targets) or (n_samples, n_classes).",
|
||||||
|
sep = ""
|
||||||
|
)
|
||||||
|
if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) {
|
||||||
|
warning(msg)
|
||||||
|
}
|
||||||
|
if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) {
|
||||||
|
warning(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
gpair$grad <- matrix(gpair$grad, nrow = n_samples)
|
||||||
|
gpair$hess <- matrix(gpair$hess, nrow = n_samples)
|
||||||
.Call(
|
.Call(
|
||||||
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
|
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
|
||||||
)
|
)
|
||||||
|
|||||||
@ -403,7 +403,7 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
|||||||
|
|
||||||
XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) {
|
XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length";
|
CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length.";
|
||||||
SEXP gdim = getAttrib(grad, R_DimSymbol);
|
SEXP gdim = getAttrib(grad, R_DimSymbol);
|
||||||
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
|
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
|
||||||
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
|
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
|
||||||
@ -415,8 +415,8 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
|
|||||||
double const *d_hess = REAL(hess);
|
double const *d_hess = REAL(hess);
|
||||||
|
|
||||||
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
|
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
|
||||||
auto [s_grad, s_hess] =
|
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
|
||||||
xgboost::detail::MakeGradientInterface(ctx, d_grad, d_hess, n_samples, n_targets);
|
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
|
||||||
CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
|
CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
|
||||||
asInteger(iter), s_grad.c_str(), s_hess.c_str()));
|
asInteger(iter), s_grad.c_str(), s_hess.c_str()));
|
||||||
|
|
||||||
@ -435,7 +435,7 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
|||||||
std::vector<const char*> vec_sptr;
|
std::vector<const char*> vec_sptr;
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
||||||
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
|
vec_names.emplace_back(CHAR(asChar(VECTOR_ELT(evnames, i))));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
vec_sptr.push_back(vec_names[i].c_str());
|
vec_sptr.push_back(vec_names[i].c_str());
|
||||||
|
|||||||
@ -64,23 +64,80 @@ test_that("custom objective using DMatrix attr works", {
|
|||||||
expect_equal(class(bst), "xgb.Booster")
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("custom objective with multi-class works", {
|
test_that("custom objective with multi-class shape", {
|
||||||
data <- as.matrix(iris[, -5])
|
data <- as.matrix(iris[, -5])
|
||||||
label <- as.numeric(iris$Species) - 1
|
label <- as.numeric(iris$Species) - 1
|
||||||
dtrain <- xgb.DMatrix(data = data, label = label)
|
dtrain <- xgb.DMatrix(data = data, label = label)
|
||||||
nclasses <- 3
|
n_classes <- 3
|
||||||
|
|
||||||
fake_softprob <- function(preds, dtrain) {
|
fake_softprob <- function(preds, dtrain) {
|
||||||
expect_true(all(matrix(preds) == 0.5))
|
expect_true(all(matrix(preds) == 0.5))
|
||||||
grad <- rnorm(dim(as.matrix(preds))[1])
|
## use numeric vector here to test compatibility with XGBoost < 2.1
|
||||||
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
grad <- rnorm(length(as.matrix(preds)))
|
||||||
hess <- rnorm(dim(as.matrix(preds))[1])
|
expect_equal(dim(data)[1] * n_classes, dim(as.matrix(preds))[1] * n_classes)
|
||||||
return (list(grad = grad, hess = hess))
|
hess <- rnorm(length(as.matrix(preds)))
|
||||||
|
return(list(grad = grad, hess = hess))
|
||||||
}
|
}
|
||||||
fake_merror <- function(preds, dtrain) {
|
fake_merror <- function(preds, dtrain) {
|
||||||
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
expect_equal(dim(data)[1] * n_classes, dim(as.matrix(preds))[1])
|
||||||
}
|
}
|
||||||
param$objective <- fake_softprob
|
param$objective <- fake_softprob
|
||||||
param$eval_metric <- fake_merror
|
param$eval_metric <- fake_merror
|
||||||
bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
|
bst <- xgb.train(param, dtrain, 1, num_class = n_classes)
|
||||||
|
})
|
||||||
|
|
||||||
|
softmax <- function(values) {
|
||||||
|
values <- as.numeric(values)
|
||||||
|
exps <- exp(values)
|
||||||
|
den <- sum(exps)
|
||||||
|
return(exps / den)
|
||||||
|
}
|
||||||
|
|
||||||
|
softprob <- function(predt, dtrain) {
|
||||||
|
y <- getinfo(dtrain, "label")
|
||||||
|
|
||||||
|
n_samples <- dim(predt)[1]
|
||||||
|
n_classes <- dim(predt)[2]
|
||||||
|
|
||||||
|
grad <- matrix(nrow = n_samples, ncol = n_classes)
|
||||||
|
hess <- matrix(nrow = n_samples, ncol = n_classes)
|
||||||
|
|
||||||
|
for (i in seq_len(n_samples)) {
|
||||||
|
t <- y[i]
|
||||||
|
p <- softmax(predt[i, ])
|
||||||
|
for (c in seq_len(n_classes)) {
|
||||||
|
g <- if (c - 1 == t) {
|
||||||
|
p[c] - 1.0
|
||||||
|
} else {
|
||||||
|
p[c]
|
||||||
|
}
|
||||||
|
h <- max((2.0 * p[c] * (1.0 - p[c])), 1e-6)
|
||||||
|
grad[i, c] <- g
|
||||||
|
hess[i, c] <- h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return(list(grad = grad, hess = hess))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_that("custom objective with multi-class works", {
|
||||||
|
data <- as.matrix(iris[, -5])
|
||||||
|
label <- as.numeric(iris$Species) - 1
|
||||||
|
|
||||||
|
dtrain <- xgb.DMatrix(data = data, label = label)
|
||||||
|
|
||||||
|
param$num_class <- 3
|
||||||
|
param$objective <- softprob
|
||||||
|
param$eval_metric <- "merror"
|
||||||
|
param$base_score <- 0.5
|
||||||
|
|
||||||
|
custom_bst <- xgb.train(param, dtrain, 2)
|
||||||
|
custom_predt <- predict(custom_bst, dtrain)
|
||||||
|
|
||||||
|
param$objective <- "multi:softmax"
|
||||||
|
builtin_bst <- xgb.train(param, dtrain, 2)
|
||||||
|
builtin_predt <- predict(builtin_bst, dtrain)
|
||||||
|
|
||||||
|
expect_equal(custom_predt, builtin_predt)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -602,6 +602,13 @@ auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
|||||||
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
|
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... S>
|
||||||
|
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
||||||
|
std::size_t in_shape[sizeof...(S)];
|
||||||
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
|
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Ordinal(), order};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
||||||
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
||||||
|
|||||||
@ -607,8 +607,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto ctx = xgboost::detail::BoosterCtx(handle);
|
auto ctx = xgboost::detail::BoosterCtx(handle);
|
||||||
auto [s_grad, s_hess] =
|
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
|
||||||
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
|
ctx, grad, hess, xgboost::linalg::kC, n_samples, n_targets);
|
||||||
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
|
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
|
||||||
s_hess.c_str());
|
s_hess.c_str());
|
||||||
|
|
||||||
|
|||||||
@ -354,12 +354,12 @@ void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data
|
|||||||
* @brief Make array interface for other language bindings.
|
* @brief Make array interface for other language bindings.
|
||||||
*/
|
*/
|
||||||
template <typename G, typename H>
|
template <typename G, typename H>
|
||||||
auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std::size_t n_samples,
|
auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, linalg::Order order,
|
||||||
std::size_t n_targets) {
|
std::size_t n_samples, std::size_t n_targets) {
|
||||||
auto t_grad =
|
auto t_grad = linalg::MakeTensorView(ctx, order, common::Span{grad, n_samples * n_targets},
|
||||||
linalg::MakeTensorView(ctx, common::Span{grad, n_samples * n_targets}, n_samples, n_targets);
|
n_samples, n_targets);
|
||||||
auto t_hess =
|
auto t_hess = linalg::MakeTensorView(ctx, order, common::Span{hess, n_samples * n_targets},
|
||||||
linalg::MakeTensorView(ctx, common::Span{hess, n_samples * n_targets}, n_samples, n_targets);
|
n_samples, n_targets);
|
||||||
auto s_grad = linalg::ArrayInterfaceStr(t_grad);
|
auto s_grad = linalg::ArrayInterfaceStr(t_grad);
|
||||||
auto s_hess = linalg::ArrayInterfaceStr(t_hess);
|
auto s_hess = linalg::ArrayInterfaceStr(t_hess);
|
||||||
return std::make_tuple(s_grad, s_hess);
|
return std::make_tuple(s_grad, s_hess);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user