R/survspat.R

Defines functions survspat

Documented in survspat

##' survspat function
##'
##' A function to run a Bayesian analysis on censored spatial survial data assuming a proportional hazards model using an adaptive Metropolis-adjusted
##' Langevin algorithm.
##'
##' @param formula the model formula in a format compatible with the function flexsurvreg from the flexsurv package
##' @param data a SpatialPointsDataFrame object containing the survival data as one of the columns OR for polygonal data a data.frame, in which case, the argument shape must also be supplied
##' @param dist choice of distribution function for baseline hazard. Current options are: exponentialHaz, weibullHaz, gompertzHaz, makehamHaz, tpowHaz
##' @param cov.model an object of class covmodel, see ?covmodel ?ExponentialCovFct or ?SpikedExponentialCovFct
##' @param mcmc.control mcmc control parameters, see ?mcmcpars
##' @param priors an object of class Priors, see ?mcmcPriors
##' @param shape when data is a data.frame, this can be a SpatialPolygonsDataFrame, or a SpatialPointsDataFrame, used to model spatial variation at the small region level. The regions are the polygons, or they represent the (possibly weighted) centroids of the polygons.
##' @param ids named list entry shpid character string giving name of variable in shape to be matched to variable dataid in data. dataid is the second entry of the named list.
##' @param control additional control parameters, see ?inference.control
##' @param boundingbox optional bounding box over which to construct computational grid, supplied as an object on which the function 'bbox' returns the bounding box
##' @return an object inheriting class 'mcmcspatsurv' for which there exist methods for printing, summarising and making inference from.
##' @seealso \link{tpowHaz}, \link{exponentialHaz}, \link{gompertzHaz}, \link{makehamHaz}, \link{weibullHaz},
##' \link{covmodel}, \link{ExponentialCovFct}, \code{SpikedExponentialCovFct},
##' \link{mcmcpars}, \link{mcmcPriors}, \link{inference.control}
##' @references
##' \enumerate{
##'     \item Benjamin M. Taylor and Barry S. Rowlingson (2017). spatsurv: An R Package for Bayesian Inference with Spatial Survival Models. Journal of Statistical Software, 77(4), 1-32, doi:10.18637/jss.v077.i04.
##' }
##' @export


survspat <- function(   formula,
                        data,
                        dist,
                        cov.model,
                        mcmc.control,
                        priors,
                        shape=NULL,
                        ids=list(shpid=NULL,dataid=NULL),
                        control=inference.control(gridded=FALSE),
                        boundingbox=NULL){

    if(!is.null(control$nis) & !control$gridded){
        stop("If control$nis is not null, then you must use gridded inference, please set control$gridded to TRUE and select an appropriate cell width.")
    }

    imputeMode <- FALSE

    if(!is.null(control$nis)){
        if(!is.list(control$nis)){
            stop("control$nis must be a list of matrices.")
        }
        if(is.null(control$olinfo)){
            stop("Require overlay information for control$nis, generated by function prepDataSpatial")
        }

        imputeMode <- TRUE
    }

    if(!is.null(boundingbox)){
        try(boundingbox <- spTransform(boundingbox,CRS(proj4string(data))))
        if(inherits(boundingbox,"try-error")){
            boundingbox <- spTransform(boundingbox,CRS(proj4string(shape)))
        }
    }

    formula <- as.formula(formula)

    # initial checks
    if(!(inherits(data,"SpatialPointsDataFrame")|inherits(data,"data.frame"))){
        stop("'data' must be of class 'SpatialPointsDataFrame' or 'data.frame'.")
    }
    if(inherits(data,"data.frame")&is.null(shape)){
        stop("If inherits(data,'data.frame')==TRUE, shape cannot be NULL")
    }

    latentmode <- "points"
    if(inherits(data,"data.frame")){
        latentmode <- "polygons"
        if(!is.null(control$imputation)){
            latentmode <- "points"
        }
    }
    if(inherits(cov.model,"SPDEmodel")){
        latentmode <- "SPDE"
    }

    if(imputeMode){
        latentmode <- "points"

        coordinates(data) <- dat2coords(control$nis,control$olinfo,data[[ids$dataid]])
        proj4string(data) <- CRS(proj4string(shape))
        #browser()
    }

    if(inherits(data,"SpatialPointsDataFrame")&!is.null(shape)){
        if(latentmode!="SPDE"){
            warning("Non NULL shape with spatial points survival data, are you sure you want to proceed ???",immediate.=TRUE)
            Sys.sleep(2)
        }
    }

    if(latentmode=="polygons" & control$gridded){
        stop("Cannot have control$gridded==TRUE and !is.null(shape)==TRUE at the same time")
    }

    if(latentmode=="SPDE" & control$gridded){
        stop("Cannot use SPDE mode and have !is.null(shape)==TRUE at the same time. Please provide a polygon object on which to produce spatial predictions.")
    }

    responsename <- as.character(formula[[2]])
    if(latentmode=="points"){
        if(!is.null(control$imputation)){
            survivaldata <- data@data[[responsename]]
        }
        else{
            survivaldata <- data[[responsename]]
        }
    }
    else{
        survivaldata <- data[[responsename]]
    }
    checkSurvivalData(survivaldata)

    if(latentmode=="SPDE"){
        if(!is.na(proj4string(data)) & !is.na(proj4string(shape))){
            if(proj4string(data)!=proj4string(shape)){
                stop("'shape' and 'data' must have the same proj4string.")
            }
        }

        if(is.null(control$cellwidth)){
            stop("Must specify 'cellwidth' in inference.control in SPDE mode.")
        }
    }


    # if(!inherits(data,"SpatialPointsDataFrame")){
    #     stop("data must be an object of class SpatialPointsDataFrame")
    # }

    # okay, start the MCMC!

    # start timing, maybe
    if(!control$timeonlyMCMC){
        start <- Sys.time()
    }

    if(latentmode=="points" | latentmode=="SPDE"){
        coords <- coordinates(data)
    }
    else{
        coords <- coordinates(shape)
    }

    control$dist <- dist

    funtxt <- ""
    if(control$gridded){
        funtxt <- "_gridded"
    }
    if(latentmode=="polygons"){
        funtxt <- "_polygonal"
    }
    if(latentmode=="SPDE"){
        funtxt <- "_SPDE"
    }

    gridobj <- NULL

    if(control$gridded){
        gridobj <- FFTgrid(spatialdata=data,cellwidth=control$cellwidth,ext=control$ext,boundingbox=boundingbox)
    	del1 <- gridobj$del1
    	del2 <- gridobj$del2
    	Mext <- gridobj$Mext
    	Next <- gridobj$Next
    	mcens <- gridobj$mcens
    	ncens <- gridobj$ncens
    	## COMPUTE GRID DISTANCES ##
    	x <- gridobj$mcens
        y <- gridobj$ncens
        xidx <- rep(1:Mext,Next)
        yidx <- rep(1:Next,each=Mext)
        dxidx <- pmin(abs(xidx-xidx[1]),Mext-abs(xidx-xidx[1]))
        dyidx <- pmin(abs(yidx-yidx[1]),Next-abs(yidx-yidx[1]))
        u <- sqrt(((x[2]-x[1])*dxidx)^2+((y[2]-y[1])*dyidx)^2)

        spix <- grid2spix(xgrid=mcens,ygrid=ncens,proj4string=CRS(proj4string(data)))

        control$fftgrid <- gridobj
        control$idx <- over(data,geometry(spix))
        control$Mext <- Mext
        control$Next <- Next
        control$uqidx <- unique(control$idx)

        cat("Output grid size: ",Mext/control$ext," x ",Next/control$ext,"\n")
    }
    else{
        if(latentmode=="polygons"){
            u <- as.vector(as.matrix(dist(coords)))
            control$idx <- match(data[,ids$dataid],shape@data[,ids$shpid])
            control$n <- nrow(shape)
            control$uqidx <- unique(control$idx)
        }
        else if(latentmode=="SPDE"){

            matobj <- setupPrecMatStruct(shape=shape,cellwidth=control$cellwidth,no=cov.model$order)
            control$precmat <- matobj$f
            control$grid <- matobj$grid

            # require(INLA)
            # now reorder the indices for faster computation
            # tempprec <- control$precmat(SPDEprec(0.1,cov.model$order))
            # reord <- inla.qreordering(tempprec)
            # env <- environment(control$precmat)
            # image.plot(as.matrix(tempprec))
            # image.plot(as.matrix(tempprec)[reord$reordering,reord$reordering])
            # rord <- function(oldidx,newidx){
            #     ans <- rep(NA,length(oldidx))
            #     for(i in 1:length(newidx)){
            #         ans[oldidx==i] <- newidx[i]
            #     }
            #     return(ans)
            # }
            # env$index[,1] <- rord(env$index[,1],reord$reordering)
            # env$index[,2] <- rord(env$index[,2],reord$reordering)
            # shuff <- function(x){
            #     if(x[2]>x[1]){
            #         return(c(x[2:1],x[3]))
            #     }
            #     else{
            #         return(x)
            #     }
            # }
            # env$index <- t(apply(env$index,1,shuff))
            # env$grid <- env$grid[reord$reordering]
            # tempprec1 <- control$precmat(SPDEprec(0.1,cov.model$order))
            #browser()

            control$idx <- over(data,geometry(control$grid))
            control$n <- length(control$grid)
            control$uqidx <- unique(control$idx)
        }
        else{
            u <- as.vector(as.matrix(dist(coords)))
        }
    }

    DATA <- data

    if(latentmode=="points" | latentmode=="SPDE"){
        data <- data@data
    }

    ##########
    # This chunk of code borrowed from flexsurvreg
    ##########

    call <- match.call()
    indx <- match(c("formula", "data"), names(call), nomatch = 0)
    if (indx[1] == 0){
        stop("A \"formula\" argument is required")
    }
    temp <- call[c(1, indx)]
    temp[[1]] <- as.name("model.frame")
    m <- eval(temp, parent.frame())

    Terms <- attr(m, "terms")
    X <- model.matrix(Terms, m)

    ##########
    # End of borrowed code
    ##########

    X <- X[, -1, drop = FALSE]

    mcmcloop <- mcmcLoop(N=mcmc.control$nits,burnin=mcmc.control$burn,thin=mcmc.control$thin,progressor=mcmcProgressTextBar)

    info <- distinfo(dist)()

    control$omegatrans <- info$trans
    control$omegaitrans <- info$itrans
    control$omegajacobian <- info$jacobian # used in computing the derivative of the log posterior with respect to the transformed omega (since it is easier to compute with respect to omega)
    control$omegahessian <- info$hessian


    #######

    control$censoringtype <- attr(survivaldata,"type")

    if(control$censoringtype=="left" | control$censoringtype=="right"){
        control$censored <- survivaldata[,"status"]==0
        control$notcensored <- !control$censored

        control$Ctest <- any(control$censored)
        control$Utest <- any(control$notcensored)

        control$idxi <- list()
        control$idxicensored <- list()
        control$idxinotcensored <- list()
        lapply(control$uqidx,function(i){control$idxi[[i]] <<- which(control$idx==i)})
        lapply(control$uqidx,function(i){control$idxicensored[[i]] <<- control$censored & control$idx==i})
        lapply(control$uqidx,function(i){control$idxinotcensored[[i]] <<- control$notcensored & control$idx==i})

        if(control$Ctest){
            control$idxicensored <- lapply(control$idxicensored,function(x){try(which(x),silent=TRUE)})
        }
        if(control$Utest){
            control$sumidxinotcensored <- lapply(control$idxinotcensored,function(x){try(sum(x),silent=TRUE)})
            control$idxinotcensored <- lapply(control$idxinotcensored,function(x){try(which(x),silent=TRUE)})
        }
    }
    else{
        control$rightcensored <- survivaldata[,"status"] == 0
        control$notcensored <- survivaldata[,"status"] == 1
        control$leftcensored <- survivaldata[,"status"] == 2
        control$intervalcensored <- survivaldata[,"status"] == 3

        control$Rtest <- any(control$rightcensored)
        control$Utest <- any(control$notcensored)
        control$Ltest <- any(control$leftcensored)
        control$Itest <- any(control$intervalcensored)

        control$idxirightcensored <- list()
        control$idxinotcensored <- list()
        control$idxileftcensored <- list()
        control$idxiintervalcensored <- list()

        lapply(control$uqidx,function(i){control$idxirightcensored[[i]] <<- control$rightcensored & control$idx==i})
        lapply(control$uqidx,function(i){control$idxinotcensored[[i]] <<- control$notcensored & control$idx==i})
        lapply(control$uqidx,function(i){control$idxileftcensored[[i]] <<- control$leftcensored & control$idx==i})
        lapply(control$uqidx,function(i){control$idxiintervalcensored[[i]] <<- control$intervalcensored & control$idx==i})

        if(control$Rtest){
            control$idxirightcensored <- lapply(control$idxirightcensored,function(x){try(which(x),silent=TRUE)})
        }
        if(control$Utest){
            control$idxinotcensored <- lapply(control$idxinotcensored,function(x){try(which(x),silent=TRUE)})
        }
        if(control$Ltest){
            control$idxileftcensored <- lapply(control$idxileftcensored,function(x){try(which(x),silent=TRUE)})
        }
        if(control$Itest){
            control$idxiintervalcensored <- lapply(control$idxiintervalcensored,function(x){try(which(x),silent=TRUE)})
        }
    }

    #######

    if(control$nugget){
        control$U <- 0
    }


    cat("\n","Getting initial estimates of model parameters using maximum likelihood on non-spatial version of the model","\n")
    mlmod <- maxlikparamPHsurv(surv=survivaldata,X=X,control=control)
    estim <- mlmod$par
    print(mlmod)
    cat("Done.\n")
    #browser()

    cat("Calibrating MCMC algorithm and finding initial values ...\n")

    betahat <- estim[1:ncol(X)]
    omegahat <- estim[(ncol(X)+1):length(estim)]

    if(latentmode!="SPDE"){
        control$sigmaidx <- match("sigma",cov.model$parnames)
        if(is.na(control$sigmaidx)){
            stop("At least one of the parameters must be the variance of Y, it should be named sigma")
        }
    }

    Yhat <- estimateY(  X=X,
                        betahat=betahat,
                        omegahat=omegahat,
                        surv=survivaldata,
                        control=control)

    calibrate <- get(paste("proposalVariance",funtxt,sep=""))

    if(control$nugget){ # dummy for calibration purposes
        control$U <- rep(0,nrow(X))
        control$Ugamma <- rep(0,nrow(X))
        control$Usigma <- 0
        control$logUsigma <- 0
    }

    other <- calibrate( X=X,
                        surv=survivaldata,
                        betahat=betahat,
                        omegahat=omegahat,
                        Yhat=Yhat,
                        priors=priors,
                        cov.model=cov.model,
                        u=u,
                        control=control)



    #gammahat <- other$gammahat
    etahat <- other$etahat
    SIGMA <- other$sigma

    beta <- betahat
    omega <- omegahat
    eta <- etahat

    gamma <- rep(0,nrow(X))
    if(control$gridded){
        gamma <- matrix(0,control$Mext,control$Next)
    }
    if(latentmode=="polygons" | latentmode=="SPDE"){
        gamma <- rep(0,control$n)
    }

    lenbeta <- length(beta)
    lenomega <- length(omega)
    leneta <- length(eta)
    lengamma <- length(gamma)



    npars <- lenbeta + lenomega + leneta + lengamma

    #SIGMA <- diag(1e-4,npars)

    SIGMA[1:(lenbeta+lenomega),1:(lenbeta+lenomega)] <- (1.65^2/((lenbeta+lenomega)^(1/3)))*SIGMA[1:(lenbeta+lenomega),1:(lenbeta+lenomega)]
    SIGMA[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta),(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)] <- 0.4*(2.38^2/leneta)* SIGMA[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta),(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)]
    SIGMA[(lenbeta+lenomega+leneta+1):(lenbeta+lenomega+leneta+lengamma),(lenbeta+lenomega+leneta+1):(lenbeta+lenomega+leneta+lengamma)] <- (1.65^2/(lengamma^(1/3)))*SIGMA[(lenbeta+lenomega+leneta+1):(lenbeta+lenomega+leneta+lengamma),(lenbeta+lenomega+leneta+1):(lenbeta+lenomega+leneta+lengamma)]

    if(control$gridded){
        matidx <- matrix(0,control$Mext,control$Next)
        matidx[1:(control$Mext/control$ext),1:(control$Next/control$ext)] <- 1
        matidx <- as.logical(matidx) # used to select which Y's to save
    }

    diagidx <- 1:npars
    diagidx <- matrix(diagidx,nrow=npars,ncol=2)

    SIGMApars <- as.matrix(SIGMA[1:(lenbeta+lenomega+leneta),1:(lenbeta+lenomega+leneta)])
    if(any(eigen(SIGMApars)$values<0)){
        SIGMApars <- as.matrix(nearPD(SIGMApars)$mat)
        warning("Fixing parameter proposal matrix ... maybe check model priors?",immediate.=TRUE)
    }

    SIGMAparsINV <- solve(SIGMApars)
    cholSIGMApars <- t(chol(SIGMApars))

    SIGMAgamma <- SIGMA[diagidx][(lenbeta+lenomega+leneta+1):npars]
    SIGMAgammaINV <- 1/SIGMAgamma
    cholSIGMAgamma <- sqrt(SIGMAgamma)

    if(control$nugget){

        Ugamma <- rep(0,nrow(X))
        control$Ugamma <- Ugamma

        #SIGMAgamma <- control$split * SIGMA[diagidx][(lenbeta+lenomega+leneta+1):npars]
        #SIGMAgammaINV <- 1/SIGMAgamma
        #cholSIGMAgamma <- sqrt(SIGMAgamma)

        #SIGMA_Ugamma <- mean((1-control$split) * SIGMA[diagidx][(lenbeta+lenomega+leneta+1):npars]) # proposal variance for Ugamma
        SIGMA_Ugamma <- mean(SIGMA[diagidx][(lenbeta+lenomega+leneta+1):npars])
        SIGMA_UgammaINV <- 1/SIGMA_Ugamma
        cholSIGMA_Ugamma <- sqrt(SIGMA_Ugamma)

        etatemp <- cov.model$itrans(eta)
        #Usigma <- (1-control$split) * etatemp[control$sigmaidx] # initialise Usigma where U = Usigma * Ugamma
        #Usigma <- etatemp[control$sigmaidx]
        Usigma <- exp(control$logUsigma_priormean)
        logUsigma <- log(Usigma)
        #etatemp[control$sigmaidx] <- control$split * etatemp[control$sigmaidx]
        #eta <- cov.model$trans(etatemp)

        #SIGMA_logUsigma <- (1-control$split) * SIGMA[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta),(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)][control$sigmaidx,control$sigmaidx]
        SIGMA_logUsigma <- SIGMA[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta),(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)][control$sigmaidx,control$sigmaidx]
        SIGMA_logUsigmaINV <- 1/SIGMA_logUsigma
        cholSIGMA_logUsigma <- sqrt(SIGMA_logUsigma)

        control$U <- -Usigma^2/2 # U this will be updated using a random walk, not worrying about approximately optimal scaling
        control$Usigma <- Usigma
        control$logUsigma <- logUsigma

    }

    #browser()


    cat("Running MCMC ...\n")

    h <- 1


    LOGPOST <- get(paste("logPosterior",funtxt,sep=""))

    #browser()

    oldlogpost <- LOGPOST(  surv=survivaldata,
                            X=X,
                            beta=beta,
                            omega=omega,
                            eta=eta,
                            gamma=gamma,
                            priors=priors,
                            cov.model=cov.model,
                            u=u,
                            control=control,
                            gradient=TRUE)

    if(control$nugget){
        oldUprior <- sum(dnorm(control$Ugamma,log=TRUE))
        oldUsigmaprior <- dnorm(control$logUsigma,control$logUsigma_priormean,control$logUsigma_priorsd,log=TRUE)
    }


    betasamp <- c()
    omegasamp <- c()
    etasamp <- c()
    Ysamp <- c()

    gamma <- c(gamma) # turn gamma into a vector

    if(control$nugget){
        tarrec <- oldlogpost$logpost + oldUprior + oldUsigmaprior
    }
    else{
        tarrec <- oldlogpost$logpost
    }


    print(SIGMA[1:8,1:8])

    loglik <- c()

    gammamean <- 0
    count <- 1

    # start timing, maybe
    if(control$timeonlyMCMC){
        start <- Sys.time()
    }

    bad <-  c()

    #gammasave <- c()

    indiv_loglik <- c()
    Umean <- 0
    Uvar <- 0
    Usigma_mean <- 0
    Usigma_var <- 0

    U_save <- c()
    Usigma_save <- c()

    while(nextStep(mcmcloop)){

        if(imputeMode){
            if(iteration(mcmcloop)%%101==0){
                data <- as.data.frame(data)
                coordinates(data) <- dat2coords(control$nis,control$olinfo,data[[ids$dataid]])
                proj4string(data) <- CRS(proj4string(shape))
                control$idx <- over(data,geometry(spix))
                control$uqidx <- unique(control$idx)

                if(control$censoringtype=="left" | control$censoringtype=="right"){
                    control$idxi <- list()
                    control$idxicensored <- list()
                    control$idxinotcensored <- list()
                    lapply(control$uqidx,function(i){control$idxi[[i]] <<- which(control$idx==i)})
                    lapply(control$uqidx,function(i){control$idxicensored[[i]] <<- control$censored & control$idx==i})
                    lapply(control$uqidx,function(i){control$idxinotcensored[[i]] <<- control$notcensored & control$idx==i})

                    if(control$Ctest){
                        control$idxicensored <- lapply(control$idxicensored,function(x){try(which(x),silent=TRUE)})
                    }
                    if(control$Utest){
                        control$sumidxinotcensored <- lapply(control$idxinotcensored,function(x){try(sum(x),silent=TRUE)})
                        control$idxinotcensored <- lapply(control$idxinotcensored,function(x){try(which(x),silent=TRUE)})
                    }
                }
                else{
                    control$idxirightcensored <- list()
                    control$idxinotcensored <- list()
                    control$idxileftcensored <- list()
                    control$idxiintervalcensored <- list()

                    lapply(control$uqidx,function(i){control$idxirightcensored[[i]] <<- control$rightcensored & control$idx==i})
                    lapply(control$uqidx,function(i){control$idxinotcensored[[i]] <<- control$notcensored & control$idx==i})
                    lapply(control$uqidx,function(i){control$idxileftcensored[[i]] <<- control$leftcensored & control$idx==i})
                    lapply(control$uqidx,function(i){control$idxiintervalcensored[[i]] <<- control$intervalcensored & control$idx==i})

                    if(control$Rtest){
                        control$idxirightcensored <- lapply(control$idxirightcensored,function(x){try(which(x),silent=TRUE)})
                    }
                    if(control$Utest){
                        control$idxinotcensored <- lapply(control$idxinotcensored,function(x){try(which(x),silent=TRUE)})
                    }
                    if(control$Ltest){
                        control$idxileftcensored <- lapply(control$idxileftcensored,function(x){try(which(x),silent=TRUE)})
                    }
                    if(control$Itest){
                        control$idxiintervalcensored <- lapply(control$idxiintervalcensored,function(x){try(which(x),silent=TRUE)})
                    }
                }

                oldlogpost <- LOGPOST(  surv=survivaldata, # re-compute likelihood with new data
                                        X=X,
                                        beta=beta,
                                        omega=omega,
                                        eta=eta,
                                        gamma=gamma,
                                        priors=priors,
                                        cov.model=cov.model,
                                        u=u,
                                        control=control,
                                        gradient=TRUE)

            }
        }

        stuffpars <- c(beta,omega,eta)
        propmeanpars <- stuffpars + (h/2)*SIGMApars%*%oldlogpost$grad[1:(lenbeta+lenomega+leneta)]
        newstuffpars <- propmeanpars + sqrt(h)*cholSIGMApars%*%rnorm(lenbeta+lenomega+leneta)


        propmeangamma <- gamma + (h/2)*SIGMAgamma*oldlogpost$grad[(lenbeta+lenomega+leneta+1):npars]
        newstuffgamma <- propmeangamma + sqrt(h)*cholSIGMAgamma*rnorm(lengamma)
        ngam <- newstuffgamma
        if(control$gridded){
            ngam <- matrix(ngam,control$Mext,control$Next)
        }

        if(control$nugget){
            oldU <- control$U
            oldUgamma <- control$Ugamma
            oldUsigma <- control$Usigma
            oldlogUsigma <- control$logUsigma

            propmeanUgamma <- oldUgamma + (h/2)*SIGMA_Ugamma*oldlogpost$dP_dUgamma
            newUgamma <- propmeanUgamma + sqrt(h)*cholSIGMA_Ugamma*rnorm(nrow(X))

            propmeanlogUsigma <- oldlogUsigma + (h/2)*SIGMA_logUsigma*oldlogpost$dP_dlogUsigma
            newlogUsigma <- propmeanlogUsigma + sqrt(h)*cholSIGMA_logUsigma*rnorm(1)

            control$U <- -exp(newlogUsigma)^2/2 + exp(newlogUsigma) * newUgamma
            control$Ugamma <- newUgamma
            control$Usigma <- exp(newlogUsigma)
            control$logUsigma <- newlogUsigma

        }

        newlogpost <- LOGPOST(  surv=survivaldata,
                                X=X,
                                beta=newstuffpars[1:lenbeta],
                                omega=newstuffpars[(lenbeta+1):(lenbeta+lenomega)],
                                eta=newstuffpars[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)],
                                gamma=ngam,
                                priors=priors,
                                cov.model=cov.model,
                                u=u,
                                control=control,
                                gradient=TRUE)

        revmeanpars <- newstuffpars + (h/2)*SIGMApars%*%newlogpost$grad[1:(lenbeta+lenomega+leneta)]
        revmeangamma <- newstuffgamma + (h/2)*SIGMAgamma*newlogpost$grad[(lenbeta+lenomega+leneta+1):npars]

        if(control$nugget){
            revmeanUgamma <- newUgamma + (h/2)*SIGMA_Ugamma*newlogpost$dP_dUgamma
            revmeanlogUsigma <- newlogUsigma + (h/2)*SIGMA_logUsigma*newlogpost$dP_dlogUsigma
        }

        revdiffpars <- as.matrix(stuffpars-revmeanpars)
        forwdiffpars <- as.matrix(newstuffpars-propmeanpars)
        revdiffgamma <- as.matrix(gamma-revmeangamma)
        forwdiffgamma <- as.matrix(newstuffgamma-propmeangamma)


        logfrac <- newlogpost$logpost - oldlogpost$logpost -
                            (0.5/h)*t(revdiffpars)%*%SIGMAparsINV%*%revdiffpars +
                            (0.5/h)*t(forwdiffpars)%*%SIGMAparsINV%*%forwdiffpars -
                            (0.5/h)*sum(revdiffgamma*SIGMAgammaINV*revdiffgamma) +
                            (0.5/h)*sum(forwdiffgamma*SIGMAgammaINV*forwdiffgamma)

        if(control$nugget){

            revdifflogUsigma <- as.matrix(oldlogUsigma-revmeanlogUsigma)
            forwdifflogUsigma <- as.matrix(control$logUsigma-propmeanlogUsigma)
            revdiffUgamma <- as.matrix(oldUgamma-revmeanUgamma)
            forwdiffUgamma <- as.matrix(control$Ugamma-propmeanUgamma)

            propUprior <- sum(dnorm(control$Ugamma,log=TRUE))
            propUsigmaprior <- dnorm(control$logUsigma,control$logUsigma_priormean,control$logUsigma_priorsd,log=TRUE)

            logfrac <- logfrac + propUprior - oldUprior +
                                propUsigmaprior - oldUsigmaprior -
                                (0.5/h)*t(revdifflogUsigma)%*%SIGMA_logUsigmaINV%*%revdifflogUsigma +
                                (0.5/h)*t(forwdifflogUsigma)%*%SIGMA_logUsigmaINV%*%forwdifflogUsigma -
                                (0.5/h)*sum(revdiffUgamma*SIGMA_UgammaINV*revdiffUgamma) +
                                (0.5/h)*sum(forwdiffUgamma*SIGMA_UgammaINV*forwdiffUgamma)
        }

        ac <- min(1,exp(as.numeric(logfrac)))
        if(is.na(ac) | is.nan(ac)){
            ac <- 0
            bad <- c(bad,iteration(mcmcloop))
            warning("An acceptance probability could not be calculated for this iteration, this is likely because the spatial decay parameter was too big for this choice of 'ext'. Either increase ext, or tighten prior on spatial decay parameter. At the end of the run check $bad to see which iterations this affected. Stop the run if this problem persists.",immediate.=TRUE)
        }



        #if(iteration(mcmcloop)==2000){browser()}

        if(ac>runif(1)){
            beta <- newstuffpars[1:lenbeta]
            omega <- newstuffpars[(lenbeta+1):(lenbeta+lenomega)]
            eta <- newstuffpars[(lenbeta+lenomega+1):(lenbeta+lenomega+leneta)]
            gamma <- newstuffgamma

            if(control$nugget){
                oldUprior <- propUprior
                oldUsigmaprior <- propUsigmaprior
            }

            oldlogpost <- newlogpost
        }
        else{
            if(control$nugget){ # revert back to old
                control$U <- oldU
                control$Ugamma <- oldUgamma
                control$Usigma <- oldUsigma
                control$logUsigma <- oldlogUsigma
            }
        }

        h <- exp(log(h) + (1/(iteration(mcmcloop)^0.5))*(ac-0.574))
        if(iteration(mcmcloop)%%100==0){
            cat("\n","h =",h,"\n")
        }


        if(is.retain(mcmcloop)){
            betasamp <- rbind(betasamp,as.vector(beta))
            omegasamp <- rbind(omegasamp,as.vector(omega))
            etasamp <- rbind(etasamp,as.vector(eta))
            if(control$gridded){
                Ysamp <- rbind(Ysamp,as.vector(oldlogpost$Y[matidx]))
            }
            else{
                Ysamp <- rbind(Ysamp,as.vector(oldlogpost$Y))
            }

            if(control$nugget){
                tarrec <- c(tarrec,oldlogpost$logpost + oldUprior + oldUsigmaprior)
            }
            else{
                tarrec <- c(tarrec,oldlogpost$logpost)
            }
            loglik <- c(loglik,oldlogpost$loglik)
            gammamean <- ((count-1)/count)*gammamean + (1/count)*gamma
            if(control$nugget){
                Umean <- ((count-1)/count)*Umean + (1/count)*control$U
                Usigma_mean <- ((count-1)/count)*Usigma_mean + (1/count)*control$Usigma
                if(count>1){
                    Uvar <- ((count-2)/(count-1))*Uvar + (count/(count-1)^2)*(control$U-Umean)^2
                    Usigma_var <- ((count-2)/(count-1))*Usigma_var + (count/(count-1)^2)*(control$Usigma-Usigma_mean)^2
                }

                if(control$savenugget){
                    U_save <- rbind(U_save,control$U)
                    Usigma_save <- rbind(Usigma_save,control$Usigma)
                }

            }
            #gammasave <- rbind(gammasave,gamma)
            indiv_loglik <- rbind(indiv_loglik,oldlogpost$indiv_loglik)
            count <- count + 1
        }
    }

    colnames(Ysamp) <- paste("Y",1:ncol(Ysamp),sep="")

    if(control$nugget){
        control$U <- Umean
        control$Ugamma <- Umean / Usigma_mean
        control$Usigma <- Usigma_mean
        control$logUsigma <- log(Usigma_mean)
        warning("For models with a nugget effect, the DIC is approximate. Use WAIC instead.")
    }

    # Compute DIC
    Dhat <- -2*LOGPOST(  surv=survivaldata,
                                X=X,
                                beta=colMeans(betasamp),
                                omega=colMeans(omegasamp),
                                eta=colMeans(etasamp),
                                gamma=gammamean,
                                priors=priors,
                                cov.model=cov.model,
                                u=u,
                                control=control,
                                gradient=TRUE)$loglik
    pD <- -2*mean(loglik) - Dhat
    DIC <- Dhat + 2*pD

    retlist <- list()
    retlist$formula <- formula
    retlist$data <- DATA
    retlist$dist <- dist
    retlist$cov.model <- cov.model
    retlist$mcmc.control <- mcmc.control
    retlist$priors <- priors
    retlist$control <- control

    retlist$terms <- Terms
    retlist$mlmod <- mlmod

    ####
    #   Back transform for output
    ####

    if(length(omega)>1){
        omegasamp <- t(apply(omegasamp,1,control$omegaitrans))
    }
    else{
        omegasamp <- t(t(apply(omegasamp,1,control$omegaitrans)))
    }
    colnames(omegasamp) <- info$parnames

    if(length(eta)>1){
        etasamp <- t(apply(etasamp,1,cov.model$itrans))
    }
    else{
        etasamp <- t(t(apply(etasamp,1,cov.model$itrans)))
    }
    colnames(etasamp) <- cov.model$parnames

    ####

    colnames(betasamp) <- colnames(model.matrix(formula,data))[-1] #attr(Terms,"term.labels")
    retlist$betasamp <- betasamp
    retlist$omegasamp <- omegasamp
    retlist$etasamp <- etasamp
    retlist$Ysamp <- Ysamp

    retlist$loglik <- loglik
    retlist$Dhat <- Dhat
    retlist$pD <- pD
    retlist$DIC <- DIC

    retlist$lpd_hat <- sum(log(colMeans(exp(indiv_loglik))))
    retlist$phat_waic <- sum(apply(indiv_loglik,2,var))
    retlist$WAIC = -2 * (retlist$lpd_hat - retlist$phat_waic)

    retlist$X <- X
    retlist$survivaldata <- survivaldata

    retlist$gridded <- control$gridded
    if(latentmode=="SPDE"){
        retlist$precmat <- control$precmat
        retlist$grid <- control$grid
        retlist$idx <- control$idx
        retlist$uqidx <- control$uqidx
    }
    else if(control$gridded){
        retlist$M <- Mext/control$ext
        retlist$N <- Next/control$ext
        retlist$xvals <- mcens[1:retlist$M]
        retlist$yvals <- ncens[1:retlist$N]
        lookup <- as.vector(matrix(1:(Mext*Next),Mext,Next)[1:retlist$M,1:retlist$N])
        ref <- as.vector(matrix(1:(retlist$M*retlist$N),retlist$M,retlist$N))
        retlist$cellidx <- sapply(control$idx,function(i){ref[lookup==i]})
    }

    retlist$tarrec <- tarrec
    retlist$lasth <- h
    retlist$bad <- bad

    retlist$omegatrans <- control$omegatrans
    retlist$omegaitrans <- control$omegaitrans

    retlist$control <- control
    retlist$censoringtype <- attr(survivaldata,"type")

    retlist$shape <- shape
    retlist$ids <- ids
    retlist$latentmode <- latentmode

    retlist$nugget <- control$nugget
    if(control$nugget){
        retlist$nugget_mean <- Umean
        retlist$nugget_var <- Uvar
        retlist$Usigma_mean <- Usigma_mean
        retlist$Usigma_var <- Usigma_var

        if(control$savenugget){
            retlist$U <- U_save
            retlist$Usigma <- Usigma_save
        }
    }

    retlist$time.taken <- Sys.time() - start

    cat("Time taken:",retlist$time.taken,"\n")

    class(retlist) <- c("list","mcmcspatsurv")

    return(retlist)
}

Try the spatsurv package in your browser

Any scripts or data that you put into this service are public.

spatsurv documentation built on Oct. 19, 2023, 9:07 a.m.