Nothing
#' Cross validation check for spline in time, spline in space time and GAM in order to select the most appropriate number of knots when creating basis functions.
#'
#' @param data Raw input data
#' @param prediction_grid_res Resolution of grid. Predictions over every 50 years(default) can vary based on user preference, as larger values will reduce computational run time.
#' @param spline_nseg This setting is focused on the Noisy Input Spline model. It provides the number of segments used to create basis functions.
#' @param spline_nseg_t This setting is focused on the Noisy Input Generalised Additive Model. It provides the number of segments used to create basis functions.
#' @param spline_nseg_st This setting is focused on the Noisy Input Generalised Additive Model. It provides the number of segments used to create basis functions.
#' @param model_type The user selects their statistical model type. The user can select a Noisy Input Spline in Time using "ni_spline_t". The user can select a Noisy Input Spline in Space Time using "ni_spline_st". The user can select a Noisy Input Generalised Additive Model using "ni_gam_decomp".
#' @param n_iterations Number of iterations. Increasing this value will increase the computational run time.
#' @param n_burnin Size of burn-in. This number removes a certain number of samples at the beginning.
#' @param n_thin Amount of thinning.
#' @param n_chains Number of MCMC chains. The number of times the model will be run.
#' @param n_fold Number of folds required in the cross validation. The default is 5 fold cross validation.
#' @param seed If the user wants reproducible results, seed stores the output when random selection was used in the creation of the cross validation.
#' @param CI Size of the credible interval required by the user. The default is 0.95 corresponding to 95%.
#'
#' @return A list containing the model comparison measures, e.g. Root Mean Square Error (RMSE), and plot of true vs predicted values
#' @export
#'
#' @examples
#' \donttest{
#' data <- NAACproxydata %>% dplyr::filter(Site == "Cedar Island")
#' cross_val_check(data = data, model_type = "ni_spline_t",n_fold = 2)
#' }
cross_val_check <- function(data,
prediction_grid_res = 50,
spline_nseg = NULL,
spline_nseg_t = 20,
spline_nseg_st = 6,
n_iterations = 1000,
n_burnin = 100,
n_thin = 5,
n_chains = 2,
model_type,
n_fold = 5,
seed = NULL,
CI = 0.95) {
# Cross Validation tests-----------------
CV_fold_number <- RSL <- Region <- Site <- SiteName <- lwr_PI <- obs_in_PI <- pred <- true_RSL <- upr_PI <- xl <- xr <- y_post_pred <- NULL
# Random selection
base::set.seed(seed)
# Input data
data <- data %>%
dplyr::mutate(SiteName = as.factor(paste0(Site, ",", "\n", " ", Region)))
# Checking the number of observations > n_folds
nrow_by_site <- data %>%
dplyr::group_by(SiteName) %>%
dplyr::reframe(nrow_site = dplyr::n())
if(any(nrow_by_site$nrow_site < n_fold) == TRUE){
stop("Not enough observations in each site for the number of folds required.")
}
df_split_index <- kfold_fun(data$Age,
k = n_fold,
by = data$SiteName
)
data$CV_fold <- df_split_index
# Empty list for model runs
model_run_list <- list()
for (i in 1:n_fold) {
if (model_type == "ni_gam_decomp") {
# Segment your data by fold using the which() function
CV_fold <- base::which(df_split_index == i, arr.ind = TRUE)
# Remove SiteName column as it causes errors in next step
data_update <- data %>% dplyr::select(!SiteName)
# Run reslr_load to get GIA rates
data_GIA_df <- reslr::reslr_load(data_update,
prediction_grid_res = prediction_grid_res,
include_linear_rate = TRUE
)
data <- data_GIA_df$data
# Remove SiteName column as it causes errors in next step
data <- data %>% dplyr::select(!SiteName)
# Test and training
test_set <- data[CV_fold, ]
training_set <- data[-CV_fold, ]
# reslr_load
input_train <- reslr::reslr_load(training_set,
prediction_grid_res = prediction_grid_res,
cross_val = TRUE,
test_set = test_set,
include_linear_rate = TRUE
#include_tide_gauge = TRUE,
#all_TG_1deg = TRUE
)
# reslr_mcmc
train_output <- reslr::reslr_mcmc(input_train,
model_type = model_type,
spline_nseg_t = spline_nseg_t,
spline_nseg_st = spline_nseg_st,
n_iterations = n_iterations,
n_burnin = n_burnin,
n_thin = n_thin,
n_chains = n_chains,
CI = CI
)
} else {
# Segment your data by fold using the which() function
CV_fold <- base::which(df_split_index == i, arr.ind = TRUE)
test_set <- data[CV_fold, ]
training_set <- data[-CV_fold, ] %>% dplyr::select(-SiteName)
# reslr_load
input_train <- reslr::reslr_load(training_set,
prediction_grid_res = prediction_grid_res,
cross_val = TRUE,
test_set = test_set
)
# reslr_mcmc
train_output <- reslr::reslr_mcmc(input_train,
model_type = model_type,
spline_nseg = spline_nseg,
n_iterations = n_iterations,
n_burnin = n_burnin,
n_thin = n_thin,
n_chains = n_chains,
CI = CI
)
}
# Check convergence of model:
summary(train_output)
if (model_type == "ni_gam_decomp") {
output_df <- train_output$output_dataframes$total_model_df
} else {
# Take out the dataframe with true & predicted
output_df <- train_output$output_dataframes
}
# Column to identify the fold number in the loop
output_df$CV_fold_number <- as.character(i)
# Append this df into a list to combine to do the tests
model_run_list[i] <- list(output_df)
}
# Combining all the dataframes
CV_model_run_df <- suppressWarnings(
dplyr::bind_rows(model_run_list)
)
# Removing rows without the test set:
CV_model_df <- CV_model_run_df %>%
dplyr::filter(is.na(CV_fold) == FALSE)
# Mean Error & Mean Absolute Error & Root mean square error for each fold:
ME_MAE_RSME_fold <- CV_model_df %>%
dplyr::mutate(CV_fold_number = as.factor(CV_fold_number)) %>%
dplyr::group_by(CV_fold_number) %>%
dplyr::reframe(
RSME = unique(sqrt((sum(RSL - pred)^2) / dplyr::n())),
MAE = unique(sum(abs(RSL - pred)) / dplyr::n()),
ME = unique(mean(RSL - pred))
)
# Mean Error & Mean Absolute Error & Root mean square error overall
ME_MAE_RSME_overall <- CV_model_df %>%
dplyr::reframe(
RSME = unique(sqrt((sum(RSL - pred)^2) / dplyr::n())),
MAE = unique(sum(abs(RSL - pred)) / dplyr::n()),
ME = unique(mean(RSL - pred))
)
# Mean Error & Mean Absolute Error & Root mean square error for each fold & site:
ME_MAE_RSME_fold_site <- CV_model_df %>%
dplyr::mutate(CV_fold_number = as.factor(CV_fold_number)) %>%
dplyr::group_by(SiteName, CV_fold_number) %>%
dplyr::reframe(
RSME = unique(sqrt((sum(RSL - pred)^2) / dplyr::n())),
MAE = unique(sum(abs(RSL - pred)) / dplyr::n()),
ME = unique(mean(RSL - pred))
)
# Mean Error & Mean Absolute Error & Root mean square error for each site:
ME_MAE_RSME_site <- CV_model_df %>%
dplyr::group_by(SiteName) %>%
dplyr::reframe(
RSME = unique(sqrt((sum(RSL - pred)^2) / dplyr::n())),
MAE = unique(sum(abs(RSL - pred)) / dplyr::n()),
ME = unique(mean(RSL - pred))
)
# Model dataframe CV
CV_model_df <- CV_model_df %>%
dplyr::rename(
true_RSL = RSL,
pred_RSL = pred
)
# Creating the prediction intervals outside JAGS to include Age Error
# Overall Empirical Coverage
CV_model_df <- CV_model_df %>%
dplyr::mutate(
obs_in_PI =
ifelse(dplyr::between(
true_RSL, upr_PI,
lwr_PI
),
TRUE, FALSE
)
)
# Total coverage is trues/ number of rows with the prediction interval
total_coverage <-
length(which(CV_model_df$obs_in_PI == "TRUE")) / nrow(CV_model_df)
# Coverage is trues/ number of rows with the prediction interval by site
coverage_by_site <- CV_model_df %>%
dplyr::group_by(SiteName) %>%
dplyr::reframe(
coverage_by_site =
unique(length(which(obs_in_PI == "TRUE")) / dplyr::n())
)
# Prediction Interval size
prediction_interval_size <- CV_model_df %>%
dplyr::group_by(SiteName) %>%
dplyr::reframe(PI_width = unique(mean(upr_PI - lwr_PI)))
# True vs Predicted plot
true_pred_plot <- ggplot2::ggplot(data = CV_model_df, ggplot2::aes(
x = true_RSL,
y = y_post_pred,
colour = "PI"
)) +
ggplot2::geom_errorbar(
data = CV_model_df,
ggplot2::aes(
x = true_RSL,
ymin = lwr_PI,
ymax = upr_PI
),
colour = "red3",
width = 0, alpha = 0.5
) +
ggplot2::geom_point() +
ggplot2::geom_abline(
data = CV_model_df,
ggplot2::aes(intercept = 0, slope = 1, colour = "True = Predicted")
) +
ggplot2::theme_bw() +
ggplot2::theme(
axis.title = ggplot2::element_text(size = 9, face = "bold"),
axis.text = ggplot2::element_text(size = 9),
strip.background = ggplot2::element_rect(fill = c("white")),
strip.text = ggplot2::element_text(size = 10),
legend.text = ggplot2::element_text(size = 7),
legend.title = ggplot2::element_blank(),
axis.text.x = ggplot2::element_text(size = 8),
axis.text.y = ggplot2::element_text(size = 8)
) +
ggplot2::theme(legend.box = "horizontal", legend.position = "bottom") +
ggplot2::labs(
x = "True Relative Sea Level (m)",
y = "Predicted Relative Sea Level (m)"
) +
ggplot2::scale_colour_manual("",
values = c(
c(
"PI" = "red3",
# "True = Predicted" = "black")
"True = Predicted" = "black"
)
),
labels = c(
"PI" = paste0(unique(CV_model_df$CI), " Prediction Interval"),
"True = Predicted" = "True = Predicted"
)
) +
ggplot2::facet_wrap(~SiteName, scales = "free") +
ggplot2::guides(
colour = ggplot2::guide_legend(override.aes = list(
linetype = c(1, 1),
shape = c(NA, NA),
size = 2
))
)
# Return a list of CV tests
cross_validation_tests <- list(
ME_MAE_RSME_fold_site = ME_MAE_RSME_fold_site,
ME_MAE_RSME_site = ME_MAE_RSME_site,
ME_MAE_RSME_overall = ME_MAE_RSME_overall,
ME_MAE_RSME_fold = ME_MAE_RSME_fold,
true_pred_plot = true_pred_plot,
CV_model_df = CV_model_df,
total_coverage = total_coverage,
prediction_interval_size = prediction_interval_size,
coverage_by_site = coverage_by_site
)
return(cross_validation_tests)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.