R/predict.R

Defines functions get_ipred_data integrate predict.regsurv

Documented in predict.regsurv

#' Predictions for a regsurv model
#'
#' @param object regsurv object
#' @param prep survprep object used to fit the regsurv model
#' @param lambda.index lambda index for the required predictions (e.g. optimal value based on cross-validation)
#' @param newdata if provided, this should be a matrix with untransformed time-to-event in the first column and the model.matrix for
#'   for main main predictor effects as would be provided to survprep()
#' @param type type of prediction ("cumhazard" or "surv")
#' @param ... as for predict()
#'
#' @return prediction of the requested type
#' @export
#' @method predict regsurv
predict.regsurv <- function(object, prep, lambda.index, newdata=NULL, type=c("cumhazard", "surv"), ...){

  if(class(object) != "regsurv"){
    stop("predict.regsurv only takes objects of class regsurv as a first argument")
  }

  if(class(prep) != "survprep"){
    stop("regsurv only takes objects of class survprep as a second argument")
  }

  if(prep$survprep.id != object$survprep.id){
    stop("regsurv object and survprep object do not match")
  }

  mod <- object
  spline.type <- prep$spline.type
  time.scale <- prep$time.scale
  knots <- prep$knots
  iknots <- prep$iknots
  time.type <- prep$time.type
  itime.type <- prep$itime.type
  tv <- prep$tv
  scales <- prep$scales
  shifts <- prep$shifts
  if(prep$model.scale == "loghazard"){
    qpoints <- prep$qpoints
    rule <- legendre.quadrature.rules(qpoints)[[qpoints]]
  }
  betahat <- as.matrix(mod$betahat.scaled)[ ,lambda.index]

  if(is.null(newdata)){
    X <- prep$mm$d[ ,prep$which.param[[2]], drop=FALSE]
    tte <- prep$tte
  } else {
    X <- newdata[ ,-1, drop=FALSE]
    if(prep$time.scale == "logtime"){
      tte <- as.vector(log(newdata[ ,1]))
    } else {
      tte <- as.vector(newdata[ ,1])
    }
  }

  if(prep$model.scale == "loghazard"){
    if(time.scale == "time"){
      glsbi <- lapply(1:nrow(X), function(i)
      {
        lower = 0 # time zero; might alter in case of left-truncation
        lambda = (tte[i] - lower)/2 # the upper limit of integration is x
        mu = (lower + tte[i])/2 # equivalent to lambda for current application (with lower = 0)
        y = lambda * rule$x + mu
        Xi <- as.matrix(X[rep(i, length(y)), ])
        colnames(Xi) <- colnames(X)

        z <- sbi(t=y, X=Xi, time.type=time.type, itime.type=itime.type, tv=tv,
                 knots=knots, iknots=iknots, spline.type=spline.type)$d
        z[ ,-1] <- t((t(z[ ,-1]) - shifts) / scales)
        list(w=rule$w, lambda=lambda, z=z)
      })
    } else { # so for logtime
      glsbi <- lapply(1:nrow(X), function(i)
      {
        lower = 0 # time zero; might alter in case of left-truncation
        lambda = (exp(tte[i]) - lower)/2 # the upper limit of integration is x
        mu = (lower + exp(tte[i]))/2 # equivalent to lambda for current application (with lower = 0)
        y = lambda * rule$x + mu
        Xi <- as.matrix(X[rep(i, length(y)), ])
        colnames(Xi) <- colnames(X)

        z <- sbi(t=log(y), X=Xi, time.type=time.type, itime.type=itime.type, tv=tv,
                 knots=knots, iknots=iknots, spline.type=spline.type)$d
        z[ ,-1] <- t((t(z[ ,-1]) - shifts) / scales)
        list(w=rule$w, lambda=lambda, z=z)
      })
    }
    if(type=="cumhazard"){return(sapply(glsbi, integrate, param=betahat))}
    if(type=="surv"){return(exp(-sapply(glsbi, integrate, param=betahat)))}
  }

  if(prep$model.scale == "logHazard"){
    mm <- sbi(t=tte, X=X, time.type=time.type, itime.type=itime.type, tv=tv,
               knots=prep$knots, iknots=prep$iknots, spline.type=spline.type)

    mm.scaled <- mm
    mm.scaled$d[ ,-1] <- t((t(mm$d[ ,-1]) - shifts) / scales)

    if(type=="cumhazard"){return(exp(mm.scaled$d %*% betahat + tte))}
    if(type=="surv"){return(exp(-exp(mm.scaled$d %*% betahat + tte)))}
  }
}


integrate <- function(x, param){
  x$lambda * sum(x$w * exp(x$z %*% param))
}


get_ipred_data <- function(times, Xnew){
  n <- nrow(Xnew)
  Xnew <- Xnew[rep(1:nrow(Xnew), each=length(times)), ]
  rep(times, times=nrow(Xnew))
  cbind("id"=rep(1:n, each=length(times)),
        "time"=rep(times, times=n),
        as.matrix(Xnew))
}
jeroenhoogland/regsurv documentation built on March 20, 2023, 3:37 a.m.