R/emfrail.R

Defines functions emfrail

Documented in emfrail

#' Fitting semi-parametric shared frailty models with the EM algorithm
#'
#' @importFrom survival Surv coxph cox.zph
#' @importFrom stats approx coef model.frame model.matrix pchisq printCoefmat nlm uniroot cor optimize
#' @importFrom magrittr "%>%"
#' @importFrom Rcpp evalCpp
#' @importFrom Matrix bdiag
#' @importFrom numDeriv hessian
#' @useDynLib frailtyEM, .registration=TRUE
#' @include em_fit.R
#' @include emfrail_aux.R
#'
#' @param formula A formula that contains on the left hand side an object of the type \code{Surv}
#' and on the right hand side a \code{+cluster(id)} statement. Two special statments may also be used:
#' \code{+strata()} for specifying a grouping column that will represent different strata and
#' \code{+terminal()}
#' @param data A \code{data.frame} in which the formula argument can be evaluated
#' @param distribution An object as created by \code{\link{emfrail_dist}}
#' @param control An object as created by \code{\link{emfrail_control}}
#' @param model Logical. Should the model frame be returned?
#' @param model.matrix Logical. Should the model matrix be returned?
#' @param ... Other arguments, currently used to warn about deprecated argument names
#' @export
#'
#' @details The \code{emfrail} function fits shared frailty models for processes which have intensity
#' \deqn{\lambda(t) = z \lambda_0(t) \exp(\beta' \mathbf{x})}
#' with a non-parametric (Breslow) baseline intensity \eqn{\lambda_0(t)}. The outcome
#' (left hand side of the \code{formula}) must be a \code{Surv} object.
#'
#' If the object is \code{Surv(tstop, status)} then the usual failure time data is represented.
#' Gap-times between recurrent events are represented in the same way.
#' If the left hand side of the formula is created as \code{Surv(tstart, tstop, status)}, this may represent a number of things:
#' (a) recurrent events episodes in calendar time where a recurrent event episode starts at \code{tstart} and ends at \code{tstop}
#' (b) failure time data with time-dependent covariates where \code{tstop} is the time of a change in covariates or censoring
#' (\code{status = 0}) or an event time (\code{status = 1}) or (c) clustered failure time with left truncation, where
#' \code{tstart} is the individual's left truncation time. Unlike regular Cox models, a major distinction is that in case (c) the
#' distribution of the frailty must be considered conditional on survival up to the left truncation time.
#'
#' The \code{+cluster()} statement specified the column that determines the grouping (the observations that share the same frailty).
#' The \code{+strata()} statement specifies a column that determines different strata, for which different baseline hazards are calculated.
#' The \code{+terminal} specifies a column that contains an indicator for dependent censoring, and then performs a score test
#'
#' The \code{distribution} argument must be generated by a call to \code{\link{emfrail_dist}}. This determines the
#' frailty distribution, which may be one of gamma, positive stable or PVF (power-variance-function), and the starting
#' value for the maximum likelihood estimation. The PVF family
#' also includes a tuning parameter that differentiates between inverse Gaussian and compound Poisson distributions.
#' Note that, with univariate data (at most one event per individual, no clusters), only distributions with finite expectation
#' are identifiable. This means that the positive stable distribution should have a maximum likelihood on the edge of the parameter
#' space (\eqn{theta = +\inf}, corresponding to a Cox model for independent observations).
#'
#' The \code{control} argument must be generated by a call to \code{\link{emfrail_control}}. Several parameters
#' may be adjusted that control the precision of the convergenge criteria or supress the calculation of different
#' quantities.
#'
#' @return An object of class \code{emfrail} that contains the following fields:
#' \item{coefficients}{A named vector of the estimated regression coefficients}
#' \item{hazard}{The breslow estimate of the baseline hazard at each event time point, in chronological order}
#' \item{var}{The variance-covariance matrix corresponding to the coefficients and hazard, assuming \eqn{\theta} constant}
#' \item{var_adj}{The variance-covariance matrx corresponding to the
#' coefficients and hazard, adjusted for the estimation of theta}
#' \item{logtheta}{The logarithm of the point estimate of \eqn{\theta}. For the gamma and
#' PVF family of distributions, this is the inverse of the estimated frailty variance.}
#' \item{var_logtheta}{The variance of the estimated logarithm of \eqn{\theta}}
#' \item{ci_logtheta}{The likelihood-based 95\% confidence interval for the logarithm of \eqn{\theta}}
#' \item{frail}{The posterior (empirical Bayes) estimates of the frailty for each cluster}
#' \item{residuals}{A list with two elements, cluster which is a vector that the sum of the
#' cumulative hazards from each cluster for a frailty value of 1, and
#' individual, which is a vector that contains the cumulative hazard corresponding to each row of the data,
#'  multiplied by the corresponding frailty estimate}
#' \item{tev}{The time points of the events in the data set, this is the same length as hazard}
#' \item{nevents_id}{The number of events for each cluster}
#' \item{loglik}{A vector of length two with the log-likelihood of the starting Cox model
#' and the maximized log-likelihood}
#' \item{ca_test}{The results of the Commenges-Andersen test for heterogeneity}
#' \item{cens_test}{The results of the test for dependence between a recurrent event and a terminal event,
#' if the \code{+terminal()} statement is specified and the frailty distribution is gamma}
#' \item{zph}{The result of \code{cox.zph} called on a model with the estimated log-frailties as offset}
#' \item{formula, distribution, control}{The original arguments}
#' \item{nobs, fitted}{Number of observations and fitted values (i.e. \eqn{z \exp(\beta^T x)})}
#' \item{mf}{The \code{model.frame}, if \code{model = TRUE}}
#' \item{mm}{The \code{model.matrix}, if \code{model.matrix = TRUE}}
#'
#' @md
#' @note Several options in the \code{control} arguemnt shorten the running time for \code{emfrail} significantly.
#' These are disabling the adjustemnt of the standard errors (\code{se_adj = FALSE}), disabling the likelihood-based confidence intervals (\code{lik_ci = FALSE}) or
#' disabling the score test for heterogeneity (\code{ca_test = FALSE}).
#'
#' The algorithm is detailed in the package vignette. For the gamma frailty,
#' the results should be identical with those from \code{coxph} with \code{ties = "breslow"}.
#'
#' @seealso \code{\link{plot.emfrail}} and \code{\link{autoplot.emfrail}} for plot functions directly available, \code{\link{emfrail_pll}} for calculating \eqn{\widehat{L}(\theta)} at specific values of \eqn{\theta},
#' \code{\link{summary.emfrail}} for transforming the \code{emfrail} object into a more human-readable format and for
#' visualizing the frailty (empirical Bayes) estimates,
#' \code{\link{predict.emfrail}} for calculating and visalizing conditional and marginal survival and cumulative
#' hazard curves. \code{\link{residuals.emfrail}} for extracting martingale residuals and \code{\link{logLik.emfrail}} for extracting
#' the log-likelihood of the fitted model.
#'
#' @examples
#'
#' m_gamma <- emfrail(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'                    data =  rats)
#'
#' # Inverse Gaussian distribution
#' m_ig <- emfrail(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'                 data =  rats,
#'                 distribution = emfrail_dist(dist = "pvf"))
#'
#' # for the PVF distribution with m = 0.75
#' m_pvf <- emfrail(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'                  data =  rats,
#'                  distribution = emfrail_dist(dist = "pvf", pvfm = 0.75))
#'
#' # for the positive stable distribution
#' m_ps <- emfrail(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'                 data =  rats,
#'                 distribution = emfrail_dist(dist = "stable"))
#' \dontrun{
#' # Compare marginal log-likelihoods
#' models <- list(m_gamma, m_ig, m_pvf, m_ps)
#'
#' models
#' logliks <- lapply(models, logLik)
#'
#' names(logliks) <- lapply(models,
#'                          function(x) with(x$distribution,
#'                                           ifelse(dist == "pvf",
#'                                                  paste(dist, "/", pvfm),
#'                                                  dist))
#' )
#'
#' logliks
#' }
#'
#' # Stratified analysis
#' \dontrun{
#'   m_strat <- emfrail(formula = Surv(time, status) ~  rx + strata(sex) + cluster(litter),
#'                      data =  rats)
#' }
#'
#'
#' # Test for conditional proportional hazards (log-frailty as offset)
#' \dontrun{
#' m_gamma <- emfrail(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'   data =  rats, control = emfrail_control(zph = TRUE))
#' par(mfrow = c(1,2))
#' plot(m_gamma$zph)
#' }
#'
#' # Draw the profile log-likelihood
#' \dontrun{
#'   fr_var <- seq(from = 0.01, to = 1.4, length.out = 20)
#'
#'   # For gamma the variance is 1/theta (see parametrizations)
#'   pll_gamma <- emfrail_pll(formula = Surv(time, status) ~  rx + sex + cluster(litter),
#'                            data =  rats,
#'                            values = 1/fr_var )
#'   plot(fr_var, pll_gamma,
#'        type = "l",
#'        xlab = "Frailty variance",
#'        ylab = "Profile log-likelihood")
#'
#'
#'   # Recurrent events
#'   mod_rec <- emfrail(Surv(start, stop, status) ~ treatment + cluster(id), bladder1)
#'   # The warnings appear from the Surv object, they also appear in coxph.
#'
#'   plot(mod_rec, type = "hist")
#' }
#'
#' # Left truncation
#' \dontrun{
#'   # We simulate some data with truncation times
#'   set.seed(2018)
#'   nclus <- 300
#'   nind <- 5
#'   x <- sample(c(0,1), nind * nclus, TRUE)
#'   u <- rep(rgamma(nclus,1,1), each = 3)
#'
#'   stime <- rexp(nind * nclus, rate = u * exp(0.5 * x))
#'
#'   status <- ifelse(stime > 5, 0, 1)
#'   stime[status == 0] <- 5
#'
#'   # truncate uniform between 0 and 2
#'   ltime <- runif(nind * nclus, min = 0, max = 2)
#'
#'   d <- data.frame(id = rep(1:nclus, each = nind),
#'                   x = x,
#'                   stime = stime,
#'                   u = u,
#'                   ltime = ltime,
#'                   status = status)
#'   d_left <- d[d$stime > d$ltime,]
#'
#'   mod <- emfrail(Surv(stime, status)~ x + cluster(id), d)
#'   # This model ignores the left truncation, 0.378 frailty variance:
#'   mod_1 <- emfrail(Surv(stime, status)~ x + cluster(id), d_left)
#'
#'   # This model takes left truncation into account,
#'  # but it considers the distribution of the frailty unconditional on the truncation
#'  mod_2 <- emfrail(Surv(ltime, stime, status)~ x + cluster(id), d_left)
#'
#'   # This is identical with:
#'   mod_cox <- coxph(Surv(ltime, stime, status)~ x + frailty(id), data = d_left)
#'
#'
#'   # The correct thing is to consider the distribution of the frailty given the truncation
#'   mod_3 <- emfrail(Surv(ltime, stime, status)~ x + cluster(id), d_left,
#'                    distribution = emfrail_dist(left_truncation = TRUE))
#'
#'   summary(mod_1)
#'   summary(mod_2)
#'   summary(mod_3)
#' }
#' @author Theodor Balan \email{hello@@tbalan.com}
#' @references Balan TA, Putter H (2019) "frailtyEM: An R Package for Estimating Semiparametric Shared Frailty Models", \emph{Journal of Statistical Software} \strong{90}(7) 1-29. doi:10.18637/jss.v090.i07

emfrail <- function(formula,
                    data,
                    distribution = emfrail_dist(),
                    control = emfrail_control(),
                    model = FALSE, model.matrix = FALSE,
                    ...) {


  # browser()
  # This part is because the update breaks old code
  extraargs <- list(...)

  if(!inherits(formula, "formula")) {
    if(inherits(formula, "data.frame")) warning("You gave a data.frame instead of a formula.
                                                Argument order has changed; now it's emfrail(formula, data, etc..).")
      stop("formula is not an object of type formula")
  }

  if(!inherits(data, "data.frame")) {
    if(inherits(data, "formula")) warning("You gave a formula instead of a data.frame.
                                            Argument order has changed; now it's emfrail(formula, data, etc..).")
    stop("data is not an object of type data.frame")
  }

  if(!inherits(distribution, "emfrail_dist"))
    stop("distribution argument misspecified; see ?emfrail_dist()")

  if(!inherits(control, "emfrail_control"))
    stop("control argument misspecified; see ?emfrail_control()")

  if(isTRUE(control$em_control$fast_fit)) {
    if(!(distribution$dist %in% c("gamma", "pvf"))) {
      #message("fast_fit option only available for gamma and pvf with m=-1/2 distributions")
      control$em_control$fast_fit <- FALSE
    }

    # version 0.5.6, the IG fast fit gets super sensitive at small frailty variance...
    if(distribution$dist == "pvf")
      control$em_control$fast_fit <- FALSE

  }


  Call <- match.call()


  if(missing(formula) | missing(data)) stop("Missing arguments")

  cluster <- function(x) x
  terminal <- function(x) x
  strata <- function(x) x

  mf <- model.frame(formula, data)


  # Identify the cluster and the ID column
  pos_cluster <- grep("cluster", names(mf))
  if(length(pos_cluster) != 1) stop("misspecified or non-specified cluster")
  id <- mf[[pos_cluster]]

  pos_terminal <- grep("terminal", names(mf))
  if(length(pos_terminal) > 1) stop("misspecified terminal()")

  pos_strata <- grep("strata", names(mf))
  if(length(pos_strata) > 0) {
    if(length(pos_strata) > 1) stop("only one strata() variable allowed")
    strats <- as.numeric(mf[[pos_strata]])
    label_strats <- levels(mf[[pos_strata]])
  } else {
    # else, everyone is in the same strata
    strats <- NULL
    label_strats <- "1"
  }


  Y <- mf[[1]]
  if(!inherits(Y, "Surv")) stop("left hand side not a survival object")
  if(ncol(Y) != 3) {
    # making it all in (tstart, tstop) format
    Y <- Surv(rep(0, nrow(Y)), Y[,1], Y[,2])
  }


  X1 <- model.matrix(formula, data)

  pos_cluster_X1 <- grep("cluster", colnames(X1))
  pos_terminal_X1 <- grep("terminal", colnames(X1))
  pos_strata_X1 <- grep("strata", colnames(X1))

  X <- X1[,-c(1, pos_cluster_X1, pos_terminal_X1, pos_strata_X1), drop=FALSE]
  # note: X has no attributes, in coxph it does.

  # mcox also works with empty matrices, but also with NULL as x.

  mcox <- survival::agreg.fit(x = X, y = Y, strata = strats, offset = NULL, init = NULL,
                              control = survival::coxph.control(),
                              weights = NULL, method = "breslow", rownames = NULL)
  # order(strat, -Y[,2])
  # the "baseline" case // this will stay constant

  if(length(X) == 0) {
    newrisk <- 1
    exp_g_x <- matrix(rep(1, length(mcox$linear.predictors)), nrow = 1)
    g <- 0
    g_x <- t(matrix(rep(0, length(mcox$linear.predictors)), nrow = 1))
  } else {
    x2 <- matrix(rep(0, ncol(X)), nrow = 1, dimnames = list(123, dimnames(X)[[2]]))
    x2 <- scale(x2, center = mcox$means, scale = FALSE)
    newrisk <- exp(c(x2 %*% mcox$coefficients) + 0)
    exp_g_x <- exp(mcox$coefficients %*% t(X))
    g <- mcox$coefficients
    g_x <- t(mcox$coefficients %*% t(X))

  }

  explp <- exp(mcox$linear.predictors) # these are with centered covariates

  # now thing is that maybe this is not very necessary,
  # but it keeps track of which row belongs to which cluster
  # and then we don't have to keep on doing this
  order_id <- match(id, unique(id))

  nev_id <- as.numeric(rowsum(Y[,3], order_id, reorder = FALSE)) # nevent per cluster
  names(nev_id) <- unique(id)

  # nrisk has the sum with every tstop and the sum of elp at risk at that tstop
  # esum has the sum of elp who enter at every tstart
  # indx groups which esum is right after each nrisk;
  # the difference between the two is the sum of elp really at risk at that time point.

  if(!is.null(strats)) {

    explp_str <- split(explp, strats)
    tstop_str <- split(Y[,2], strats)
    tstart_str <- split(Y[,1], strats)

    ord_tstop_str <- lapply(tstop_str, function(x) match(x, sort(unique(x))))
    ord_tstart_str <- lapply(tstart_str, function(x) match(x, sort(unique(x))))

    nrisk <- mapply(FUN = function(explp, y) rowsum_vec(explp, y, max(y)),
                                       explp_str,
                                       ord_tstop_str,
                                       SIMPLIFY = FALSE)

    # nrisk <- mapply(FUN = function(explp, y) rev(cumsum(rev(rowsum(explp, y[,2])))),
    #                 split(explp, strats),
    #                 split.data.frame(Y, strats),
    #                 SIMPLIFY = FALSE)


    esum <- mapply(FUN = function(explp, y) rowsum_vec(explp, y, max(y)),
                    explp_str,
                    ord_tstart_str,
                    SIMPLIFY = FALSE)

    # esum <-  mapply(FUN = function(explp, y) rev(cumsum(rev(rowsum(explp, y[,1])))),
    #                 split(explp, strats),
    #                 split.data.frame(Y, strats),
    #                 SIMPLIFY = FALSE)

    death <- lapply(
      X = split.default(Y[,3], strats),
      FUN = function(y) (y == 1)
    )

    nevent <- mapply(
      FUN = function(y, d)
        as.vector(rowsum(1 * d, y)),
      tstop_str,
      death,
      SIMPLIFY = FALSE
    )

    time_str <- lapply(
      X = tstop_str,
      FUN = function(y) sort(unique(y))
    )

    delta <- min(diff(sort(unique(Y[,2]))))/2

    time <- sort(unique(Y[,2])) # unique tstops

    etime <- lapply(
      X = tstart_str,
      FUN = function(y) c(0, sort(unique(y)),  max(y) + delta)
    )

    indx <-
      mapply(FUN = function(time, etime) findInterval(time, etime, left.open = TRUE),
             time_str,
             etime,
             SIMPLIFY = FALSE
      )

    indx2 <-
      mapply(FUN = function(y, time) findInterval(y, time),
             tstart_str,
             time_str,
             SIMPLIFY = FALSE
      )

    time_to_stop <-
      mapply(FUN = function(y, time) match(y, time),
             tstop_str,
             time_str,
             SIMPLIFY = FALSE
      )

    positions_strata <- do.call(c,split(1:nrow(Y), strats))

    atrisk <- list(death = death, nevent = nevent, nev_id = nev_id,
                   order_id = order_id, time = time, indx = indx, indx2 = indx2,
                   time_to_stop = time_to_stop,
                   ord_tstart_str = ord_tstart_str,
                   ord_tstop_str = ord_tstop_str,
                   positions_strata = positions_strata,
                   strats = strats)

    nrisk <- mapply(FUN = function(nrisk, esum, indx)  nrisk - c(esum, 0,0)[indx],
                    nrisk,
                    esum,
                    indx,
                    SIMPLIFY = FALSE)

    if(newrisk == 0) warning("Hazard ratio very extreme; please check (and/or rescale) your data")

    haz <- mapply(FUN = function(nevent, nrisk) nevent/nrisk * newrisk,
                  nevent,
                  nrisk,
                  SIMPLIFY = FALSE)

    basehaz_line <- mapply(FUN = function(haz, time_to_stop) haz[time_to_stop],
                           haz,
                           time_to_stop,
                           SIMPLIFY = FALSE)

    cumhaz <- lapply(haz, cumsum)

    cumhaz_0_line <- mapply(FUN = function(cumhaz, time_to_stop) cumhaz[time_to_stop],
                            cumhaz,
                            time_to_stop,
                            SIMPLIFY = FALSE)

    cumhaz_tstart <- mapply(FUN = function(cumhaz, indx2) c(0, cumhaz)[indx2 + 1],
                            cumhaz,
                            indx2,
                            SIMPLIFY = FALSE)

    cumhaz_line <- mapply(FUN = function(cumhaz_0_line, cumhaz_tstart, explp)
      (cumhaz_0_line - cumhaz_tstart) * explp / newrisk,
      cumhaz_0_line,
      cumhaz_tstart,
      split(explp, strats),
      SIMPLIFY = FALSE)

    cumhaz_line <- do.call(c, cumhaz_line)[order(positions_strata)]


  } else {

    ord_tstop <- match(Y[,2], sort(unique(Y[,2])))
    ord_tstart <- match(Y[,1], sort(unique(Y[,1])))

    nrisk <- rowsum_vec(explp, ord_tstop, max(ord_tstop))
    # nrisk <- rev(cumsum(rev(rowsum(explp, Y[, ncol(Y) - 1]))))
    esum <- rowsum_vec(explp, ord_tstart, max(ord_tstart))
    # esum <- rev(cumsum(rev(rowsum(explp, Y[, 1]))))

    death <- (Y[, 3] == 1)

    nevent <- as.vector(rowsum(1 * death, Y[, ncol(Y) - 1])) # per time point

    time <- sort(unique(Y[,2])) # unique tstops

    etime <- c(0, sort(unique(Y[, 1])),  max(Y[, 1]) + min(diff(time))/2)
    indx <- findInterval(time, etime, left.open = TRUE) # left.open  = TRUE is very important

    # this gives for every tstart (line variable), after which event time did it come
    indx2 <- findInterval(Y[,1], time)

    time_to_stop <- match(Y[,2], time)

    atrisk <- list(death = death, nevent = nevent, nev_id = nev_id,
                   order_id = order_id,
                   time = time, indx = indx, indx2 = indx2,
                   time_to_stop = time_to_stop,
                   ord_tstart = ord_tstart, ord_tstop = ord_tstop,
                   strats = NULL)
    nrisk <- nrisk - c(esum, 0,0)[indx]

    if(newrisk == 0) warning("Hazard ratio very extreme; please check (and/or rescale) your data")

    haz <- nevent/nrisk * newrisk
    basehaz_line <- haz[atrisk$time_to_stop]
    cumhaz <- cumsum(haz)

    cumhaz_0_line <- cumhaz[atrisk$time_to_stop]
    cumhaz_tstart <- c(0, cumhaz)[atrisk$indx2 + 1]
    cumhaz_line <- (cumhaz[atrisk$time_to_stop] - c(0, cumhaz)[atrisk$indx2 + 1]) * explp / newrisk
  }


  Cvec <- rowsum(cumhaz_line, order_id, reorder = FALSE)

  # browser()
  # this part under construction!
  # if(distribution$basehaz != "breslow") {
  #   if(any(Y[,1] != 0))
  #     stop("(tstart, tstop) not supported for parametric models")
  #
  #   dlist <- survival::survreg.distributions[[distribution$basehaz]]
  #
  #   # this stuff is more or less copied from survreg
  #
  #   logcorrect <- 0  #correction to the loglik due to transformations
  #   Ysave <- Y  # for use in the y component
  #   if (!is.null(dlist$trans)) {
  #     tranfun <- dlist$trans
  #     exactsurv <- Y[, ncol(Y)] == 1
  #     if (any(exactsurv)) {
  #       logcorrect <- sum(log(dlist$dtrans(Y[exactsurv, 1])))
  #     }
  #     if (!all(is.finite(Y)))
  #       stop("Invalid survival times for this distribution")
  #
  #   }
  #
  #   if (!is.null(dlist$scale)) {
  #     if (!missing(scale))
  #       warning(paste(dlist$name, "has a fixed scale, user specified value ignored"))
  #     scale <- dlist$scale
  #   } else scale <- 0
  #
  #   if (!is.null(dlist$dist))
  #     if (is.atomic(dlist$dist))
  #       dlist <- survreg.distributions[[dlist$dist]] else dlist <- dlist$dist
  #
  #   if (any(scale < 0))
  #     stop("Invalid scale value")
  #   # Now we convert this to a format that survreg.fit likes
  #
  #   fit <- survreg.fit(cbind(1, X), cbind(tranfun(Y[,2]), Y[,3]),
  #                      weights = NULL,
  #                      offset = NULL,
  #                      init = NULL,
  #                      controlvals = survreg.control() ,
  #                      dist = dlist, scale = scale, nstrat = 1, strata = 0, parms= NULL)
  #
  #
  # }
  #
  #
  #
  # srg1 <- survreg_simple(formula = Surv(time, status) ~ rx + sex, data = rats, dist = "weibull")
  #
  # srg1$coefficients
  # fit$coefficients
  #
  # fit <- survreg.fit(X, Y, weights = NULL, offset = NULL, init = NULL,
  #                    controlvals = survreg.control(), dist = dlist, scale = 0)
  #
  # distribution

  # dlist <- survival::survreg.distributions[[distribution$]]

  # ca_test <- NULL

  # ca_test_fit does not know strata ?!?
  if(isTRUE(control$ca_test)) {

    if(!is.null(strats)) ca_test <- NULL else
      ca_test <- ca_test_fit(mcox, X, atrisk, exp_g_x, cumhaz)

  }

  if(isTRUE(distribution$left_truncation)) {

    if(!is.null(strats))
      cumhaz_tstart <- do.call(c, cumhaz_tstart)[order(atrisk$positions_strata)]

    Cvec_lt <- rowsum(cumhaz_tstart, atrisk$order_id, reorder = FALSE)

  } else Cvec_lt <- 0 * Cvec


  # a fit just for the log-likelihood;
  if(!isTRUE(control$opt_fit)) {
    return(
      em_fit(logfrailtypar = log(distribution$theta),
           dist = distribution$dist, pvfm = distribution$pvfm,
           Y = Y, Xmat = X, atrisk = atrisk,
           basehaz_line = basehaz_line,
           mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
           Cvec = Cvec, lt = distribution$left_truncation,
           Cvec_lt = Cvec_lt, se = FALSE,
           em_control = control$em_control)
      )
  }

  # browser()

  if(distribution$dist == "stable") {
    # thing is: with stable small values of theta mean high dependence
    # I have yet to see a very high dependence there; furthermore,
    # the likelihood is pretty flat there.
    # therefore I would rather drag this towards "no dependence".
    distribution$theta <- distribution$theta + 1
  }

  outer_m <- do.call(nlm, args = c(list(f = em_fit,
                      p = log(distribution$theta),
                      hessian = TRUE,
                      dist = distribution$dist,
                      pvfm = distribution$pvfm,
                      Y = Y, Xmat = X,
                      atrisk = atrisk,
                      basehaz_line = basehaz_line,
                      mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                      Cvec = Cvec,
                      lt = distribution$left_truncation,
                      Cvec_lt = Cvec_lt, se = FALSE,
                      em_control = control$em_control), control$nlm_control))
  # control$lik_interval_stable

  if(outer_m$hessian < 1) {
    outer_m_opt <- do.call(optimize,
                           args = c(list(f = em_fit,
                                         dist = distribution$dist,
                                         pvfm = distribution$pvfm,
                                         Y = Y, Xmat = X,
                                         atrisk = atrisk,
                                         basehaz_line = basehaz_line,
                                         mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                                         Cvec = Cvec,
                                         lt = distribution$left_truncation,
                                         Cvec_lt = Cvec_lt, se = FALSE,
                                         em_control = control$em_control), lower = log(control$lik_interval)[1],
                                    upper = log(control$lik_interval)[2]))



    if(outer_m_opt$objective < outer_m$minimum) {
      hess <- numDeriv::hessian(func = em_fit, x = outer_m_opt$minimum,
                                dist = distribution$dist,
                                pvfm = distribution$pvfm,
                                Y = Y, Xmat = X,
                                atrisk = atrisk,
                                basehaz_line = basehaz_line,
                                mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                                Cvec = Cvec,
                                lt = distribution$left_truncation,
                                Cvec_lt = Cvec_lt, se = FALSE,
                                em_control = control$em_control)

      outer_m <- list(minimum = outer_m_opt$objective,
                      estimate = outer_m_opt$minimum,
                      hessian = hess)
    }
  }

  if(outer_m$hessian == 0) warning("Hessian virtually 0; frailty variance might be at the edge of the parameter space.")
  if(outer_m$hessian <= 0) hessian <- NA else hessian <- outer_m$hessian

  # likelihood-based confidence intervals
  theta_low <- theta_high <- NULL
  if(isTRUE(control$lik_ci)) {

    # With the stable distribution, a problem pops up for small values, i.e. very large association (tau large)
    # So there I use another interval for this
    if(distribution$dist == "stable") {
      control$lik_interval <- control$lik_interval_stable
    }

   skip_ci <- FALSE

   lower_llik <- try(em_fit(log(control$lik_interval[1]),
                       dist = distribution$dist,
                       pvfm = distribution$pvfm,
                       Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                       mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                       Cvec = Cvec, lt = distribution$left_truncation,
                       Cvec_lt = Cvec_lt, se = FALSE,
                       em_control = control$em_control), silent = TRUE)
  if(class(lower_llik) == "try-error") {
    warning("likelihood-based CI could not be calcuated; disable or change lik_interval[1] in emfrail_control")
    lower_llik <- NA
    log_theta_low <- log_theta_high <- NA
    skip_ci <- TRUE
  }

  upper_llik <- try(em_fit(log(control$lik_interval[2]),
         dist = distribution$dist,
         pvfm = distribution$pvfm,
         Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
         mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
         Cvec = Cvec, lt = distribution$left_truncation,
         Cvec_lt = Cvec_lt, se = FALSE,
         em_control = control$em_control), silent = TRUE)

  if(class(upper_llik) == "try-error") {
    warning("likelihood-based CI could not be calcuated; disable or lik_interval[2] in emfrail_control")
    upper_llik <- NA
    log_theta_low <- log_theta_high <- NA
    skip_ci <- TRUE
  }


  if(!isTRUE(skip_ci)) {
  if(lower_llik - outer_m$minimum < 1.92) {
    log_theta_low <- log(control$lik_interval[1])
    warning("Likelihood-based confidence interval lower limit reached, probably 0;
You can try a lower value for control$lik_interval[1].")
  } else
  log_theta_low <- uniroot(function(x, ...) outer_m$minimum - em_fit(x, ...) + 1.92,
                       interval = c(log(control$lik_interval[1]), outer_m$estimate),
                       f.lower = outer_m$minimum - lower_llik + 1.92, f.upper = 1.92,
                       tol = .Machine$double.eps^0.1,
                       dist = distribution$dist,
                       pvfm = distribution$pvfm,
                       Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                       mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                       Cvec = Cvec, lt = distribution$left_truncation,
                       Cvec_lt = Cvec_lt, se = FALSE,
                       em_control = control$em_control,
                       maxiter = 100)$root

  # this says that if I can't get a significant difference on the right side, then it's infinity
  if(upper_llik  - outer_m$minimum < 1.92) log_theta_high <- Inf else
    log_theta_high <- uniroot(function(x, ...) outer_m$minimum - em_fit(x, ...) + 1.92,
                          interval = c(outer_m$estimate, log(control$lik_interval[2])),
                          f.lower = 1.92, f.upper = outer_m$minimum - upper_llik + 1.92,
                          extendInt = c("downX"),
                          dist = distribution$dist,
                          pvfm = distribution$pvfm,
                          Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                          mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                          Cvec = Cvec, lt = distribution$left_truncation,
                          Cvec_lt = Cvec_lt, se  = FALSE,
                          em_control = control$em_control)$root
  }
  } else
    log_theta_low <- log_theta_high <- NA


  if(isTRUE(control$se))  {
    inner_m <- em_fit(logfrailtypar = outer_m$estimate,
                      dist = distribution$dist, pvfm = distribution$pvfm,
                      Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                      mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                      Cvec = Cvec, lt = distribution$left_truncation,
                      Cvec_lt = Cvec_lt, se = TRUE,
                      em_control = control$em_control,
                      return_loglik = FALSE)
  } else
    inner_m <- em_fit(logfrailtypar = outer_m$estimate,
                      dist = distribution$dist, pvfm = distribution$pvfm,
                      Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                      mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                      Cvec = Cvec, lt = distribution$left_truncation,
                      Cvec_lt = Cvec_lt, se = FALSE,
                      em_control = control$em_control,
                      return_loglik = FALSE)


  # Cox.ZPH stuff
  if(isTRUE(control$zph)) {

    # Here just fit a Cox model with the log-frailty as offset
    # note: the cox.zph function in the new version of the survival package has a terms = TRUE argument.
    # This part here is added so that emfrail continues to work with both versions.
    # do.call is needed because otherwise R CMD check returns a warning when ran on a system with the
    # old version of cox.zph
    if("terms" %in% names(formals(cox.zph))) {
    if(!is.null(strats)) {
      zph <- do.call(cox.zph, args = list(fit = coxph(Y ~ X + strata(strats) + offset(inner_m$logz), ties = "breslow"),
                                   transform = control$zph_transform,
                                   terms = FALSE))
      # zph <- cox.zph(coxph(Y ~ X + strata(strats) + offset(inner_m$logz), ties = "breslow"),
      #                transform =  control$zph_transform, terms = FALSE)
    } else {
      zph <- do.call(cox.zph, args = list(fit = coxph(Y ~ X + offset(inner_m$logz), ties = "breslow"),
                                   transform = control$zph_transform,
                                   terms = FALSE))
      # zph <- cox.zph(coxph(Y ~ X + offset(inner_m$logz), ties = "breslow"),
      #                transform =  control$zph_transform, terms = FALSE)
    }
    } else {
      if(!is.null(strats)) {
        zph <- do.call(cox.zph, args = list(fit = coxph(Y ~ X + strata(strats) + offset(inner_m$logz), ties = "breslow"),
                                            transform = control$zph_transform))

      } else
        zph <- do.call(cox.zph, args = list(fit = coxph(Y ~ X + offset(inner_m$logz), ties = "breslow"),
                                            transform = control$zph_transform))
      }



      # fix the names for nice output
      # if there is only one covariate there is not "GLOBAL" test
      attr(zph$table, "dimnames")[[1]][1:length(inner_m$coef)] <- names(inner_m$coef)
      attr(zph$y, "dimnames")[[2]] <- names(mcox$coef)
  } else zph <- NULL


  # adjusted standard error

  if(isTRUE(control$se) & isTRUE(attr(inner_m$Vcov, "class") == "try-error")) {
    inner_m$Vcov <- matrix(NA, length(inner_m$coef) + length(inner_m$haz))
    warning("Information matrix is singular")
  }

  # adjusted SE: only go on if requested and if Vcov was calculated
  if(isTRUE(control$se) &
     isTRUE(control$se_adj) &
     !all(is.na(inner_m$Vcov))) {

    # absolute value should be redundant. but sometimes the "hessian" might be 0.
    # in that case it might appear negative; this happened only on Linux...
    # h <- as.numeric(sqrt(abs(1/(attr(outer_m, "details")[[3]])))/2)
    h<- as.numeric(sqrt(abs(1/hessian))/2)
    lfp_minus <- max(outer_m$estimate - h , outer_m$estimate - 5, na.rm = TRUE)
    lfp_plus <- min(outer_m$estimate + h , outer_m$estimate + 5, na.rm = TRUE)


    final_fit_minus <- em_fit(logfrailtypar = lfp_minus,
                              dist = distribution$dist, pvfm = distribution$pvfm,
                              Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                              mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                              Cvec = Cvec, lt = distribution$left_truncation,
                              Cvec_lt = Cvec_lt, se = FALSE,
                              em_control = control$em_control,
                              return_loglik = FALSE)

    final_fit_plus <- em_fit(logfrailtypar = lfp_plus,
                             dist = distribution$dist, pvfm = distribution$pvfm,
                             Y = Y, Xmat = X, atrisk = atrisk, basehaz_line = basehaz_line,
                             mcox = list(coefficients = g, loglik = mcox$loglik),  # a "fake" cox model
                             Cvec = Cvec, lt = distribution$left_truncation,
                             Cvec_lt = Cvec_lt, se = FALSE,
                             em_control = control$em_control, return_loglik = FALSE)


    # instructional: this should be more or less equal to the
    # -(final_fit_plus$loglik + final_fit_minus$loglik - 2 * inner_m$loglik)/h^2

    # se_logtheta^2 / (2 * (final_fit$loglik -final_fit_plus$loglik ))

    if(!is.null(atrisk$strats))
      deta_dtheta <- (c(final_fit_plus$coef, do.call(c, final_fit_plus$haz)) -
                      c(final_fit_minus$coef, do.call(c, final_fit_minus$haz))) / (2*h) else
                        deta_dtheta <- (c(final_fit_plus$coef, final_fit_plus$haz) -
                                          c(final_fit_minus$coef, final_fit_minus$haz)) / (2*h)



    #adj_se <- sqrt(diag(deta_dtheta %*% (1/(attr(opt_object, "details")[[3]])) %*% t(deta_dtheta)))

    # vcov_adj = inner_m$Vcov + deta_dtheta %*% (1/(attr(outer_m, "details")[[3]])) %*% t(deta_dtheta)
   vcov_adj = inner_m$Vcov + deta_dtheta %*% (1/outer_m$hessian) %*% t(deta_dtheta)

  } else
    if(all(is.na(inner_m$Vcov)))
      vcov_adj <- inner_m$Vcov else
        vcov_adj = matrix(NA, nrow(inner_m$Vcov), nrow(inner_m$Vcov))


  if(length(pos_terminal_X1) > 0 & distribution$dist == "gamma") {
    Y[,3] <- X1[,pos_terminal_X1]

    Mres <- survival::agreg.fit(x = X, y = Y, strata = atrisk$strats, offset = NULL, init = NULL,
                        control = survival::coxph.control(),
                        weights = NULL, method = "breslow", rownames = NULL)$residuals
    Mres_id <- rowsum(Mres, atrisk$order_id, reorder = FALSE)

    theta <- exp(outer_m$estimate)

    fr <- with(inner_m, estep[,1] / estep[,2])

    numerator <- theta + inner_m$nev_id
    denominator <- numerator / fr

    lfr <- digamma(numerator) - log(denominator)

    lfr2 <- (digamma(numerator))^2 + trigamma(numerator) - (log(denominator))^2 - 2 * log(denominator) * lfr

    r <- cor(lfr, Mres_id)
    tr <- r* sqrt((length(fr) - 2) / (1 - r^2))
    p.cor <- pchisq(tr^2, df = 1, lower.tail = F)

    cens_test = c(tstat = tr, pval = p.cor)
  } else cens_test = NULL


  # Prepare some things for the output object

  if(!isTRUE(model)) model_frame <- NULL else
    model_frame <- mf
  if(!isTRUE(model.matrix)) X <- NULL


  frail <- inner_m$frail
  names(frail) <- unique(id)

  residuals <- list(group = as.numeric(inner_m$Cvec),
                    individual = as.numeric(inner_m$cumhaz_line * inner_m$fitted))

  names(residuals$group) <- unique(id)


  haz <- inner_m$haz
  tev <- inner_m$tev

  if(!is.null(atrisk$strats)) {

    names(haz) <- label_strats
    names(tev) <- label_strats

  }


  res <- list(coefficients = inner_m$coef, #
               hazard = haz,
               var = inner_m$Vcov,
               var_adj = vcov_adj,
               logtheta = outer_m$estimate,
               var_logtheta = 1/hessian,
               ci_logtheta = c(log_theta_low, log_theta_high),
               frail = frail,
               residuals = residuals,
               tev = tev,
               nevents_id = inner_m$nev_id,
               loglik = c(mcox$loglik[length(mcox$loglik)], -outer_m$minimum),
               ca_test = ca_test,
               cens_test = cens_test,
               zph = zph,
               formula = formula,
               distribution = distribution,
               control = control,
               nobs = nrow(mf),
               fitted = as.numeric(inner_m$fitted),
               mf = model_frame,
               mm = X)

  # these are things that make the predict work and other methods
  terms_2 <- delete.response(attr(mf, "terms"))
  pos_cluster_2 <- grep("cluster", attr(terms_2, "term.labels"))
  if(!is.null(mcox$coefficients)) {
    terms <- drop.terms(terms_2, pos_cluster_2)
    myxlev <- .getXlevels(terms, mf)
    attr(res, "metadata") <- list(terms, myxlev)
  }
  attr(res, "call") <-  Call
  attr(res, "class") <- "emfrail"


  res


}
teddybalan/frailtoys documentation built on Sept. 21, 2019, 6:22 p.m.