R/predict_glm.R

Defines functions get_wts_varw get_pred_glm predict.ssn_glm

Documented in predict.ssn_glm

#' @param type The scale (\code{response} or \code{link}) of predictions obtained
#'   using \code{ssn_glm} objects.
#' @param newdata_size The \code{size} value for each observation in \code{newdata}
#'   used when predicting for the binomial family.
#' @param var_correct A logical indicating whether to return the corrected prediction
#'   variances when predicting via models fit using \code{ssn_glm}. The default is
#'   \code{TRUE}.
#' @rdname predict.SSN2
#' @method predict ssn_glm
#' @export
predict.ssn_glm <- function(object, newdata, type = c("link", "response"), se.fit = FALSE, interval = c("none", "confidence", "prediction"),
                            newdata_size, level = 0.95, var_correct = TRUE, ...) {
  # match type argument so the two display
  type <- match.arg(type)

  # match interval argument so the three display
  interval <- match.arg(interval)

  # deal with newdata_size
  if (missing(newdata_size)) newdata_size <- NULL

  # handle local for now
  local <- FALSE

  # new data name
  if (missing(newdata)) newdata <- NULL
  newdata_name <- newdata
  # iterate through prediction names
  if (is.null(newdata_name) || newdata_name == "all") {
    newdata_name <- names(object$ssn.object$preds)
  }
  if (length(newdata_name) > 1) {
    pred_list <- lapply(newdata_name, function(x) predict(object, x, se.fit = se.fit, interval = interval, level = level, local = local, ...))
    names(pred_list) <- newdata_name
    return(pred_list)
  }

  if (newdata_name == ".missing") {
    add_newdata_rows <- TRUE
  } else {
    add_newdata_rows <- FALSE
  }

  # rename relevant quantities
  obdata <- object$ssn.object$obs

  # newdata and newdata name
  newdata <- object$ssn.object$preds[[newdata_name]]

  # stop if zero rows
  if (NROW(newdata) == 0) {
    return(NULL)
  }

  # get params object
  params_object <- object$coefficients$params_object
  dispersion_params_val <- as.vector(coef(object, type = "dispersion"))

  # make covariance object
  cov_vector <- covmatrix(object, newdata_name)
  cov_vector_list <- split(cov_vector, seq_len(NROW(cov_vector)))

  formula_newdata <- delete.response(terms(object))
  # fix model frame bug with degree 2 basic polynomial and one prediction row
  # e.g. poly(x, y, degree = 2) and newdata has one row
  if (any(grepl("nmatrix.", attributes(formula_newdata)$dataClasses, fixed = TRUE)) && NROW(newdata) == 1) {
    newdata <- newdata[c(1, 1), , drop = FALSE]
    newdata_model_frame <- model.frame(formula_newdata, newdata, drop.unused.levels = FALSE, na.action = na.pass, xlev = object$xlevels)
    newdata_model <- model.matrix(formula_newdata, newdata_model_frame, contrasts = object$contrasts)
    newdata_model <- newdata_model[1, , drop = FALSE]
    # find offset
    offset <- model.offset(newdata_model_frame)
    if (!is.null(offset)) {
      offset <- offset[1]
    }
    newdata <- newdata[1, , drop = FALSE]
  } else {
    newdata_model_frame <- model.frame(formula_newdata, newdata, drop.unused.levels = FALSE, na.action = na.pass, xlev = object$xlevels)
    # assumes that predicted observations are not outside the factor levels
    newdata_model <- model.matrix(formula_newdata, newdata_model_frame, contrasts = object$contrasts)
    # find offset
    offset <- model.offset(newdata_model_frame)
  }

  attr_assign <- attr(newdata_model, "assign")
  attr_contrasts <- attr(newdata_model, "contrasts")
  keep_cols <- which(colnames(newdata_model) %in% colnames(model.matrix(object)))
  newdata_model <- newdata_model[, keep_cols, drop = FALSE]
  attr(newdata_model, "assign") <- attr_assign[keep_cols]
  attr(newdata_model, "contrasts") <- attr_contrasts

  # storing newdata as a list
  newdata_rows_list <- split(newdata, seq_len(NROW(newdata)))

  # storing newdata as a list
  newdata_model_list <- split(newdata_model, seq_len(NROW(newdata)))

  # storing newdata as a list
  newdata_list <- mapply(
    x = newdata_rows_list, y = newdata_model_list, c = cov_vector_list,
    FUN = function(x, y, c) list(row = x, x0 = y, c0 = c), SIMPLIFY = FALSE
  )

  # storing cov matrix
  cov_matrix_val <- covmatrix(object)

  # total var (could do sums params object)
  total_var <- cov_matrix_val[1, 1]

  if (interval %in% c("none", "prediction")) {
    local_list <- get_local_list_prediction(local)

    # matrix cholesky
    if (local_list$method == "all") {
      cov_matrix_val <- covmatrix(object)
      cov_lowchol <- t(chol(cov_matrix_val))
      predvar_adjust_ind <- FALSE
      predvar_adjust_all <- TRUE
    } else {
      cov_lowchol <- NULL
      predvar_adjust_ind <- TRUE
      predvar_adjust_all <- FALSE
    }

    # change predvar adjust based on var correct
    if (!var_correct) {
      predvar_adjust_ind <- FALSE
      predvar_adjust_all <- FALSE
    }

    # until big data back
    # if (local_list$parallel) {
    #   cl <- parallel::makeCluster(local_list$ncores)
    #   pred_val <- parallel::parLapply(cl, newdata_list, get_pred_glm,
    #                                   se.fit = se.fit, interval = interval, formula = object$formula,
    #                    obdata = obdata, cov_matrix_val = cov_matrix_val, total_var = total_var, cov_lowchol = cov_lowchol,
    #                    Xmat = model.matrix(object), y = model.response(model.frame(object)),
    #                    betahat = coefficients(object), cov_betahat = vcov(object, var_correct = FALSE),
    #                    contrasts = object$contrasts, local = local_list,
    #                    family = object$family, w = fitted(object, type = "link"),
    #                    size = object$size, dispersion = dispersion_params_val,
    #                    predvar_adjust_ind = predvar_adjust_ind, xlevels = object$xlevels)
    #   cl <- parallel::stopCluster(cl)
    # } else {
    #   pred_val <- lapply(newdata_list, get_pred_glm,
    #                    se.fit = se.fit, interval = interval, formula = object$formula,
    #                    obdata = obdata, cov_matrix_val = cov_matrix_val, total_var = total_var, cov_lowchol = cov_lowchol,
    #                    Xmat = model.matrix(object), y = model.response(model.frame(object)),
    #                    betahat = coefficients(object), cov_betahat = vcov(object, var_correct = FALSE),
    #                    contrasts = object$contrasts, local = local_list,
    #                    family = object$family, w = fitted(object, type = "link"),
    #                    size = object$size, dispersion = dispersion_params_val,
    #                    predvar_adjust_ind = predvar_adjust_ind, xlevels = object$xlevels)

    # }

    pred_val <- lapply(newdata_list, get_pred_glm,
      se.fit = se.fit, interval = interval, formula = object$formula,
      obdata = obdata, cov_matrix_val = cov_matrix_val, total_var = total_var, cov_lowchol = cov_lowchol,
      Xmat = model.matrix(object), y = model.response(model.frame(object)),
      betahat = coefficients(object), cov_betahat = vcov(object, var_correct = FALSE),
      contrasts = object$contrasts, local = local_list,
      family = object$family, w = fitted(object, type = "link"),
      size = object$size, dispersion = dispersion_params_val,
      predvar_adjust_ind = predvar_adjust_ind, xlevels = object$xlevels
    )




    if (interval == "none") {
      fit <- vapply(pred_val, function(x) x$fit, numeric(1))
      # apply offset
      if (!is.null(offset)) {
        fit <- fit + offset
      }
      if (type == "response") {
        fit <- invlink(fit, object$family, newdata_size)
      }
      if (se.fit) {
        vars <- vapply(pred_val, function(x) x$var, numeric(1))
        if (predvar_adjust_all) {
          # predvar_adjust is for the local function so FALSE there is TRUE
          # here
          vars_adj <- get_wts_varw(
            family = object$family,
            Xmat = model.matrix(object),
            y = object$y,
            w = fitted(object, type = "link"),
            size = object$size,
            dispersion = dispersion_params_val,
            cov_lowchol = cov_lowchol,
            x0 = newdata_model,
            c0 = cov_vector
          )
          vars <- vars_adj + vars
        }
        se <- sqrt(vars)
        if (add_newdata_rows) {
          names(fit) <- object$missing_index
          names(se) <- object$missing_index
        }
        return(list(fit = fit, se.fit = se))
      } else {
        if (add_newdata_rows) {
          names(fit) <- object$missing_index
        }
        return(fit)
      }
    }

    if (interval == "prediction") {
      fit <- vapply(pred_val, function(x) x$fit, numeric(1))
      # apply offset
      if (!is.null(offset)) {
        fit <- fit + offset
      }
      vars <- vapply(pred_val, function(x) x$var, numeric(1))
      if (predvar_adjust_all) {
        vars_adj <- get_wts_varw(
          family = object$family,
          Xmat = model.matrix(object),
          y = object$y,
          w = fitted(object, type = "link"),
          size = object$size,
          dispersion = dispersion_params_val,
          cov_lowchol = cov_lowchol,
          x0 = newdata_model,
          c0 = cov_vector
        )
        vars <- vars_adj + vars
      }
      se <- sqrt(vars)
      # tstar <- qt(1 - (1 - level) / 2, df = object$n - object$p)
      tstar <- qnorm(1 - (1 - level) / 2)
      lwr <- fit - tstar * se
      upr <- fit + tstar * se
      if (type == "response") {
        fit <- invlink(fit, object$family, newdata_size)
        lwr <- invlink(lwr, object$family, newdata_size)
        upr <- invlink(upr, object$family, newdata_size)
      }
      fit <- cbind(fit, lwr, upr)
      row.names(fit) <- seq_len(NROW(fit))
      if (se.fit) {
        if (add_newdata_rows) {
          row.names(fit) <- object$missing_index
          names(se) <- object$missing_index
        }
        return(list(fit = fit, se.fit = se))
      } else {
        if (add_newdata_rows) {
          row.names(fit) <- object$missing_index
        }
        return(fit)
      }
    }
  } else if (interval == "confidence") {
    # finding fitted values of the mean parameters
    fit <- as.numeric(newdata_model %*% coef(object))
    # apply offset
    if (!is.null(offset)) {
      fit <- fit + offset
    }
    newdata_model_list <- split(newdata_model, seq_len(NROW(newdata_model)))
    vars <- as.numeric(vapply(newdata_model_list, function(x) crossprod(x, vcov(object) %*% x), numeric(1)))
    se <- sqrt(vars)
    # tstar <- qt(1 - (1 - level) / 2, df = object$n - object$p)
    tstar <- qnorm(1 - (1 - level) / 2)
    lwr <- fit - tstar * se
    upr <- fit + tstar * se
    if (type == "response") {
      fit <- invlink(fit, object$family, newdata_size)
      lwr <- invlink(lwr, object$family, newdata_size)
      upr <- invlink(upr, object$family, newdata_size)
    }
    fit <- cbind(fit, lwr, upr)
    row.names(fit) <- seq_len(NROW(fit))
    if (se.fit) {
      if (add_newdata_rows) {
        row.names(fit) <- object$missing_index
        names(se) <- object$missing_index
      }
      return(list(fit = fit, se.fit = se))
    } else {
      if (add_newdata_rows) {
        row.names(fit) <- object$missing_index
      }
      return(fit)
    }
  } else {
    stop("Interval must be none, confidence, or prediction")
  }
}






get_pred_glm <- function(newdata_list, se.fit, interval,
                         formula, obdata, cov_matrix_val, total_var, cov_lowchol,
                         Xmat, y, betahat, cov_betahat, contrasts, local,
                         family, w, size, dispersion, predvar_adjust_ind, xlevels) {
  cov_vector_val <- newdata_list$c0

  if (local$method == "covariance") {
    # n <- length(cov_vector_val)
    # cov_index <- order(as.numeric(cov_vector_val))[seq(from = n, to = max(1, n - local$size + 1))] # use abs() here?
    # obdata <- obdata[cov_index, , drop = FALSE]
    # cov_vector_val <- cov_vector_val[cov_index]
    # cov_matrix_val <- cov_matrix_val[cov_index, cov_index, drop = FALSE]
    # w <- w[cov_index]
    # y <- y[cov_index]
    # if (!is.null(size)) {
    #   size <- size[cov_index]
    # }
    # cov_lowchol <- t(Matrix::chol(Matrix::forceSymmetric(cov_matrix_val)))
    # model_frame <- model.frame(formula, obdata, drop.unused.levels = TRUE, na.action = na.pass, xlev = xlevels)
    # Xmat <- model.matrix(formula, model_frame, contrasts = contrasts)
  }

  c0 <- as.numeric(cov_vector_val)
  SqrtSigInv_X <- forwardsolve(cov_lowchol, Xmat)
  SqrtSigInv_w <- forwardsolve(cov_lowchol, w)
  residuals_pearson <- SqrtSigInv_w - SqrtSigInv_X %*% betahat
  SqrtSigInv_c0 <- forwardsolve(cov_lowchol, c0)
  x0 <- newdata_list$x0

  fit <- as.numeric(x0 %*% betahat + Matrix::crossprod(SqrtSigInv_c0, residuals_pearson))
  H <- x0 - Matrix::crossprod(SqrtSigInv_c0, SqrtSigInv_X)
  if (se.fit || interval == "prediction") {
    var <- as.numeric(total_var - Matrix::crossprod(SqrtSigInv_c0, SqrtSigInv_c0) + H %*% Matrix::tcrossprod(cov_betahat, H))
    if (predvar_adjust_ind) {
      var_adj <- get_wts_varw(family, Xmat, y, w, size, dispersion, cov_lowchol, x0, c0)
      var <- var_adj + var
    }
    pred_list <- list(fit = fit, var = var)
  } else {
    pred_list <- list(fit = fit)
  }
  pred_list
}

get_wts_varw <- function(family, Xmat, y, w, size, dispersion, cov_lowchol, x0, c0) {
  SigInv <- chol2inv(t(cov_lowchol)) # works on upchol
  SigInv_X <- SigInv %*% Xmat
  cov_betahat <- chol2inv(chol(Matrix::forceSymmetric(crossprod(Xmat, SigInv_X))))
  wts_beta <- tcrossprod(cov_betahat, SigInv_X)
  Ptheta <- SigInv - SigInv_X %*% wts_beta

  d <- get_d(family, w, y, size, dispersion)
  # and then the gradient vector
  # g <-  d - Ptheta %*% w
  # Next, compute H
  D <- get_D(family, w, y, size, dispersion)
  H <- D - Ptheta
  mHInv <- solve(-H) # chol2inv(chol(Matrix::forceSymmetric(-H))) # solve(-H)

  if (is.vector(x0)) { # for length-one predicts result x0 c0 are vectors (how splm pred operates)
    wts_pred <- x0 %*% wts_beta + c0 %*% SigInv - (c0 %*% SigInv_X) %*% wts_beta
    var_adj <- as.numeric(wts_pred %*% tcrossprod(mHInv, wts_pred))
  } else { # this is to handle the matrix arguments for non-local predict calls with spglm
    if (NROW(x0) == 1) {
      wts_pred <- x0 %*% wts_beta + c0 %*% SigInv - (c0 %*% SigInv_X) %*% wts_beta
      var_adj <- as.numeric(wts_pred %*% tcrossprod(mHInv, wts_pred))
    } else {
      var_adj <- vapply(seq_len(NROW(x0)), function(x) { # this is so that only the diagonal of these products is returned
        x0_new <- x0[x, , drop = FALSE]
        c0_new <- c0[x, , drop = FALSE]
        wts_pred <- x0_new %*% wts_beta + c0_new %*% SigInv - (c0_new %*% SigInv_X) %*% wts_beta
        as.numeric(wts_pred %*% tcrossprod(mHInv, wts_pred))
      }, numeric(1))
    }
  }
  var_adj
}

Try the SSN2 package in your browser

Any scripts or data that you put into this service are public.

SSN2 documentation built on May 29, 2024, 4:41 a.m.