R/JM.R

Defines functions propose_jm_mu_inter_Matrix_block propose_jm_mu_inter_Matrix propose_jm_mu_inter rowTensorProduct update_jm_mu_inter_Matrix update_jm_mu_inter constrain modSplineDesign propose_jm_alpha_inter update_jm_alpha_inter sm_time_transform2 predict.bamlss.jm.td jm.survplot jm.predict rJM simJM param_time_transform2 Predict.matrix.Random2.effect smooth.construct.Re2.smooth.spec center.X Predict.matrix.Random.effect smooth.construct.Re.smooth.spec propose_jm_sigma propose_jm_gamma propose_jm_dalpha propose_jm_alpha propose_jm_mu_Matrix propose_jm_mu_simple propose_jm_mu propose_jm_lambda propose_jm_slice uni.slice_tau2_logPost uni.slice_beta_logPost get_jm_prop_fun jm.mcmc update_jm_dalpha update_jm_alpha update_jm_sigma update_jm_dmu update_jm_mu_Matrix update_jm_mu update_jm_lambda update_jm_gamma jm.mode sparse_Matrix_setup jm.transform jm_bamlss

Documented in jm_bamlss jm.mcmc jm.mode jm.predict jm.survplot jm.transform rJM simJM

#######################################
## Joint model model fitting engine. ##
#######################################
## (1) The family object.
jm_bamlss <- function(...)
{
  links = c(lambda = "log", gamma = "log", mu = "identity", sigma = "log",
            alpha = "identity", dalpha = "identity")
  
  rval <- list(
    "family" = "jm",
    "names" = c("lambda", "gamma", "mu", "sigma", "alpha", "dalpha"),
    "links" = links,
    "transform" = function(x, jm.start = NULL, timevar = NULL, idvar = NULL, init.models = FALSE, 
                           plot = FALSE, interaction = FALSE, edf_alt = FALSE, start_mu = NULL, k_mu = 6, ...) {
      rval <- jm.transform(x = x$x, y = x$y, terms = x$terms, knots = x$knots,
                           formula = x$formula, family = x$family, data = x$model.frame,
                           jm.start = jm.start, timevar = timevar, idvar = idvar, 
                           interaction = interaction, edf_alt = edf_alt, start_mu = start_mu, k_mu = k_mu, ...)
      if(init.models) {
        x2 <- rval$x[c("mu", "sigma")]
        for(i in names(x2)) {
          for(j in seq_along(x2[[i]]$smooth.construct))
            x2[[i]]$smooth.construct[[j]]$update <- bfit_iwls
        }
        attr(x2, "bamlss.engine.setup") <- TRUE
        cat2("generating starting model for longitudinal part...\n")
        gpar <- bfit(x = x2, y = rval$y[[1]][, "obs", drop = FALSE],
                     family = gF2(gaussian), start = NULL, maxit = 100)$parameters
        if(plot) {
          b <- list("x" = x2, "model.frame" = x$model.frame, "parameters" = gpar)
          class(b) <- "bamlss"
          plot(results.bamlss.default(b))
        }
        x2 <- rval$x[c("lambda", "gamma")]
        attr(x2, "bamlss.engine.setup") <- TRUE
        y <- rval$y[[1]][attr(rval$y[[1]], "take"), , drop = FALSE]
        attr(y, "width") <- attr(rval$y[[1]], "width")
        attr(y, "subdivisions") <- attr(rval$y[[1]], "subdivisions")
        cat2("generating starting model for survival part...\n")
        spar <- cox.mode(x = x2, y = list(y), family = gF2("cox"),
                         start = NULL, maxit = 100)$parameters
        if(plot) {
          b <- list("x" = x2, "model.frame" = x$model.frame, "parameters" = spar)
          class(b) <- "bamlss"
          plot(results.bamlss.default(b))
        }
        rval$x <- set.starting.values(rval$x, start = c(unlist(gpar), unlist(spar)))
      }
      
      return(rval)
    },
    "optimizer" = jm.mode,
    "sampler" = jm.mcmc,
    "predict" = jm.predict
  )
  
  class(rval) <- "family.bamlss"
  rval
}


## (2) The transformer function.
jm.transform <- function(x, y, data, terms, knots, formula, family,
                         subdivisions = 25, timedependent = c("lambda", "mu", "alpha", "dalpha"), 
                         timevar = NULL, idvar = NULL, alpha = .Machine$double.eps, mu = NULL, sigma = NULL,
                         sparse = TRUE, interaction = FALSE, edf_alt = FALSE, start_mu = NULL, k_mu = 6, ...)
{
  rn <- names(y)
  y0 <- y
  y <- y[[rn]]
  ntd <- timedependent
  if(!all(ntd %in% names(x)))
    stop("the time dependent predictors specified are different from family object names!")
  
  ## Create the time grid.  
  grid <- function(upper, length){
    seq(from = 0, to = upper, length = length)
  }
  take <- NULL
  if(is.character(idvar)) {
    if(!(idvar %in% names(data)))
      stop(paste("variable", idvar, "not in supplied data set!"))
    y2 <- cbind(y[, "time"], data[[idvar]])
  } else {
    did <- as.factor(idvar)
    y2 <- cbind(y[, "time"], did)
    idvar <- deparse(substitute(idvar), backtick = TRUE, width.cutoff = 500)
    idvar <- strsplit(idvar, "$", fixed = TRUE)[[1]]
    if(length(idvar) > 1)
      idvar <- idvar[2]
    data[[idvar]] <- did
  }
  colnames(y2) <- c("time", idvar)
  take <- !duplicated(y2)
  take_last <- !duplicated(y2, fromLast = TRUE)
  y2 <- y2[take, , drop = FALSE]
  nobs <- nrow(y2)
  grid <- lapply(y2[, "time"], grid, length = subdivisions)
  width <- rep(NA, nobs)
  for(i in 1:nobs)
    width[i] <- grid[[i]][2]
  y[y[, "obs"] == Inf, "obs"] <- .Machine$double.xmax * 0.9
  y[y[, "obs"] == -Inf, "obs"] <- .Machine$double.xmin * 0.9
  attr(y, "width") <- width
  attr(y, "subdivisions") <- subdivisions
  attr(y, "grid") <- grid
  attr(y, "nobs") <- nobs
  attr(y, "take") <- take
  attr(y, "take_last") <- take_last
  yname <- all.names(x[[ntd[1]]]$formula[2])[2]
  timevar_mu <- timevar
  if(is.null(timevar_mu))
    stop("the time variable is not specified, needed for mu!")
  timevar <- yname
  attr(y, "timevar") <- c("lambda" = timevar, "mu" = timevar_mu)
  attr(y, "idvar") <- idvar
  attr(y, "interaction") <- interaction
  attr(y, "edf_alt") <- edf_alt
  attr(y, "tp") <- tp <-  FALSE
  attr(y, "fac") <- fac <-  FALSE
  attr(y, "obstime") <- data[[timevar_mu]]
  attr(y, "ids") <- data[[idvar]]
  dalpha <- has_pterms(x$dalpha$terms) | (length(x$dalpha$smooth.construct) > 0)
  
  ## update alpha for interaction term
  if(interaction){
    if(is.null(start_mu)){
      yobs <- y[, "obs"]
      center <- median(yobs)
      # scale y towards center to account for outliers
      data$mu <- (0.9 * (yobs - center)) + center
    } else {
      data$mu <- start_mu
    }

    ## set up knots for spline in interaction
    nk <- k_mu - 2   # 2 : second order P-spline basis (cubic spline)
    bs_mu <- "ps"
    xu <- max(data$mu)
    xl <- min(data$mu)
    xr <- xu - xl
    xl <- xl - xr * 0.05
    xu <- xu + xr * 0.05
    dx <- (xu - xl)/(nk - 1)
    k <- seq(xl - dx * (2 + 1), xu + dx * (2 + 1), length = nk + 2 * 2 + 2)
    ygrid <- quantile(y[, "obs"], probs = seq(0.025, 0.975, 0.025))
    ## constraint matrix for sum-to-zero constraint on a fixed grid
    C <- matrix(apply(modSplineDesign(k, ygrid, derivs = 0), 2, mean), nrow = 1)

    interaction_vars <- all.vars(x$alpha$formula)[!all.vars(x$alpha$formula) %in% "alpha"]
    if(!is.null(x$alpha$smooth.construct)) {
      attr(y, "tp") <- tp <- TRUE
    } else {
      if(length(interaction_vars) > 0){
        if(length(interaction_vars) > 1) 
           stop("Only one interaction covariate is allowed in alpha.\n Please adjust your model specification.")
        if(!is.factor(data[[interaction_vars]])) 
          stop("Please provide the covariate in alpha as a factor or as a smooth term.")
        attr(y, "fac") <- fac <- TRUE
      }
    }

    ## updating the respective s() term
    if(!tp & !fac){
      ## smooth function of mu
      f_alpha_mu <- paste0("s(mu, bs = \"ps\", k = ", k_mu, ")")
      alpha_mu <- eval(parse(text = f_alpha_mu))
      alpha_mu$C <- C
      x$alpha$smooth.construct <- smoothCon(alpha_mu, data = data, knots = list(mu = k), 
                                            absorb.cons = TRUE, n = nrow(data), scale.penalty=FALSE) 
    }
    if(fac){
      ## smooth function of mu per factor
      f_alpha_mu <- paste0("s(mu, bs = \"ps\", by = ", interaction_vars, ", k = ", k_mu,  ")")
      alpha_mu <- eval(parse(text = f_alpha_mu))
      alpha_mu$C <- C
      x$alpha$smooth.construct <- smoothCon(alpha_mu, data = data, knots = list(mu = k), 
                                            absorb.cons = TRUE, n = nrow(data), scale.penalty=FALSE) 
      # insert factor again
      f_alpha_mu <- paste0(interaction_vars, " + s(mu, bs = \"ps\", by = ", interaction_vars, ", k = ", k_mu, ")")
    }
    if(tp){
      ## tensor spline over mu and covariate: under construction, not tested yet
      term <- x$alpha$smooth.construct[[1]]$term
      k_alpha <- x$alpha$smooth.construct[[1]]$bs.dim
      bs_alpha <- substr(unlist(strsplit(class(x$alpha$smooth.construct[[1]])[1], "[.]"))[1], 1, 2)
      f_alpha_mu <- paste0("te(mu, ", term, ", bs = c(\"", bs_mu, "\",\"", bs_alpha, 
                           "\"), k = c(", k_mu,",", k_alpha, "))")
      alpha_mu <- eval(parse(text = f_alpha_mu))
      alpha_mu$C <- C
      x$alpha$smooth.construct <- smoothCon(alpha_mu, data = data, knots = list(mu = k), 
                                            absorb.cons = FALSE, n = nrow(data), scale.penalty=FALSE) 
    }
    names(x$alpha$smooth.construct) <- alpha_mu$label
    formula$alpha$formula <- update(formula$alpha$formula, paste0("alpha ~ ", f_alpha_mu))
    terms$alpha <- terms(update(terms$alpha, paste0("alpha ~ ", f_alpha_mu)), specials = c("s", "te", "t2", "s2", "ti"))
  }

  ## Recompute design matrixes for lambda, gamma, alpha.
  for(j in c("lambda", "gamma", "alpha", if(dalpha) "dalpha" else NULL)) {
    if(interaction & (j =="alpha")){
      x[[j]]<- design.construct(terms, data = data[take_last, , drop = FALSE], knots = list(mu = k),
                                model.matrix = TRUE, smooth.construct = TRUE, model = j,
                                scale.x = FALSE, absorb.cons = TRUE, C = C)[[j]]
    } else {
    x[[j]]<- design.construct(terms, data = data[take, , drop = FALSE], knots = knots,
                              model.matrix = TRUE, smooth.construct = TRUE, model = j,
                              scale.x = FALSE)[[j]]
  }
  }
  
  ## Degrees of freedom to 1 for lambda, gamma and alpha smooths.
  for(j in c("lambda", "gamma", "alpha", if(dalpha) "dalpha" else NULL)) {
    for(sj in names(x[[j]]$smooth.construct)) {
      if(sj != "model.matrix")
        x[[j]]$smooth.construct[[sj]]$xt$df <- 3
    }
  }
  
  ## dalpha dmu initialize.
  if(dalpha) {
    terms <- c(terms, "dmu" = terms$mu)
    terms$dmu <- terms(update(terms$dmu, dmu ~ .), specials = c("s", "te", "t2", "s2", "ti"))
    formula$dmu <- formula$mu
    formula$dmu$formula <- update(formula$dmu$formula, dmu ~ .)
    i <- which(!grepl(timevar_mu, labels(terms$dmu), fixed = TRUE))
    if(length(i)) {
      terms$dmu <- drop.terms(terms$dmu, i, keep.response = TRUE)
      formula$dmu$formula <- formula(terms$dmu)
    }
    x$dmu <- design.construct(terms, data = data, knots = knots,
                              model.matrix = TRUE, smooth.construct = TRUE, model = "mu",
                              scale.x = FALSE)$mu
    if(has_sterms(terms$dmu)) {
      i <- grep(timevar_mu, names(x$dmu$smooth.construct))
      x$dmu$smooth.construct <- x$dmu$smooth.construct[i]
    }
    if(has_intercept(terms$dmu)) {
      x$dmu$model.matrix <- x$dmu$model.matrix[, -grep("Intercept", colnames(x$dmu$model.matrix)), drop = FALSE]
      x$dmu$terms <- terms(update(x$dmu$terms, . ~ -1 + .), specials = c("s", "te", "t2", "s2", "ti"))
      terms$dmu <- terms(update(terms$dmu, dmu ~ -1 + .), specials = c("s", "te", "t2", "s2", "ti"))
      formula$dmu$formula <- update(formula$dmu$formula, dmu ~ -1 + .)
      if(ncol(x$dmu$model.matrix) < 1)
        x$dmu$model.matrix <- NULL
    }
    i <- which(!grepl(timevar_mu, labels(x$dmu$terms), fixed = TRUE))
    if(length(i)) {
      x$dmu$terms <- drop.terms(x$dmu$terms, i, keep.response = TRUE)
      x$dmu$formula <- formula(x$dmu$terms)
    }
    ntd <- c(ntd, "dmu")
  }

  ## The basic setup.
  if(is.null(attr(x, "bamlss.engine.setup")))
    x <- bamlss.engine.setup(x, ...)

  ## Remove intercept from lambda and alpha.
  for (j in c("lambda", if(interaction) "alpha" else NULL)){
    if(!is.null(x[[j]]$smooth.construct$model.matrix)) {
      attr(terms[[j]], "intercept") <- 0
      cn <- colnames(x[[j]]$smooth.construct$model.matrix$X)
    if("(Intercept)" %in% cn)
        x[[j]]$smooth.construct$model.matrix$X <- x[[j]]$smooth.construct$model.matrix$X[, cn != "(Intercept)", drop = FALSE]
      if(ncol(x[[j]]$smooth.construct$model.matrix$X) < 1) {
        x[[j]]$smooth.construct$model.matrix <- NULL
        x[[j]]$terms <- drop.terms.bamlss(x[[j]]$terms, pterms = FALSE, keep.intercept = FALSE)
    } else {
        x[[j]]$smooth.construct$model.matrix$term <- gsub("(Intercept)+", "",
                                                          x[[j]]$smooth.construct$model.matrix$term, fixed = TRUE)
        x[[j]]$smooth.construct$model.matrix$state$parameters <- x[[j]]$smooth.construct$model.matrix$state$parameters[-1]
        x[[j]]$terms <- drop.terms.bamlss(x[[j]]$terms, keep.intercept = FALSE) # instead of attr(x[[j]]$terms, "intercept") <- 0 
        # additional changes to remove intercept
        x[[j]]$smooth.construct$model.matrix$label <- gsub("(Intercept)+", "",
                                                           x[[j]]$smooth.construct$model.matrix$label, fixed = TRUE)
        x[[j]]$smooth.construct$model.matrix$bs.dim <- as.integer(x[[j]]$smooth.construct$model.matrix$bs.dim - 1)

        pid <- !grepl("tau", names(x[[j]]$smooth.construct$model.matrix$state$parameters)) &
          !grepl("edf", names(x[[j]]$smooth.construct$model.matrix$state$parameters))
        x[[j]]$smooth.construct$model.matrix$pid <- list("b" = which(pid), "tau2" = which(!pid))
        if(!length(x[[j]]$smooth.construct$model.matrix$pid$tau2))
          x[[j]]$smooth.construct$model.matrix$pid$tau2 <- NULL
        x[[j]]$smooth.construct$model.matrix$sparse.setup$matrix <- NULL
      }
    }
  }
  
  ## Assign time grid predict functions.

  for(i in seq_along(ntd)) {
    if(has_pterms(x[[ntd[i]]]$terms)) {
      if(ntd[i] %in% c("lambda", if(interaction) "alpha")) {
        x[[ntd[i]]]$smooth.construct$model.matrix <- param_time_transform2(x[[ntd[i]]]$smooth.construct$model.matrix,
                                                                           drop.terms.bamlss(x[[ntd[i]]]$terms, sterms = FALSE, keep.response = FALSE), data, grid, yname,
                                                                           timevar_mu, take, timevar2 = timevar_mu, idvar = idvar, delete.intercept = TRUE)
      } else {
        x[[ntd[i]]]$smooth.construct$model.matrix <- param_time_transform(x[[ntd[i]]]$smooth.construct$model.matrix,
                                                                          drop.terms.bamlss(x[[ntd[i]]]$terms, sterms = FALSE, keep.response = FALSE), data, grid, yname, 
                                                                          if(ntd[i] != "mu" & ntd[i] != "dmu") timevar else timevar_mu, take, derivMat = (ntd[i] == "dmu"))
      }
    }
    if(length(x[[ntd[i]]]$smooth.construct)) {
      for(j in names(x[[ntd[i]]]$smooth.construct)) {
        if(j != "model.matrix") {
          xterm <- x[[ntd[i]]]$smooth.construct[[j]]$term
          by <- if(x[[ntd[i]]]$smooth.construct[[j]]$by != "NA") x[[ntd[i]]]$smooth.construct[[j]]$by else NULL
          x[[ntd[i]]]$smooth.construct[[j]] <- sm_time_transform2(x[[ntd[i]]]$smooth.construct[[j]],
                                                                  data[, unique(c(xterm, yname, by, timevar, timevar_mu, idvar)), drop = FALSE], grid, yname,
                                                                  if(ntd[i] != "mu" & ntd[i] != "dmu") timevar else timevar_mu, take_last, derivMat = (ntd[i] == "dmu"))
        }
      }
    }
  }
  ## Assign time grid predict function to margin of alpha
  if(tp){
    xterm <- x[["alpha"]]$smooth.construct[[1]]$margin[[1]]$term
    by <- NULL
    x[["alpha"]]$smooth.construct[[1]]$margin[[1]] <- sm_time_transform2(x[["alpha"]]$smooth.construct[[1]]$margin[[1]],
                                                                         data[, unique(c(xterm, yname, by, timevar, timevar_mu, idvar)), drop = FALSE], grid, yname,
                                                                         timevar, take_last, derivMat = FALSE)
    
    xterm <- x[["alpha"]]$smooth.construct[[1]]$margin[[2]]$term
    by <- NULL
    x[["alpha"]]$smooth.construct[[1]]$margin[[2]] <- sm_time_transform2(x[["alpha"]]$smooth.construct[[1]]$margin[[2]],
                                                                         data[, unique(c(xterm, yname, by, timevar, timevar_mu, idvar)), drop = FALSE], grid, yname,
                                                                         timevar, take_last, derivMat = FALSE)
  }


  ## Set alpha/mu/sigma intercept starting value.
  if(!is.null(x$alpha$smooth.construct$model.matrix)) {
    if(alpha == 0)
      alpha <- .Machine$double.eps
    x$alpha$smooth.construct$model.matrix$state$parameters[1] <- alpha
    x$alpha$smooth.construct$model.matrix$state$fitted.values <- x$alpha$smooth.construct$model.matrix$X %*% x$alpha$smooth.construct$model.matrix$state$parameters
  }
  if(!is.null(x$mu$smooth.construct$model.matrix)) {
    if(!is.null(mu)) {
      if(mu == 0)
        mu <- .Machine$double.eps
    }
    x$mu$smooth.construct$model.matrix$state$parameters[1] <- if(is.null(mu)) mean(y[, "obs"], na.rm = TRUE) else mu
    x$mu$smooth.construct$model.matrix$state$fitted.values <- x$mu$smooth.construct$model.matrix$X %*% x$mu$smooth.construct$model.matrix$state$parameters
  }
  if(!is.null(x$sigma$smooth.construct$model.matrix)) {
    x$sigma$smooth.construct$model.matrix$state$parameters[1] <- if(is.null(sigma)) log(sd(y[, "obs"], na.rm = TRUE)) else sigma
    x$sigma$smooth.construct$model.matrix$state$fitted.values <- x$sigma$smooth.construct$model.matrix$X %*% x$sigma$smooth.construct$model.matrix$state$parameters
  }
  
  ## Make fixed effects term the last term in mu.
  if(!is.null(x$mu$smooth.construct$model.matrix) & (length(x$mu$smooth.construct) > 1)) {
    nmu <- names(x$mu$smooth.construct)
    nmu <- c(nmu[nmu != "model.matrix"], "model.matrix")
    x$mu$smooth.construct <- x$mu$smooth.construct[nmu]
    if(dalpha) {
      if(!is.null(x$dmu$smooth.construct$model.matrix) & (length(x$dmu$smooth.construct) > 1)) {
        nmu <- names(x$dmu$smooth.construct)
        nmu <- c(nmu[nmu != "model.matrix"], "model.matrix")
        x$dmu$smooth.construct <- x$dmu$smooth.construct[nmu]
      }
    }
  }
  
  ## Sparse matrix setup for mu/dmu.
  for(j in c("mu", if(dalpha) "dmu" else NULL)) {
    for(sj in seq_along(x[[j]]$smooth.construct)) {
      x[[j]]$smooth.construct[[sj]] <- sparse_Matrix_setup(x[[j]]$smooth.construct[[sj]], sparse = sparse, take = take, interaction = interaction)
    }
  }
  
  ## Update prior/grad/hess functions
  for(j in names(x)) {
    for(sj in names(x[[j]]$smooth.construct)) {
      priors <- make.prior(x[[j]]$smooth.construct[[sj]])
      x[[j]]$smooth.construct[[sj]]$prior <- priors$prior
      x[[j]]$smooth.construct[[sj]]$grad <- priors$grad
      x[[j]]$smooth.construct[[sj]]$hess <- priors$hess
    }
  }
  
  ## Add small constant to penalty diagonal for stabilization
  for(j in names(x)) {
    for(sj in names(x[[j]]$smooth.construct)) {
      if(length(x[[j]]$smooth.construct[[sj]]$S)){
        for(k in seq_along(x[[j]]$smooth.construct[[sj]]$S)){
          if(!is.list(x[[j]]$smooth.construct[[sj]]$S[[k]]) & !is.function(x[[j]]$smooth.construct[[sj]]$S[[k]])) {
            nc <- ncol(x[[j]]$smooth.construct[[sj]]$S[[k]])
            x[[j]]$smooth.construct[[sj]]$S[[k]] <- x[[j]]$smooth.construct[[sj]]$S[[k]] + diag(1e-08, nc, nc)
          }
        }
      }
    }
  }
  

  y0[[rn]] <- y
  
  family$p2logLik <- function(par, logPost = FALSE, ...) {
    eta <- eta_timegrid <- list()
    lprior <- 0
    nx <- c("lambda", "mu", "gamma", "sigma", "alpha")
    if(dalpha)
      nx <- c(nx, "dalpha", "dmu")
    for(j in nx) {
      eta[[j]] <- 0
      if(!(j %in% c("gamma", "sigma")))
        eta_timegrid[[j]] <- 0
      for(sj in names(x[[j]]$smooth.construct)) {
        pn <- paste(j, if(sj != "model.matrix") "s" else "p", sep = ".")
        cn <- colnames(x[[j]]$smooth.construct[[sj]]$X)
        if(is.null(cn))
          cn <- paste("b", 1:ncol(x[[j]]$smooth.construct[[sj]]$X), sep = "")
        pn0 <- paste(pn, sj, sep = ".")
        pn <- paste(pn0, cn, sep = ".")
        if(all(is.na(par[pn]))) {
          if(sj == "model.matrix")
            pn <- gsub(".p.model.matrix.", ".p.", pn, fixed = TRUE)
        }
        eta[[j]] <- eta[[j]] + x[[j]]$smooth.construct[[sj]]$fit.fun(x[[j]]$smooth.construct[[sj]]$X, par[pn])
        if(!(j %in% c("gamma", "sigma")))
          eta_timegrid[[j]] <- eta_timegrid[[j]] + x[[j]]$smooth.construct[[sj]]$fit.fun_timegrid(par[pn])
        if(logPost) {
          pn2 <- paste(pn0, "tau2", sep = ".")
          tpar <- par[grep2(c(pn, pn2), names(par), fixed = TRUE)]
          lprior <- lprior + x[[j]]$smooth.construct[[sj]]$prior(tpar)
        }
      }
    }
    eta_timegrid <- if(dalpha) {
      eta_timegrid$lambda + eta_timegrid$alpha * eta_timegrid$mu + eta_timegrid$dalpha * eta_timegrid$dmu
    } else {
      if(interaction){
        eta_timegrid$lambda + eta_timegrid$alpha
      } else {
      eta_timegrid$lambda + eta_timegrid$alpha * eta_timegrid$mu
    }
    }
    eeta <- exp(eta_timegrid)
    int <- width * (0.5 * (eeta[, 1] + eeta[, subdivisions]) + apply(eeta[, 2:(subdivisions - 1)], 1, sum))
    logLik <- sum((eta_timegrid[, ncol(eta_timegrid)] + eta$gamma) * y[attr(y, "take"), "status"], na.rm = TRUE) -
      exp(eta$gamma) %*% int + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
    if(logPost)
      logLik <- logLik + lprior
    return(drop(logLik))
  }
  
  return(list("x" = x, "y" = y0, "family" = family, "terms" = terms, "formula" = formula))
}


sparse_Matrix_setup <- function(x, sparse = TRUE, force = FALSE, take, interaction)
{
  if(sparse & ((ncol(x$sparse.setup$crossprod) < (ncol(x$X) * 0.5)) | force)) {
      x$X <- Matrix(x$X, sparse = TRUE)
      x$XT <- Matrix(x$XT, sparse = TRUE)
      for(j in seq_along(x$S))
        x$S[[j]] <- Matrix(x$S[[j]], sparse = TRUE)
    if(interaction){
      x$update <- update_jm_mu_inter_Matrix
      x$propose <- propose_jm_mu_inter_Matrix
    } else{
      x$update <- update_jm_mu_Matrix
      x$propose <- propose_jm_mu_Matrix
    }
  } else {
    if(interaction){
      x$update <- update_jm_mu_inter
      x$propose <- propose_jm_mu_inter
    } else{
    x$update <- update_jm_mu
    x$propose <- propose_jm_mu_simple
  }
  }
  if(!is.null(x$sparse.setup$matrix)) {
    x$sparse.setup[["mu.matrix"]] <- x$sparse.setup$matrix[take, , drop = FALSE]
  }
  return(x)
}


jm.mode <- function(x, y, start = NULL, weights = NULL, offset = NULL,
                    criterion = c("AICc", "BIC", "AIC"), maxit = c(100, 1),
                    nu = c("lambda" = 0.1, "gamma" = 0.1, "mu" = 0.1, "sigma" = 0.1, "alpha" = 0.1, "dalpha" = 0.1),
                    update.nu = TRUE, eps = 0.0001, alpha.eps = 0.001, ic.eps = 1e-08, nback = 40,
                    verbose = TRUE, digits = 4, ...)
{
  ## Hard coded.
  fix.lambda <- FALSE
  fix.gamma <- FALSE
  fix.alpha <- FALSE
  fix.mu <- FALSE
  fix.sigma <- FALSE
  fix.dalpha <- FALSE
  nu.eps <- -Inf
  edf.eps <- Inf
  
  dalpha <- has_pterms(x$dalpha$terms) | (length(x$dalpha$smooth.construct) > 0)
  
  if(!is.null(start)) {
    if(is.matrix(start)) {
      if(any(i <- grepl("Mean", colnames(start))))
        start <- start[, i]
      else stop("the starting values should be a vector not a matrix!")
    }
    if(dalpha) {
      if(!any(grepl("dmu.", names(start), fixed = TRUE))) {
        start.dmu <- start[grep("mu.", names(start), fixed = TRUE)]
        names(start.dmu) <- gsub("mu.", "dmu.", names(start.dmu), fixed = TRUE)
        start <- c(start, start.dmu)
      }
    }
    x <- set.starting.values(x, start)
  }
  
  criterion <- match.arg(criterion)
  ia <- interactive()
  
  ## Names of parameters/predictors.
  nx <- names(x)
  
  ## Extract y.
  y <- y[[1]]
  
  ## Number of observations.
  nobs <- attr(y, "nobs")
  
  ## Number of subdivions used for the time grid.
  sub <- attr(y, "subdivisions")
  
  ## The interval width from subdivisons.
  width <- attr(y, "width")
  
  ## Subject specific indicator
  take <- attr(y, "take")
  take_last <- attr(y, "take_last")
  nlong <- length(take)
  
  ## Extract the status for individual i.
  status <- y[take, "status"]
  
  ## Interaction setup
  interaction <- attr(y, "interaction")
  tp <- attr(y, "tp") 
  fac <- attr(y, "fac")
  edf_alt <- attr(y, "edf_alt")
  
  ## Make id for individual i.
  id <- which(take)
  id <- append(id[-1], nlong + 1) - id 
  id <- rep(1:nobs, id)
  
  ## Compute additive predictors.
  eta <- get.eta(x, expand = FALSE)
  
  ## Correct gamma predictor if NULL.
  if(!length(x$gamma$smooth.construct)) {
    eta$gamma <- rep(0, length = nobs)
  }
  
  nu <- rep(nu, length.out = 6)
  if(is.null(names(nu)))
    names(nu) <- c("lambda", "gamma", "mu", "sigma", "alpha", "dalpha")
  
  ## For the time dependent part, compute
  ## predictors based on the time grid.
  eta_timegrid_mu <- 0
  if(length(x$mu$smooth.construct)) {
    for(j in names(x$mu$smooth.construct)) {
      b <- get.par(x$mu$smooth.construct[[j]]$state$parameters, "b")
      x$mu$smooth.construct[[j]]$state$fitted_timegrid <- x$mu$smooth.construct[[j]]$fit.fun_timegrid(b)
      eta_timegrid_mu <- eta_timegrid_mu + x$mu$smooth.construct[[j]]$fit.fun_timegrid(b)
      x$mu$smooth.construct[[j]]$state$nu <- nu["mu"]
    }
  }
  
  eta_timegrid_lambda <- 0
  if(length(x$lambda$smooth.construct)) {
    for(j in names(x$lambda$smooth.construct)) {
      b <- get.par(x$lambda$smooth.construct[[j]]$state$parameters, "b")
      eta_timegrid_lambda <- eta_timegrid_lambda + x$lambda$smooth.construct[[j]]$fit.fun_timegrid(b)
      x$lambda$smooth.construct[[j]]$state$nu <- nu["lambda"]
    }
  }
  
  eta_timegrid_dmu <- eta_timegrid_dalpha <- 0
  if(dalpha) {
    if(length(x$dalpha$smooth.construct)) {
      for(j in names(x$dalpha$smooth.construct)) {
        b <- get.par(x$dalpha$smooth.construct[[j]]$state$parameters, "b")
        eta_timegrid_dalpha <- eta_timegrid_dalpha + x$dalpha$smooth.construct[[j]]$fit.fun_timegrid(b)
        x$dalpha$smooth.construct[[j]]$state$nu <- nu["dalpha"]
      }
    }
    if(length(x$dmu$smooth.construct)) {
      for(j in names(x$dmu$smooth.construct)) {
        b <- get.par(x$mu$smooth.construct[[j]]$state$parameters, "b")
        eta_timegrid_dmu <- eta_timegrid_dmu + x$dmu$smooth.construct[[j]]$fit.fun_timegrid(b)
        eta$dmu <- x$dmu$smooth.construct[[j]]$fit.fun(x$dmu$smooth.construct[[j]]$X, b)
      }
    }
  }
  
  eta_timegrid_alpha <- 0
  if(length(x$alpha$smooth.construct)) {
    for(j in names(x$alpha$smooth.construct)) {
      b <- get.par(x$alpha$smooth.construct[[j]]$state$parameters, "b")
      # update alpha according to mu
      if(interaction & !("model.matrix" %in% attr(x$alpha$smooth.construct[[j]], "class"))){
        Xmu <- as.vector(t(eta_timegrid_mu))
        if(!tp) {
          X <- x$alpha$smooth.construct[[j]]$update(Xmu, "mu")    
        } else {
          Xalpha <- x$alpha$smooth.construct[[j]]$margin[[1]]$update(Xmu, "mu")
          Xalpha2 <- x$alpha$smooth.construct[[j]]$margin[[2]]$fit.fun_timegrid(NULL)
          X <- rowTensorProduct(Xalpha, Xalpha2) 
        }
        fit_timegrid <- matrix(drop(X %*% b), nrow = nrow(eta_timegrid_mu), 
                               ncol = ncol(eta_timegrid_mu), byrow = TRUE)
        x$alpha$smooth.construct[[j]]$state$fitted.values <- fit_timegrid[, ncol(fit_timegrid)]
        x$alpha$smooth.construct[[j]]$state$fitted_timegrid <- fit_timegrid
        eta_timegrid_alpha <- eta_timegrid_alpha + fit_timegrid
      } else {
        x$alpha$smooth.construct[[j]]$state$fitted_timegrid <- x$alpha$smooth.construct[[j]]$fit.fun_timegrid(b)
        eta_timegrid_alpha <- eta_timegrid_alpha + x$alpha$smooth.construct[[j]]$fit.fun_timegrid(b)   
      }
      x$alpha$smooth.construct[[j]]$state$nu <- nu["alpha"]
    }
    eta$alpha <- eta_timegrid_alpha[, ncol(eta_timegrid_alpha)]
  } 

  # for interaction effect eta_mu is within eta_alpha
  if(interaction){
    eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
  } else{
  eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
  }
  knots <- if(interaction){
              if(tp){x$alpha$smooth.construct[[1]]$margin[[1]]$knots} 
              else {x$alpha$smooth.construct[[1]]$knots}} 
           else {NULL}
  
  for(k in c("gamma", "sigma")) {
    if(length(x[[k]]$smooth.construct)) {
      for(j in names(x[[k]]$smooth.construct)) {
        x[[k]]$smooth.construct[[j]]$state$nu <- nu[k]
      }
    }
  }
  
  ## Set edf to zero 4all terms.
  for(k in names(x)) {
    if(length(x[[k]]$smooth.construct)) {
      for(j in names(x[[k]]$smooth.construct)) {
        x[[k]]$smooth.construct[[j]]$state$edf <- 0
      }
    }
  }
  
  ## Extract current value of the log-posterior.
  get_LogPost <- function(eta_timegrid, eta, log.priors) {
    eeta <- exp(eta_timegrid)
    int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
    logLik <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
      drop(exp(eta$gamma)) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
    logPost <- as.numeric(logLik + log.priors)
    return(logPost)
  }
  
  eps <- rep(eps, length.out = 2)
  if(length(maxit) < 2)
    maxit <- c(maxit, 1)
  
  ## Start the backfitting algorithm.
  eps0 <- eps0_surv <- eps0_long <- eps0_alpha <- eps0_dalpha <- eps[1] + 1
  iter <- 0; ic_contrib <- NULL
  edf <- 0; ok.alpha <- FALSE
  ptm <- proc.time()
  while((eps0 > eps[1]) & (iter < maxit[1])) {
    if(eps0 < nu.eps)
      update.nu <- FALSE
    
    eta0_surv <- do.call("cbind", eta[c("lambda", "gamma")])
    eta0_long <- rowsum(do.call("cbind", eta[c("mu", "sigma")]), id)
    eta0_alpha <- if(is.null(dalpha)) eta$alpha else do.call("cbind", eta[c("alpha", "dalpha")])
    
    if(max(c(eps0_surv, eps0_long)) < alpha.eps) 
      ok.alpha <- TRUE
    #checks# Setup for checks while running
    myplots <- FALSE
    if(myplots){
      times <- matrix(unlist(attr(y, "grid")), ncol = ncol(eta_timegrid), byrow = TRUE)
      obstimes <- attr(y, "obstime")
      ids <- attr(y, "ids")
      par(mfrow = c(2,2))
    }
    mydiag <- FALSE
    
    ################
    ## Alpha part ##
    ################
    if(!fix.alpha & ok.alpha) {
      eps00 <- eps[2] + 1; iter00 <- 0
      eta00 <- eta$alpha
      while(eps00 > eps[2] & iter00 < maxit[2]) {
        for(sj in seq_along(x$alpha$smooth.construct)) {
          # #checks#
          # cat("\n iteration: ", iter, ", predictor: alpha", ", term: ", sj)
          if(interaction){
            state <- update_jm_alpha_inter(x$alpha$smooth.construct[[sj]], eta, eta_timegrid,
                                           eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                           status, update.nu, width, criterion, get_LogPost, nobs, eps0_alpha < edf.eps, edf = edf, edf_alt, 
                                           tp = tp, ...)
          } else {
          state <- update_jm_alpha(x$alpha$smooth.construct[[sj]], eta, eta_timegrid,
                                   eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                   status, update.nu, width, criterion, get_LogPost, nobs, eps0_alpha < edf.eps, edf = edf, ...)
          }
          eta_timegrid_alpha <- eta_timegrid_alpha - x$alpha$smooth.construct[[sj]]$state$fitted_timegrid + state$fitted_timegrid
          if(interaction){
            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
          } else {
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          }       
          eta$alpha <- eta$alpha - fitted(x$alpha$smooth.construct[[sj]]$state) + fitted(state)
          #checks#
          if(myplots){
            plot(eta_timegrid_mu, state$fitted_timegrid, 
                 main=paste0("alpha: iteration ", iter, ", edf ", round(state$edf, 2), ", term: ", names(x$alpha$smooth.construct)[sj]))
            abline(h = 0, col = "red")
            abline(v = median(y[, 3]), col = "red")
            plot(eta_timegrid_mu, eta_timegrid_alpha, main=paste0("alpha: iteration ", iter, ", edf ", round(state$edf, 2)))
            plot(eta_timegrid_mu[, ncol(eta_timegrid_mu)], eta$alpha)
            matplot(t(times), t(eta_timegrid_alpha), main=paste0("alpha: iteration ", iter, ", edf ", round(state$edf, 2)), type = "l")
          }
          if(mydiag){
            if(min(eta_timegrid_alpha) < -10 | max(eta_timegrid_alpha) > 10) browser()
          }
          edf <- edf - x$alpha$smooth.construct[[sj]]$state$edf + state$edf
          x$alpha$smooth.construct[[sj]]$state <- state
        }
        
        if(maxit[2] > 1) {
          eps00 <- mean(abs((eta$alpha - eta00) / eta$alpha), na.rm = TRUE)
          if(is.na(eps00) | !is.finite(eps00)) eps00 <- eps[2] + 1
        }
        
        iter00 <- iter00 + 1
      }
    }
    
    if(!fix.dalpha & dalpha & ok.alpha) {
      eps00 <- eps[2] + 1; iter00 <- 0
      eta00 <- eta$dalpha
      while(eps00 > eps[2] & iter00 < maxit[2]) {
        for(sj in seq_along(x$dalpha$smooth.construct)) {
          state <- update_jm_dalpha(x$dalpha$smooth.construct[[sj]], eta, eta_timegrid,
                                    eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                    status, update.nu, width, criterion, get_LogPost, nobs, eps0_alpha < edf.eps, edf = edf, ...)
          eta_timegrid_dalpha <- eta_timegrid_dalpha - x$dalpha$smooth.construct[[sj]]$state$fitted_timegrid + state$fitted_timegrid
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          eta$dalpha <- eta$dalpha - fitted(x$dalpha$smooth.construct[[sj]]$state) + fitted(state)
          edf <- edf - x$dalpha$smooth.construct[[sj]]$state$edf + state$edf
          x$dalpha$smooth.construct[[sj]]$state <- state
        }
        
        if(maxit[2] > 1) {
          eps00 <- mean(abs((eta$dalpha - eta00) / eta$dalpha), na.rm = TRUE)
          if(is.na(eps00) | !is.finite(eps00)) eps00 <- eps[2] + 1
        }
        
        iter00 <- iter00 + 1
      }
    }
    
    #######################
    ## Longitudinal part ##
    #######################
    eps00 <- eps[2] + 1; iter00 <- 0
    eta00 <- do.call("cbind", eta[c("mu", "sigma")])
    while(eps00 > eps[2] & iter00 < maxit[2]) {
      if(!fix.mu) {
        for(sj in names(x$mu$smooth.construct)) {
          # #checks#
          # cat("\n iteration: ", iter, ", predictor: mu", ", term: ", sj)
          state <- x$mu$smooth.construct[[sj]]$update(x$mu$smooth.construct[[sj]], y, eta, eta_timegrid,
                                                      eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                                      status, update.nu, width, criterion, get_LogPost, nobs, eps0_long < edf.eps, edf = edf,
                                                      dx = if(dalpha) x$dmu$smooth.construct[[sj]] else NULL, 
                                                      xsmalpha = if(interaction) x$alpha$smooth.construct else NULL, 
                                                      knots = knots, tp = if(interaction) tp else FALSE, fac = if(interaction) fac else FALSE, ...)
          eta_timegrid_mu <- eta_timegrid_mu - x$mu$smooth.construct[[sj]]$state$fitted_timegrid + state$fitted_timegrid

          if(interaction){
            for (i in names(x$alpha$smooth.construct)){
              if(i != "model.matrix"){   # only smooth.constructs need to be updated
                g_a <- get.par(x$alpha$smooth.construct[[i]]$state$parameters, "b")
                Xalpha <- modSplineDesign(knots, as.vector(t(eta_timegrid_mu)), derivs = 0)
                Xalpha <- constrain(x$alpha$smooth.construct[[i]], Xalpha)
                if(fac)
                  Xalpha <- Xalpha * x$alpha$smooth.construct[[i]]$by_timegrid
                if(tp)  
                  Xalpha <- rowTensorProduct(Xalpha, Xalpha2) 
                
                alpha_state <- matrix(Xalpha %*% g_a, nrow = nrow(eta_timegrid), ncol = ncol(eta_timegrid), byrow = TRUE)
                eta_timegrid_alpha <- eta_timegrid_alpha - x$alpha$smooth.construct[[i]]$state$fitted_timegrid + alpha_state
                x$alpha$smooth.construct[[i]]$state$fitted_timegrid <- alpha_state
                eta$alpha <- eta$alpha - fitted(x$alpha$smooth.construct[[i]]$state) + alpha_state[, ncol(eta_timegrid)]
                x$alpha$smooth.construct[[i]]$state$fitted.values <- alpha_state[, ncol(eta_timegrid)]
              } 
            }

            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
          } else {
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          }
          eta$mu <- eta$mu - fitted(x$mu$smooth.construct[[sj]]$state) + fitted(state)
          edf <- edf - x$mu$smooth.construct[[sj]]$state$edf + state$edf
          x$mu$smooth.construct[[sj]]$state <- state
          #checks# 
          if(myplots){
            matplot(t(times),t(state$fitted_timegrid), type = "l", main = paste0("mu: iteration ", iter, ", edf ", round(state$edf, 2), " ", sj))
            matplot(t(times), t(eta_timegrid_mu), type = "l", main = paste0("mu: iteration ", iter, ", edf ", round(state$edf, 2), " ", sj))
            plot(eta_timegrid_mu, eta_timegrid_alpha, main=paste0("mu: iteration ", iter, ", edf ", round(state$edf, 2), "alpha"))
            plot(eta_timegrid_mu[, ncol(eta_timegrid_mu)], eta$alpha, main=paste0("mu: iteration ", iter, ", edf ", round(state$edf, 2), "alpha"))
            # print(xyplot(eta$mu ~ obstimes, group = ids, type = "l", main = paste0("mu: iteration ", iter, ", edf ", round(state$edf, 2), " ", sj)))
          }
          
          if(mydiag){
            if(min(eta_timegrid_alpha) < -10 | max(eta_timegrid_alpha) > 10) browser()
            if(min(eta_timegrid_mu) < -10 | max(eta_timegrid_alpha) > 20) browser()
          }
          
          #!# Danger: dalpha and interaction not properly set up
          if(dalpha & (sj %in% names(x$dmu$smooth.construct))) {
            state <- update_jm_dmu(x$dmu$smooth.construct[[sj]], x$mu$smooth.construct[[sj]])
            eta_timegrid_dmu <- eta_timegrid_dmu - x$dmu$smooth.construct[[sj]]$state$fitted_timegrid + state$fitted_timegrid
            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
            eta$dmu <- eta$dmu - fitted(x$dmu$smooth.construct[[sj]]$state) + fitted(state)
            x$dmu$smooth.construct[[sj]]$state <- state
          }
        }
      }
      
      if(!fix.sigma) {
        for(sj in seq_along(x$sigma$smooth.construct)) {
          state <- update_jm_sigma(x$sigma$smooth.construct[[sj]], y, eta,
                                   eta_timegrid, update.nu, criterion, get_LogPost, nobs, eps0_long < edf.eps, edf = edf, ...)
          eta$sigma <- eta$sigma - fitted(x$sigma$smooth.construct[[sj]]$state) + fitted(state)
          edf <- edf - x$sigma$smooth.construct[[sj]]$state$edf + state$edf
          x$sigma$smooth.construct[[sj]]$state <- state
        }
      }
      
      if(maxit[2] > 1) {
        eps00 <- do.call("cbind", eta[c("mu", "sigma")])
        eps00 <- mean(abs((eps00 - eta00) / eps00), na.rm = TRUE)
        if(is.na(eps00) | !is.finite(eps00)) eps00 <- eps[2] + 1
      }
      
      iter00 <- iter00 + 1
    }
    
    eps00 <- eps[2] + 1; iter00 <- 0
    eta00 <- do.call("cbind", eta[c("lambda", "gamma")])
    while(eps00 > eps[2] & iter00 < maxit[2]) {
      ###################
      ## Survival part ##
      ###################
      eeta <- exp(eta_timegrid)
      int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
      
      if(!fix.lambda) {
        for(sj in seq_along(x$lambda$smooth.construct)) {
          # cat("\n iteration: ", iter, ", predictor: lambda", ", term: ", sj)
          state <- update_jm_lambda(x$lambda$smooth.construct[[sj]], eta, eta_timegrid,
                                    eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                    status, update.nu, width, criterion, get_LogPost, nobs, eps0_surv < edf.eps, edf = edf, interaction, ...)
          eta_timegrid_lambda <- eta_timegrid_lambda - x$lambda$smooth.construct[[sj]]$state$fitted_timegrid + state$fitted_timegrid
          if(interaction){
            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
          } else {
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          }
          eta$lambda <- eta$lambda - fitted(x$lambda$smooth.construct[[sj]]$state) + fitted(state)
          edf <- edf - x$lambda$smooth.construct[[sj]]$state$edf + state$edf
          x$lambda$smooth.construct[[sj]]$state <- state
        }
      }
      if(mydiag){
        if(min(eta_timegrid) < -100 | max(eta_timegrid) > 100) browser()
      }
      
      if(!fix.gamma) {
        if(length(x$gamma$smooth.construct)) {
          for(sj in seq_along(x$gamma$smooth.construct)) {
            # cat("\n iteration: ", iter, ", predictor: gamma", ", term: ", sj)
            state <- update_jm_gamma(x$gamma$smooth.construct[[sj]], eta, eta_timegrid, int0,
                                     status, update.nu, criterion, get_LogPost, nobs, eps0_surv < edf.eps, edf = edf, ...)
            eta$gamma <- eta$gamma - fitted(x$gamma$smooth.construct[[sj]]$state) + fitted(state)
            edf <- edf - x$gamma$smooth.construct[[sj]]$state$edf + state$edf
            x$gamma$smooth.construct[[sj]]$state <- state
          }
        }
      }
      
      if(maxit[2] > 1) {
        eps00 <- do.call("cbind", eta[c("lambda", "gamma")])
        eps00 <- mean(abs((eps00 - eta00) / eps00), na.rm = TRUE)
        if(is.na(eps00) | !is.finite(eps00)) eps00 <- eps[2] + 1
      }
      
      iter00 <- iter00 + 1
    }
    
    eps0_surv <- do.call("cbind", eta[c("lambda", "gamma")])
    eps0_long <- rowsum(do.call("cbind", eta[c("mu", "sigma")]), id)
    eps0_alpha <- if(is.null(dalpha)) eta$alpha else do.call("cbind", eta[c("alpha", "dalpha")])
    
    eps0_surv <- if(!(fix.lambda & fix.gamma)) {
      mean(abs((eps0_surv - eta0_surv) / eps0_surv), na.rm = TRUE)
    } else 0
    eps0_long <- if(!(fix.mu & fix.sigma)) {
      mean(abs((eps0_long - eta0_long) / eps0_long), na.rm = TRUE)
    } else 0
    eps0_alpha <- if(!fix.alpha) {
      if(ok.alpha) {
        mean(abs((eps0_alpha - eta0_alpha) / eps0_alpha), na.rm = TRUE)
      } else 1 + eps[1]
    } else 0
    
    eps0 <- mean(c(eps0_surv, eps0_long, eps0_alpha))
    if(is.na(eps0) | !is.finite(eps0)) eps0 <- eps[1] + 1
    
    eeta <- exp(eta_timegrid)
    int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
    logLik <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
      exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
    edf <- get.edf(x, type = 2)
    IC <- get.ic2(logLik, edf, length(eta$mu), criterion)
    
    ic_contrib <- c(ic_contrib, IC)
    if(iter > nback) {
      adic <- abs(diff(tail(ic_contrib, nback))) / tail(ic_contrib, nback - 1)
      if(all(adic < ic.eps))
        eps0 <- 0
    }
    
    iter <- iter + 1
    
    if(verbose) {
      cat(if(ia) "\r" else "\n")
      vtxt <- paste(
        criterion, " ", fmt(IC, width = 8, digits = digits),
        " logLik ", fmt(logLik, width = 8, digits = digits),
        " edf ", fmt(edf, width = 6, digits = digits),
        " eps ", paste(fmt(eps0, width = 6, digits = digits + 2), "|S",
                       fmt(eps0_surv, width = 6, digits = digits + 2), "|L",
                       fmt(eps0_long, width = 6, digits = digits + 2), "|A",
                       fmt(eps0_alpha, width = 6, digits = digits + 2), sep = ""),
        " iteration ", formatC(iter, width = nchar(maxit[1])), sep = ""
      )
      cat(vtxt)
      if(.Platform$OS.type != "unix" & ia) flush.console()
    }
  }
  
  elapsed <- c(proc.time() - ptm)[3]
  
  if(verbose) {
    et <- if(elapsed > 60) {
      paste(formatC(format(round(elapsed / 60, 2), nsmall = 2), width = 5), "min", sep = "")
    } else paste(formatC(format(round(elapsed, 2), nsmall = 2), width = 5), "sec", sep = "")
    cat("\nelapsed time: ", et, "\n", sep = "")
  }
  
  eeta <- exp(eta_timegrid)
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  logLik <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  logPost <- as.numeric(logLik + get.log.prior(x))
  
  if(iter == maxit[1])
    warning("the backfitting algorithm did not converge, please check argument eps and maxit!")
  
  return(list("fitted.values" = eta, "parameters" = get.all.par(x),
              "logLik" = logLik, "logPost" = logPost, "hessian" = get.hessian(x),
              "converged" = iter < maxit[1], "time" = elapsed))
}


### Backfitting updating functions.
update_jm_gamma <- function(x, eta, eta_timegrid, int,
                            status, update.nu, criterion, get_LogPost, nobs, do.optim2, edf, ...)
{
  ## Gradient and hessian.
  tint <- xhess0 <- 0
  for(i in 1:nobs) {
    tint <- tint + exp(eta$gamma[i]) * x$X[i, , drop = FALSE] * int[i]
    xhess0 <- xhess0 + exp(eta$gamma[i]) * x$X[i, ] %*% t(x$X[i, ]) * int[i]
  }
  xgrad <- drop(t(status) %*% x$X - tint)
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par[x$pid$tau2] <- tau2
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- xhess0 + x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(xhess, index = x$sparse.setup)
      Hs <- Sigma %*% xgrad
      g <- par[x$pid$b]
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          par[x$pid$b] <- g2
          fit <- x$fit.fun(x$X, g2)
          eta$gamma <- eta$gamma - fitted(x$state) + fit
          lp <- get_LogPost(eta_timegrid, eta, x$prior(par))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      fit <- x$fit.fun(x$X, g2)
      eta$gamma <- eta$gamma - fitted(x$state) + fit
      edf1 <- sum_diag(xhess0 %*% Sigma)
      edf <- edf0 + edf1
      
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          par[x$pid$b] <- g2
          opt_state <- list("parameters" = par,
                            "fitted.values" = fit, "edf" = edf1, "hessian" = xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"))
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
  }
  
  xgrad <- drop(xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE))
  xhess <- xhess0 + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(xhess, index = NULL)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      
      ## Update additive predictors.
      fit <- x$fit.fun(x$X, g2)
      eta$gamma <- eta$gamma - fitted(x$state) + fit
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  g2 <- drop(g + x$state$nu * Hs)
  names(g2) <- names(g)
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update additive predictors.
  fit <- x$fit.fun(x$X, g2)
  x$state$fitted.values <- fit
  x$state$edf <- sum_diag(xhess0 %*% Sigma)
  x$state$hessian <- xhess
  
  return(x$state)
}


update_jm_lambda <- function(x, eta, eta_timegrid,
                             eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                             status, update.nu, width, criterion, get_LogPost, nobs, do.optim2, edf, interaction, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Compute gradient and hessian integrals.
  # int <- survint(X, eeta, width, exp(eta$gamma), index = x$sparse.setup$matrix) # throws error with time-varying covariates
  int <- survint(X, eeta, width, exp(eta$gamma))
  # returns negative Hessian
  
  xgrad <- drop(t(status) %*% x$XT - int$grad)
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par <- set.par(par, tau2, "tau2")
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- int$hess + x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- get.par(par, "b")
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          names(g2) <- names(g)
          x$state$parameters <- set.par(x$state$parameters, g2, "b")
          x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
          
          ## Update additive predictors.
          fit_timegrid <- x$fit.fun_timegrid(g2)
          eta_timegrid_lambda <- eta_timegrid_lambda - x$state$fitted_timegrid + fit_timegrid
          if(interaction){
            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
          } else {
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          }          
          lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2, expand = FALSE)
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta$lambda <- eta$lambda - fitted(x$state) + fit
      eta_timegrid_lambda <- eta_timegrid_lambda - x$state$fitted_timegrid + fit_timegrid
      if(interaction){
        eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
      } else {
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      } 
      edf1 <- sum_diag(int$hess %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          opt_state <- list("parameters" = set.par(par, g2, "b"),
                            "fitted.values" = fit, "fitted_timegrid" = fit_timegrid,
                            "edf" = edf1, "hessian" = xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"))
    # print(round(tau2, 4))
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
  }
  
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- int$hess + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(xhess, index = NULL)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      
      ## Update additive predictors.
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta_timegrid_lambda <- eta_timegrid_lambda - x$state$fitted_timegrid + fit_timegrid
      if(interaction){
        eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
      } else {
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      }      
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  g2 <- drop(g + x$state$nu * Hs)
  names(g2) <- names(g)
  
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update additive predictors.
  x$state$fitted_timegrid <- x$fit.fun_timegrid(g2)
  x$state$fitted.values <- x$fit.fun(x$X, g2, expand = FALSE)
  x$state$edf <- sum_diag(int$hess %*% Sigma)
  x$state$hessian <- xhess
  
  return(x$state)
}


update_jm_mu <- function(x, y, eta, eta_timegrid,
                         eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                         status, update.nu, width, criterion, get_LogPost, nobs, do.optim2, edf, 
                         dx = NULL, xsmalpha = NULL, knots = NULL, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  dX <- if(!is.null(dx)) dx$fit.fun_timegrid(NULL) else NULL
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta * eta_timegrid_alpha, width, exp(eta$gamma),
                 eeta * eta_timegrid_alpha^2, index = x$sparse.setup[["mu.matrix"]],
                 dX, if(!is.null(dX)) eeta * eta_timegrid_dalpha else NULL,
                 if(!is.null(dX)) eeta * eta_timegrid_dalpha^2 else NULL)
  
  xgrad <- drop(t(x$X) %*% drop((y[, "obs"]  - eta$mu) / exp(eta$sigma)^2) +
                  t(x$XT) %*% drop(eta$alpha * status) - int$grad)
  if(!is.null(dx))
    xgrad <- drop(xgrad + t(dx$XT) %*% drop(eta$dalpha * status))
  XWX <- if(is.null(x$sparse.setup$matrix)) {
    crossprod(x$X * (1 / exp(eta$sigma)^2), x$X)
  } else do.XWX(x$X, exp(eta$sigma)^2, index = x$sparse.setup$matrix)
  xhess0 <- -1 * XWX - int$hess
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par <- set.par(par, tau2, "tau2")
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      # negative Hessian in xhess0
      xhess <- xhess0 - x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(-1 * xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- get.par(par, "b")
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          names(g2) <- names(g)
          x$state$parameters <- set.par(x$state$parameters, g2, "b")
          x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
          fit_timegrid <- x$fit.fun_timegrid(g2)
          eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
          if(!is.null(dx))
            eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          fit <- x$fit.fun(x$X, g2)
          eta$mu <- eta$mu - fitted(x$state) + fit
          lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2)
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta$mu <- eta$mu - fitted(x$state) + fit
      eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
      if(!is.null(dx))
        eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      edf1 <- sum_diag((-1 * xhess0) %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          opt_state <- list("parameters" = set.par(par, g2, "b"),
                            "fitted.values" = fit, "fitted_timegrid" = fit_timegrid,
                            "edf" = edf1, "hessian" = -1 * xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"), maxit = 1)
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
  }
  
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- xhess0 - x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(-1 * xhess, index = NULL)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
      if(!is.null(dx))
        eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      fit <- x$fit.fun(x$X, g2)
      eta$mu <- eta$mu - fitted(x$state) + fit
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  g2 <- drop(g + x$state$nu * Hs)
  names(g2) <- names(g)
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update fitted values..
  x$state$fitted_timegrid <- x$fit.fun_timegrid(g2)
  x$state$fitted.values <- x$fit.fun(x$X, g2)
  x$state$edf <- sum_diag((-1 * xhess0) %*% Sigma)
  x$state$hessian <- -1 * xhess
  
  return(x$state)
}


update_jm_mu_Matrix <- function(x, y, eta, eta_timegrid,
                                eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                                status, update.nu, width, criterion, get_LogPost, nobs, do.optim2, edf, 
                                dx = NULL, xsmalpha = NULL, knots = NULL, tp = FALSE, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  dX <- if(!is.null(dx)) dx$fit.fun_timegrid(NULL) else NULL
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta * eta_timegrid_alpha, width, exp(eta$gamma),
                 eeta * eta_timegrid_alpha^2, index = x$sparse.setup[["mu.matrix"]],
                 dX, if(!is.null(dX)) eeta * eta_timegrid_dalpha else NULL,
                 if(!is.null(dX)) eeta * eta_timegrid_dalpha^2 else NULL)
  
  xgrad <- crossprod(x$X, drop((y[, "obs"]  - eta$mu) / exp(eta$sigma)^2)) +
    crossprod(x$XT,  drop(eta$alpha * status)) - int$grad
  if(!is.null(dx))
    xgrad <- xgrad + crossprod(dx$XT, drop(eta$dalpha * status))
  XWX <- crossprod(Diagonal(x = 1 / exp(eta$sigma)^2) %*% x$X, x$X)
  xhess0 <- -1 * XWX - int$hess
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par[x$pid$tau2] <- tau2
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- xhess0 - x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(-1 * xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- par[x$pid$b]
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          par[x$pid$b] <- g2
          eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + x$fit.fun_timegrid(g2)
          if(!is.null(dx))
            eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          fit <- x$fit.fun(x$X, g2)
          eta$mu <- eta$mu - fitted(x$state) + fit
          lp <- get_LogPost(eta_timegrid, eta, x$prior(par))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2)
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta$mu <- eta$mu - fitted(x$state) + fit
      eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
      if(!is.null(dx))
        eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      edf1 <- sum_diag((-1 * xhess0) %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          par[x$pid$b] <- g2
          opt_state <- list("parameters" = par,
                            "fitted.values" = fit, "fitted_timegrid" = fit_timegrid,
                            "edf" = edf1, "hessian" = -1 * xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"), maxit = 1)
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters[x$pid$tau2] <- tau2
  }
  
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- xhess0 - x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(-1 * xhess, index = NULL)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      x$state$parameters[x$pid$b] <- g2
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
      if(!is.null(dx))
        eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g2)
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      fit <- x$fit.fun(x$X, g2)
      eta$mu <- eta$mu - fitted(x$state) + fit
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  g2 <- drop(g + x$state$nu * Hs)
  x$state$parameters[x$pid$b] <- g2
  
  ## Update fitted values..
  x$state$fitted_timegrid <- x$fit.fun_timegrid(g2)
  x$state$fitted.values <- x$fit.fun(x$X, g2)
  x$state$edf <- sum_diag((-1 * xhess0) %*% Sigma)
  x$state$hessian <- -1 * xhess
  
  return(x$state)
}


update_jm_dmu <- function(dx, x)
{
  b <- get.state(x, "b")
  state <- x$state
  state$fitted.values <- dx$fit.fun(dx$X, b)
  state$fitted_timegrid <- dx$fit.fun_timegrid(b)
  state$edf <- 0
  return(state)
}


update_jm_sigma <- function(x, y, eta, eta_timegrid,
                            update.nu, criterion, get_LogPost, nobs, do.optim2, edf, ...)
{
  xgrad <- crossprod(x$X, -1 + (y[, "obs"] - eta$mu)^2 / exp(eta$sigma)^2)
  xhess0 <- -2 * crossprod(x$X*drop((y[, "obs"] - eta$mu) / exp(eta$sigma)^2),
                           x$X*drop(y[, "obs"] - eta$mu))
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    env <- new.env()
    
    objfun1 <- function(tau2) {
      par <- set.par(par, tau2, "tau2")
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- xhess0 - x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(-1 * xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- get.par(par, "b")
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          names(g2) <- names(g)
          x$state$parameters <- set.par(x$state$parameters, g2, "b")
          x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
          fit <- x$fit.fun(x$X, g2)
          eta$sigma <- eta$sigma - fitted(x$state) + fit
          lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2)
      eta$sigma <- eta$sigma - fitted(x$state) + fit
      edf1 <- sum_diag((-1 * xhess0) %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          par[x$pid$b] <- g2
          opt_state <- list("parameters" = par,
                            "fitted.values" = fit,
                            "edf" = edf1, "hessian" = -1 * xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"), maxit = 1)
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters[x$pid$tau2] <- tau2
  }
  
  xhess <- xhess0 - x$hess(score = NULL, x$state$parameters, full = FALSE)
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(-1 * xhess, index = NULL)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      fit <- x$fit.fun(x$X, g2)
      eta$sigma <- eta$sigma - fitted(x$state) + fit
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  
  g2 <- drop(g + x$state$nu * Hs)
  names(g2) <- names(g)
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update fitted values.
  x$state$fitted.values <- x$fit.fun(x$X, g2)
  x$state$edf <- sum_diag((-1 * xhess0) %*% Sigma)
  x$state$hessian <- -1 * xhess
  
  return(x$state)
}


update_jm_alpha <- function(x, eta, eta_timegrid,
                            eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                            status, update.nu, width, criterion, get_LogPost, nobs, do.optim2, edf, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta * eta_timegrid_mu, width, exp(eta$gamma),
                 eeta * (eta_timegrid_mu^2), index = x$sparse.setup$matrix)
  xgrad <- t(x$XT) %*% drop(eta_timegrid_mu[, ncol(eeta)] * status) - int$grad
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par <- set.par(par, tau2, "tau2")
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- int$hess + x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- get.par(par, "b")
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          names(g2) <- names(g)
          x$state$parameters <- set.par(x$state$parameters, g2, "b")
          x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
          fit_timegrid <- x$fit.fun_timegrid(g2)
          eta_timegrid_alpha <- eta_timegrid_alpha - x$state$fitted_timegrid + fit_timegrid
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2)
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta$alpha <- eta$alpha - fitted(x$state) + fit
      eta_timegrid_alpha <- eta_timegrid_alpha - x$state$fitted_timegrid + fit_timegrid
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      edf1 <- sum_diag(int$hess %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          opt_state <- list("parameters" = set.par(par, g2, "b"),
                            "fitted.values" = fit, "fitted_timegrid" = fit_timegrid,
                            "edf" = edf1, "hessian" = xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"))
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
  }
  
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- int$hess + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(xhess, x$sparse.setup)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta_timegrid_alpha <- eta_timegrid_alpha - x$state$fitted_timegrid + fit_timegrid
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  
  g2 <- drop(g + x$state$nu * Hs)        
  names(g2) <- names(g)
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update fitted values.
  x$state$fitted_timegrid <- x$fit.fun_timegrid(g2)
  x$state$fitted.values <- x$fit.fun(x$X, g2, expand = FALSE)
  x$state$edf <- sum_diag(int$hess %*% Sigma)

  x$state$hessian <- xhess
  
  return(x$state)
}


update_jm_dalpha <- function(x, eta, eta_timegrid,
                             eta_timegrid_lambda, eta_timegrid_alpha, eta_timegrid_mu, eta_timegrid_dalpha, eta_timegrid_dmu,
                             status, update.nu, width, criterion, get_LogPost, nobs, do.optim2, edf, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta * eta_timegrid_dmu, width, exp(eta$gamma),
                 eeta * (eta_timegrid_dmu^2), index = x$sparse.setup$matrix)
  xgrad <- t(x$XT) %*% drop(eta_timegrid_dmu[, ncol(eeta)] * status) - int$grad
  
  env <- new.env()
  
  if((!(!x$state$do.optim | x$fixed | x$fxsp)) & do.optim2) {
    par <- x$state$parameters
    
    edf0 <- if(is.null(edf)) {
      0
    } else {
      edf - x$state$edf
    }
    
    objfun1 <- function(tau2) {
      par <- set.par(par, tau2, "tau2")
      xgrad <- xgrad + x$grad(score = NULL, par, full = FALSE)
      xhess <- int$hess + x$hess(score = NULL, par, full = FALSE)
      Sigma <- matrix_inv(xhess, index = NULL)
      Hs <- Sigma %*% xgrad
      g <- get.par(par, "b")
      if(update.nu) {
        objfun.nu <- function(nu) {
          g2 <- drop(g + nu * Hs)
          names(g2) <- names(g)
          x$state$parameters <- set.par(x$state$parameters, g2, "b")
          x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
          fit_timegrid <- x$fit.fun_timegrid(g2)
          eta_timegrid_dalpha <- eta_timegrid_dalpha - x$state$fitted_timegrid + fit_timegrid
          eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
          return(-1 * lp)
        }
        nu <- optimize(f = objfun.nu, interval = c(0, 1))$minimum
      } else {
        nu <- x$state$nu
      }
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      fit <- x$fit.fun(x$X, g2)
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta$dalpha <- eta$dalpha - fitted(x$state) + fit
      eta_timegrid_dalpha <- eta_timegrid_dalpha - x$state$fitted_timegrid + fit_timegrid
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      edf1 <- sum_diag(int$hess %*% Sigma)
      edf <- edf0 + edf1
      logLik <- get_LogPost(eta_timegrid, eta, 0)
      ic <- get.ic2(logLik, edf, length(eta$mu), criterion)
      if(!is.null(env$ic_val)) {
        if((ic < env$ic_val) & (ic < env$ic00_val)) {
          opt_state <- list("parameters" = set.par(par, g2, "b"),
                            "fitted.values" = fit, "fitted_timegrid" = fit_timegrid,
                            "edf" = edf1, "hessian" = xhess,
                            "nu" = nu, "do.optim" = x$state$do.optim)
          assign("state", opt_state, envir = env)
          assign("ic_val", ic, envir = env)
        }
      } else assign("ic_val", ic, envir = env)
      return(ic)
    }
    
    assign("ic00_val", objfun1(get.state(x, "tau2")), envir = env)
    tau2 <- tau2.optim(objfun1, start = get.state(x, "tau2"))
    
    if(!is.null(env$state))
      return(env$state)
    
    x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
  }
  
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- int$hess + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(xhess, x$sparse.setup)
  Hs <- Sigma %*% xgrad
  
  ## Update regression coefficients.
  g <- get.state(x, "b")
  
  if(update.nu) {
    objfun2 <- function(nu) {
      g2 <- drop(g + nu * Hs)
      names(g2) <- names(g)
      x$state$parameters <- set.par(x$state$parameters, g2, "b")
      fit_timegrid <- x$fit.fun_timegrid(g2)
      eta_timegrid_dalpha <- eta_timegrid_dalpha - x$state$fitted_timegrid + fit_timegrid
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
      lp <- get_LogPost(eta_timegrid, eta, x$prior(x$state$parameters))
      return(-1 * lp)
    }
    x$state$nu <- optimize(f = objfun2, interval = c(0, 1))$minimum
  }
  nu <- x$state$nu
  
  g2 <- drop(g + nu * Hs)
  
  names(g2) <- names(g)
  x$state$parameters <- set.par(x$state$parameters, g2, "b")
  
  ## Update fitted values.
  x$state$fitted_timegrid <- x$fit.fun_timegrid(g2)
  x$state$fitted.values <- x$fit.fun(x$X, g2, expand = FALSE)
  x$state$edf <- sum_diag(int$hess %*% Sigma)
  x$state$hessian <- xhess
  
  return(x$state)
}


## (5) Joint model MCMC.
jm.mcmc <- function(x, y, family, start = NULL, weights = NULL, offset = NULL,
                    n.iter = 1200, burnin = 200, thin = 1, verbose = TRUE, digits = 4, step = 20, ...)
{
  ## Hard coded.
  fixed <- NULL
  slice <- NULL
  
  dalpha <- has_pterms(x$dalpha$terms) | (length(x$dalpha$smooth.construct) > 0)
  
  nu <- 1
  
  if(!is.null(start)) {
    if(is.matrix(start)) {
      if(any(i <- grepl("Mean", colnames(start))))
        start <- start[, i]
      else stop("the starting values should be a vector not a matrix!")
    }
    if(dalpha) {
      if(!any(grepl("dmu.", names(start), fixed = TRUE))) {
        start.dmu <- start[grep("mu.", names(start), fixed = TRUE)]
        names(start.dmu) <- gsub("mu.", "dmu.", names(start.dmu), fixed = TRUE)
        start <- c(start, start.dmu)
      }
    }
    x <- set.starting.values(x, start)
  }
  
  ## Names of parameters/predictors.
  nx <- names(x)
  
  ## Extract y.
  y <- y[[1]]
  
  ## Number of observations.
  nobs <- attr(y, "nobs")
  
  ## Number of subdivions used for the time grid.
  sub <- attr(y, "subdivisions")
  
  ## The interval width from subdivisons.
  width <- attr(y, "width")
  
  ## Subject specific indicator
  take <- attr(y, "take")
  nlong <- length(take)
  
  ## Extract the status for individual i.
  status <- y[take, "status"]
  
  ## Interaction setup
  interaction <- attr(y, "interaction")
  tp <- attr(y, "tp") 
  fac <- attr(y, "fac")
  
  ## Make id for individual i.
  id <- which(take)
  id <- append(id[-1], nlong + 1) - id 
  id <- rep(1:nobs, id)
  
  ## Compute additive predictors.
  eta <- get.eta(x, expand = FALSE) # potentially include alpha_mu
  
  ## Correct gamma predictor if NULL.
  if(!length(x$gamma$smooth.construct)) {
    eta$gamma <- rep(0, length = nobs)
  }
  
  ## For the time dependent part, compute
  ## predictors based on the time grid.
  eta_timegrid_mu <- 0
  if(length(x$mu$smooth.construct)) {
    for(j in names(x$mu$smooth.construct)) {
      b <- get.par(x$mu$smooth.construct[[j]]$state$parameters, "b")
      x$mu$smooth.construct[[j]]$state$fitted_timegrid <- x$mu$smooth.construct[[j]]$fit.fun_timegrid(b)
      eta_timegrid_mu <- eta_timegrid_mu + x$mu$smooth.construct[[j]]$state$fitted_timegrid
    }
  }
  
  
  eta_timegrid_alpha <- 0
  if(length(x$alpha$smooth.construct)) {
    for(j in names(x$alpha$smooth.construct)) {
      b <- get.par(x$alpha$smooth.construct[[j]]$state$parameters, "b")
      # update alpha according to mu
      if(interaction & !("model.matrix" %in% attr(x$alpha$smooth.construct[[j]], "class"))){
        Xmu <- as.vector(t(eta_timegrid_mu))
        if(!tp) {
          X <- x$alpha$smooth.construct[[j]]$update(Xmu, "mu")    
        } else {
          Xalpha <- x$alpha$smooth.construct[[j]]$margin[[1]]$update(Xmu, "mu")
          Xalpha2 <- x$alpha$smooth.construct[[j]]$margin[[2]]$fit.fun_timegrid(NULL)
          X <- rowTensorProduct(Xalpha, Xalpha2) 
        }
        fit_timegrid <- matrix(drop(X %*% b), nrow = nrow(eta_timegrid_mu), 
                               ncol = ncol(eta_timegrid_mu), byrow = TRUE)
        x$alpha$smooth.construct[[j]]$state$fitted.values <- fit_timegrid[, ncol(fit_timegrid)]
        x$alpha$smooth.construct[[j]]$state$fitted_timegrid <- fit_timegrid
        eta_timegrid_alpha <- eta_timegrid_alpha + fit_timegrid
      } else {
        x$alpha$smooth.construct[[j]]$state$fitted_timegrid <- x$alpha$smooth.construct[[j]]$fit.fun_timegrid(b)
        eta_timegrid_alpha <- eta_timegrid_alpha + x$alpha$smooth.construct[[j]]$fit.fun_timegrid(b)   
      }
      x$alpha$smooth.construct[[j]]$state$nu <- nu["alpha"]
    }
    eta$alpha <- eta_timegrid_alpha[, ncol(eta_timegrid_alpha)]
  }     
  
  
  eta_timegrid_lambda <- 0
  if(length(x$lambda$smooth.construct)) {
    for(j in names(x$lambda$smooth.construct)) {
      b <- get.par(x$lambda$smooth.construct[[j]]$state$parameters, "b")
      x$lambda$smooth.construct[[j]]$state$fitted_timegrid <- x$lambda$smooth.construct[[j]]$fit.fun_timegrid(b)
      eta_timegrid_lambda <- eta_timegrid_lambda + x$lambda$smooth.construct[[j]]$state$fitted_timegrid
    }
  }
  
  eta_timegrid_dmu <- eta_timegrid_dalpha <- 0
  if(dalpha) {
    if(length(x$dalpha$smooth.construct)) {
      for(j in names(x$dalpha$smooth.construct)) {
        b <- get.par(x$dalpha$smooth.construct[[j]]$state$parameters, "b")
        x$dalpha$smooth.construct[[j]]$state$fitted_timegrid <- x$dalpha$smooth.construct[[j]]$fit.fun_timegrid(b)
        eta_timegrid_dalpha <- eta_timegrid_dalpha + x$dalpha$smooth.construct[[j]]$fit.fun_timegrid(b)
      }
    }
    if(length(x$dmu$smooth.construct)) {
      for(j in names(x$dmu$smooth.construct)) {
        b <- get.par(x$mu$smooth.construct[[j]]$state$parameters, "b")
        x$dmu$smooth.construct[[j]]$state$fitted_timegrid <- x$dmu$smooth.construct[[j]]$fit.fun_timegrid(b)
        eta_timegrid_dmu <- eta_timegrid_dmu + x$dmu$smooth.construct[[j]]$fit.fun_timegrid(b)
      }
    }
  }
  
  # for interaction effect eta_mu is within eta_alpha
  if(interaction){
    eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
  } else{
  eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
  }
  knots <- if(interaction){
    if(tp){x$alpha$smooth.construct[[1]]$margin[[1]]$knots} 
    else {x$alpha$smooth.construct[[1]]$knots}} 
  else {NULL}
  myplots <- FALSE
  if(myplots){
    times <- matrix(unlist(attr(y, "grid")), ncol = ncol(eta_timegrid), byrow = TRUE)
    obstimes <- attr(y, "obstime")
    ids <- attr(y, "ids")
    par(mfrow = c(2,2))
  }
  
  ## Porcess iterations.
  if(burnin < 1) burnin <- 1
  if(burnin > n.iter) burnin <- floor(n.iter * 0.1)
  if(thin < 1) thin <- 1
  iterthin <- as.integer(seq(burnin, n.iter, by = thin))
  
  ## Samples.
  samps <- list()
  for(i in nx) {
    samps[[i]] <- list()
    for(j in names(x[[i]]$smooth.construct)) {
      samps[[i]][[j]] <- list(
        "samples" = matrix(NA, nrow = length(iterthin), ncol = length(x[[i]]$smooth.construct[[j]]$state$parameters)),
        "edf" = rep(NA, length = length(iterthin)),
        "alpha" = rep(NA, length = length(iterthin)),
        "accepted" = rep(NA, length = length(iterthin))
      )
      colnames(samps[[i]][[j]]$samples) <- names(x[[i]]$smooth.construct[[j]]$state$parameters)
    }
  }
  logLik.samps <- logPost.samps <- rep(NA, length = length(iterthin))
  
  foo <- function(x) {
    if(is.null(x)) return(0)
    if(is.na(x)) return(0)
    x <- exp(x)
    if(x < 0)
      x <- 0
    if(x > 1)
      x <- 1
    x
  }
  
  ## Integrals.
  eeta <- exp(eta_timegrid)
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  
  ## Slice sampling?
  if(is.null(slice)) {
    slice <- c("lambda" = FALSE, "gamma" = FALSE, "mu" = FALSE,
               "sigma" = FALSE, "alpha" = FALSE, "dmu" = FALSE, "dalpha" = FALSE)
  } else {
    if(is.null(names(slice))) {
      slice <- rep(slice, length.out = 7)
      names(slice) <- c("lambda", "gamma", "mu", "sigma", "alpha", "dmu", "dalpha")
    } else {
      npar <- c("lambda", "gamma", "mu", "sigma", "alpha", "dmu", "dalpha")
      if(!all(j <- npar %in% names(slice))) {
        nnot <- npar[!j]
        for(j in npar) {
          if(j %in% nnot) {
            nslice <- c(names(slice), j)
            slice <- c(slice, FALSE)
            names(slice) <- nslice
          }
        }
      }
    }
  }
  
  nx2 <- if(dalpha) {
    c("lambda", "gamma", "mu", "sigma", "alpha", "dalpha")
  } else {
    c("lambda", "gamma", "mu", "sigma", "alpha")
  }
  
  ## Start sampling.
  if(verbose) {
    cat2("Starting the sampler...")
    if(!interactive())
      cat("\n")
  }
  
  nstep <- step
  step <- floor(n.iter / step)
  
  ptm <- proc.time()
  for(iter in 1:n.iter) {
    if(save <- iter %in% iterthin)
      js <- which(iterthin == iter)
    
    for(i in nx2) {
      if(i == "gamma") {
        eeta <- exp(eta_timegrid)
        int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
      }
      
      prop_fun <- get_jm_prop_fun(i, slice[i], interaction)
      # 3jump
      for(sj in names(x[[i]]$smooth.construct)) {
        
        p.state <- prop_fun(x[[i]]$smooth.construct[[sj]],
                            y, eta, eta_timegrid, eta_timegrid_lambda, eta_timegrid_mu, eta_timegrid_alpha,
                            eta_timegrid_dmu, eta_timegrid_dalpha,
                            width, sub, nu, status, id = i, int0, nobs,
                            dx = if(dalpha & (i == "mu")) x[["dmu"]]$smooth.construct[[sj]] else NULL,
                            xsmalpha = if(interaction & (i == "mu")) x$alpha$smooth.construct else NULL, 
                            knots = if(interaction & (i == "mu")) knots else NULL, tp = tp, fac = fac, ...)
        ## If accepted, set current state to proposed state.
        accepted <- if(is.na(p.state$alpha)) FALSE else log(runif(1)) <= p.state$alpha
        
        #checks# if(i == "mu") cat("\n", sj, round(exp(p.state$alpha), 2))
        if(i %in% fixed)
          accepted <- FALSE
        
        if(accepted) {
          if(i %in% c("lambda", "mu", "alpha", "dalpha")) {
            if(i == "lambda")
              eta_timegrid_lambda <- eta_timegrid_lambda - x[[i]]$smooth.construct[[sj]]$state$fitted_timegrid + p.state$fitted_timegrid
            if(i == "mu") {
              # cat("\n iteration: ", iter, ", predictor: ", i, ", term: ", sj)
              eta_timegrid_mu <- eta_timegrid_mu - x[[i]]$smooth.construct[[sj]]$state$fitted_timegrid + p.state$fitted_timegrid
              if(interaction){
                for (j in names(x$alpha$smooth.construct)){
                  if(j != "model.matrix"){   # only smooth.constructs need to be updated
                    g_a <- get.par(x$alpha$smooth.construct[[j]]$state$parameters, "b")
                    Xalpha <- modSplineDesign(knots, as.vector(t(eta_timegrid_mu)), derivs = 0)
                    Xalpha <- constrain(x$alpha$smooth.construct[[j]], Xalpha)
                    if(fac)
                      Xalpha <- Xalpha * x$alpha$smooth.construct[[j]]$by_timegrid
                    if(tp)  
                      Xalpha <- rowTensorProduct(Xalpha, Xalpha2) 
                    alpha_state <- matrix(Xalpha %*% g_a, nrow = nrow(eta_timegrid), ncol = ncol(eta_timegrid), byrow = TRUE)
                    eta_timegrid_alpha <- eta_timegrid_alpha - x$alpha$smooth.construct[[j]]$state$fitted_timegrid + alpha_state
                    x$alpha$smooth.construct[[j]]$state$fitted_timegrid <- alpha_state
                    eta$alpha <- eta$alpha - fitted(x$alpha$smooth.construct[[j]]$state) + alpha_state[, ncol(eta_timegrid)]
                    x$alpha$smooth.construct[[j]]$state$fitted.values <- alpha_state[, ncol(eta_timegrid)]
                  } 
                }
              }
              if(dalpha & (sj %in% names(x[["dmu"]]$smooth.construct))) {
                p.state.dmu <- update_jm_dmu(x[["dmu"]]$smooth.construct[[sj]], x[["mu"]]$smooth.construct[[sj]])
                eta_timegrid_dmu <- eta_timegrid_dmu - x[["dmu"]]$smooth.construct[[sj]]$state$fitted_timegrid + p.state.dmu$fitted_timegrid
                eta[["dmu"]] <- eta[["dmu"]] - fitted(x[["dmu"]]$smooth.construct[[sj]]$state) + fitted(p.state.dmu)
                x[["dmu"]]$smooth.construct[[sj]]$state <- p.state.dmu
              }
            }
            if(i == "alpha"){
              eta_timegrid_alpha <- eta_timegrid_alpha - x[[i]]$smooth.construct[[sj]]$state$fitted_timegrid + p.state$fitted_timegrid
            }
            if(i == "dalpha")
              eta_timegrid_dalpha <- eta_timegrid_dalpha - x[[i]]$smooth.construct[[sj]]$state$fitted_timegrid + p.state$fitted_timegrid
            if(interaction){
              eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
            } else {
            eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
          }
          }
          eta[[i]] <- eta[[i]] - fitted(x[[i]]$smooth.construct[[sj]]$state) + fitted(p.state)
          if(i == "alpha"){
            if(myplots & iter == round(iter, -1)){
              par(mfrow=c(2,2))
              plot(eta_timegrid_mu, eta_timegrid_alpha, main=paste0("alpha: iteration ", iter, ", edf ", round(p.state$edf, 2)))
              plot(eta_timegrid_mu[, ncol(eta_timegrid_mu)], eta$alpha, main=paste0("alpha: iteration ", iter, ", edf ", round(p.state$edf, 2)))
              plot(eta_timegrid_mu, p.state$fitted_timegrid, 
                   main=paste0("alpha: iteration ", iter, ", edf ", round(p.state$edf, 2), ", term: ", sj))
              abline(h = 0, col = "red")
              abline(v = median(y[, 3]), col = "red")
              matplot(t(times), t(eta_timegrid_alpha), main=paste0("alpha: iteration ", iter, ", edf ", round(p.state$edf, 2)), type = "l")
             Sys.sleep(2)
              }
          }
            if(i == "mu"){
              if(myplots & iter == round(iter, -1)){
                par(mfrow=c(2,2))
                matplot(t(times),t(p.state$fitted_timegrid), type = "l", main = paste0("mu: iteration ", iter, ", edf ", round(p.state$edf, 2), " ", sj))
                matplot(t(times), t(eta_timegrid_mu), type = "l", main = paste0("mu: iteration ", iter, ", edf ", round(p.state$edf, 2), " ", sj))
                plot(eta_timegrid_mu, eta_timegrid_alpha, main=paste0("mu: iteration ", iter, ", edf ", round(p.state$edf, 2), "alpha"))
                plot(eta_timegrid_mu[, ncol(eta_timegrid_mu)], eta$alpha, main=paste0("mu: iteration ", iter, ", edf ", round(p.state$edf, 2), "alpha"))
                Sys.sleep(1)
              }
            }
          x[[i]]$smooth.construct[[sj]]$state <- p.state 
        }
        
        ## Save the samples and acceptance.
        if(save) {
          samps[[i]][[sj]]$samples[js, ] <- x[[i]]$smooth.construct[[sj]]$state$parameters
          samps[[i]][[sj]]$edf[js] <- x[[i]]$smooth.construct[[sj]]$state$edf
          samps[[i]][[sj]]$alpha[js] <- foo(p.state$alpha)
          samps[[i]][[sj]]$accepted[js] <- accepted
          if(dalpha & (i == "mu") & (sj %in% names(x[["dmu"]]$smooth.construct))) {
            samps[["dmu"]][[sj]]$samples[js, ] <- x[["dmu"]]$smooth.construct[[sj]]$state$parameters
            samps[["dmu"]][[sj]]$edf[js] <- x[["dmu"]]$smooth.construct[[sj]]$state$edf
            samps[["dmu"]][[sj]]$alpha[js] <- foo(p.state$alpha)
            samps[["dmu"]][[sj]]$accepted[js] <- accepted
          }
        }
      }
    }
    
    if(save) {
      logLik.samps[js] <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
        exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
      logPost.samps[js] <- as.numeric(logLik.samps[js] + get.log.prior(x))
    }
    
    if(verbose) barfun(ptm, n.iter, iter, step, nstep)
  }
  
  if(verbose) cat("\n")
  
  for(i in names(samps)) {
    for(j in names(samps[[i]])) {
      samps[[i]][[j]] <- do.call("cbind", samps[[i]][[j]])
      cn <- if(j == "model.matrix") {
        paste(i, "p", j, colnames(samps[[i]][[j]]), sep = ".")
      } else {
        paste(i, "s", j, colnames(samps[[i]][[j]]), sep = ".")
      }
      colnames(samps[[i]][[j]]) <- cn
    }
    samps[[i]] <- do.call("cbind", samps[[i]])
  }
  samps$logLik <- logLik.samps
  samps$logPost <- logPost.samps
  samps <- do.call("cbind", samps)
  
  ## Compute DIC. #
  dev <- -2 * logLik.samps
  mpar <- apply(samps, 2, mean, na.rm = TRUE)
  names(mpar) <- colnames(samps)
  ll <- family$p2logLik(mpar)
  mdev <- -2 * ll
  pd <- mean(dev) - mdev
  DIC <- mdev + 2 * pd
  samps <- cbind(samps,
                 "DIC" = rep(DIC, length.out = nrow(samps)),
                 "pd" = rep(pd, length.out = nrow(samps))
  )
  
  return(as.mcmc(samps))
}


## Get proposal function.
get_jm_prop_fun <- function(i, slice = FALSE, interaction = FALSE) {
  function(...) {
    prop_fun <- if(slice) {
      propose_jm_slice
    } else {
      if(interaction & (i == "alpha")){
        get("propose_jm_alpha_inter")
      } else {
      get(paste("propose_jm", i, sep = "_"))
    }
    }
    state <- prop_fun(...)
    if(inherits(state, "try-error")) {
      warning(paste("problems sampling parameter ", i, "!", sep = ""))
      return(list("alpha" = -Inf))
    } else {
      return(state)
    }
  }
}


## JM proposal functions.
uni.slice_beta_logPost <- function(g, x, family, y = NULL, eta = NULL, id,
                                   eta_timegrid, eta_timegrid_lambda, eta_timegrid_mu, eta_timegrid_alpha,
                                   eta_timegrid_dmu, eta_timegrid_dalpha,
                                   width, sub, status, dx = NULL, interaction = FALSE, ...)
{
  if(id %in% c("lambda", "mu", "alpha", "dalpha")) {
    if(id == "lambda")
      eta_timegrid_lambda <- eta_timegrid_lambda - x$state$fitted_timegrid + x$fit.fun_timegrid(g)
    if(id == "mu") {
      eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + x$fit.fun_timegrid(g)
      if(!is.null(dx)) {
        eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g)
      }
    }
    if(id == "alpha")
      eta_timegrid_alpha <- eta_timegrid_alpha - x$state$fitted_timegrid + x$fit.fun_timegrid(g)
    if(id == "dalpha")
      eta_timegrid_dalpha <- eta_timegrid_dalpha - x$state$fitted_timegrid + x$fit.fun_timegrid(g)
    if(interaction){
      eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
    } else {
    eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
  }
  }
  
  eta[[id]] <- eta[[id]] - fitted(x$state) + x$fit.fun(x$X, g, expand = FALSE)
  
  eeta <- exp(eta_timegrid)
  int <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  
  ll <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  lp <- x$prior(g)
  
  return(ll + lp)
}


uni.slice_tau2_logPost <- function(g, x, family, y = NULL, eta = NULL, id, ll = 0)
{
  lp <- x$prior(g)
  return(ll + lp)
}


propose_jm_slice <- function(x, y,
                             eta, eta_timegrid, eta_timegrid_lambda, eta_timegrid_mu, eta_timegrid_alpha,
                             eta_timegrid_dmu, eta_timegrid_dalpha,
                             width, sub, nu, status, id, dx = NULL, interaction = FALSE, ...)
{
  for(j in seq_along(get.par(x$state$parameters, "b"))) {
    x$state$parameters <- uni.slice(x$state$parameters, x, family, y,
                                    eta, id = id, j, logPost = uni.slice_beta_logPost, eta_timegrid = eta_timegrid,
                                    eta_timegrid_lambda = eta_timegrid_lambda, eta_timegrid_mu = eta_timegrid_mu,
                                    eta_timegrid_alpha = eta_timegrid_alpha, eta_timegrid_dmu = eta_timegrid_dmu,
                                    eta_timegrid_dalpha = eta_timegrid_dalpha, width = width, sub = sub, status = status,
                                    dx = dx, interaction = interaction)
  }
  
  g <- get.par(x$state$parameters, "b")
  fit <- drop(x$X %*% g)
  eta[[id]] <- eta[[id]] - fitted(x$state) + fit
  x$state$fitted.values <- fit
  
  if(id %in% c("lambda", "mu", "alpha", "dalpha"))
    x$state$fitted_timegrid <- x$fit.fun_timegrid(g)
  
  ## Sample variance parameter.
  if(!x$fixed & is.null(x$sp) & length(x$S)) {
    if((length(x$S) < 2) & (attr(x$prior, "var_prior") == "ig")) {
      g <- get.par(x$state$parameters, "b")
      a <- x$rank / 2 + x$a
      b <- 0.5 * crossprod(g, x$S[[1]]) %*% g + x$b
      tau2 <- 1 / rgamma(1, a, b)
      x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
    } else {
      i <- grep("tau2", names(x$state$parameters))
      for(j in i) {
        x$state$parameters <- uni.slice(x$state$parameters, x, NULL, NULL,
                                        NULL, id = "mu", j, logPost = uni.slice_tau2_logPost, lower = 0, ll = 0)
      }
    }
  }
  
  ## Compute acceptance probablity.
  x$state$alpha <- 0
  
  return(x$state)
}


propose_jm_lambda <- function(x, y,
                              eta, eta_timegrid, eta_timegrid_lambda, eta_timegrid_mu, eta_timegrid_alpha,
                              eta_timegrid_dmu, eta_timegrid_dalpha,
                              width, sub, nu, status, id, ...)
{
  interaction <- attr(y, "interaction")
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  
  ## Timegrid predictor.
  eeta <- exp(eta_timegrid)
  
  ## Old logLik and prior.
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  pibeta <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  p1 <- x$prior(x$state$parameters)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta, width, exp(eta$gamma), index = x$sparse.setup$matrix)
  xgrad <- drop(t(status) %*% x$XT - int$grad)
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- int$hess + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(xhess, index = NULL)
  
  ## Save old coefficients.
  g0 <- get.state(x, "b")
  
  ## Get new position.
  mu <- drop(g0 + nu * Sigma %*% xgrad)
  
  ## Sample new parameters.
  g <- drop(rmvnorm(n = 1, mean = mu, sigma = Sigma, method="chol"))
  names(g) <- names(g0)
  x$state$parameters <- set.par(x$state$parameters, g, "b")
  
  ## Compute log priors.
  p2 <- x$prior(x$state$parameters)
  qbetaprop <- dmvnorm(g, mean = mu, sigma = Sigma, log = TRUE)
  
  ## Update additive predictors.
  fit_timegrid <- x$fit.fun_timegrid(g)
  eta_timegrid_lambda <- eta_timegrid_lambda - x$state$fitted_timegrid + fit_timegrid
  if(interaction){
    eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha + eta_timegrid_dalpha * eta_timegrid_dmu
  } else {
  eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
  }
  x$state$fitted_timegrid <- fit_timegrid
  
  fit <- drop(x$X %*% g)
  eta$lambda <- eta$lambda - fitted(x$state) + fit
  x$state$fitted.values <- fit
  
  ## New logLik.
  eeta <- exp(eta_timegrid)
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  pibetaprop <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  
  ## Prior prob.
  int <- survint(X, eeta, width, exp(eta$gamma), index = x$sparse.setup$matrix)
  xgrad <- drop(t(status) %*% x$XT - int$grad)
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  xhess <- int$hess + x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  Sigma2 <- matrix_inv(xhess, index = NULL)
  mu2 <- drop(g + nu * Sigma2 %*% xgrad)
  qbeta <- dmvnorm(g0, mean = mu2, sigma = Sigma2, log = TRUE)
  
  ## Save edf.
  x$state$edf <- sum_diag(int$hess %*% Sigma2)
  
  ## Sample variance parameter.
  if(!x$fixed & is.null(x$sp) & length(x$S)) {
    if((length(x$S) < 2) & (attr(x$prior, "var_prior") == "ig")) {
      g <- get.par(x$state$parameters, "b")
      a <- x$rank / 2 + x$a
      b <- 0.5 * crossprod(g, x$S[[1]]) %*% g + x$b
      tau2 <- 1 / rgamma(1, a, b)
      x$state$parameters <- set.par(x$state$parameters, tau2, "tau2")
    } else {
      i <- grep("tau2", names(x$state$parameters))
      for(j in i) {
        x$state$parameters <- uni.slice(x$state$parameters, x, NULL, NULL,
                                        NULL, id = "mu", j, logPost = uni.slice_tau2_logPost, lower = 0, ll = 0)
      }
    }
  }
  
  ## Compute acceptance probablity.
  x$state$alpha <- drop((pibetaprop + qbeta + p2) - (pibeta + qbetaprop + p1))
  # cat(paste("\n",pibeta, pibetaprop, qbeta, qbetaprop, p1, p2, x$state$alpha, sep=";"))
  return(x$state)
}


propose_jm_mu <- function(x, ...)
{
  return(x$propose(x, ...))
}

propose_jm_mu_simple <- function(x, y,
                                 eta, eta_timegrid, eta_timegrid_lambda, eta_timegrid_mu, eta_timegrid_alpha,
                                 eta_timegrid_dmu, eta_timegrid_dalpha,
                                 width, sub, nu, status, id, dx = NULL, ...)
{
  ## The time-dependent design matrix for the grid.
  X <- x$fit.fun_timegrid(NULL)
  dX <- if(!is.null(dx)) dx$fit.fun_timegrid(NULL) else NULL
  
  ## Timegrid lambda.
  eeta <- exp(eta_timegrid)
  
  ## Old logLik and prior.
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  pibeta <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  p1 <- x$prior(x$state$parameters)
  
  ## Compute gradient and hessian integrals.
  int <- survint(X, eeta * eta_timegrid_alpha, width, exp(eta$gamma),
                 eeta * eta_timegrid_alpha^2, index = x$sparse.setup[["mu.matrix"]],
                 dX, if(!is.null(dX)) eeta * eta_timegrid_dalpha else NULL,
                 if(!is.null(dX)) eeta * eta_timegrid_dalpha^2 else NULL)
  
  xgrad <- drop(t(x$X) %*% drop((y[, "obs"]  - eta$mu) / exp(eta$sigma)^2) +
                  t(x$XT) %*% drop(eta$alpha * status) - int$grad)
  if(!is.null(dx))
    xgrad <- drop(xgrad + t(dx$XT) %*% drop(eta$dalpha * status))
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  XWX <- if(is.null(x$sparse.setup$matrix)) {
    crossprod(x$X * (1 / exp(eta$sigma)^2), x$X)
  } else do.XWX(x$X, exp(eta$sigma)^2, index = x$sparse.setup$matrix)
  xhess <- -1 * XWX - int$hess
  xhess <- xhess - x$hess(score = NULL, x$state$parameters, full = FALSE)
  
  ## Compute the inverse of the hessian.
  Sigma <- matrix_inv(-1 * xhess, index = NULL)
  
  ## Save old coefficients.
  g0 <- get.state(x, "b")
  
  ## Get new position.
  mu <- drop(g0 + nu * Sigma %*% xgrad)
  
  ## Sample new parameters.
  g <- drop(rmvnorm(n = 1, mean = mu, sigma = Sigma, method="chol"))
  names(g) <- names(g0)
  x$state$parameters <- set.par(x$state$parameters, g, "b")
  
  ## Compute log priors.
  p2 <- x$prior(x$state$parameters)
  qbetaprop <- dmvnorm(g, mean = mu, sigma = Sigma, log = TRUE)
  
  ## Update additive predictors.
  fit_timegrid <- x$fit.fun_timegrid(g)
  eta_timegrid_mu <- eta_timegrid_mu - x$state$fitted_timegrid + fit_timegrid
  if(!is.null(dx))
    eta_timegrid_dmu <- eta_timegrid_dmu - dx$state$fitted_timegrid + dx$fit.fun_timegrid(g)
  eta_timegrid <- eta_timegrid_lambda + eta_timegrid_alpha * eta_timegrid_mu + eta_timegrid_dalpha * eta_timegrid_dmu
  x$state$fitted_timegrid <- fit_timegrid
  
  fit <- drop(x$X %*% g)
  eta$mu <- eta$mu - fitted(x$state) + fit
  x$state$fitted.values <- fit
  
  ## New logLik.
  eeta <- exp(eta_timegrid)
  int0 <- width * (0.5 * (eeta[, 1] + eeta[, sub]) + apply(eeta[, 2:(sub - 1)], 1, sum))
  pibetaprop <- sum((eta_timegrid[,ncol(eta_timegrid)] + eta$gamma) * status, na.rm = TRUE) -
    exp(eta$gamma) %*% int0 + sum(dnorm(y[, "obs"], mean = eta$mu, sd = exp(eta$sigma), log = TRUE))
  
  ## Prior prob.
  int <- survint(X, eeta * eta_timegrid_alpha, width, exp(eta$gamma),
                 eeta * eta_timegrid_alpha^2, index = x$sparse.setup[["mu.matrix"]],
                 dX, if(!is.null(dX)) eeta * eta_timegrid_dalpha else NULL,
                 if(!is.null(dX)) eeta * eta_timegrid_dalpha^2 else NULL)
  
  xgrad <- drop(t(x$X) %*% drop((y[, "obs"]  - eta$mu) / exp(eta$sigma)^2) +
                  t(x$XT) %*% drop(eta$alpha * status) - int$grad)
  if(!is.null(dx))
    xgrad <- drop(xgrad + t(dx$XT) %*% drop(eta$dalpha * status))
  xgrad <- xgrad + x$grad(score = NULL, x$state$parameters, full = FALSE)
  XWX <- if(is.null(x$sparse.setup$matrix)) {
    crossprod(x$X * (1 / exp(eta$sigma)^2), x$X)
  } else do.XWX(x$X, exp(eta$sigma)^2, index = x$sparse.setup$matrix)
  xhess <- -1 * XWX - int$hess
  xhess <- xhess - x$hess(score = NULL, x$state$parameters, full = FALSE)
  Sigma2 <- matrix_inv(-1 * xhess, index = NULL)
  mu2 <- drop(g + nu * Sigma2 %*% xgrad)
  qbeta <- dmvnorm(g0, mean = mu2, sigma = Sigma2, log = TRUE)
  
  ## Save edf.
  x$state$edf <- sum_diag((-1 * int$hess) %*% Sigma2)
  
  ## Sample variance parameter.
  if(!x$fixed & is.null(x$sp) & length(x$S)) {
    if((length(x$S) < 2) & (attr(x$prior, "var_prior") == "ig")) {
      g <- get.par(x$state$parameters, "b")
      a <- x$rank / 2 + x$a
      b <- 0.5 * crossprod(g, x$S[[1]]) %*% g + x$b
      tau2 <- 1 / rgamma(1, a, b)
      x$state$parameters <- set.par(x