R/Utils.R

Defines functions alpha.sample nu.sample sigma2.sample beta.sample MCMC.param.check BASSLINE_convert Trace_plot

Documented in BASSLINE_convert Trace_plot

#' @title Produce a trace plot of a variable's MCMC chain
#' @description Plots the chain across (non-discarded) iterations for a
#' specified observation
#' @param variable Indicates the index of the variable
#' @param chain MCMC chains generated by a BASSLINE MCMC function
#'
#' @return A ggplot2 object
#' @examples
#' library(BASSLINE)
#'
#' # Please note: N=1000 is not enough to reach convergence.
#' # This is only an illustration. Run longer chains for more accurate
#' # estimations.
#'
#' LN <- MCMC_LN(
#'   N = 1000, thin = 20, burn = 40, Time = cancer[, 1],
#'   Cens = cancer[, 2], X = cancer[, 3:11]
#' )
#' Trace_plot(1, LN)
#'
#' @export
Trace_plot <- function(variable = NULL, chain = NULL) {
  if (is.null(variable) | is.null(chain)) {
    stop("variable and chain must be provided\n")
  }

  Iteration <- Value <- NULL
  title <- paste("Trace Plot for Variable", colnames(chain)[variable])
  df <- data.frame(Iteration = seq(nrow(chain)), Value = chain[, variable])
  p <- ggplot2::ggplot(ggplot2::aes(x = Iteration, y = Value), data = df)
  p <- p + ggplot2::geom_line() + ggplot2::theme_bw() + ggplot2::ggtitle(title)
  return(p)
}

#' @title Convert dataframe with mixed variables to a numeric matrix
#' @description BASSLINE's functions require a numeric matrix be provided.
#'   This function converts a dataframe of mixed variable types (numeric and
#'    factors) to a matrix. A factor with $m$ levels is converted to $m$ columns
#'    with binary values used to denote which level the observation belongs to.
#' @param df A dataframe intended for conversion
#' @return A numeric matrix suitable for BASSLINE functions
#' @examples
#' library(BASSLINE)
#' Time <- c(5, 15, 15)
#' Cens <- c(1, 0, 1)
#' experiment <- as.factor(c("chem1", "chem2", "chem3"))
#' age <- c(15, 35, 20)
#' df <- data.frame(Time, Cens, experiment, age)
#' converted <- BASSLINE_convert(df)
#'
#' @export
BASSLINE_convert <- function(df) {
  n.obs <- nrow(df)
  n.vars <- ncol(df)
  # Init object to be returned
  BASSLINE.mat <- c()

  for (columns in (1:n.vars)) {
    original.var <- df[, columns]
    if (is.factor(original.var) == F) {
      # If not a factor then add column to BASSLINE.mat
      BASSLINE.mat <- cbind(BASSLINE.mat, original.var)
      colnames(BASSLINE.mat)[ncol(BASSLINE.mat)] <- colnames(df[columns])
    } else {
      # if factor then add column for each level
      for (levels in unique(df[, columns])) {
        level.binary <- rep(0, n.obs)
        for (i in 1:n.obs) {
          if (df[i, columns] == levels) level.binary[i] <- 1
        }

        BASSLINE.mat <- cbind(BASSLINE.mat, level.binary)
        level.name <- paste(colnames(df[columns]), ".", levels,
          sep = ""
        )
        colnames(BASSLINE.mat)[ncol(BASSLINE.mat)] <- level.name
      }
    }
  }
  row.names(BASSLINE.mat) <- 1:n.obs
  return(as.matrix(BASSLINE.mat))
}


#### Input check for MCMC functions
MCMC.param.check <- function(N,
                             thin,
                             burn,
                             Time,
                             Cens,
                             X,
                             beta0,
                             sigma20,
                             prior,
                             set,
                             eps_l,
                             eps_r) {
  num.obs <- nrow(X)
  num.covariates <- ncol(X)

  if (is.matrix(X) == F) {
    stop("X is not a matrix.\n")
  }

  # Check N is a  0 < integer
  if (N <= 0 | N %% 1 != 0) {
    stop("N should be an integer greater than zero.\n")
  }

  if (thin < 2 | thin %% 1 != 0) {
    stop("thin should be a integer > 2.\n")
  }

  if (burn < 0 | burn %% 1 != 0) {
    stop("burn should be a non-negative integer.\n")
  }

  if (N < burn) {
    stop("N must be greater than burn.\n")
  }

  if (burn %% thin != 0) {
    stop("burn must a multiple of thin.\n")
  }

  if (N %% thin != 0) {
    stop("N must a multiple of thin.\n")
  }

  if (all(Time > 0) == F) {
    stop("All values in Time should be non-negative.\n")
  }

  if (length(Time) != num.obs) {
    stop("Time is not the correct length.\n")
  }

  if (all(Cens == 1 | Cens == 0) == F) {
    stop("Cens should be either 0 or 1 for each observation\n")
  }
  if (length(Cens) != num.obs) {
    stop("Cens is not the correct length.\n")
  }

  if (length(beta0) != num.covariates) {
    stop("beta0 is not the correct length.\n")
  }

  if (prior %in% c(1, 2, 3) == F) {
    stop("prior should be 1, 2 or 3. See documentation\n")
  }

  if (set %in% c(1, 2) == F) {
    stop("set should be 1 or 2. See documentation\n")
  }

  if (burn == 0) {
    cat(paste0("Note! No burn-in period is being used!\n"))
  }
}

#### Inital value samplers
beta.sample <- function(n) {
  cat("Sampling initial betas from a Normal(0, 1) distribution\n")
  betas <- stats::rnorm(n, 0, 1)
  cat(paste("Initial beta", seq_len(length(betas)), ":", round(betas, 2), "\n"))
  cat("\n")
  return(betas)
}

sigma2.sample <- function() {
  cat("Sampling initial sigma^2 from a Gamma(2, 2) distribution\n")
  sigma2 <- stats::rgamma(1, 2, 2)
  cat(paste("Initial sigma^2 :", round(sigma2, 2), "\n\n"))
  return(sigma2)
}

nu.sample <- function() {
  cat("Sampling initial nu from a Gamma(2, 2) distribution\n")
  nu <- stats::rgamma(1, 2, 2)
  cat(paste("Initial nu :", round(nu, 2), "\n\n"))
  return(nu)
}

alpha.sample <- function() {
  cat("Sampling initial alpha from a Uniform(1, 2) distribution\n")
  alpha <- stats::runif(1, 1, 2)
  cat(paste("Initial alpha :", round(alpha, 2), "\n\n"))
  return(alpha)
}
nathansam/SMLN documentation built on May 14, 2022, 9:07 p.m.