R/pred_se.R

#' @title Prediction and Standard errors for models
#' @description Create point predictions and their standard errors for a list of models and \code{newdata} data frames. 
#' Predictions are specific to \code{merMod-class} objects and specific to this project.
#' @param obj.list A list of model objects of class merMod-class
#' @param newdata.list A list of newdata data.frames matching the models above on which predictions are desired
#' @param allow.new.levels Logical. Should new random effects levels be allowed? Defaults to TRUE
#' @param trans.func A list of reverse transformation functions. 
#' @param alpha A numeric scalar. The signigance level desired for Standard Errors. Defaults to 0.05
#' @param R.slope A character vector denoting the random slopes of the models. In \code{c("none", "xmas", "mf_day", "both")}
#' @param title A logical vector denoting if the level-2 grouping variable is at the 'title' level. Defaults to FALSE
#' @param level3 A logical vector denoting if the models are 3-level models. Defaults to FALSE
#' @return A \code{list} of predictions and standard errors for each dataset / model.
#' @export
pred_se <- function(obj.list, newdata.list, allow.new.levels= rep(TRUE, 3),
                    trans.func, alpha= 0.05, R.slope= rep("none", 3),
                    title= rep(FALSE, 3), level3= rep(FALSE, 3), usa= rep(FALSE, 3)) 
{ 
  
  if (length(obj.list) != length(newdata.list)) {
    stop("Must input equal length obj.list and newdata.list.")
  }
  
  require(lme4)
  require(Matrix)
  # overall preliminaries
  p <- length(obj.list)
  pred.dat <- list()
  
  for (i in 1:p) {    
    ### 01. Remove obs from newdata with new FE levels in format2  & flag shrinkage
    newdata2 <- trans_newdata(dat= newdata.list[[i]], object= obj.list[[i]], title= title[i], level3= level3[i])
        
    ### 02. get predictions and SE for each model    
    # point prediction    
    yhat <- predict(obj.list[[i]], newdata= newdata2, allow.new.levels= allow.new.levels[i])
    # make RE model matrix
    n <- nrow(newdata2)
    
    if (R.slope[i] == "none") {
      z.mm <- rep(1, length= n)
    } else if (R.slope[i] == "xmas") {
      z.mm <- Matrix(cbind(rep(1, length= n), newdata.list[[i]]$xmas), sparse= TRUE)
    } else if (R.slope[i] == "mf_day") {
      z.mm <- Matrix(cbind(rep(1, length= n), newdata.list[[i]]$mf_day), sparse= TRUE)
    } else if (R.slope[i] == "both") {
      z.mm <- Matrix(cbind(rep(1, length= n), newdata.list[[i]]$mf_day, newdata.list[[i]]$xmas), sparse= TRUE)
    }
    
    # 02B. SE of prediction
    if (usa[i] == FALSE) {
      V2 <- Matrix(Matrix::tcrossprod(z.mm %*% VarCorr(obj.list[[i]])[[1]], z.mm), sparse=TRUE) # level 2 covariance matrix
      V1 <- Diagonal(n=n, attributes(VarCorr(obj.list[[i]]))$sc^2) # level 1 variance
      V <- V2 + V1 # overall variance matrix
      rm(V1, V2); gc()
      y_se <- sqrt(Matrix::diag(V)) * qnorm(1-alpha/2) # SE of prediction estimate 
      rm(V); gc()  
    } else if (usa[i] == TRUE) {
      y_se <- sqrt(attributes(VarCorr(obj.list[[i]]))$sc^2 + attributes(VarCorr(obj.list[[i]])[[1]])$std^2) * 
        qnorm(1-alpha/2)
    }
    
    
    # 02c. transform predictions back to original scale
      # retain identifying info
    if (level3[i] == TRUE) {
      p_out <- data.frame(newdata2[, c("geography", "albumid", "title", "format2", "title_format2", 
                                       "date2", "shrinkage")],
                          l= rep(NA, n), y_hat= rep(NA, n), u= rep(NA,n))  
    } else {
      p_out <- data.frame(newdata2[, c("albumid", "title", "format2", "title_format2", "date2", "shrinkage")],
                          l= rep(NA, n), y_hat= rep(NA, n), u= rep(NA,n))
    }
        
    p_out[,"l"] <- trans.func[[i]](yhat - y_se)
    p_out[,"y_hat"] <- trans.func[[i]](yhat)
    p_out[,"u"]   <- trans.func[[i]](yhat + y_se)
    
    # add to return list
    pred.dat[[i]] <- p_out[!is.na(p_out$y_hat),]
  }
  ### 03. combine and return predictions
  
  
  return(pred.dat)
}
alexWhitworth/concord documentation built on May 11, 2019, 11:25 p.m.