R/predict.R

Defines functions compute_Nsurv.SurvPredict compute_Ninit compute_Nsurv predict.SurvFit predict

Documented in compute_Ninit compute_Nsurv compute_Nsurv.SurvPredict predict predict.SurvFit

#' @name PredictSurvFit
#' 
#' @title Prediction base on \code{SurvFit} objects
#' 
#' @description
#' This is the generic \code{predict} S3 method for the \code{SurvFit} class.
#' It provides predicted survival rate for "SD" or "IT" models under constant or time-variable exposure.
#' prediction on constant exposure profile
#' 
#' Note: On constant exposure profiles, the results is explicit (exact), so you 
#' don't have to profile
#'
#' @param fit an object of class \code{SurvFit}
#' @param display.exposure concentration points on which prediction is done
#' @param display.parameters parameters of the specific model.
#' @param hb_value a numeric used as `hb_value` (can be set to 0 to remove
#' background mortality and take only effect parameters).
#' @param interpolate_length if \code{display.time} is \code{NULL}, the argument
#' \code{interpolate_length} can be used to provide a sequence from 0 to maximum of
#' the time of exposure in original dataset (used for fitting).
#' @param interpolate_method The interpolation method for concentration.
#'  See package \code{deSolve} for details.
#' Default is \code{linear}.
#' @param \dots Further arguments to be passed to generic methods
#' 
#' @return a \code{list} of \code{data.frame} with the quantiles of outputs in
#' \code{df_quantiles} or all the MCMC chains \code{df_spaghetti}
#' 
#' @export
predict <- function(fit, ...){
    UseMethod("predict")
}


#' @name PredictSurvFit
#' @export
predict.SurvFit <- function(fit,
                            display.exposure = NULL,
                            hb_value = NULL,
                            interpolate_length = NULL,
                            interpolate_method = "linear", ...){
    # EXPOSURE PROFILES
    if (is.null(display.exposure)) {
        df <- data.frame(
            time = fit$standata$time_X,
            conc = fit$standata$conc,
            replicate = fit$standata$replicate_X)
    } else{
        df <- display.exposure
    }
    # CHECK DISPLAY.EXPOSURE
    if (is_exposure_constant(df)) {
        predict <- predict_SurvFitCstExp(
            fit = fit,
            display.exposure = df,
            hb_value = hb_value,
            interpolate_length = interpolate_length
        )
    } else{
        predict <- predict_SurvFitVarExp(
            fit = fit,
            display.exposure = df,
            hb_value = hb_value,
            interpolate_length = interpolate_length,
            interpolate_method = interpolate_method
        )
    }
    return(predict)
}

#' @name ComputePredictSurvFit
#' 
#' @title Compute post value on object
#' 
#' @description
#' `compute_Nsurv`: compute the number of survival `Nsurv`
#' 
#' @param x an object of class \code{SurvPredict}
#' @param Ninit initial number of individual. Default is NULL.
#' @param \dots Further arguments to be passed to generic methods
#' 
#' @return No return value, called for side effects. Return the same object 
#' after computing Number of survivor (`Nsurv` column) and number of initial 
#' individuals (`Ninit` column). 
#' 
#' @export
compute_Nsurv <- function(x, ...){
    UseMethod("compute_Nsurv")
}

#' @name ComputePredictSurvFit
#' @export
compute_Ninit <- function(x, ...){
    x[["Nsurv"]][x[["time"]] == 0]
}

#' @name ComputePredictSurvFit
#' @export
compute_Nsurv.SurvPredict <- function(x, Ninit = NULL, ...){
    predict <- x
    replicate = predict[["df_quantile"]][["replicate"]]
    df_mcmc <- predict[["df_mcmc"]]
    ls_psurv_mcmc <- lapply(unique(replicate), function(r){
        df_mcmc[replicate == r, ]
    })
    if (length(Ninit) != length(ls_psurv_mcmc)) {
        ls_Ninit = lapply(1:length(ls_psurv_mcmc), function(i) Ninit)
    } else{
        ls_Ninit = lapply(1:length(ls_psurv_mcmc), function(i) Ninit[i])
    }
    names(ls_psurv_mcmc) <- unique(replicate)
    names(ls_Ninit) <- unique(replicate)

    ls_Nsurv = lapply(seq_along(ls_psurv_mcmc), function(i){
        psurv_mcmc = ls_psurv_mcmc[[i]]
        Ninit = ls_Ninit[[i]]
        Nsurv_mcmc = matrix(NA, nrow = nrow(psurv_mcmc), ncol = ncol(psurv_mcmc))
        Nsurv_mcmc[1, ] =  stats::rbinom(
            ncol(psurv_mcmc),
            size = Ninit,
            prob = as.numeric(psurv_mcmc[1,]))
        for (i in 2:nrow(Nsurv_mcmc)) {
            Nsurv_mcmc[i, ] =  stats::rbinom(
                ncol(psurv_mcmc),
                size = Nsurv_mcmc[i - 1, ],
                prob = as.numeric(psurv_mcmc[i,])
            )
        }
        return(Nsurv_mcmc)
    })
    df_Nsurv = do.call("rbind", ls_Nsurv)
    Nsurv_quant = apply(df_Nsurv, 1, quantile, probs = c(0.025, 0.5, 0.975), na.rm = TRUE)
    df_quant = data.frame(
        conc = predict[["df_quantile"]][["conc"]],
        time = predict[["df_quantile"]][["time"]],
        replicate = predict[["df_quantile"]][["replicate"]],
        q50 = Nsurv_quant[2,],
        qinf95 = Nsurv_quant[1,],
        qsup95 = Nsurv_quant[3,]
    )
    
    return_object <- return(list(
        df_quantile = df_quant,
        df_mcmc = as.data.frame(df_Nsurv))
    )
    class(return_object) <- append("SurvPredict", class(return_object))
}

Try the morseTKTD package in your browser

Any scripts or data that you put into this service are public.

morseTKTD documentation built on June 8, 2025, 10:28 a.m.