R/predict_models.R

Defines functions predict_models

Documented in predict_models

#' Predict models
#'
#' @param data A data frame with columns "model" and "test" containing
#' the fitted models and test datasets, respectively.
#' @param select_cols Columns to select.
#' @param blend Default is NULL. Only has to be set when an mlist object has been fitted.
#' Defines how model predictions are blended. Options: "average" for averaging predictions,
#' or "logistic" for model stacking with logistic regression.
#' @param passed_args Default is NULL. Only has to be set when an mlist object has been fitted.
#' List of lists with arguments to be passed on to predict function (as used in do.call)
#' @param occ_col Name of the response (occurence) column in data.
#' @param keep_models Whether to keep the model objects in the output data frame. Can be set to FALSE to reduce object size.
#' @param ... arguments passed on to predict
#'
#' @return A data frame with added column "prediction" containing the model predictions.
#' @export

predict_models <- function(data, select_cols = NULL, blend = NULL, passed_args = NULL, occ_col = "occ", keep_models = TRUE, ...) {

    if("mlist" %in% class(data$model[[1]])) {

      out <- dplyr::mutate(data, prediction = purrr::map2(model, test, function(model,test) {
        test <- as.data.frame(test)
        if (!is.null(select_cols)){
          test <- dplyr::select(test, !!select_cols)
        }
        preds <- purrr::map2(model, passed_args, function(model, passed_args){
          argslist <- append(list(object = model, newdata = test), passed_args)
          do.call(predict, argslist)
        })

        preds_vector <- purrr::map_if(preds, function(x) !is.vector (x), function(x) (as.numeric(as.data.frame(x)[["1"]])))

        if (blend == "average") {

         unname(rowMeans(simplify2array(preds_vector)))

        } else if (blend == "logistic"){

          model_df <- cbind(as.data.frame(simplify2array(preds_vector)), y= dplyr::pull(test, occ_col))
          predict(glm(y~., family = "binomial", data = model_df), type = "response")

        }

      }))

    } else {

      out <- dplyr::mutate(data, prediction = purrr::map2(model, test, function(model,test) {
        test <- as.data.frame(test)
        if (!is.null(select_cols)){
          test <- dplyr::select(test, !!select_cols)
        }
        predict(model, test, ...)
      })) %>%
        mutate(
          prediction = purrr::map_if(
            prediction, function(x) !is.vector (x), function(x) (as.numeric(as.data.frame(x)[["1"]]))
            )
          )



    }

  if(!keep_models){
    out <- out %>% dplyr::select(-model)
  }

  return(out)

}
juoe/sdmflow documentation built on Feb. 23, 2020, 7:38 p.m.