[R] address some lintr warnings (#8609)

This commit is contained in:
James Lamb 2022-12-17 04:36:14 -06:00 committed by GitHub
parent 53e6e32718
commit 17ce1f26c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 137 additions and 116 deletions

View File

@ -629,7 +629,7 @@ xgb.attributes <- function(object) {
#' @export #' @export
xgb.config <- function(object) { xgb.config <- function(object) {
handle <- xgb.get.handle(object) handle <- xgb.get.handle(object)
.Call(XGBoosterSaveJsonConfig_R, handle); .Call(XGBoosterSaveJsonConfig_R, handle)
} }
#' @rdname xgb.config #' @rdname xgb.config

View File

@ -404,7 +404,7 @@ test_that("Configuration works", {
config <- xgb.config(bst) config <- xgb.config(bst)
xgb.config(bst) <- config xgb.config(bst) <- config
reloaded_config <- xgb.config(bst) reloaded_config <- xgb.config(bst)
expect_equal(config, reloaded_config); expect_equal(config, reloaded_config)
}) })
test_that("strict_shape works", { test_that("strict_shape works", {

View File

@ -28,7 +28,9 @@ Package loading:
require(xgboost) require(xgboost)
require(Matrix) require(Matrix)
require(data.table) require(data.table)
if (!require('vcd')) install.packages('vcd') if (!require('vcd')) {
install.packages('vcd')
}
``` ```
> **VCD** package is used for one of its embedded dataset only. > **VCD** package is used for one of its embedded dataset only.

View File

@ -163,9 +163,10 @@ We were able to get the log-odds to agree, so now let's manually calculate the s
bst_preds <- predict(bst, as.matrix(data$dates)) bst_preds <- predict(bst, as.matrix(data$dates))
# calculate the predictions casting doubles to floats # calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(fl(data$dates)<fl(node$split_condition), bst_from_json_preds <- ifelse(
as.numeric(1/(1+exp(-1*fl(node$children[[1]]$leaf)))), fl(data$dates) < fl(node$split_condition)
as.numeric(1/(1+exp(-1*fl(node$children[[2]]$leaf)))) , as.numeric(1 / (1 + exp(-1 * fl(node$children[[1]]$leaf))))
, as.numeric(1 / (1 + exp(-1 * fl(node$children[[2]]$leaf))))
) )
# test that values are equal # test that values are equal
@ -177,9 +178,10 @@ None are exactly equal again. What is going on here? Well, since we are using
How do we fix this? We have to ensure we use the correct data types everywhere and the correct operators. If we use only floats, the float library that we have loaded will ensure the 32-bit float exponentiation operator is applied. How do we fix this? We have to ensure we use the correct data types everywhere and the correct operators. If we use only floats, the float library that we have loaded will ensure the 32-bit float exponentiation operator is applied.
```{r} ```{r}
# calculate the predictions casting doubles to floats # calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(fl(data$dates)<fl(node$split_condition), bst_from_json_preds <- ifelse(
as.numeric(fl(1)/(fl(1)+exp(fl(-1)*fl(node$children[[1]]$leaf)))), fl(data$dates) < fl(node$split_condition)
as.numeric(fl(1)/(fl(1)+exp(fl(-1)*fl(node$children[[2]]$leaf)))) , as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[1]]$leaf))))
, as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[2]]$leaf))))
) )
# test that values are equal # test that values are equal

View File

@ -1,8 +1,10 @@
site <- 'http://cran.r-project.org' site <- 'http://cran.r-project.org'
if (!require('dummies')) if (!require('dummies')) {
install.packages('dummies', repos = site) install.packages('dummies', repos = site)
if (!require('insuranceData')) }
if (!require('insuranceData')) {
install.packages('insuranceData', repos = site) install.packages('insuranceData', repos = site)
}
library(dummies) library(dummies)
library(insuranceData) library(insuranceData)
@ -14,5 +16,16 @@ data$STATE <- as.factor(data$STATE)
data$CLASS <- as.factor(data$CLASS) data$CLASS <- as.factor(data$CLASS)
data$GENDER <- as.factor(data$GENDER) data$GENDER <- as.factor(data$GENDER)
data.dummy <- dummy.data.frame(data, dummy.class='factor', omit.constants=TRUE); data.dummy <- dummy.data.frame(
write.table(data.dummy, 'autoclaims.csv', sep=',', row.names=F, col.names=F, quote=F) data
, dummy.class = 'factor'
, omit.constants = TRUE
)
write.table(
data.dummy
, 'autoclaims.csv'
, sep = ','
, row.names = FALSE
, col.names = FALSE
, quote = FALSE
)

View File

@ -25,7 +25,7 @@ param <- list("objective" = "binary:logitraw",
watchlist <- list("train" = xgmat) watchlist <- list("train" = xgmat)
nrounds <- 120 nrounds <- 120
print ("loading data end, start to boost trees") print ("loading data end, start to boost trees")
bst <- xgb.train(param, xgmat, nrounds, watchlist ); bst <- xgb.train(param, xgmat, nrounds, watchlist)
# save out model # save out model
xgb.save(bst, "higgs.model") xgb.save(bst, "higgs.model")
print ('finish training') print ('finish training')

View File

@ -40,7 +40,7 @@ for (i in 1:length(threads)){
watchlist <- list("train" = xgmat) watchlist <- list("train" = xgmat)
nrounds <- 120 nrounds <- 120
print ("loading data end, start to boost trees") print ("loading data end, start to boost trees")
bst <- xgb.train(param, xgmat, nrounds, watchlist ); bst <- xgb.train(param, xgmat, nrounds, watchlist)
# save out model # save out model
xgb.save(bst, "higgs.model") xgb.save(bst, "higgs.model")
print ('finish training') print ('finish training')
@ -67,4 +67,3 @@ xgboost.time
# [[5]] # [[5]]
# user system elapsed # user system elapsed
# 157.390 5.988 40.949 # 157.390 5.988 40.949

View File

@ -24,8 +24,13 @@ param <- list("objective" = "multi:softprob",
# Run Cross Validation # Run Cross Validation
cv.nrounds <- 50 cv.nrounds <- 50
bst.cv <- xgb.cv(param=param, data = x[trind,], label = y, bst.cv <- xgb.cv(
nfold = 3, nrounds=cv.nrounds) param = param
, data = x[trind, ]
, label = y
, nfold = 3
, nrounds = cv.nrounds
)
# Train the model # Train the model
nrounds <- 50 nrounds <- 50
@ -37,7 +42,7 @@ pred <- matrix(pred,9,length(pred)/9)
pred <- t(pred) pred <- t(pred)
# Output submission # Output submission
pred <- format(pred, digits=2,scientific=F) # shrink the size of submission pred <- format(pred, digits = 2, scientific = FALSE) # shrink the size of submission
pred <- data.frame(1:nrow(pred), pred) pred <- data.frame(1:nrow(pred), pred)
names(pred) <- c('id', paste0('Class_', 1:9)) names(pred) <- c('id', paste0('Class_', 1:9))
write.csv(pred, file = 'submission.csv', quote = FALSE, row.names = FALSE) write.csv(pred, file = 'submission.csv', quote = FALSE, row.names = FALSE)