R/fit_models.R

Defines functions fit_models mlist

Documented in fit_models

#' Fit models
#'
#' @param data A data frame with column "train" containing the training datasets.
#' @param model_call A quoted call for fitting the models.
#' @param drop_cols Names of columns to drop before fitting.
#' @param drop_zero_var_preds Boolean. Whether to drop variables with zero variance.
#'
#' @return Data frame with added column "model" containing the fitted models.
#' @export

fit_models <- function(data, model_call, drop_cols = NULL, drop_zero_var_preds = TRUE) {

  zero_var <- function(train) {
    names(unlist(which(!purrr::map(train, function(x) length(unique(x))) > 1)))
  }

  if (is.call(model_call)) {
    out <- dplyr::mutate(data, model = purrr::map(train, function(train){

      train <- as.data.frame(train)

      if (!is.null(drop_cols)) {
        train <- dplyr::select(train, -dplyr::one_of(drop_cols))
      }

      if (drop_zero_var_preds) {
        train <- dplyr::select(train, -dplyr::one_of(zero_var(train)))
      }

      eval(model_call)

    }))
  } else if (is.list(model_call)){

    out <- dplyr::mutate(data, model = purrr::map(train, function(train){

      train <- as.data.frame(train)

      if (!is.null(drop_cols)) {
        train <- dplyr::select(train, -one_of(drop_cols))
      }

      if (drop_zero_var_preds) {
        train <- dplyr::select(train, -one_of(zero_var(train)))
      }

      mlist(purrr::map(model_call, function(model_call){
        eval(model_call)
      }))


    }))

  } else {
    stop("Wrong input for model_call.")
  }

}


mlist <- function(models) {
  if (!is.list(models)) stop("models must be a list")
  structure(models, class = "mlist")
}
juoe/sdmflow documentation built on Feb. 23, 2020, 7:38 p.m.