R/utils.R

Defines functions boot_fun log_spaced output_table summary.autoMrP plot.autoMrP predict_glmmLasso multicore quiet loss_score_ranking mean_squared_false_error f1_score binary_cross_entropy mean_absolute_error mean_squared_error loss_function model_list_pca model_list cv_folding ebma_folding error_checks

Documented in binary_cross_entropy cv_folding ebma_folding error_checks f1_score log_spaced loss_function loss_score_ranking mean_absolute_error mean_squared_error mean_squared_false_error model_list model_list_pca multicore output_table plot.autoMrP predict_glmmLasso quiet summary.autoMrP

################################################################################
#            Function to check if arguments to auto_MrP() are valid            #
################################################################################

#' Catches user input errors
#'
#' \code{error_checks()} checks for incorrect data entry in \code{autoMrP()}
#' call.
#'
#' @inheritParams auto_MrP
#' @return No return value, called for detection of errors in autoMrP() call.

error_checks <- function(y, L1.x, L2.x, L2.unit, L2.reg, L2.x.scale, pcs,
                         folds, bin.proportion, bin.size, survey, census,
                         ebma.size, k.folds, cv.sampling, loss.unit, loss.fun,
                         best.subset, lasso, pca, gb, svm, mrp, forward.select,
                         best.subset.L2.x, lasso.L2.x, gb.L2.x, svm.L2.x,
                         mrp.L2.x, gb.L2.unit, gb.L2.reg, lasso.lambda,
                         lasso.n.iter, uncertainty, boot.iter) {


  # Check if y is a character scalar
  if (!(is.character(y) & length(y) == 1)) {
    stop(paste("The argument 'y', specifying the outcome variable, must be a",
               " character scalar.", sep = ""))
  }

  # Check if y is in survey data
  if (!(y %in% colnames(survey))) {
    stop(paste("Outcome '", y,
               "' is not in your survey data.", sep = ""))
  }

  # Check if L1.x is a character vector
  if (!is.character(L1.x)) {
    stop(paste("The argument 'L1.x', specifying the individual-level variables",
               " to be used to predict y, must be a character vector.",
               sep = ""))
  }

  # Check if L1.x is in survey data
  if (!all(L1.x %in% colnames(survey))) {
    stop(cat(paste("Individual-level variable '",
                   L1.x[which(!(L1.x %in% colnames(survey)))],
                   "', specified in argument 'L1.x', is not in your survey",
                   " data.\n", sep = ""), sep = ""))
  }

  # Check if L1.x is in census data
  if (!all(L1.x %in% colnames(census))) {
    stop(cat(paste("Individual-level variable '",
                   L1.x[which(!(L1.x %in% colnames(census)))],
                   "', specified in argument 'L1.x', is not in your census",
                   " data.\n", sep = ""), sep = ""))
  }

  # Check if L2.x is a character vector
  if (!is.character(L2.x) & !is.null(L2.x)) {
    stop(paste("The argument 'L2.x', specifying the context-level variables to",
               " be used to predict y, must be a character vector.", sep = ""))
  }

  # Check if L2.x is in survey data (unless it is empty)
  if (all(L2.x != "")){
    if (!all(L2.x %in% colnames(survey))) {
      stop(cat(paste("Context-level variable '",
                     L2.x[which(!(L2.x %in% colnames(survey)))],
                     "', specified in argument 'L2.x', is not in your survey",
                     " data.\n", sep = ""), sep = ""))
    }

    # Check if L2.x is in census data
    if (!all(L2.x %in% colnames(census))) {
      stop(cat(paste("Context-level variable '",
                     L2.x[which(!(L2.x %in% colnames(census)))],
                     "', specified in argument 'L2.x', is not in your census",
                     " data.\n", sep = ""), sep = ""))
    }
  }

  # Check if L2.unit is a character scalar
  if (!(is.character(L2.unit) & length(L2.unit) == 1)) {
    stop(paste("The argument 'L2.unit', specifying the geographic unit,",
               " must be a character scalar.", sep = ""))
  }

  # Check if L2.unit is in survey data
  if (!(L2.unit %in% colnames(survey))) {
    stop(paste("The geographic unit '", L2.unit,
               "' is not in your survey data.", sep = ""))
  }

  # Check if L2.unit is in census data
  if (!(L2.unit %in% colnames(census))) {
    stop(paste("The geographic unit '", L2.unit,
               "' is not in your census data.", sep = ""))
  }

  # Check if L2.reg is NULL
  if (!is.null(L2.reg)) {
    # Check if L2.reg is a character scalar
    if (!(is.character(L2.reg) & length(L2.reg) == 1)) {
      stop(paste("The argument 'L2.reg', specifying the geographic region,",
                 " must be a character scalar.", sep = ""))
    }

    # Check if L2.reg is in survey data
    if (!(L2.reg %in% colnames(survey))) {
      stop(paste("The geographic region '", L2.reg,
                 "' is not in your survey data.", sep = ""))
    }

    # Check if L2.reg is in census data
    if (!(L2.reg %in% colnames(census))) {
      stop(paste("The geographic region '", L2.reg,
                 "' is not in your census data.", sep = ""))
    }

    # Check if each geographic unit is nested in only one geographic region
    # in survey data
    if (any(unlist(lapply(dplyr::group_split(survey, .data[[L2.unit]]),
                          function(x) length(unique(x[[L2.reg]])))) > 1)) {
      stop(cat(paste("The geographic unit '",
                     L2.unit[which(unlist(lapply(dplyr::group_split(survey, .data[[L2.unit]]),
                                                 function(x) length(unique(x[[L2.reg]])))) > 1)],
                     "' is nested in multiple regions in your survey data.\n"), sep = ""))
    }

    # Check if each geographic unit is nested in only one geographic region
    # in census data
    if (any(unlist(lapply(dplyr::group_split(census, .data[[L2.unit]]),
                          function(x) length(unique(x[[L2.reg]])))) > 1)) {
      stop(cat(paste("The geographic unit '",
                     L2.unit[which(unlist(lapply(dplyr::group_split(census, .data[[L2.unit]]),
                                                 function(x) length(unique(x[[L2.reg]])))) > 1)],
                     "' is nested in multiple regions in your census data.\n"), sep = ""))
    }
  }

  # Check if L2.x.scale is logical
  if (!is.logical(L2.x.scale)) {
    stop(paste("The logical argument 'L2.x.scale' must be either TRUE or",
               " FALSE.", sep = ""))
  }

  # Check if folds is NULL
  if (!is.null(folds)) {
    # Check if folds is a character scalar
    if (!(is.character(folds) & length(folds) == 1)) {
      stop(paste("The argument 'folds', specifying the fold to which each",
                 " observation is to be allocated, must be a character scalar.",
                 sep = ""))
    }

    # Check if folds is in survey data
    if (!(folds %in% colnames(survey))) {
      stop(paste("Fold variable '", folds,
                 "' is not in your survey data.", sep = ""))
    }

    # Check if folds contains a sequence of integer numbers
    folds_var <- survey %>%
      dplyr::select_at(.vars = folds) %>%
      dplyr::pull() %>%
      unique() %>%
      sort()

    if (isFALSE(all(dplyr::near(folds_var, as.integer(folds_var))))) {
      stop(paste("Fold variable '", folds,
                 "' must contain integer numbers only.", sep = ""))
    }

    if (!(folds_var == 1:max(folds_var))) {
      stop(paste("Fold variable '", folds,
                 "' must contain a sequence of integers running from 1 to ",
                 max(folds_var), ".", sep = ""))
    }
  } else {
    # Check if ebma.size is NULL
    if (is.null(ebma.size)) {
      stop(paste("If argument 'folds' is NULL, then argument 'ebma.size' must",
                 " be specified.", sep = ""))
    } else {
      # Check if ebma.size is a proportion in the open unit interval
      if (!(is.numeric(ebma.size) & ebma.size >= 0 & ebma.size < 1)) {
        stop(paste("The argument 'ebma.size', specifying the share of",
                   " respondents to be allocated to the EBMA fold, must take a",
                   " number in the open unit interval.", sep = ""))
      }
      # Check if ebma.size is 0
      if (ebma.size == 0){
        # Check if all classifiers are turned off and mrp is turned on
        if (sum(isTRUE(best.subset), isTRUE(lasso), isTRUE(pca), isTRUE(gb),
                isTRUE(svm),isTRUE(mrp)) > 1){
          stop(paste("If ebma.size = 0, then only 1 classifier out of best.subset,",
                     " lasso, pca, gb, svm, and mrp can be set to TRUE and all",
                     " other classifiers must be set to FALSE. Try setting these",
                     " explicitly.", sep = ""))
        }
       }
    }

    # Check if k.folds is NULL
    if (is.null(k.folds)) {
      stop(paste("If argument 'folds' is NULL, then argument 'k.folds' must",
                 " be specified.", sep = ""))
    } else {
      # Check if k.folds is an integer-valued scalar
      if (!(dplyr::near(k.folds, as.integer(k.folds)) &
            length(k.folds) == 1)) {
        stop(paste("The argument 'k.folds', specifying the number of folds to",
                   " be used in cross-validation, must be an integer-valued",
                   " scalar.",
                   sep = ""))
      } else {
        # Check if k.folds is less than or equal to the number of survey
        # respondents
        if (k.folds > nrow(survey)) {
          stop(paste("The argument 'k.folds', specifying the number of folds",
                     " to be used in cross-validation, cannot be larger than",
                     " the number of survey respondents, ", nrow(survey), ".",
                     sep = ""))
        }
      }
    }

    # Check if cv.sampling is NULL
    if (is.null(cv.sampling)) {
      stop(paste("If argument 'folds' is NULL, then argument 'cv.sampling'",
                 " must be specified.", sep = ""))
    } else {
      # Check if cv.sampling is either "individuals" or "L2 units"
      if (!cv.sampling %in% c("individuals", "L2 units")) {
        stop(paste("The argument 'cv.sampling', specifying the sampling method",
                   " to be used for cross-validation, must be either",
                   " 'individuals' or 'L2 units'.", sep = ""))
      }
    }
  }

  # Check if bin.proportion is NULL
  if (!is.null(bin.proportion)) {
    # Check if bin.proportion is a character scalar
    if (!(is.character(bin.proportion) & length(bin.proportion) == 1)) {
      stop(paste("The argument 'bin.proportion', specifying the variable that",
                 " indicates the proportion of ideal types in the census data,",
                 " must be a character scalar.", sep = ""))
    }

    # Check if bin.proportion is in census data
    if (!(bin.proportion %in% colnames(census))) {
      stop(paste("Variable '", bin.proportion,
                 "', indicating the proportion of ideal types, is not in your",
                 " census data.", sep = ""))
    }

    # Check if bin.proportion is a proportion
    bin_proportion_var <- census %>%
      dplyr::select_at(.vars = bin.proportion) %>%
      dplyr::pull() %>%
      unique()

    if (!(is.numeric(bin_proportion_var) &
          min(bin_proportion_var) >= 0 &
          max(bin_proportion_var) <= 1)) {
      stop(paste("Variable '", bin.proportion,
                 "', indicating the proportion of ideal types, can only take",
                 " values lying in the unit interval.", sep = ""))
    }
  } else {
    # Check if bin.size is NULL
    if (!is.null(bin.size)) {
      # Check if bin.size is a character scalar
      if (!(is.character(bin.size) & length(bin.size) == 1)) {
        stop(paste("The argument 'bin.size', specifying the variable that",
                   " indicates the bin size of ideal types in the census data,",
                   " must be a character scalar.", sep = ""))
      }

      # Check if bin.size is in census data
      if (!(bin.size %in% colnames(census))) {
        stop(paste("Variable '", bin.size,
                   "', indicating the bin size of ideal types, is not in your",
                   " census data.", sep = ""))
      }

      # Check if bin.size contains only non-negative numbers
      bin_size_var <- census %>%
        dplyr::select_at(.vars = bin.size) %>%
        dplyr::pull() %>%
        unique()

      if (!is.numeric(bin_size_var)) {
        stop(paste("Variable '", bin.size,
                   "', indicating the bin size of ideal types, must be numeric.",
                   sep = ""))
      }

      if (min(bin_size_var) < 0) {
        stop(paste("Variable '", bin.size,
                   "', indicating the bin size of ideal types, can only take",
                   " non-negative values.", sep = ""))
      }
    } else {
      stop(paste("Either argument 'bin.proportion' or argment 'bin.size' must",
                 " be specified to perform post-stratification.", sep = ""))
    }
  }

  # Check if survey data is provided as a data.frame
  if (is.null(survey)) {
    stop(paste("Argument 'survey' cannot be NULL. Please provide survey data.",
               sep = ""))
  } else {
    if (!is.data.frame(survey)) {
      stop(paste("The argument 'survey', specifying the survey data,",
                 " must be a data.frame.", sep = ""))
    }
  }

  # Check if census data is provided as a data.frame
  if (is.null(census)) {
    stop(paste("Argument 'census' cannot be NULL. Please provide census data.",
               sep = ""))
  } else {
    if (!is.data.frame(census)) {
      stop(paste("The argument 'census', specifying the census data,",
                 " must be a data.frame.", sep = ""))
    }
  }

  # Check if loss.unit is either "individuals" or "L2 units"
  if (!all(loss.unit %in% c("individuals", "L2 units"))) {
    stop(paste("The argument 'loss.unit', specifying the level at which to",
               " evaluate prediction performance, must be either",
               " 'individuals' or 'L2 units'.", sep = ""))
  }

  # Check if loss.fun is either "MSE" or "MAE"
  if (!all(loss.fun %in% c("MSE", "MAE", "cross-entropy", "f1", "msfe"))) {
    stop(paste("The argument 'loss.fun', specifying the loss function used",
               " to measure prediction performance, must be either",
               " 'MSE', 'MAE', 'cross-entropy', 'f1', or 'msfe'.", sep = ""))
  }

  # Check if best.subset is logical
  if (is.logical(best.subset)) {
    # Check if best.subset is TRUE
    if (isTRUE(best.subset)) {
      # Check if best.subset.L2.x is NULL
      if (!is.null(best.subset.L2.x)) {
        # Check if best.subset.L2.x is a character vector
        if (!is.character(best.subset.L2.x)) {
          stop(paste("The argument 'best.subset.L2.x', specifying the context-level",
                     " variables to be used by the best subset classifier, must be",
                     " a character vector.", sep = ""))
        }

        # Check if best.subset.L2.x is in survey data
        if (!all(best.subset.L2.x %in% colnames(survey))) {
          stop(cat(paste("Context-level variable '",
                         best.subset.L2.x[which(!(best.subset.L2.x %in% colnames(survey)))],
                         "', specified in argument 'best.subset.L2.x' to be used by the",
                         " best subset classifier, is not in your survey data.", sep = ""),
                   sep = ""))
        }

        # Check if best.subset.L2.x is in census data
        if (!all(best.subset.L2.x %in% colnames(census))) {
          stop(cat(paste("Context-level variable '",
                         best.subset.L2.x[which(!(best.subset.L2.x %in% colnames(census)))],
                         "', specified in argument 'best.subset.L2.x' to be used by the",
                         " best subset classifier, is not in your census data.", sep = ""),
                   sep = ""))
        }
      }
    } else {
      # Check if best.subset.L2.x is NULL
      if (!is.null(best.subset.L2.x)) {
        warning(paste("The argument 'best.subset.L2.x', specifying the context-level",
                      " variables to be used by the best subset classifier, will be",
                      " ignored because 'best.subset' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'best.subset', indicating whether the",
               " best subset classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if lasso is logical
  if (is.logical(lasso)) {
    # Check if lasso is TRUE
    if (isTRUE(lasso)) {
      # Check if lasso.L2.x is NULL
      if (!is.null(lasso.L2.x)) {
        # Check if lasso.L2.x is a character vector
        if (!is.character(lasso.L2.x)) {
          stop(paste("The argument 'lasso.L2.x', specifying the context-level",
                     " variables to be used by the lasso classifier, must be",
                     " a character vector.", sep = ""))
        }

        # Check if lasso.L2.x is in survey data
        if (!all(lasso.L2.x %in% colnames(survey))) {
          stop(cat(paste("Context-level variable '",
                         lasso.L2.x[which(!(lasso.L2.x %in% colnames(survey)))],
                         "', specified in argument 'lasso.L2.x' to be used by the",
                         " lasso classifier, is not in your survey data.", sep = ""),
                   sep = ""))
        }

        # Check if lasso.L2.x is in census data
        if (!all(lasso.L2.x %in% colnames(census))) {
          stop(cat(paste("Context-level variable '",
                         lasso.L2.x[which(!(lasso.L2.x %in% colnames(census)))],
                         "', specified in argument 'lasso.L2.x' to be used by the",
                         " lasso classifier, is not in your census data.", sep = ""),
                   sep = ""))
        }
      }

      # Check if is provided but not a numeric vector
      if (!is.null(lasso.lambda)){
        if (!is.numeric(lasso.lambda)){
          stop("lasso.lambda must be 'NULL' or a non-negative numeric vector.")
        } else{
          # Check if lasso.lambda contains non-negative values
          if (!all(lasso.lambda > 0)){
            stop("lasso.lambda must not contain negative values")
          }
        }
      }

      # Check if lasso.n.iter is NULL
      if (!is.null(lasso.n.iter)) {
        if (!(dplyr::near(lasso.n.iter, as.integer(lasso.n.iter)) &
              length(lasso.n.iter) == 1)) {
          stop("lasso.n.iter specifies the Lasso grid size. It must be a non-negative integer valued scalar.")
        }
      }
    } else {
      # Check if lasso.L2.x is NULL
      if (!is.null(lasso.L2.x)) {
        warning(paste("The argument 'lasso.L2.x', specifying the context-level",
                      " variables to be used by the lasso classifier, will be",
                      " ignored because 'lasso' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'lasso', indicating whether the lasso",
               " classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if pca is logical
  if (is.logical(pca)) {
    # Check if pca is TRUE
    if (isTRUE(pca)) {
      # Check if pcs is NULL
      if (!is.null(pcs)) {
        # Check if pcs is a character vector
        if (!is.character(pcs)) {
          stop(paste("The argument 'pcs', specifying the principal components of",
                     " the context-level variables, must be a character vector.",
                     sep = ""))
        }

        # Check if pcs is in survey data
        if (!all(pcs %in% colnames(survey))) {
          stop(cat(paste("Principal component '",
                         pcs[which(!(pcs %in% colnames(survey)))],
                         "', specified in argument 'pcs', is not in your survey",
                         " data.\n", sep = ""), sep = ""))
        }

        # Check if pcs is in census data
        if (!all(pcs %in% colnames(census))) {
          stop(cat(paste("Principal component '",
                         pcs[which(!(pcs %in% colnames(census)))],
                         "', specified in argument 'pcs', is not in your census",
                         " data.\n", sep = ""), sep = ""))
        }
      } else{
        # Check if pcs are not specified but column names contain "PC" followed by at least one number
        if (is.null(pcs)){
          if (any(grepl(pattern = "PC[0-9]?", x = names(survey)))){
            stop(paste("Survey contains the column names: ",
                       paste(names(survey)[grepl(pattern = "PC[0-9]?", x = names(survey))], collapse = ", "),
                       ". These must be specified in the argument 'pcs' or removed from survey or renamed in survey.", sep = ""))
          }
          if (any(grepl(pattern = "PC[0-9]?", x = names(census)))){
            stop(paste("Census contains the column names: ",
                       paste(names(census)[grepl(pattern = "PC[0-9]?", x = names(census))], collapse = ", "),
                       ". These must be specified in the argument 'pcs' or removed from census or renamed in census.", sep = ""))
          }
        }
      }
    } else {
      # Check if pcs is NULL
      if (!is.null(pcs)) {
        warning(paste("The argument 'pcs', specifying the principal components",
                      " of the context-level variables, will be ignored because",
                      " 'pca' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'pca', indicating whether the PCA",
               " classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if gb is logical
  if (is.logical(gb)) {
    # Check if gb is TRUE
    if (isTRUE(gb)) {
      # Check if gb.L2.x is NULL
      if (!is.null(gb.L2.x)) {
        # Check if gb.L2.x is a character vector
        if (!is.character(gb.L2.x)) {
          stop(paste("The argument 'gb.L2.x', specifying the context-level",
                     " variables to be used by the GB classifier, must be",
                     " a character vector.", sep = ""))
        }

        # Check if gb.L2.x is in survey data
        if (!all(gb.L2.x %in% colnames(survey))) {
          stop(cat(paste("Context-level variable '",
                         gb.L2.x[which(!(gb.L2.x %in% colnames(survey)))],
                         "', specified in argument 'gb.L2.x' to be used by the GB",
                         " classifier, is not in your survey data.", sep = ""),
                   sep = ""))
        }

        # Check if gb.L2.x is in census data
        if (!all(gb.L2.x %in% colnames(census))) {
          stop(cat(paste("Context-level variable '",
                         gb.L2.x[which(!(gb.L2.x %in% colnames(census)))],
                         "', specified in argument 'gb.L2.x' to be used by the GB",
                         " classifier, is not in your census data.", sep = ""),
                   sep = ""))
        }
      }

      # Check if gb.L2.unit is logical
      if (!is.logical(gb.L2.unit)) {
        stop(paste("The logical argument 'gb.L2.unit', indicating whether",
                   " 'L2.unit' should be included in the GB classifier must be",
                   " either TRUE or FALSE.", sep = ""))
      }

      # Check if gb.L2.reg is logical
      if (!is.logical(gb.L2.reg)) {
        stop(paste("The logical argument 'gb.L2.reg', indicating whether",
                   " 'L2.reg' should be included in the GB classifier must be",
                   " either TRUE or FALSE.", sep = ""))
      }
    } else {
      # Check if gb.L2.x is NULL
      if (!is.null(gb.L2.x)) {
        warning(paste("The argument 'gb.L2.x', specifying the context-level",
                      " variables to be used by the GB classifier, will be",
                      " ignored because 'gb' is set to FALSE.", sep = ""))
      }

      # Check if gb.L2.unit has a value other than the default
      # if (!isFALSE(gb.L2.unit)) {
      #   stop(paste("The argument 'gb.L2.unit', indicating whether 'L2.unit'",
      #              " should be included in the GB classifier, will be",
      #              " ignored because 'gb' is set to FALSE.", sep = ""))
      # }

      # Check if gb.L2.reg has a value other than the default
      if (!isFALSE(gb.L2.reg)) {
        stop(paste("The argument 'gb.L2.reg', indicating whether 'L2.reg'",
                   " should be included in the GB classifier, will be",
                   " ignored because 'gb' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'gb', indicating whether the GB",
               " classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if svm is logical
  if (is.logical(svm)) {
    # Check if svm is TRUE
    if (isTRUE(svm)) {
      # Check if svm.L2.x is NULL
      if (!is.null(svm.L2.x)) {
        # Check if svm.L2.x is a character vector
        if (!is.character(svm.L2.x)) {
          stop(paste("The argument 'svm.L2.x', specifying the context-level",
                     " variables to be used by the SVM classifier, must be",
                     " a character vector.", sep = ""))
        }

        # Check if svm.L2.x is in survey data
        if (!all(svm.L2.x %in% colnames(survey))) {
          stop(cat(paste("Context-level variable '",
                         svm.L2.x[which(!(svm.L2.x %in% colnames(survey)))],
                         "', specified in argument 'svm.L2.x' to be used by the",
                         " SVM classifier, is not in your survey data.", sep = ""),
                   sep = ""))
        }

        # Check if svm.L2.x is in census data
        if (!all(svm.L2.x %in% colnames(census))) {
          stop(cat(paste("Context-level variable '",
                         svm.L2.x[which(!(svm.L2.x %in% colnames(census)))],
                         "', specified in argument 'svm.L2.x' to be used by the",
                         " SVM classifier, is not in your census data.", sep = ""),
                   sep = ""))
        }
      }
    } else {
      # Check if svm.L2.x is NULL
      if (!is.null(svm.L2.x)) {
        warning(paste("The argument 'svm.L2.x', specifying the context-level",
                      " variables to be used by the SVM classifier, will be",
                      " ignored because 'svm' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'svm', indicating whether the SVM",
               " classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if mrp is logical
  if (is.logical(mrp)) {
    # Check if mrp is TRUE
    if (isTRUE(mrp)) {
      # Check if mrp.L2.x is NULL
      if (!is.null(mrp.L2.x)) {
        # Check if mrp.L2.x is a character vector
        if (!is.character(mrp.L2.x)) {
          stop(paste("The argument 'mrp.L2.x', specifying the context-level",
                     " variables to be used by the standard MRP classifier, must",
                     " be a character vector.", sep = ""))
        }

        # Check if mrp.L2.x is in survey data
        if (all(mrp.L2.x != "empty")){
          if (!all(mrp.L2.x %in% colnames(survey))) {
            stop(cat(paste("Context-level variable '",
                           mrp.L2.x[which(!(mrp.L2.x %in% colnames(survey)))],
                           "', specified in argument 'mrp.L2.x' to be used by the",
                           " standard MRP classifier, is not in your survey data.",
                           sep = ""), sep = ""))
          }
        }

        # Check if mrp.L2.x is in census data
        if (all(mrp.L2.x != "empty")){
          if (!all(mrp.L2.x %in% colnames(census))) {
            stop(cat(paste("Context-level variable '",
                           mrp.L2.x[which(!(mrp.L2.x %in% colnames(census)))],
                           "', specified in argument 'mrp.L2.x' to be used by the",
                           " standard MRP classifier, is not in your census data.",
                           sep = ""), sep = ""))
          }
        }
      }
    } else {
      # Check if mrp.L2.x is NULL
      if (!is.null(mrp.L2.x)) {
        warning(paste("The argument 'mrp.L2.x', specifying the context-level",
                      " variables to be used by the standard MRP classifier,",
                      " will be ignored because 'mrp' is set to FALSE.", sep = ""))
      }
    }
  } else {
    stop(paste("The logical argument 'mrp', indicating whether the standard",
               " MRP classifier is to be used for predicting y,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if forward.select is logical
  if (!is.logical(forward.select)) {
    stop(paste("The logical argument 'forward.select', indicating whether to",
               " use forward selection instead of best subset selection,",
               " must be either TRUE or FALSE.", sep = ""))
  }

  # Check if boot.iter corresponds to uncertainty, i.e. NULL if uncertainty = FALSE
  if (!uncertainty){
    if(!is.null(boot.iter)) {
      warning("boot.iter is ignored unless uncertainty = TRUE.")
    }
  }

  # # Check if supplied seed is integer
  # if (!is.null(seed)){
  #   if (isFALSE(dplyr::near(seed, as.integer(seed)))) {
  #     stop("Seed must be either NULL or an integer-valued scalar.")
  #   }
  # }
}


################################################################################
#                    Function to create EBMA hold-out fold                     #
################################################################################

#' Generates data fold to be used for EBMA tuning
#'
#' #' \code{ebma_folding()} generates a data fold that will not be used in
#' classifier tuning. It is data that is needed to determine the optimal
#' tolerance for EBMA.
#'
#' @param data The full survey data. A tibble.
#' @param L2.unit Geographic unit. A character scalar containing the column name
#'   of the geographic unit in \code{survey} and \code{census} at which outcomes
#'   should be aggregated.
#' @param ebma.size EBMA fold size. A number in the open unit interval
#'   indicating the proportion of respondents to be allocated to the EBMA fold.
#'   Default is \eqn{1/3}.
#' @return Returns a list with two elements which are both tibble. List element
#'   one is named \code{ebma_fold} and contains the tibble used in Ensemble
#'   Bayesian Model Averaging Tuning. List element two is named \code{cv_data}
#'   and contains the tibble used for classifier tuning.

ebma_folding <- function(data, L2.unit, ebma.size) {
  # Add row number to data frame
  data <- data %>%
    dplyr::mutate(index = dplyr::row_number())

  # Split data by geographic unit into a list of data frames
  data_list <- data %>%
    dplyr::group_split(.data[[L2.unit]])

  # Sample one respondent per geographic unit
  one_per_unit <- lapply(data_list, function(x) {
    sample(x$index, size = 1, replace = FALSE)
  }) %>%
    unlist()

  # Create EBMA hold-out fold
  ebma_fold <- data %>%
    dplyr::filter(index %in% one_per_unit)

  data <- data %>%
    dplyr::filter(!(index %in% one_per_unit))

  remainder <- sample(data$index, size = ebma.size - length(one_per_unit),
                      replace = FALSE)

  ebma_remainder <- data %>%
    dplyr::filter(index %in% remainder)

  ebma_fold <- ebma_fold %>%
    dplyr::bind_rows(ebma_remainder)

  # Extract EBMA hold-out fold from survey sample
  cv_data <- data %>%
    dplyr::filter(!(index %in% ebma_fold$index))

  # Remove index variable
  ebma_fold <- ebma_fold %>%
    dplyr::select(-index)

  cv_data <- cv_data %>%
    dplyr::select(-index)

  # Function output
  out <- list(ebma_fold = ebma_fold,
              cv_data = cv_data)
  return(out)
}


################################################################################
#                         Function to create CV folds                          #
################################################################################

#' Generates folds for cross-validation
#'
#' \code{cv_folding} creates folds used in classifier training within the survey
#' data.
#'
#' @param data The survey data; must be a tibble.
#' @param L2.unit The column name of the factor variable identifying the
#'   context-level unit
#' @param k.folds An integer value indicating the number of folds to be
#'   generated.
#' @param cv.sampling Cross-validation sampling method. A character-valued
#'   scalar indicating whether cross-validation folds should be created by
#'   sampling individual respondents (\code{individuals}) or geographic units
#'   (\code{L2 units}). Default is \code{L2 units}. \emph{Note:} ignored if
#'   \code{folds} is provided, but must be specified otherwise.
#' @return Returns a list with length specified by \code{k.folds} argument. Each
#'   element is a tibble with a fold used in k-fold cross-validation.

cv_folding <- function(data, L2.unit, k.folds,
                       cv.sampling = c("individuals", "L2 units")) {

  if (cv.sampling == "individuals") {
    # Add row number to data frame
    data <- data %>%
      dplyr::mutate(index = row_number())

    # Randomize indices of individuals
    indices <- sample(data$index, size = length(data$index), replace = FALSE)

    # Define number of units per fold
    no_floor <- floor(length(indices) / k.folds)
    no_remaining <- length(indices) - no_floor * k.folds

    no_fold <- rep(no_floor, times = k.folds)

    if (no_remaining > 0) {
      no_fold[1:no_remaining] <- no_fold[1:no_remaining] + 1
    }

    # Split indices into folds
    fold_indices <- split(indices, rep(1:k.folds, times = no_fold))

    # Partition data according to indices
    out <- lapply(fold_indices, function(x) {
      data %>%
        dplyr::filter(index %in% x) %>%
        dplyr::select(-index)
    })
  } else {
    # Extract indices of geographic units
    indices <- data[[L2.unit]] %>%
      unique()

    # Randomize order of indices
    indices <- sample(indices, size = length(indices), replace = FALSE)

    # Define number of units per fold
    no_floor <- floor(length(indices) / k.folds)
    no_remaining <- length(indices) - no_floor * k.folds

    no_fold <- rep(no_floor, times = k.folds)

    if (no_remaining > 0) {
      no_fold[1:no_remaining] <- no_fold[1:no_remaining] + 1
    }

    # Split indices into folds
    fold_indices <- split(indices, rep(1:k.folds, times = no_fold))

    # Partition data according to indices
    out <- lapply(fold_indices, function(x) {
      data %>%
        dplyr::filter(.data[[L2.unit]] %in% x)
    })
  }

  # Function output
  return(out)
}


################################################################################
#           Function to create model list for best subset classifier           #
################################################################################

#' A list of models for the best subset selection.
#'
#' \code{model_list()} generates an exhaustive list of lme4 model formulas from
#' the individual level and context level variables as well as geographic unit
#' variables to be iterated over in best subset selection.
#'
#' @param y Outcome variable. A character vector containing the column names of
#'   the outcome variable.
#' @param L1.x Individual-level covariates. A character vector containing the
#'   column names of the individual-level variables in \code{survey} and
#'   \code{census} used to predict outcome \code{y}. Note that geographic unit
#'   is specified in argument \code{L2.unit}.
#' @param L2.x Context-level covariates. A character vector containing the
#'   column names of the context-level variables in \code{survey} and
#'   \code{census} used to predict outcome \code{y}.
#' @param L2.unit Geographic unit. A character scalar containing the column name
#'   of the geographic unit in \code{survey} and \code{census} at which outcomes
#'   should be aggregated.
#' @param L2.reg Geographic region. A character scalar containing the column
#'   name of the geographic region in \code{survey} and \code{census} by which
#'   geographic units are grouped (\code{L2.unit} must be nested within
#'   \code{L2.reg}). Default is \code{NULL}.
#' @return Returns a list with the number of elements equal to 2^k where k is
#'   the number context-level variables. Each element is of class formula.

model_list <- function(y, L1.x, L2.x, L2.unit, L2.reg = NULL) {
  # Individual-level random effects
  L1_re <- paste(paste("(1 | ", L1.x, ")", sep = ""), collapse = " + ")

  # Geographic unit or geographic unit-geographic region random effects
  if (is.null(L2.reg)) {
    L2_re <- paste("(1 | ", L2.unit, ")", sep = "")
  } else {
    L2_re <- paste(paste("(1 | ", L2.reg, "/", L2.unit, ")", sep = ""),
                   collapse = " + ")
  }

  # Combine all random effects
  all_re <- paste(c(L1_re, L2_re), collapse = " + ")

  # Empty model
  empty_model <- list(as.formula(paste(y, " ~ ", all_re, sep = "")))

  # Remaining models
  L2_list <- lapply(seq_along(L2.x), function(x) {combn(L2.x, x)})
  L2_list <- lapply(L2_list, function(x) {
    apply(x, 2, function(c) {
      as.formula(paste(y, " ~ ", paste(c, collapse = " + "), " + ", all_re, sep = ""))
    })
  }) %>%
    unlist()

  # Combine models in list
  out <- c(empty_model, L2_list)

  # Function output
  return(out)
}


################################################################################
#               Function to create model list for PCA classifier               #
################################################################################

#' A list of models for the best subset selection with PCA.
#'
#' \code{model_list_pca()} generates an exhaustive list of lme4 model formulas
#' from the individual level and context level principal components as well as
#' geographic unit variables to be iterated over in best subset selection with
#' principal components.
#'
#' @param y Outcome variable. A character vector containing the column names of
#'   the outcome variable.
#' @param L1.x Individual-level covariates. A character vector containing the
#'   column names of the individual-level variables in \code{survey} and
#'   \code{census} used to predict outcome \code{y}. Note that geographic unit
#'   is specified in argument \code{L2.unit}.
#' @param L2.x Context-level covariates. A character vector containing the
#'   column names of the context-level variables in \code{survey} and
#'   \code{census} used to predict outcome \code{y}.
#' @param L2.unit Geographic unit. A character scalar containing the column name
#'   of the geographic unit in \code{survey} and \code{census} at which outcomes
#'   should be aggregated.
#' @param L2.reg Geographic region. A character scalar containing the column
#'   name of the geographic region in \code{survey} and \code{census} by which
#'   geographic units are grouped (\code{L2.unit} must be nested within
#'   \code{L2.reg}). Default is \code{NULL}.
#' @return Returns a list with the number of elements k+1 where k is the number
#'   of context-level variables. Each element is of class formula. The first
#'   element is a model with context-level variables and the following models
#'   iteratively add the principal components as context-level variables.

model_list_pca <- function(y, L1.x, L2.x, L2.unit, L2.reg = NULL) {
  # Individual-level random effects
  L1_re <- paste(paste("(1 | ", L1.x, ")", sep = ""), collapse = " + ")

  # Geographic unit or Geographic unit-Geographic region random effects
  if (is.null(L2.reg)) {
    L2_re <- paste("(1 | ", L2.unit, ")", sep = "")
  } else {
    L2_re <- paste(paste("(1 | ", L2.reg, "/", L2.unit, ")", sep = ""),
                   collapse = " + ")
  }

  # Combine all random effects
  all_re <- paste(c(L1_re, L2_re), collapse = " + ")

  # Empty model
  empty_model <- list(as.formula(paste(y, " ~ ", all_re, sep = "")))

  # Remaining models
  L2_list <- lapply(seq_along(L2.x), function(x) {L2.x[1:x]})
  L2_list <- lapply(L2_list, function(x) {
    as.formula(paste(y, " ~ ", paste(x, collapse = " + "), " + ", all_re, sep = ""))
  })

  # Combine models in list
  out <- c(empty_model, L2_list)

  # Function output
  return(out)
}


################################################################################
#                           Prediction loss function                           #
################################################################################

#' Estimates loss value.
#'
#' \code{loss_function()} estimates the loss based on a loss function.
#'
#' @param pred Predictions of outcome. A numeric vector of outcome predictions.
#' @param data.valid Test data set. A tibble of data that was not used for
#'   prediction.
#' @param loss.unit Loss function unit. A character-valued scalar indicating
#'   whether performance loss should be evaluated at the level of individual
#'   respondents (\code{individuals}) or geographic units (\code{L2 units}).
#'   Default is \code{individuals}.
#' @param loss.fun Loss function. A character-valued scalar indicating whether
#'   prediction loss should be measured by the mean squared error (\code{MSE})
#'   or the mean absolute error (\code{MAE}). Default is \code{MSE}.
#' @param y Outcome variable. A character vector containing the column names of
#'   the outcome variable.
#' @param L2.unit Geographic unit. A character scalar containing the column name
#'   of the geographic unit in \code{survey} and \code{census} at which outcomes
#'   should be aggregated.
#' @return Returns a tibble with number of rows equal to the number of loss
#'   functions tested (defaults to 4 for cross-entropy, f1, MSE, and msfe). The
#'   number of columns is 2 where the first is called measure and contains the
#'   names of the loss-functions and the second is called value and contains the
#'   loss-function scores.

loss_function <- function(pred, data.valid,
                          loss.unit = c("individuals", "L2 units"),
                          loss.fun = c("MSE", "MAE", "cross-entropy"),
                          y, L2.unit) {

  ## Loss functions
  # MSE
  mse <- mean_squared_error(
    pred = pred, data.valid = data.valid,
    y = y, L2.unit = L2.unit)

  # MAE
  mae <- mean_absolute_error(
    pred = pred, data.valid = data.valid,
    y = y, L2.unit = L2.unit)

  # binary cross-entropy
  bce <- binary_cross_entropy(
    pred = pred, data.valid = data.valid,
    y = y, L2.unit = L2.unit)

  # f1 score
  f1 <- f1_score(
    pred = pred, data.valid = data.valid,
    L2.unit = L2.unit, y = y)

  # mean squared false error
  msfe <- mean_squared_false_error(
    pred = pred, data.valid = data.valid,
    y = y, L2.unit = L2.unit)

  # Combine loss functions
  score <- mse %>%
    dplyr::bind_rows(mae) %>%
    dplyr::bind_rows(bce) %>%
    dplyr::bind_rows(f1) %>%
    dplyr::bind_rows(msfe)

  # Filter score table by loss function and loss unit
  score <- score %>%
    dplyr::filter(measure %in% loss.fun) %>%
    dplyr::filter(level %in% loss.unit) %>%
    dplyr::group_by(measure) %>%
    dplyr::summarise(value = mean(value), .groups = "drop" )

  # Function output
  return(score)
}


###########################################################################
# mean squared error/ brier score -----------------------------------------
###########################################################################

#' Estimates the mean squared prediction error.
#'
#' \code{mean_squared_error()} estimates the mean squared error for the desired
#' loss unit.
#' @inheritParams loss_function
#' @return Returns a tibble containing two mean squared prediction errors. The
#'   first is measured at the level of individuals and the second is measured at
#'   the context level. The tibble dimensions are 2x3 with variables: measure,
#'   value and level.

mean_squared_error <- function(pred, data.valid, y, L2.unit){

  # outcome
  out <- dplyr::tibble(
    measure = rep("MSE", 2),
    value = rep(NA, 2),
    level = c( "individuals", "L2 units")
  )

  # mse values
  values <- rep(NA, 2)

  # loss unit = "individual"
  values[1] <- mean((data.valid[[y]] - pred)^2)

  # loss unit = "L2 units"
  l2 <- data.valid %>%
    dplyr::mutate(pred = pred) %>%
    dplyr::rowwise() %>%
    dplyr::mutate(sqe = (.data[[y]] - pred)^2 ) %>%
    dplyr::ungroup() %>%
    dplyr::group_by(!! rlang::sym(L2.unit)) %>%
    dplyr::summarise(mse = mean(sqe), .groups = "drop")

  values[2] <- mean(dplyr::pull(.data = l2, var = mse))

  out <- dplyr::mutate(out, value = values)

  return(out)
}


###########################################################################
# mean absolute error -----------------------------------------------------
###########################################################################

#' Estimates the mean absolute prediction error.
#'
#' \code{mean_absolute_error()} estimates the mean absolute error for the
#' desired loss unit.
#' @inheritParams loss_function
#' @return Returns a tibble containing two mean absolute prediction errors. The
#'   first is measured at the level of individuals and the second is measured at
#'   the context level. The tibble dimensions are 2x3 with variables: measure,
#'   value and level.


mean_absolute_error <- function(pred, data.valid, y, L2.unit){

  # outcome
  out <- dplyr::tibble(
    measure = rep("MAE", 2),
    value = rep(NA, 2),
    level = c( "individuals", "L2 units"))

  # mae values
  values <- rep(NA, 2)

  # loss unit = "individual"
  values[1] <- mean(abs(data.valid[[y]] - pred))

  # loss unit = "L2 units"
  l2 <- data.valid %>%
    dplyr::mutate(pred = pred) %>%
    dplyr::rowwise() %>%
    dplyr::mutate(ae = abs(.data[[y]] - pred)) %>%
    dplyr::ungroup() %>%
    dplyr::group_by(!! rlang::sym(L2.unit)) %>%
    dplyr::summarise(mae = mean(ae), .groups = "drop")

  values[2] <- mean(dplyr::pull(.data = l2, var = mae))

  out <- dplyr::mutate(out, value = values)

  return(out = out)
}


###########################################################################
# binary cross-entropy ----------------------------------------------------
###########################################################################

#' Estimates the inverse binary cross-entropy, i.e. 0 is the best score and 1
#' the worst.
#'
#' \code{binary_cross_entropy()} estimates the inverse binary cross-entropy on
#' the individual and state-level.
#' @inheritParams loss_function
#' @return Returns a tibble containing two binary cross-entropy prediction
#'   errors. The first is measured at the level of individuals and the second is
#'   measured at the context level. The tibble dimensions are 2x3 with
#'   variables: measure, value and level.


binary_cross_entropy <- function(pred, data.valid,
                                 loss.unit = c("individuals", "L2 units"),
                                 y, L2.unit){

  # outcome
  out <- dplyr::tibble(
    measure = rep("cross-entropy", 2),
    value = rep(NA, 2),
    level = c( "individuals", "L2 units")
  )

  # cross-entropy values
  values <- rep(NA, 2)

  # loss unit = "individual"
  values[1] <- (mean( data.valid[[y]] * log(pred) + (1 - data.valid[[y]]) * log(1 - pred)))*-1

  # loss unit = "L2 units"
  l2 <- data.valid %>%
    dplyr::mutate(pred = pred) %>%
    dplyr::rowwise() %>%
    dplyr::mutate(ce = .data[[y]] * log(pred) + (1 - .data[[y]]) * log(1 - pred) ) %>%
    dplyr::ungroup() %>%
    dplyr::group_by(!! rlang::sym(L2.unit)) %>%
    dplyr::summarise(bce = mean(ce), .groups = "drop")

  values[2] <- mean(dplyr::pull(.data = l2, var = bce)) *-1

  out <- dplyr::mutate(out, value = values)
  return(out)
}


###########################################################################
# F1 score ----------------------------------------------------------------
###########################################################################
#' Estimates the inverse f1 score, i.e. 0 is the best score and 1 the worst.
#'
#' \code{f1_score()} estimates the inverse f1 scores on the individual and state
#' levels.
#' @inheritParams loss_function
#' @return Returns a tibble containing two f1 prediction errors. The first is
#'   measured at the level of individuals and the second is measured at the
#'   context level. The tibble dimensions are 2x3 with variables: measure, value
#'   and level.



f1_score <- function(pred, data.valid, y, L2.unit){

  ## individual level

  # true positives
  tp_ind <- data.valid %>%
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    dplyr::select( !! rlang::sym(y), pval ) %>%
    dplyr::filter(pval == 1 & !! rlang::sym(y) == 1) %>%
    dplyr::summarise(tp = sum(pval)) %>%
    dplyr::pull(var = tp)

  # false positives
  fp_ind <- data.valid %>%
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    dplyr::select( !! rlang::sym(y), pval ) %>%
    dplyr::filter(pval == 1 & !! rlang::sym(y) == 0 ) %>%
    dplyr::summarise(fp = sum(pval)) %>%
    dplyr::pull(var = fp)

  # false negatives
  fn_ind <- data.valid %>%
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    dplyr::select( !! rlang::sym(y), pval ) %>%
    dplyr::filter(pval == 0 & !! rlang::sym(y) == 1 ) %>%
    dplyr::summarise(fn = sum(!! rlang::sym(y))) %>%
    dplyr::pull(var = fn)

  # f1 score
  f1 <- tp_ind / (tp_ind + 0.5 * (fp_ind + fn_ind) )

  # state-level f1 score
  state_out <- data.valid %>%
    # predicted values
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    # select L2.unit, y, and predicted values
    dplyr::select( !! rlang::sym(L2.unit), !! rlang::sym(y), pval ) %>%
    # group by L2.unit
    dplyr::group_by( !! rlang::sym(L2.unit) ) %>%
    # nest data
    tidyr::nest() %>%
    # new column with state-level f1 values
    dplyr::mutate(
      f1 = purrr::map(data, function(x){
        # true positives
        tp <- x %>%
          dplyr::select( !! rlang::sym(y), pval ) %>%
          dplyr::filter(pval == 1 & !! rlang::sym(y) == 1) %>%
          dplyr::summarise(tp = sum(pval)) %>%
          dplyr::pull(var = tp)
        # false positives
        fp <- x %>%
          dplyr::select( !! rlang::sym(y), pval ) %>%
          dplyr::filter(pval == 1 & !! rlang::sym(y) == 0 ) %>%
          dplyr::summarise(fp = sum(pval)) %>%
          dplyr::pull(var = fp)
        # false negatives
        fn <- x %>%
          dplyr::select( !! rlang::sym(y), pval ) %>%
          dplyr::filter(pval == 0 & !! rlang::sym(y) == 1 ) %>%
          dplyr::summarise(fn = sum(!! rlang::sym(y))) %>%
          dplyr::pull(var = fn)
        # f1 score
        f1 <- tp / (tp + 0.5 * (fp + fn) ) })) %>%
    # unnest f1 values
    tidyr::unnest(f1) %>%
    dplyr::select( !! rlang::sym(L2.unit), f1 ) %>%
    dplyr::ungroup() %>%
    dplyr::summarise(f1 = mean(f1, na.rm = TRUE), .groups = "drop")

  # return
  out <- dplyr::tibble(
    measure = c("f1", "f1"),
    value = c(1 - f1, 1 - dplyr::pull(.data = state_out, var = f1)),
    level = c("individuals", "L2 units"))

  return(out)

}


###########################################################################
# Mean squared false error-------------------------------------------------
###########################################################################
#' Estimates the mean squared false error.
#'
#' \code{msfe()} estimates the inverse f1 scores on the individual and state
#' levels.
#' @inheritParams loss_function
#' @return Returns a tibble containing two mean squared false prediction errors.
#'   The first is measured at the level of individuals and the second is
#'   measured at the context level. The tibble dimensions are 2x3 with
#'   variables: measure, value and level.



mean_squared_false_error <- function(pred, data.valid, y, L2.unit){

  ## individual level
  msfe_l1 <- data.valid %>%
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    dplyr::select( !! rlang::sym(y), pval ) %>%
    dplyr::group_by( !! rlang::sym(y) ) %>%
    dplyr::mutate(err = (!! rlang::sym(y) - pval) ) %>%
    dplyr::summarise(err_rates = mean(err), .groups = "drop") %>%
    dplyr::mutate(err_rates = err_rates^2) %>%
    dplyr::summarise( msfe = sum(err_rates)) %>%
    dplyr::pull(var = msfe)

  ## group level
  msfe_l2 <- data.valid %>%
    dplyr::mutate(pval = ifelse(test = pred > 0.5, yes = 1, no = 0)) %>%
    dplyr::select( !! rlang::sym(L2.unit), !! rlang::sym(y), pval ) %>%
    dplyr::group_by( !! rlang::sym(L2.unit) ) %>%
    tidyr::nest() %>%
    dplyr::mutate(msfe = purrr::map(data, function(x){
      msfe <- x %>%
        dplyr::group_by( !! rlang::sym(y) ) %>%
        dplyr::mutate(err = (!! rlang::sym(y) - pval) ) %>%
        dplyr::summarise(err_rates = mean(err), .groups = "drop") %>%
        dplyr::mutate(err_rates = err_rates^2) %>%
        dplyr::summarise( msfe = sum(err_rates)) %>%
        dplyr::pull(var = msfe)
    })) %>%
    tidyr::unnest(msfe) %>%
    dplyr::ungroup() %>%
    dplyr::summarise(msfe = mean(msfe), .groups = "drop") %>%
    dplyr::pull(var = msfe)

  # return
  out <- dplyr::tibble(
    measure = c("msfe", "msfe"),
    value = c(msfe_l1, msfe_l2),
    level = c("individuals", "L2 units"))

  return(out)

}

###########################################################################
# Loss score ranking ------------------------------------------------------
###########################################################################

#' Ranks tuning parameters according to loss functions
#'
#' \code{loss_score_ranking()} ranks tuning parameters according to the scores
#' received in multiple loss functions.
#'
#' @inheritParams loss_function
#' @param score A data set containing loss function names, the loss function
#'   values, and the tuning parameter values.
#' @return Returns a tibble containing the parameter grid as well as a rank
#'   column that corresponds to the cross-validation rank of a parameter
#'   combination across all loss function scores.

loss_score_ranking <- function(score, loss.fun){

  # tuning parameter names
  params <- names(score)[!names(score) %in% c("measure", "value")]

  ranking <- lapply(loss.fun, function(x){
    score %>%
      dplyr::filter(measure == x) %>%
      dplyr::arrange(value) %>%
      dplyr::mutate(rank = dplyr::row_number())
  })

  ranking <- dplyr::bind_rows(ranking) %>%
    dplyr::group_by( !!!rlang::syms(params) ) %>%
    dplyr::summarise(rank = sum(rank), .groups = "drop") %>%
    dplyr::arrange(rank)

  return(ranking)

}

################################################################################
#                   Suppress cat in external package                           #
################################################################################

#' Suppress cat in external package
#'
#' \code{quiet()} suppresses cat output.
#'
#' @param x Input. It can be any kind.

quiet <- function(x) {
  sink(tempfile())
  on.exit(sink())
  invisible(force(x))
}

################################################################################
#                   Register Cores for multicore                               #
################################################################################

#' Register cores for multicore computing
#'
#' \code{multicore()} registers cores for parallel processing.
#'
#' @param cores Number of cores to be used. An integer. Default is \code{1}.
#' @param type Whether to start or end parallel processing. A character string.
#'   The possible values are \code{open}, \code{close}.
#' @param cl The registered cluster. Default is \code{NULL}
#' @return No return value, called to register or un-register clusters for
#'   parallel processing.

multicore <- function(cores = 1, type, cl = NULL) {

  # Start parallel processing
  if (type == "open"){
    # register clusters for windows
    if( Sys.info()["sysname"] == "Windows" ){
      cl <- parallel::makeCluster(cores)
      doParallel::registerDoParallel(cl)
      parallel::clusterCall(cl, function(x) .libPaths(x), .libPaths())
    } else {
      cl <- parallel::makeForkCluster(cores)
      doParallel::registerDoParallel(cl)
    }
    return(cl)
  }

  # Stop parallel processing
  if (type == "close"){
    parallel::stopCluster(cl)
  }
}


################################################################################
#                 Predict function for glmmLasso                               #
################################################################################

#' Predicts on newdata from glmmLasso objects
#'
#' \code{glmmLasso()} predicts on newdata objects from a glmmLasso object.
#'
#' @inheritParams auto_MrP
#' @param m A \code{glmmLasso()} object.
#' @return Returns a numeric vector of predictions from a \code{glmmLasso()}
#'   object.

predict_glmmLasso <- function(census, m, L1.x, lasso.L2.x, L2.unit, L2.reg) {

  # Fixed effects
  fixed_effects <- as.matrix(cbind(1, as.data.frame(census)[, lasso.L2.x])) %*% cbind(m$coefficients)

  # Individual-level random effects
  ind_ranef <- dplyr::select(.data = census, dplyr::one_of(L1.x))
  ind_ranef[] <- base::Map(paste, names(ind_ranef), ind_ranef, sep = '')
  ind_ranef <- cbind(apply(ind_ranef, 1, function(x){
    sum(m$ranef[which(names(m$ranef) %in% x)])
  }))

  # State random effects
  state_ranef <- cbind(paste(L2.unit, as.character(dplyr::pull(.data = census, var = L2.unit)), sep = ""))
  state_ranef <- cbind(apply(state_ranef, 1, function(x){
    if (x %in% names(m$ranef)){
      m$ranef[names(m$ranef) == x]
    } else{
      0
    }}))

  # Region random effect
  if(!is.null(L2.reg)){
    region_ranef <- cbind(paste(L2.reg, as.character(as.data.frame(census)[, L2.reg]), sep = ""))
    region_ranef <- cbind(apply(region_ranef, 1, function(x){
      m$ranef[names(m$ranef) == x]
    }))
  }

  # Predictions
  if(!is.null(L2.reg)){
    lasso_preds <- cbind(fixed_effects, ind_ranef, state_ranef, region_ranef)
  } else{
    lasso_preds <- cbind(fixed_effects, ind_ranef, state_ranef)
  }
  lasso_preds <- base::apply(X = lasso_preds, MARGIN = 1, FUN = sum)
  lasso_preds <- stats::pnorm(lasso_preds)

  return(lasso_preds)
}


################################################################################
#                 Plot method for autoMrP                                      #
################################################################################

#' A plot method for autoMrP objects. Plots unit-level preference estiamtes.
#'
#' \code{plot.autoMrP()} plots unit-level preference estimates and error bars.
#'
#' @param x An \code{autoMrP()} object.
#' @param algorithm The algorithm/classifier fo which preference estimates are
#'   desired. A character-valued scalar indicating either \code{ebma} or the
#'   classifier to be used. Allowed choices are: "ebma", "best_subset", "lasso",
#'   "pca", "gb", "svm", and "mrp". Default is \code{ebma}.
#' @param ci.lvl The level of the confidence intervals. A proportion. Default is
#'   \code{0.95}. Confidence intervals are based on bootstrapped estimates and
#'   will not be printed if bootstrapping was not carried out.
#' @param ... Additional arguments affecting the summary produced.
#' @return Returns a \code{ggplot2} object of the preference estimates for the
#' selected classifier.
#' @export
#' @export plot.autoMrP

plot.autoMrP <- function(x, algorithm = "ebma", ci.lvl = 0.95, ...){

  # L2.unit identifier
  L2.unit <- names(x$classifiers)[1]


  # Error if requested algorithm was not fitted
  if(! algorithm %in% names(x$classifiers) & algorithm != "ebma" ){
    stop('The ', algorithm, ' classifier was not run. Re-run autoMrP() with the requested algorithm. Allowed choices are: "ebma", "best_subset", "lasso", "pca", "gb", "svm", and "mrp".')
  }

  # plot classifier if EBMA was not estimated
  if( "EBMA step skipped (only 1 classifier run)" %in% x$ebma ) {
    algorithm <- names(x$classifiers)[-1]
  }

  # plot data
  if(algorithm == "ebma"){
    # EBMA summary
    plot_data <- x$ebma %>%
      dplyr::group_by(!! rlang::sym(L2.unit)) %>%
      dplyr::summarise(median = stats::median(ebma, na.rm = TRUE),
                       lb = stats::quantile(x = ebma, p = (1 - ci.lvl)*.5, na.rm = TRUE),
                       ub = stats::quantile(x = ebma, p = ci.lvl + (1 - ci.lvl)*.5, na.rm = TRUE),
                       .groups = "drop") %>%
      dplyr::arrange(median) %>%
      dplyr::mutate(rank = dplyr::row_number()) %>%
      dplyr::mutate(rank = as.factor(rank)) %>%
      dplyr::mutate(!!rlang::sym(L2.unit) := forcats::fct_reorder(!!rlang::sym(L2.unit), median))
  } else{
    # One of the classifiers
    plot_data <- x$classifiers %>%
      dplyr::group_by(!! rlang::sym(L2.unit)) %>%
      dplyr::select(all_of(L2.unit), contains(algorithm)) %>%
      dplyr::summarise_all(.funs = list(median = ~ stats::quantile(x = ., probs = 0.5, na.rm = TRUE),
                                        lb = ~ stats::quantile(x = ., probs = (1 - ci.lvl) *.5, na.rm = TRUE),
                                        ub = ~ stats::quantile(x = ., probs = ci.lvl + (1 - ci.lvl) *.5, na.rm = TRUE))) %>%
      dplyr::arrange(median) %>%
      dplyr::mutate(!!rlang::sym(L2.unit) := forcats::fct_reorder(!!rlang::sym(L2.unit), median))
  }

  # y axis tick labels
  ylabs <- as.character(dplyr::pull(.data = plot_data, var = L2.unit))

  # plot (with/ without error bars)
  if(all(plot_data$median == plot_data$lb)){
    ggplot2::ggplot(
      data = plot_data,
      mapping = ggplot2::aes_string(x = 'median', y = L2.unit)) +
      ggplot2::geom_point() +
      ggplot2::labs(x = 'Estimates')
  } else{
    ggplot2::ggplot(
      data = plot_data,
      mapping = ggplot2::aes_string(x = 'median', y = L2.unit)) +
      ggplot2::geom_point() +
      ggplot2::labs(x = 'Estimates') +
      ggplot2::geom_errorbarh(mapping = ggplot2::aes(xmin = lb, xmax = ub))
  }
}


################################################################################
#                 Summary method for autoMrP                                   #
################################################################################

#' A summary method for autoMrP objects.
#'
#' \code{summary.autoMrP()} ...
#'
#' @param object An \code{autoMrP()} object for which a summary is desired.
#' @param ci.lvl The level of the confidence intervals. A proportion. Default is
#'   \code{0.95}. Confidence intervals are based on bootstrapped estimates and
#'   will not be printed if bootstrapping was not carried out.
#' @param digits The number of digits to be displayed. An integer scalar.
#'   Default is \code{4}.
#' @param format The table format. A character string passed to
#'   \code{\link[knitr]{kable}}. Default is \code{simple}.
#' @param classifiers Summarize a single classifier. A character string. Must be
#'   one of \code{best_subset}, \code{lasso}, \code{pca}, \code{gb}, \code{svm},
#'   or \code{mrp}. Default is \code{NULL}.
#' @param n Number of rows to be printed. An integer scalar. Default is
#'   \code{10}.
#' @param ... Additional arguments affecting the summary produced.
#' @return No return value, prints a summary of the context level preference
#'  estimates to the console.
#' @export
#' @export summary.autoMrP

summary.autoMrP <- function(object, ci.lvl = 0.95, digits = 4, format = "simple",
                            classifiers = NULL, n = 10, ...){

  # weights
  if ( all(c("autoMrP", "weights") %in% class(object)) ){

    # error message if weights summary called without running multiple classifiers
    if (any(object == "EBMA step skipped (only 1 classifier run)")){
      stop("Weights are not reported if the EBMA step was skipped. Re-run autoMrP with multiple classifiers.")
    }

    # weights vector to tibble
    if( is.null(dim(object)) ){
      object <- dplyr::tibble(!!!object)
    }

    # summary statistics
    s_data <- object %>%
      tidyr::pivot_longer(
        cols = dplyr::everything(),
        names_to = "method",
        values_to = "estimates") %>%
      dplyr::group_by(method) %>%
      dplyr::summarise(
        min = base::min(estimates, na.rm = TRUE),
        quart1 = stats::quantile(x = estimates, probs = 0.25, na.rm = TRUE),
        median = stats::median(estimates, na.rm = TRUE),
        mean = base::mean(estimates, na.rm = TRUE),
        quart3 = stats::quantile(x = estimates, probs = 0.75, na.rm = TRUE),
        max = base::max(estimates, na.rm = TRUE),
        .groups = "drop") %>%
      dplyr::arrange(dplyr::desc(median))

    # weights with uncertainty
    if ( all(s_data$median != s_data$min) ){
      n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
      cat( paste("\n", "# EBMA classifier weights:"), sep = "")
      # output table
      output_table(
        object = s_data[1:n, ],
        col.names = c(
          "Classifier",
          "Min.",
          "1st Qu.",
          "Median",
          "Mean",
          "3rd Qu.",
          "Max"),
        format = format,
        digits = digits)
      if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
    } else{
      s_data <- dplyr::select(.data = s_data, method, median)
      n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
      cat( paste("\n", "# EBMA classifier weights:"), sep = "")
      output_table(
        object = s_data[1:n, ],
        col.names = c(
          "Classifier",
          "Weight"),
        format = format,
        digits = digits)
      if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
    }
  }

  # ensemble summary
  else if ( all(c("autoMrP", "ensemble") %in% class(object)) ) {

    # unit identifier
    L2.unit <- names(object)[1]

    # summary statistics
    s_data <- object %>%
      dplyr::group_by(!! rlang::sym(L2.unit)) %>%
      dplyr::summarise(
        min = base::min(ebma, na.rm = TRUE),
        lb = stats::quantile(x = ebma, probs = (1 - ci.lvl)*.5, na.rm = TRUE),
        median = stats::quantile(x = ebma, probs = .5, na.rm = TRUE),
        ub = stats::quantile(x = ebma, probs = ci.lvl + (1 - ci.lvl)*.5, na.rm = TRUE),
        max = base::max(ebma, na.rm = TRUE),
        .groups = "drop"
      )

    # with or without uncertainty
    if ( all(s_data$median != s_data$lb) ){
      cat( paste("\n", "# EBMA estimates:"), sep = "")
      # output table
      output_table(
        object = s_data[1:n, ],
        col.names = c(
          L2.unit,
          "Min.",
          "Lower bound",
          "Median",
          "Upper bound",
          "Max"),
        format = format,
        digits = digits)
      if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")

    } else{
      s_data <- dplyr::select(.data = s_data, dplyr::one_of(L2.unit), median)
      n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
      cat( paste("\n", "# EBMA estimates:"), sep = "")
      output_table(
        object = s_data[1:n, ],
        col.names = c(L2.unit, "Estimate"),
        format = format,
        digits = digits)
      if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
    }
  }

  # classifier summary
  else if ( all(c("autoMrP", "classifiers") %in% class(object)) ){

    # unit identifier
    L2.unit <- names(object)[1]

    # multiple classifiers
    if (base::is.null(classifiers)){

      # point estimates for all classifiers
      s_data <- object %>%
        dplyr::group_by(!! rlang::sym(L2.unit)) %>%
        dplyr::summarise_all(.funs = median )

      # output table
      ests <- paste(names(object)[-1], collapse = ", ")
      n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
      cat( paste("\n", "# estimates of classifiers: ", ests), sep = "")
      output_table(object = s_data[1:n, ],
                   col.names = names(s_data),
                   format = format,
                   digits = digits)
      if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
    } else{

      # summary statistics
      s_data <- object %>%
        dplyr::select(dplyr::one_of(L2.unit,classifiers)) %>%
        dplyr::group_by(!! rlang::sym(L2.unit)) %>%
        dplyr::summarise_all(.funs = list(
          min = ~ base::min(x = ., na.rm = TRUE),
          lb = ~ stats::quantile(x = ., probs = (1 - ci.lvl)*.5, na.rm = TRUE),
          median = ~ stats::median(x = ., na.rm = TRUE),
          ub = ~ stats::quantile(x = ., probs = ci.lvl + (1 - ci.lvl)*.5, na.rm = TRUE),
          max = ~ base::max(x = ., na.rm = TRUE)
          ))

      # with or without uncertainty
      if( all(s_data$median != s_data$lb) ){
        n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
        cat( paste("\n", "# estimates of", classifiers, "classifier"), sep = "")
        output_table(
          object = s_data[1:n, ],
          col.names = c(
            L2.unit,
            "Min.",
            "Lower bound",
            "Median",
            "Upper bound",
            "Max"),
          format = format,
          digits = digits)
        if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
      } else{
        s_data <- dplyr::select(.data = s_data, dplyr::one_of(L2.unit), "median")
        n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
        cat( paste("\n", "# estimates of", classifiers, "classifier"), sep = "")
        output_table(
          object = s_data[1:n, ],
          col.names = c(L2.unit, "Estimate"),
          format = format,
          digits = digits)
        if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
      }
    }
  }

  # autoMrP list object
  else if ( all(c("autoMrP", "list") %in% class(object)) ){

    # unit identifier
    L2.unit <- names(object$classifiers)[1]

    # if EBMA was run
    if( !"EBMA step skipped (only 1 classifier run)" %in% object$ebma ){

      # Summarize EBMA or classifier specified in classifiers argument
      if( is.null(classifiers)) {
        s_data <- object$ebma
      } else{
        # check whether classifier was fitted
        if( !classifiers %in% names(object$classifiers) ){
          stop( classifiers, " was not fitted. Summary available for: ", paste(names(object$classifiers)[-1], collapse = ", "))
        }
        s_data <- object$classifiers %>%
          dplyr::select( dplyr::one_of(L2.unit, classifiers) )
      }

      # summary statistics
      s_data <- s_data %>%
        dplyr::group_by(!! rlang::sym(L2.unit)) %>%
        dplyr::summarise_all(.funs = list(
          min = ~ base::min(x = ., na.rm = TRUE),
          lb = ~ stats::quantile(x = ., probs = (1 - ci.lvl)*.5, na.rm = TRUE),
          median = ~ stats::median(x = ., na.rm = TRUE ),
          ub = ~ stats::quantile(x = ., probs = ci.lvl + (1 - ci.lvl)*.5, na.rm = TRUE),
          max = ~ base::max(x = ., na.rm = TRUE)
        ))

      # with or without uncertainty
      if( all(s_data$median != s_data$lb) ){
        n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
        if( is.null(classifiers) ){
          cat( paste("\n", "# EBMA estimates:"), sep = "")
        } else{
          cat( paste("\n", "# ", classifiers, " estimates", sep = ""))
        }
        output_table(
          object = s_data[1:n, ],
          col.names = c( L2.unit, "Min.", "Lower bound", "Median", "Upper bound", "Max"),
        format = format,
        digits = digits)
        if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
      } else{
        s_data <- dplyr::select(.data = s_data, dplyr::one_of(L2.unit), median)
        n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
        cat( paste("\n", "# EBMA estimates:"), sep = "")
        output_table(object = s_data[1:n, ], col.names = c(L2.unit, "Median"), format = format, digits = digits)
        if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
      }
    } else{

      # Summarize all classifiers or classifier specified in classifiers argument
      # or MrP model
      if( is.null(classifiers) ) {
        s_data <- object$classifiers
      } else{
        s_data <- object$classifiers %>%
          dplyr::select(dplyr::one_of(L2.unit, classifiers) )
      }

      # summary statistics
      s_data <- s_data %>%
        dplyr::group_by(!! rlang::sym(L2.unit)) %>%
        dplyr::summarise_all(.funs = list(
          min = ~ base::min(x = ., na.rm = TRUE),
          lb = ~ stats::quantile(x = ., probs = (1 - ci.lvl)*.5, na.rm = TRUE),
          median = ~ stats::median(x = ., na.rm = TRUE),
          ub = ~ stats::quantile(x = ., probs = ci.lvl + (1 - ci.lvl)*.5, na.rm = TRUE),
          max = ~ base::max(x = ., na.rm = TRUE)
        ))

      # with or without uncertainty
      comparison <- s_data %>%
        dplyr::select(dplyr::one_of(grep(
          pattern = "median", x = names(s_data), value = "TRUE")[1],
          grep(pattern = "lb", x = names(s_data), value = "TRUE")[1]))

        # summarize one classifier
        if ( sum(grepl(pattern = "best_subset|pca|lasso|gb|svm|mrp", x = names(s_data))) < 4 ){

          # without uncertainty
          if( all(comparison[,1] != comparison[,2]) ){

            n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
            cat( paste("\n", "# ", names(object$classifiers)[2]," estimates:"), sep = "")
            output_table(
              object = s_data[1:n, ],
              col.names = c(
                L2.unit,
                "Min.",
                "Lower bound",
                "Median",
                "Upper bound",
                "Max"),
              format = format,
              digits = digits)
            if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
          } else{
            n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
            cat( paste("\n", "# estimates of: ", paste(names(object$classifiers)[-1], collapse = ", ")), sep = "")
            s_data <- s_data %>%
              dplyr::select(dplyr::one_of( L2.unit), contains("median"))
            output_table(
              object = s_data[1:n, ],
              col.names = names(s_data),
              format = format,
              digits = digits)
            if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
          }

      } else {
       # drop uncertainty columns
        if ( ncol(s_data) < 5 ){
          s_data <- dplyr::select(.data = s_data, dplyr::one_of(L2.unit), median)
          n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
          cat( paste("\n", "# ", names(object$classifiers)[2]," estimates:"), sep = "")
          output_table(object = s_data[1:n, ], col.names = c(L2.unit, "Estimate"), format = format, digits = digits)
          if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
        } else{
          n <- ifelse(n <= nrow(s_data), yes = n, no = nrow(s_data) )
          s_data <- s_data %>%
            dplyr::select(dplyr::one_of( L2.unit), contains("median"))
          output_table(
            object = s_data[1:n, ],
            col.names = names(s_data),
            format = format,
            digits = digits)
          if (n < nrow(s_data)) cat( paste("... with", nrow(s_data)-n, " more rows", "\n", "\n"), sep = "")
        }
      }
    }
  }
}

################################################################################
#                 Output table for summary                                     #
################################################################################

#' A table for the summary  function
#'
#' \code{output_table()} ...
#'
#' @inheritParams summary.autoMrP
#' @param col.names The column names of the table. A
#' @return No return value, prints a table to the console.

output_table <- function(object, col.names, format, digits){

  # output table
  print( knitr::kable(x = object,
                      col.names = col.names,
                      format = format,
                      digits = digits))

}


################################################################################
#                 Equal spacing on the log scale                               #
################################################################################

#' Sequence that is equally spaced on the log scale
#'
#' @param min The minimum value of the sequence. A positive numeric scalar (min
#'   > 0).
#' @param max The maximum value of the sequence. a positive numeric scalar (max
#'   > 0).
#' @param n The length of the sequence. An integer valued scalar.
#' @return Returns a numeric vector with length specified in argument \code{n}.
#'   The vector elements are equally spaced on the log-scale.

log_spaced <- function(min, max, n){
  return(base::exp( base::seq(from = base::log(min), to = base::log(max), length.out = n)))
}



################################################################################
#         Runs operations inside bootstrapping loop                            #
################################################################################

#' @inheritParams summary.autoMrP

# Bootstrap function run inside the foreach loop

boot_fun <- function(survey, L2.unit,
                     y, L1.x, L2.x, mrp.L2.x, L2.reg, L2.x.scale,
                     pcs, folds, bin.proportion, bin.size, census,
                     k.folds, cv.sampling, loss.unit, loss.fun, best.subset,
                     lasso, pca, gb, svm, mrp, forward.select, best.subset.L2.x,
                     lasso.L2.x, pca.L2.x, pc.names, gb.L2.x, svm.L2.x,
                     svm.L2.unit, svm.L2.reg, gb.L2.unit, gb.L2.reg,
                     lasso.lambda, lasso.n.iter, gb.interaction.depth,
                     gb.shrinkage, gb.n.trees.init, gb.n.trees.increase,
                     gb.n.trees.max, gb.n.minobsinnode, svm.kernel,
                     svm.gamma, svm.cost, ebma.tol, ebma.size, cores, verbose) {

  # Bootstrap sample --------------------------------------------------------

  # simple sampling with replacement
  # boot_sample <- dplyr::slice_sample(
  #   .data = survey, n = nrow(survey), replace = TRUE)

  # number of states to draw in cluster balanced bootstrap
  # avg_n <- survey %>% dplyr::mutate(nrows = dplyr::n()) %>%
  #   dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #   dplyr::mutate(state_n = dplyr::n(),
  #                 state_proportion =  (dplyr::n() / nrows))%>%
  #   dplyr::summarise(state_n = mean(state_n),
  #                    state_proportion = mean(state_proportion),
  #                    .groups = 'drop') %>%
  #   dplyr::summarise(n = weighted.mean(x = state_n, w = state_proportion)) %>%
  #   as.numeric
  #
  # #avg_n <- floor(x = (nrow(survey) / avg_n) )
  # avg_n <- round(x = (nrow(survey) / avg_n), digits = 0)
  #
  # # balanced cluster bootstrap
  # boot_sample <- survey %>%
  #   dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #   tidyr::nest() %>%
  #   dplyr::left_join(
  #     y = survey %>%
  #       dplyr::mutate(nrows = dplyr::n()) %>%
  #       dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #       dplyr::mutate(state_proportion =  (dplyr::n() / nrows)) %>%
  #       dplyr::summarise(state_proportion = mean(state_proportion),
  #                        .groups = 'drop'), by = L2.unit) %>%
  #   dplyr::ungroup() %>%
  #   dplyr::slice_sample(n = avg_n, weight_by = state_proportion,
  #                       replace = TRUE) %>%
  #   dplyr::select(-state_proportion) %>%
  #   tidyr::unnest(data)

  # # drop observations if the bootstrap sample is too large
  # if (nrow(boot_sample) > nrow(survey)){
  #   boot_sample <- dplyr::slice_sample(.data = boot_sample, n = nrow(survey),
  #                                      replace = TRUE)
  # }

  # # Sample 1) regions; 2) states; 3) individuals
  # if (!is.null(L2.reg)) {
  #   # Step 1: Sample regions but sample at least 2 different regions
  #   boot_sample <- dplyr::bind_rows(
  #       # Sample at least 2 different regions
  #       dplyr::slice_sample(.data = survey %>%
  #                             dplyr::group_by(!!rlang::sym(L2.reg)) %>%
  #                             tidyr::nest() %>%
  #                             dplyr::ungroup()
  #                           , n = 2, replace = FALSE),
  #       # Sample from all regions with replacement
  #       dplyr::slice_sample(.data = survey %>%
  #                             dplyr::group_by(!!rlang::sym(L2.reg)) %>%
  #                             tidyr::nest() %>%
  #                             dplyr::ungroup(),
  #                           n = (length(unique(unlist(survey[, L2.reg]))) - 2),
  #                           replace = TRUE)) %>%
  #     tidyr::unnest(data) %>%
  #     dplyr::ungroup() %>%
  #   # Step 2: sample states with replacement
  #     dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #     tidyr::nest() %>%
  #     dplyr::ungroup() %>%
  #     dplyr::slice_sample(n = nrow(.), replace = TRUE) %>%
  #     tidyr::unnest(data) %>%
  #     dplyr::ungroup() %>%
  #     # Step 3: Sample individuals with replacement
  #     dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #     tidyr::nest() %>%
  #     dplyr::mutate(data = purrr::map(data, function(x){
  #       data = dplyr::slice_sample(.data = x, n = nrow(x), replace = TRUE)})) %>%
  #     tidyr::unnest(data) %>%
  #     dplyr::ungroup()
  # } else {
  #   # Step 1: sample states with replacement
  #   boot_sample <- survey %>%
  #     dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #     tidyr::nest() %>%
  #     dplyr::ungroup() %>%
  #     dplyr::slice_sample(n = nrow(.), replace = TRUE) %>%
  #     tidyr::unnest(data) %>%
  #     dplyr::ungroup() %>%
  #     # Step 2: Sample individuals with replacement within states
  #     dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #     tidyr::nest() %>%
  #     dplyr::mutate(data = purrr::map(data, function(x){
  #       data = dplyr::slice_sample(.data = x, n = nrow(x), replace = TRUE)})) %>%
  #     tidyr::unnest(data) %>%
  #     dplyr::ungroup()
  # }

  # sample states with replacement & within states sample individuals with
  # replacement
  # Step 1: sample states with replacement
  # boot_sample <- survey %>%
  #   dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #   tidyr::nest() %>%
  #   dplyr::ungroup() %>%
  #   dplyr::slice_sample(n = nrow(.), replace = TRUE) %>%
  #   tidyr::unnest(data) %>%
  #   dplyr::ungroup() %>%
  #   # Step 2: Sample individuals with replacement within states
  #   dplyr::group_by(!!rlang::sym(L2.unit)) %>%
  #   tidyr::nest() %>%
  #   dplyr::mutate(data = purrr::map(data, function(x){
  #     data = dplyr::slice_sample(.data = x, n = nrow(x), replace = TRUE)})) %>%
  #   tidyr::unnest(data) %>%
  #   dplyr::ungroup()

  # cluster bootstrap - sample L2 units with replacement
  boot_sample <- survey %>%
    dplyr::group_by(!!rlang::sym(L2.unit)) %>%
    tidyr::nest() %>%
    dplyr::ungroup() %>%
    dplyr::slice_sample(n = nrow(.), replace = TRUE) %>%
    tidyr::unnest(data) %>%
    dplyr::ungroup()

  # state stratified sample
  # boot_sample <- survey %>%
  #   dplyr::group_by( !! rlang::sym(L2.unit) ) %>%
  #   tidyr::nest() %>%
  #   dplyr::mutate(data = purrr::map(data, function(x){
  #     data = dplyr::slice_sample(.data = x, n = nrow(x), replace = TRUE)
  #   })) %>%
  #   tidyr::unnest(data) %>%
  #   dplyr::ungroup()

  # at least one observation per state and simple sample
  # boot_sample <- survey %>%
  #   dplyr::group_by(!! rlang::sym(L2.unit)) %>%
  #   dplyr::mutate(state_pick = ifelse(
  #     test = dplyr::row_number() == sample(x = 1:dplyr::n(), size = 1),
  #     yes = 1, no = 0)) %>%
  #   dplyr::ungroup() %>%
  #   dplyr::group_by(state_pick) %>%
  #   tidyr::nest() %>%
  #   dplyr::mutate(data = ifelse(
  #     test = state_pick == 0,
  #     yes = purrr::map(data, function(x){
  #       dplyr::slice_sample(.data = x, n = nrow(x), replace = TRUE)
  #     }),
  #     no = data
  #   )) %>%
  #   tidyr::unnest(data) %>%
  #   dplyr::ungroup()

  # no-bootstrapping (same data as original survey (for testing only))
  # boot_sample <- survey


  # do not predict outcomes for states that are not in boot_sample ----------
  # boot_census <- census %>%
  #   dplyr::filter( !!rlang::sym(L2.unit) %in% unique(dplyr::pull(
  #     .data = boot_sample, var = !!rlang::sym(L2.unit))) )

  # Create folds ------------------------------------------------------------

  if (is.null(folds)) {

    # EBMA hold-out fold
    ebma.size <- round(nrow(boot_sample) * ebma.size, digits = 0)

    if(ebma.size>0){
      ebma_folding_out <- ebma_folding(data = boot_sample,
                                       L2.unit = L2.unit,
                                       ebma.size = ebma.size)
      ebma_fold <- ebma_folding_out$ebma_fold
      cv_data <- ebma_folding_out$cv_data
    } else{
      ebma_fold <- NULL
      cv_data <- boot_sample
    }

    # K folds for cross-validation
    cv_folds <- cv_folding(
      data = cv_data,
      L2.unit = L2.unit,
      k.folds = k.folds,
      cv.sampling = cv.sampling)
  } else {

    if (ebma.size > 0){
      # EBMA hold-out fold
      ebma_fold <- boot_sample %>%
        dplyr::filter_at(dplyr::vars(dplyr::one_of(folds)),
                         dplyr::any_vars(. == k.folds + 1))
    }

    # K folds for cross-validation
    cv_data <- boot_sample %>%
      dplyr::filter_at(dplyr::vars(dplyr::one_of(folds)),
                       dplyr::any_vars(. != k.folds + 1))

    cv_folds <- cv_data %>%
      dplyr::group_split(.data[[folds]])
  }



  # Run classifiers ---------------------------------------------------------

  # Estimate on 1 sample in autoMrP
  boot_mrp <- run_classifiers(
    census = census,
    cv.folds = cv_folds,
    cv.data = cv_data,
    ebma.fold = ebma_fold,
    ebma.n.draws = 1,
    verbose = FALSE,
    cores = 1,
    y = y,
    L1.x = L1.x,
    L2.x = L2.x,
    mrp.L2.x = mrp.L2.x,
    L2.unit = L2.unit,
    L2.reg = L2.reg,
    L2.x.scale = L2.x.scale,
    pcs = pcs,
    folds = folds,
    bin.proportion = bin.proportion,
    bin.size = bin.size,
    ebma.size = ebma.size,
    k.folds = k.folds,
    cv.sampling = cv.sampling,
    loss.unit = loss.unit,
    loss.fun = loss.fun,
    best.subset = best.subset,
    lasso = lasso,
    pca = pca,
    gb = gb,
    svm = svm,
    mrp = mrp,
    forward.select = forward.select,
    best.subset.L2.x = best.subset.L2.x,
    lasso.L2.x = lasso.L2.x,
    pca.L2.x = pca.L2.x,
    pc.names = pc.names,
    gb.L2.x = gb.L2.x,
    svm.L2.x = svm.L2.x,
    svm.L2.unit = svm.L2.unit,
    svm.L2.reg = svm.L2.reg,
    gb.L2.unit = gb.L2.unit,
    gb.L2.reg = gb.L2.reg,
    lasso.lambda = lasso.lambda,
    lasso.n.iter = lasso.n.iter,
    gb.interaction.depth = gb.interaction.depth,
    gb.shrinkage = gb.shrinkage,
    gb.n.trees.init = gb.n.trees.init,
    gb.n.trees.increase = gb.n.trees.increase,
    gb.n.trees.max = gb.n.trees.max,
    gb.n.minobsinnode = gb.n.minobsinnode,
    svm.kernel = svm.kernel,
    svm.gamma = svm.gamma,
    svm.cost = svm.cost,
    ebma.tol = ebma.tol
  )
}

Try the autoMrP package in your browser

Any scripts or data that you put into this service are public.

autoMrP documentation built on Aug. 17, 2023, 5:07 p.m.