R/functions.R

### y is an nx3 matrix, where n is the number of trios
### 1st column = M, 2nd column = F, 3rd column = Offspring
### S is the number of gibbs sampler scans
### mu is the mean of the normal prior put on each class mean
### tau is the standard deviation (not variance) of the same distribution
### Put a Gamma(nu/2, nu/2 * lambda^2) on the class precisions 
### Assume prior of dirichlet(a,b,c) for the joint prior on the 
###   probabilities of class membership
### This model assumes no independence of offspring copy number state
###   to parent copy number state
#' A Gibbs Sampler for genotyping copy number in trios
#' 
#' This function allows you to obtain a posterior probability distribution for copy number genotyping on Illumina arrays for an fixed genetic model where tau1=1, tau2=0.5 and tau3=0
#' @param y Data$response where data is as generated by the simulate_data function.
#' @param S The number of iterations for the Gibbs Sampler (suggest default 10000).
#' @param mu Defaults to -.02.
#' @param xi Defaults to sqrt(200).
#' @param nu Defaults to 3.
#' @param lambda Defaults to 0.05.
#' @param alpha Alpha1 and alpha2 and psi refer to HWE parameters as specified in Cardin et al (alpha=w, psi=lambda).
#' @keywords marimba
#'

initialize_chains <- function(states, N, K, S) {
  z <- matrix(0, N, K)
  theta <- matrix(0, S, K)
  sigma <- matrix(0, S, K)
  tau <- matrix(0, S, 3)
  logll <- rep(0, S)
  p <- matrix(0, nrow = S, ncol = K)
  pi.child <- matrix(0, nrow = S, ncol = K)
  colnames(p) <- states
  colnames(pi.child) <- states
  colnames(tau) <- c(1,0.5,0)
  colnames(sigma) <- colnames(theta) <- states
  list(z=z,
       theta=theta,
       sigma=sigma,
       logll=logll,
       tau=tau,
       p=p,
       pi.child=pi.child)
}

# create rds matrix file

mprob.matrix <-  function(tau=c(0.5, 0.5, 0.5), gp){
  states <- gp$states
  tau1 <- tau[1]
  tau2 <- tau[2]
  tau3 <- tau[3]
  mendelian.probs <- array(dim=c(25,6))
  colnames(mendelian.probs) <- c("parents", "p(0|f,m)", "p(1|f,m)", "p(2|f,m)", "p(3|f,m)", "p(4|f,m)")
  
  mendelian.probs[1, 2:6] <- c(1,0,0,0,0)
  mendelian.probs[2, 2:6] <- c(tau1, 1 - tau1, 0,0,0)
  mendelian.probs[3, 2:6] <- c(tau2 / 2, 0.5, (1 - tau2) / 2, 0,0)
  mendelian.probs[4, 2:6] <- c(0, tau3, 1 - tau3, 0,0)
  mendelian.probs[5, 2:6] <- c(0,0,1,0,0)
  mendelian.probs[6, 2:6] <- c(tau1, 1 - tau1, 0,0,0)
  mendelian.probs[7, 2:6] <- c((tau1^2), 2 * (tau1 * (1 - tau1)), ((1 - tau1)^2), 0,0)
  mendelian.probs[8, 2:6] <- c((tau1 * tau2) / 2, (tau2 * (1 - tau1) + tau1) / 2, (tau1 * (1-tau2) + (1 - tau1)) / 2, ((1 - tau1) * (1 - tau2)) / 2, 0)
  mendelian.probs[9, 2:6] <- c(0, tau1 * tau3, tau1 * (1 - tau3) + (1 - tau1) * tau3, (1- tau1) * (1 - tau3), 0)
  mendelian.probs[10, 2:6] <- c(0, 0, tau1, (1 - tau1), 0)
  mendelian.probs[11, 2:6] <- c(tau2 / 2, 0.5, (1 - tau2) / 2, 0, 0)
  mendelian.probs[12, 2:6] <- c((tau1 * tau2) / 2, (tau1 + tau2 * (1 - tau1)) / 2, ((1 - tau1) + (tau1 * (1-tau2)) ) / 2, (1 - tau1) * (1 - tau2) / 2, 0)
  mendelian.probs[13, 2:6] <- c(tau2^2 / 4, tau2 / 2, (0.5 + tau2 * (1 - tau2)) / 2, (1 - tau2) / 2, (1 - tau2)^2 / 4)
  mendelian.probs[14, 2:6] <- c(0, tau2 * tau3 / 2, (tau3 + tau2 * (1 - tau3)) / 2, (((1 - tau3) + (1 - tau2) * tau3) / 2), (1 - tau2) * (1 - tau1) /2)
  mendelian.probs[15, 2:6] <- c(0, 0, tau2 / 2, 0.5, (1 - tau2) / 2)
  mendelian.probs[16, 2:6] <- c(0, tau3, (1-tau3), 0, 0)
  mendelian.probs[17, 2:6] <- c(0, tau1 * tau3, tau1 * (1 - tau3) + (1 - tau1) * tau3, (1 - tau1) * (1 - tau3), 0)
  mendelian.probs[18, 2:6] <- c(0, tau2 * tau3 / 2, (tau3 + tau2 * (1 - tau3)) / 2, ((1 - tau3) + (1 - tau2) * tau3) / 2, (1 - tau2) * (1 - tau3) / 2)
  mendelian.probs[19, 2:6] <- c(0,0, tau3^2, 2 * tau3 * (1 - tau3), (1 - tau3)^2)
  mendelian.probs[20, 2:6] <- c(0,0,0, tau3, 1-tau3)
  mendelian.probs[21, 2:6] <- c(0,0,1,0,0)
  mendelian.probs[22, 2:6] <- c(0,0, tau1, 1 - tau1, 0)
  mendelian.probs[23, 2:6] <- c(0,0, tau2 / 2, 0.5, (1 - tau2) / 2)
  mendelian.probs[24, 2:6] <- c(0,0,0, tau3, 1 - tau3)
  mendelian.probs[25, 2:6] <- c(0,0,0,0,1)
  
  if(all((rowSums(mendelian.probs, na.rm=T))==1)==F) stop("mendelian matrix is incorrect")
  
  mendelian.probs[, 1] <- c(00, 01, 02, 03, 04, 
                            10, 11, 12, 13, 14,
                            20, 21, 22, 23, 24,
                            30, 31, 32, 33, 34,
                            40, 41, 42, 43, 44)
  
  mprob.mat <- as.tibble(mendelian.probs)

  
  mprob.mat[, 1] <- c("00", "01", "02", "03", "04", 
                 "10", "11", "12", "13", "14",
                 "20", "21", "22", "23", "24",
                 "30", "31", "32", "33", "34",
                 "40", "41", "42", "43", "44")
  
  mprob.mat <- mprob.subset(mprob.mat, gp)
  mprob.mat 
  
  extdata <- system.file("extdata", package="marimba2")
  saveRDS(mprob.mat, file.path(extdata, "mendelian_probs2.rds"))
}

mprob.subset <- function(mprob.mat, gp) {
  K <- gp$K
  states <- gp$states
  col.a <- states[1] + 2
  col.b <- states[K] + 2
  
  ref.geno <- c("00", "01", "02", "03", "04", 
                "10", "11", "12", "13", "14",
                "20", "21", "22", "23", "24",
                "30", "31", "32", "33", "34",
                "40", "41", "42", "43", "44")
  index <- mprob.label(gp)
  rows <- match(index, ref.geno)
    
  mprob.subset <- mprob.mat[rows, c(col.a:col.b)]
  mprob.rows <- mprob.mat[rows, 1]
  mprob.subset <- cbind(mprob.rows, mprob.subset)
  mprob.subset
}

mprob.label <- function(gp){
  n <- gp$K
  v <- gp$states
  combo <- permutations(n=n, r=2, v=v, repeats.allowed=T)
  geno.combo <- paste0(combo[,1], combo[,2])
  geno.combo
}

# deprecate

gMendelian <- function(tau=c(1, 0.5, 0), err=1e-4){
  tau.one <- tau[1]
  tau.two <- tau[2]
  tau.three <- tau[3]
  tau1 <- tau.one-err/2
  tau2 <- tau.two-err/2
  tau3 <- tau.three
  mendelian.probs <- array(dim=c(3, 3, 3))
  genotypes <- c("BB", "AB", "AA")
  dimnames(mendelian.probs) <- list(paste0("O_", genotypes),
                                    paste0("F_", genotypes),
                                    paste0("M_", genotypes))
  mendelian.probs[, 1, 1] <- c((1 - tau3)^2, 2 * tau3 * (1 - tau3), tau3^2)
  mendelian.probs[, 2, 1] <- c((1 - tau2) * (1 - tau3), tau2 * (1 - tau3) + (1 - tau2) * tau3, tau2 * tau3)
  mendelian.probs[, 3, 1] <- c((1 - tau1) * (1 - tau3), tau1 * (1 - tau3) + tau3 * (1 - tau1), tau1 * tau3)
  mendelian.probs[, 1, 2] <- c((1 - tau2) * (1 - tau3), tau2 * (1 - tau3) + (1 - tau2) * tau3, tau2 * tau3)
  mendelian.probs[, 2, 2] <- c((1 - tau2)^2, 2 * tau2 * (1 - tau2), tau2^2)
  mendelian.probs[, 3, 2] <- c((1 - tau1) * (1 - tau2), tau1 * (1 - tau2) + tau2 * (1 - tau1), tau1 * tau2)
  mendelian.probs[, 1, 3] <- c((1 - tau1) * (1 - tau3), tau1 * (1 - tau3) + tau3 * (1 - tau1), tau1 * tau3)
  mendelian.probs[, 2, 3] <- c((1 - tau1) * (1 - tau2), tau1 * (1 - tau2) + tau2 * (1 - tau1), tau1 * tau2)
  mendelian.probs[, 3, 3] <- c((1 - tau1)^2, 2 * tau1 * (1 - tau1), tau1^2)
  mendelian.probs
}

kmean_clusters <- function(dat, params){
  K <- params$K
  y.mf <- dat[, 1:2]
  y.mfdf <- melt(y.mf)
  y.mfresponse <- y.mfdf$value
  kmeans.result <- kmeans(y.mfresponse, K)
  kmeans.center <- kmeans.result$centers
  kmeans.order <- order(kmeans.center)
  kmeans.clust <- kmeans.result$cluster
  kmeans.clust <- sapply(kmeans.clust, function(x) which(kmeans.order==x))
  y.mfdf$cn <- kmeans.clust - 1
  y.mfdf
}

kmean_clusters.add <- function(kmean.df){
  # use y.mfdf as input from kmean_clusters
  kmean.clust<-kmean.df$cn
  kmean.df$cn<-kmean.clust+1
  kmean.df
}

p_offspring <- function(cn, mendelian.probs, theta){
  t(apply(cn, 1, lookup_prob, mendelian.probs=mendelian.probs, theta=theta))
}

lookup_prob <- function(cn, mendelian.probs, theta) {
  att <- names(theta)[1]
  if((as.numeric(att))==0) {
    mendelian.probs[, cn[2] + 1, cn[1] + 1]
  } else if((as.numeric(att))==1) {
    mendelian.probs[, cn[2], cn[1]]
  } else {
    mendelian.probs[, cn[2] - 1, cn[1] - 1]
  }
}

cnProb <- function(current, tbl, p){
  sigma <- current$sigma
  theta <- current$theta
  K <- length(theta)
  N <- nrow(tbl)
  ##
  ## Repeat each y K times so that each y is evaluated for the K different values of theta and sigma
  ##
  d.y <- tbl$log_ratio %>% rep(each=K) %>%
    dnorm(mean=theta, sd=sigma) %>%
    matrix(N, K, byrow=TRUE) %>%
    "*"(p)
  cn.denom <- rowSums(d.y, na.rm=TRUE)
  cn.prob <- d.y/cn.denom
  if(any(cn.denom == 0)) stop("zeros in denominator")
  colnames(cn.prob) <- NULL
  stopifnot(all.equal(rowSums(cn.prob), rep(1, nrow(cn.prob))))
  cn.prob
}

# this function initialises and contains the mendelian transmission probabilities
.init <- function(data, K, tau, e){
  extdata <- system.file("extdata", package="marimba2")
  mprob <- readRDS(file.path(extdata, "mendelian_probs2.rds"))
  if(e > 0){
    mprob2 <- mprob %>% select(starts_with("p("))
    mprob2 <- (mprob2 + e)/(rowSums(mprob2 + e))
    mprob[, -1] <- mprob2
  }
  lr <- data$log_ratio
  tmp <- capture.output(mmod <- Mclust(lr, G=K))
  data$z <- mmod$classification
  data$copy_number <- data$z - 1
  theta <- mmod$parameters$mean
  sigma <- sqrt(mmod$parameters$variance$sigmasq)
  p <- mmod$parameters$pro
  pi.child <- p
  current <- list(data=data,
                  theta=theta,
                  sigma=sigma,
                  p=p,
                  pi.child = pi.child,
                  tau=tau,
                  mprob=mprob)
  current$logll <- compute_loglik(current)
  current
}

.init2 <- function(gp){
  K <- gp$K
  theta <- sort(rnorm(K, gp$mu, gp$xi))
  logsigma <- rnorm(K, log(0.3), 1)
  sigma <- exp(logsigma)
  p <- rdirichlet(1, gp$a)
  pi.child <- rdirichlet(1, gp$a)
  e <- gp$error
  extdata <- system.file("extdata", package="marimba2")
  mprob <- readRDS(file.path(extdata, "mendelian_probs2.rds"))
  if(e > 0){
    mprob2 <- mprob %>% select(starts_with("p("))
    mprob2 <- (mprob2 + e)/(rowSums(mprob2 + e))
    mprob[, -1] <- mprob2
  }
  current <- list(theta=theta,
                  sigma=sigma,
                  p=p,
                  pi.child = pi.child,
                  mprob=mprob,
                  tau=c(1, 0.5, 0))
}

# different between gmodel and gmodel2 is in current - first includes data, second is parameters only
# variable model in function calls usually refers to gmodel
# different is .init vs .init2 - gmodel calls mclust and gmodel2 uses gp$a

gmodel2 <- function(data,
                    mp=mcmcParams(),
                    gp=geneticParams()){
  current <- .init2(gp)
  chains <- initialize_chains(states=gp$states,
                              N=nrow(data),
                              K=gp$K,
                              S=mp$iter + 1L)
  list(current=current,
       chains=chains,
       mp=mp,
       gp=gp)
}

gmodel <- function(data,
                   mp=mcmcParams(),
                   gp=geneticParams()){
  current <- .init(data, gp$K, gp$tau, gp$error)
  chains <- initialize_chains(states=gp$states,
                              N=nrow(data),
                              K=gp$K,
                              S=mp$iter + 1L)
  list(current=current,
       chains=chains,
       mp=mp,
       gp=gp)
}

initialize_gmodel <- function(dat, params, comp=1){
  y.mf <- multi_K(dat, params)
  y.mf <- y.mf[[comp]]
  y.o <- dat[, "o"]
  y.odf <- data.frame(value = y.o)
  ## Assume that Z=0 corresponds to class 1, Z=1 corresponds to class 2
  ## Random guesses for Z
  cn.mf <- dcast(y.mf, Var1 ~ Var2, value.var="cn")[, -1]
  y.mf.list <- split(y.mf$value, y.mf$cn)
  ## Calculate sufficient statistics
  Ns <- sapply(y.mf.list, length)
  mus <- sapply(y.mf.list, mean, na.rm=TRUE)
  vars <- sapply(y.mf.list, var, na.rm=TRUE)
  theta <- mus
  sigma2 <- vars
  sigma <- sqrt(sigma2)
  a <- params$a
  d <- length(table(y.mf$cn))
  a <- a[1:d]
  pi <- rdirichlet(1, Ns+a)[1, ]
  names(pi) <- names(theta)
  
  # update when making it for alternative types of models
  #eta1 <- params$eta1
  #eta2 <- params$eta2
  # tau <- rbeta(1, eta1, eta2)
  
  ## create initial mendelian matrix
  mendelian.probs <- gMendelian()
  
  ## Draw for offspring
  p.o <- p_offspring(cn.mf, mendelian.probs, theta)

  # process p.o into appropriate dimensions for < 3 components
  p.o <- dim.reduct(p.o, comp)

  cn.prob <- cnProb(p.o, y.o, theta, sigma)
  cn.prob <- as.matrix(cn.prob)
  ##
  ## initialize copy number for offspring
  ##
  y.odf <- init.offspring.cn(cn.prob, cn.mf, comp, y.odf)
  tab <- table(y.odf$cn)
  check<-length(unique(y.mf$cn))
  if(length(tab) != check) stop("Some components unobserved")
  cn.all <- as.matrix(cbind(cn.mf, y.odf$cn))
  colnames(cn.all) <- c("m", "f", "o")
  list(y=dat,
       cn=cn.all,
       m.probs=mendelian.probs,
       theta=theta,
       sigma=sigma,
       pi=pi,
       tau=tau)
}

dim.reduct <- function(offspring.probs, comp){
  if (comp==2){
    offspring.probs <- offspring.probs[,1:2]
  } else if (comp==3){
    offspring.probs <- offspring.probs[,2:3]
  } else if (comp==4) {
    offspring.probs <- as.matrix(offspring.probs[,2])
  } else if (comp==5) {
    offspring.probs <- as.matrix(offspring.probs[,3])
  } else offspring.probs <- offspring.probs
  return (offspring.probs)
}

init.offspring.cn <- function(cn.prob, cn.mf, comp, y.odf){
  if (comp < 3){
    cn.o <- rMultinom(cn.prob, 1)[, 1]
    y.odf$cn <- cn.o - 1
  } else {
    if (comp==3){
      cn.o <- rMultinom(cn.prob, 1)[, 1]
      y.odf$cn <- cn.o
    } else {
      cn.o <- cn.mf$m
      y.odf$cn <- cn.o
    }
  }
  return (y.odf)
}

compute_loglik <- function(current){
  mdat <- current$data
  lr <- mdat$log_ratio
  z <- mdat$z
  theta <- current$theta[z]
  sigma <- current$sigma[z]
  sum(dnorm(lr, theta, sigma, log=TRUE))
}

componentStats <- function(y, cn){
  yy <- as.numeric(y)
  cn2 <- as.numeric(cn)
  ylist <- split(yy, cn2)
  means <- sapply(ylist, mean, na.rm=TRUE)
  sds <- sapply(ylist, sd, na.rm=TRUE)

  ## record just the parents for the mix probs
  y.p <- as.numeric(y[, 1:2])
  cn.p <- as.numeric(cn[, 1:2])
  y.p.list <- split(y.p, cn.p)
  ns <- sapply(y.p.list, length)
  list(means=means, Ns=ns, sds=sds)
}

rMultinom2 <- function(x, ...){
  ## rMultinom will not return an integer if columns are named
  colnames(x) <- NULL
  xx <- rMultinom(x, ...)
  xx
}

rMultinom3 <- function(tbl){
  tbl %>% select(starts_with("p(")) %>%
    rMultinom2(1)
}

update_parents <- function(model){
  current <- model$current
  tbl <- current$data
  K <- length(current$theta)
  tbl.parents <- tbl %>%
    filter(family_member %in% c("m", "f"))
  ##
  ## parents:  p-matrix is the population mixture probabilities
  ##
  p <- current$p
  N <- nrow(tbl.parents)
  p <- matrix(p, N, K, byrow=TRUE)
  tbl.parents$z <- cnProb(current, tbl.parents, p) %>%
    rMultinom(m=1)  %>% "["(, 1)
  M <- cn_adjust(model)
  tbl.parents$copy_number <- tbl.parents$z + M
  tbl.parents
}

update_offspring <- function(model, tbl.parents){
  ##
  ## Similar to update for parents, but plug in mendelian probabilities for the population mixture probabilities
  ##
  gp <- model$gp
  current <- model$current
  mprob <- current$mprob
  K <- gp$K
  p <- tbl.parents %>%
    group_by(id) %>%
    arrange(family_member) %>%
    summarize(parents=paste0(copy_number, collapse="")) %>%
    mutate(parents=substr(parents, 1, 2)) %>%
    left_join(mprob, by="parents") 
  tbl.child <- current$data %>%
    filter(family_member=="o", id %in% tbl.parents$id) %>%
    arrange(id)
  p <- p %>% arrange(id) %>%
    select(starts_with("p("))
  tbl.child$z <- cnProb(current, tbl.child, p) %>%
    rMultinom(m=1)
  M <- cn_adjust(model)
  tbl.child$copy_number <- tbl.child$z + M
  tbl.child
}

cn_adjust <- function(model){
  gp <- model$gp
  states <- gp$states
  start.state <-states[1]
  M <- (-1)
  M <- ifelse(start.state==1, M==0, M)
  M <- ifelse(start.state==2, M==1, M)
  M <- ifelse(start.state==3, M==2, M)
}

update_cn <- function(model){
  tbl.parents <- update_parents(model)
  ## offspring copy number is conditional on the updated parental copy numbers
  tbl.child <- update_offspring(model, tbl.parents)
  tbl <- bind_rows(tbl.parents, tbl.child) %>%
    arrange(id, family_member)
  ## check:  should be hemizygous but not called hemizygous
  tmp <- tbl %>%
    filter(log_ratio < -0.5 & log_ratio > -1.5 & copy_number != 1)
  if(nrow(tmp) > 0){
    ggplot(tbl, aes(copy_number, log_ratio)) +
      geom_jitter(width=0.05, aes(color=family_member))
    problem.id <- tmp$id
    tbl %>% filter(id %in% problem.id, family_member %in% c("f", "m")) %>%
      group_by(id) %>%
      summarize(fm=paste0(family_member, collapse=""),
                copy_number=paste0(copy_number, collapse=""))
    tbl %>% filter(id %in% problem.id, family_member %in% c("f", "m", "o"))

    tbl.parents2 <- tbl %>% filter(id %in% tmp$id, family_member %in% c("f", "m"))
    tbl.child2 <- update_offspring(model, tbl.parents2)
  }
  tbl
}

missingCnState <- function(cn.states, expected=as.character(c(0, 1, 2))){
  observed.cnstates <- names(table(cn.states))
  not.observed <- expected[ !expected %in% observed.cnstates ]
  not.two <- names(table(cn.states)[table(cn.states)==1])
  not.observed <- c(not.observed, not.two)
  not.observed <- unique(not.observed)
}

replaceIfMissing <- function(cn.states, expected=as.character(c(0, 1, 2))){
  missing.states <- missingCnState(cn.states)
  if(length(missing.states) > 0){
    cn.states[sample(seq_along(cn.states), 4)] <- as.numeric(missing.states)
  }
  cn.states
}

balance_cn <- function(stats, current, gp){
  K <- gp$K
  states <- gp$states
  index <- which(stats$n < K)
  if(nrow(stats) != K){
    stats <- tibble(copy_number=states)  %>%
      left_join(stats, by="copy_number")
    index <- which(is.na(stats$mean))
  }
  tbl <- current$data
  theta <- current$theta
  sigma <- current$sigma
  lr <- tbl$log_ratio
  ##
  ## we need at least 2 observations to compute a standard deviation
  ## - re-assign the 2 most probable observations 
  ##
  for(i in c(index)){
    d <- dnorm(lr, theta, sigma, log=TRUE)
    p <- d/sum(d)
    ix <- order(p, decreasing=TRUE)[1:K]
    tbl$z[ix] <- i
  }
  tbl$copy_number <- tbl$z - 1
  tbl
}

update_theta <- function(ymns, sigma, ns, params){
  mu.0 <- params$mu[1]
  xi <- params$xi[1] ## prior variance
  ## let phi denote inverse variance
  data.precision <- ns * 1/sigma^2
  prior.precision <- 1/xi
  post.prec <- prior.precision + data.precision
  mu.n <- (mu.0 * prior.precision + ymns * data.precision) / post.prec
  K <- length(ymns)
  rnorm(K, mu.n, sqrt(1/post.prec))
}

component_stats <- function(tbl){
  tbl %>% group_by(copy_number) %>%
    summarize(mean=mean(log_ratio),
              sd=sd(log_ratio),
              n=n())
}

parent_stats <- function(tbl){
  tbl %>%
    filter(family_member %in% c("f", "m")) %>%
    component_stats
}

child_stats <- function(tbl){
  tbl %>%
    filter(family_member %in% c("o")) %>%
    component_stats
}

##ybar <- function(stats){
##  stats$
##  ##yy <- unlist(model$logr, use.names=FALSE)
##  ##cn <- factor(unlist(model$z, use.names=FALSE))
##  ##mns <- tapply(yy, cn, mean, na.rm=TRUE)
##  if(any(is.na(mns))) stop("NAs in variances")
##  mns
##}
##
##yvar <- function(model){
##  yy <- unlist(model$logr, use.names=FALSE)
##  cn <- factor(unlist(model$cn, use.names=FALSE))
##  vars <- tapply(yy, cn, var, na.rm=TRUE)
##  if(any(is.na(vars))) stop("NAs in variances")
##  vars
##}

## n_child <- function(model){
##   number_cnstate(model, "o")
## }
## 
## number_cnstate <- function(model, member=c("f", "m", "o")){
##   current <- model$current
##   freq <- current$z %>% select(member) %>%
##     gather(key="member", value="copy_number") %>%
##     group_by(copy_number) %>%
##     summarize(n=n())
##   freq
## }
## 
## n_parents <- function(model){
##   number_cnstate(model, c("f", "m"))
## }
## 
## n_all <- function(model){
##   number_cnstate(model)
## }

##n_child <- function(gmodel){
##  yy <- as.numeric(gmodel$y[, c("o")])
##  cn <- as.numeric(gmodel$cn[, 3])
##  tab <- tapply(yy, cn, length)
##  freq <- setNames(rep(0L, length(gmodel$theta)),
##                   attributes(tab)$dimnames[[1]])
##  freq[names(tab)] <- tab
##  freq
##}

## mean: shape/rate
## var:  shape/rate^2
update_sigma <- function(params, ns, theta, ymns, yvars){
##  if(any(sqrt(yvars) > 0.25)){
##    browser()
##  }
  nu.0 <- params$nu
  s2.0 <- params$sigma2.0
  nu.n <- nu.0 + ns
  s2.n <- 1/nu.n * (nu.0*s2.0 + (ns-1)*yvars + ns*(ymns-theta)^2)
  K <- length(ns)
  prec <- rgamma(K, nu.n/2, nu.n*s2.n/2)
  s <- sqrt(1/prec)
  s
}

# key module for different models

gmodel_oneiter <- function(model){
  gp <- model$gp
  K <- gp$K
  current <- model$current
  tbl <- current$data
  N <- nrow(tbl)
  p <- current$p
  pi.child <- current$pi.child
  theta <- current$theta
  sigma <- current$sigma
  states <- gp$states
  ##
  ## update component labels
  ##
  current$data <- update_cn(model)
  ##
  ## update theta
  ##
  stats <- component_stats(current$data)
  if(any(stats$n < K) | nrow(stats) < K){
    current$data <- balance_cn(stats, current, gp)
    stats <- component_stats(current$data)
  }
  ymeans <- stats$mean
  yvars <- stats$sd^2
  ns <- stats$n
  current$theta <- update_theta(params=gp,
                                ymns=ymeans,
                                sigma=sigma,
                                ns=ns)
  ##
  ## update sigma
  ##
  current$sigma <- update_sigma(params=gp,
                                ns=ns,
                                theta=theta,
                                ymns=ymeans,
                                yvars=yvars)
  ## update mixture probabilities
  ## -- only for parents
  parents <- parent_stats(current$data)
  if(nrow(parents) != K){
    parents2 <- tibble(copy_number=gp$states[1]:gp$states[K])  %>%
      left_join(parents, by="copy_number")
    parents2$n[is.na(parents2$n)] <- 0
    ##    pdat <- current$data %>%
    ##      filter(family_member %in% c("f", "m"))
    ##    current.parents <- current
    ##    current.parents$data <- pdat
    ##    pdat <- balance_cn(parents2, current.parents)
    ## now need
    parents <- parents2
  }
  current$p <- update_p(parents$n, gp)
  
  # update child stats
  offspring <- child_stats(current$data)
  if(nrow(offspring) != K){
    offspring2 <- tibble(copy_number=gp$states[1]:gp$states[K])  %>%
      left_join(offspring, by="copy_number")
    offspring2$n[is.na(offspring2$n)] <- 0
    offspring <- offspring2
  }
  current$pi.child <- offspring$n / (sum(offspring$n))
  
  ##
  ## update transmission matrix
  ## TODO: update_tau
  ## model <- updateTransmissionProb(model, params)
  ## update log likelihood
  current$logll <- compute_loglik(current)
  model$current <- current
  model
}

update_p <- function(ns, params){
  a <- params$a
  ##ns <- n_parents(gmodel)
  d <- length(ns)
  a <- a[1:d]
  alpha.n <- ns + a
  p <- tryCatch(rdirichlet(1, alpha.n), error=function(e) NULL)
  if(is.null(p)) browser()
  ##if(length(p) < 3) browser()
  p
}

update_tau <- function(gmodel, params){
  a <- params$a
  N <- params$N
  ns <- n_parents(gmodel)
  d <- length(ns) - 1
  bvec0 <- a[1:d]
  nvec <- as.numeric(n_child(gmodel))
  allele.freq <- DirichSampHWE(nvec,bvec0, 1)
  pvec <- allele.freq$pvec[2]
}

update_tau.env <- function(gmodel, params){
  a <- params$a
  N <- params$N
  ns <- n_parents(gmodel)
  d <- length(ns) - 1
  bvec0 <- a[1:d]
  nvec <- as.numeric(n_child(gmodel))
  # the line below will throw an error if balance.cn.all is not working properly to balance 3 components
  # bvec0 should always be c(1,1) and nvec a length 3 vector for 3 component model
  allele.freq <- DirichSampHWE(nvec,bvec0, 1)
  # to generalise to multi-components, the line below will need to be revised.
  pvec <- allele.freq$pvec[2]
  denom <- 2 * ns[[3]] + ns[[2]]
  numer <- pvec * 2* N
  tau.env <- numer / denom
  gmodel$tau[] <- tau.env
  tau <- gmodel$tau
  tau
}

update_tau.intermed <- function(gmodel, params){
  # see notes in update_tau.env
  a <- params$a
  N <- params$N
  ns <- n_parents(gmodel)
  d <- length(ns) - 1
  bvec0 <- a[1:d]
  nvec <- as.numeric(n_child(gmodel))
  allele.freq <- DirichSampHWE(nvec,bvec0, 1)
  pvec <- allele.freq$pvec[2]
  denom <- ns[[2]]
  # numerator here is just for heterozygotes in kids so calculated accordingly
  numer <- pvec * 2 * N * (nvec[2] / (nvec[2] + 2 * nvec[3]))
  tau.intermed <- numer / denom
  gmodel$tau[] <- tau.intermed
  tau <- gmodel$tau
  tau
}

update_mendel <- function(gmodel){
  m.probs <- gmodel$m.probs
  att <- attributes(gmodel$theta)
  dim <- as.numeric(att$names) + 1
  dim.l <- length(dim)
  subset <- c(dim[1]:dim[dim.l])
  m.probs <- m.probs[subset, subset, subset]
  gmodel$m.probs <- m.probs
  gmodel
 }

# this module updates both the Mendelian transmission matrix as well as the taus where relecant
updateTransmissionProb <- function(gmodel, params){

  #ncp <- params$ncp

if(params$model == "Genetic") {
  m.probs <- gMendelian(tau.one<-params$tau.one, tau.two<-params$tau.two, tau.three<-params$tau.three, err<-params$error)
  gmodel$m.probs <- m.probs
  gmodel
} else if(params$model == "Environmental") {
tau <- update_tau.env(gmodel, params)
gmodel$tau <- tau
m.probs <- gMendelian(tau.one<-gmodel$tau, tau.two<-gmodel$tau, tau.three<-gmodel$tau, err<-params$error)
gmodel$m.probs <- m.probs
gmodel
} else {
  tau <- update_tau.intermed(gmodel, params)
  gmodel$tau <- tau
  m.probs <- gMendelian(tau.one<-params$tau.one, tau.two<-gmodel$tau, tau.three<-params$tau.three, err<-params$error)
  gmodel$m.probs <- m.probs
  gmodel
}
  return (gmodel)
  }

# move_chains <- function(chains, gmodel, i){
  
#generalised chain movement
 #K <- length(gmodel$theta)
  
  #chains$C$C0 <- chains$C$C0 + 1L * (gmodel$cn == 0)
  #chains$C$C1 <- chains$C$C1 + 1L * (gmodel$cn == 1)
  #chains$C$C2 <- chains$C$C2 + 1L * (gmodel$cn == 2)
  ##
  
  ##
  #chains$theta[i, ] <- gmodel$theta
  #chains$sigma[i, ] <- gmodel$sigma
  #chains$logll[i] <- compute_loglik(gmodel)
  #chains$pi[i, ] <- gmodel$pi
  #chains$tau[i] <- gmodel$tau
  #chains
#}

update_chains <- function(model, i){
  chains <- model$chains
  current <- model$current
  dat <- current$data
  gp <- model$gp
  K <- gp$K
  z <- dat$copy_number + 1
  zchain <- chains$z
  for(j in seq_len(K)){
    zchain[, j] <- zchain[, j] + (z == j)*1L
  }
  chains$theta[i, ] <- current$theta
  chains$sigma[i, ] <- current$sigma
  chains$p[i, ] <- current$p
  chains$pi.child[i, ] <- current$pi.child
  chains$tau[i, ] <- current$tau
  chains$logll[i] <- current$logll
  chains$z <- zchain
  chains
}

.gibbs_burnin <- function(model){
  mp <- model$mp
  S <- mp$burnin
  for(i in seq_len(S)){
    model <- gmodel_oneiter(model)
  }
  model
}

##############################
###the customised functions
#############################

###########################
###original working well###
###########################
# wrapper <- _genetic <- mcmc <- oneiter (make changes to oneiter)

gmodel_mcmc <- function(model){
  mp <- model$mp
  S <- mp$iter
  T <- max(mp$thin, 1L)
  for(i in seq_len(S)){
    for(j in seq_len(T)){
      model <- gmodel_oneiter(model)
    }
    ## the initial values are in position 1
    model$chains <- update_chains(model, i+1) 
  }
  model
}

gibbs_genetic <- function(model){
  mp <- model$mp
  S <- mp$iter
  B <- mp$burnin
  ##model$chains <- update_chains(model, 1)
  if(B > 0){
    model <- .gibbs_burnin(model)
    model$chains <- update_chains(model, 1)
  }
  if(S > 1){
    model <- gmodel_mcmc(model)
  }
  model
}

# new wrapper function
# no need for this, see functions below
gibbs.cnv.call <- function(K, states, tau, xi, mu, nu, sigma2.0, a, eta, error, ncp, model, dat){
  gparams <- geneticParams(K=K, states = states,
                           tau = tau,
                           xi = xi,
                           mu = mu, nu = nu,
                           sigma2.0 = sigma2.0, a = a,
                           eta = eta, error = error,
                           ncp = ncp, model = model)
  
  gmodel <- gmodel(dat$data)
  gmodel.fit <- gibbs_genetic(gmodel)
  gmodel.map <- map_cn2(gmodel.fit)
  gmodel.stats <- model.compare(gmodel.fit)
  
  return(list(
    theta = gmodel.fit$chains$theta,
    sigma = gmodel.fit$chains$sigma,
    mixture.probs = gmodel.fit$chains$p,
    tau = gmodel.fit$chains$tau,
    cn.ind = gmodel.map,
    #cn.ind = gmodel.map$cn,
    cn.state.prob = gmodel.map$prob,
    logll = gmodel.fit$chains$logll,
    BIC = gmodel.stats[1],
    DIC = gmodel.stats[2]
  ))
}

 ##################
    # old wrapper function 
  # deprecated 
  #######################
    
# gibbs.cnv.wrapper <- function(K, states, tau, xi, mu, nu, sigma2.0, a, eta, error, ncp, model, y=y){
  # gparams <- geneticParams(K=K, states = states,
  #                         tau = tau,
   #                        xi = xi,
    #                       mu = mu, nu = nu,
     #                      sigma2.0 = sigma2.0, a = a,
      #                     eta = eta, error = error,
       #                    ncp = ncp, model = model)
 ## gmodel.k3 <- initialize_gmodel(y, gparams, 1)
  #gmodel.k2a <- initialize_gmodel(y, gparams, 2)
  #gmodel.k2b <- initialize_gmodel(y, gparams, 3)
  #gmodel.k1a <- initialize_gmodel(y, gparams, 4)
  #gmodel.k1b <- initialize_gmodel(y, gparams, 5)
  
  ## fit.k3 <- gibbs_genetic(gmodel.k3, gparams)
  #fit.k2a <- gibbs_genetic(gmodel.k2a, gparams)
  #fit.k2b <- gibbs_genetic(gmodel.k2b, gparams)
  #fit.k1a <- gibbs_genetic(gmodel.k1a, gparams)
  #fit.k1b <- gibbs_genetic(gmodel.k1b, gparams)
  
 ## fit.chains.k3 <- fit.k3$chains$C
  #fit.chains.k2a <- fit.k2a$chains$C
  #fit.chains.k2b <- fit.k2b$chains$C
  #fit.chains.k1a <- fit.k1a$chains$C
  #fit.chains.k1b <- fit.k1b$chains$C
  
 ## cn.stats.k3 <- map_cn(gmodel.k3, fit.chains.k3, gparams)
  #cn.stats.k2a <- map_cn(gmodel.k2a, ch, gparams)
  #cn.stats.k2b <- map_cn(gmodel.k2b, ch, gparams)
  #cn.stats.k1a <- map_cn(gmodel.k1a, ch, gparams)
  #cn.stats.k1b <- map_cn(gmodel.k1b, ch, gparams)
  
 ## model.metrics <- model.comparison(fit.k3)
  
 # return(list(post.thetas = fit.k3$chains$theta,
  #            post.sigmas = fit.k3$chains$sigma,
   #           post.pi = fit.k3$chains$pi,
    #          taus = fit.k3$chains$tau,
     #         post.cn0 = fit.k3$chains$C$C0,
      #        post.cn1 = fit.k3$chains$C$C1,
       #       post.cn2 = fit.k3$chains$C$C2,
        #      trio.cn = cn.stats.k3$cn,
         #     trio.cn.probs = cn.stats.k3$prob,
          #    logll = fit.k3$chains$logll,
           #   model.BIC = model.metrics[1],
            #  model.DIC = model.metrics[2])
         #)
# }

# model comparison metrixs
# calculation of BIC and DIC

model.compare <- function (fit.model) {
  BIC <- calc.BIC(fit.model)
  DIC <- calc.DIC(fit.model)
  list = c(model.BIC = BIC,
           model.DIC = DIC)
}

calc.BIC <- function (fit.model) {
  n.sample <- length(fit.model$current$data$z)
  # what is npar free parameters
  npar <- length(fit.model$chains$logll)
  logll <- fit.model$chains$logll
  bic.calc <- log(n.sample)*npar - 2*logll
  bic.calc <- mean(bic.calc)
  bic.calc
}

calc.DIC <- function (fit.model) {
  S <- nrow(fit.model$chains$theta)
  logll <- fit.model$chains$logll
  logll.sum <- sum(fit.model$chains$logll)
  npar <- 2*(logll - (1/S * logll.sum))
  dic.calc <- -2*(logll-npar)
  dic.calc <- mean(dic.calc)
  dic.calc
}

# output results summary
gibbs.results <- function(sim.truth, gibbs.cnv, params){
  N <- params$N
  
  # compare the truth set with call set
  truth <- sim.truth["cn"]
  truth <- truth$cn
  map.cn <- gibbs.cnv$trio.cn
  pr.cn <- gibbs.cnv$trio.cn.probs
  contig <- table(map.cn, truth)
  index <- which(rowSums(map.cn != truth) > 0)
  mistake.true.cn <- truth[index, ]
  mistake.cn.called <- map.cn[index, ]
  mistake.cn.probs <- pr.cn[index, ]
  
  # effective size
  p <- gibbs.cnv$post.pi
  eff.size <- effectiveSize(p)
  
  # compare posterior probabilities for different components
  truth.p <- sim.truth$p
  empirical.p <- table(truth)/(N*3)
  p.mns <- colMeans(p)
  attributes(p.mns)$names <- c("0","1","2")
  
  # taus
  posterior.tau <- median(gibbs.cnv$taus)
  
  # taus from truth set
  parents.count <- n_parents(sim.truth)
  child.count <- n_child(sim.truth)
  
  numer.env <- child.count[[2]] + 2 * child.count[[3]]
  denom.env <- parents.count[[2]] + 2 * parents.count[[3]]
  tau.env <- numer.env / denom.env
  
  numer.intermed <- child.count[[2]]
  denom.intermed <- parents.count[[2]]
  tau.intermed <- numer.intermed / denom.intermed
  
  # truth table
  truth.tb <- smartbind(truth.p, empirical.p, p.mns)
  truth.tb[is.na(truth.tb)] <- 0
  row.names(truth.tb) <- c("truth", "empirical", "estimated")
  
  truth.tb.2 <- data.frame(cn=c("C0", "C1", "C2"),
                        truth=as.numeric(truth.p),
                        empirical=as.numeric(empirical.p),
                        estimated = as.numeric(p.mns))
  
  return(list(truth.table = contig,
              missed.true.cn = mistake.true.cn,
              missed.called.cn = mistake.cn.called,
              missed.called.cn.probs = mistake.cn.probs,
              taus = posterior.tau,
              true.tau.env = tau.env,
              true.tau.intermed = tau.intermed,
              effective.Size = eff.size,
              cn.chain = p,
              mixture.prob.tb1 = truth.tb,
              mixture.prob.tb2 = truth.tb.2))
}

## are the mixing probs are in agreement with the empirical values estimated from
## the simulated data?
gg_prob.comparison <- function (gibbs.results) {
  p.df <- melt(gibbs.results$cn.chain)
  colnames(p.df)[2] <- "cn"
  ggplot(p.df, aes(Var1, value)) +
    geom_line(color="gray") +
    geom_hline(data=gibbs.results$mixture.prob.tb2,
               mapping=aes(yintercept=empirical)) +
    facet_wrap(~cn)
}

gg_inten.comparison1 <- function (sim.truth, gibbs.cnv, params) {
  
  p <- cbind(as.numeric(gibbs.cnv$post.cn0), as.numeric(gibbs.cnv$post.cn1), as.numeric(gibbs.cnv$post.cn2))
  cn.model <- apply(p, 1, which.max) - 1
  cn.trio <- matrix(cn.model, params$N, params$K)
  colnames(cn.trio) <- c("m", "f", "o")
  df <- data.frame(y=as.numeric(sim.truth$y),
                   cn=factor(cn.model),
                   truth=factor(as.numeric(sim.truth$cn)),
                   concordant=cn.model==as.numeric(sim.truth$cn))
  
  ggplot(df, aes(cn, y)) +
    geom_jitter(width=0.1,
                aes(color=concordant)) +
    geom_boxplot(alpha=0.2, outlier.fill="transparent",
                 outlier.color="transparent")
}

gg_inten.comparison2 <- function (sim.truth, gibbs.cnv, params) {
  
  p <- cbind(as.numeric(gibbs.cnv$post.cn0), as.numeric(gibbs.cnv$post.cn1), as.numeric(gibbs.cnv$post.cn2))
  cn.model <- apply(p, 1, which.max) - 1
  cn.trio <- matrix(cn.model, params$N, params$K)
  colnames(cn.trio) <- c("m", "f", "o")
  df <- data.frame(y=as.numeric(sim.truth$y),
                   cn=factor(cn.model),
                   truth=factor(as.numeric(sim.truth$cn)),
                   concordant=cn.model==as.numeric(sim.truth$cn))
  
  ggplot(df, aes(truth, y)) + geom_jitter(width=0.1) +
    geom_boxplot(alpha=0.2, outlier.fill="transparent",
                 outlier.color="transparent")
}

## We will use this function to simulate data
## We will simulate using assumed distributions, which is definitely not the 
##   actual case
## Output: a list of two  n x 3 matrices, where each row is a trio, 
##   1st column = M, 2nd column = F, 3rd column = O
##   1st list is of fluorescence values (what we would observe)
##   2nd list is of true CN states (for evaluation)
## Inputs:
##   n -- the number of trios
##   p -- a vector of length three giving the probabilities of having
##        CN 0, 1, 2 respectively (for mother and father)
##   theta -- a vector of length three giving the distribution means for
##        CN 0, 1, 2 respectively 
##   sigma -- a vector of length three giving the distribution SDs for
##        CN 0, 1, 2 respectively 

statistics_1000g <- function(region){
  path <- system.file("extdata", package="marimba")
  vcf.del <- read.table(file.path(path, "parameter.p.1000gp.v2.txt"),
                        header=TRUE, sep="\t")
  median.sd <- read.table(file.path(path, "medians.sd.txt"),
                          header=TRUE, sep="\t")
  p <- as.numeric(vcf.del[region, c(9:11)])
  theta <- as.numeric(median.sd[region, c(1:3)])
  sigma <- as.numeric(median.sd[region, c(4:6)])
  params <- cbind(p, theta, sigma)
  colnames(params) <- c("p", "theta", "sigma")
  as.tibble(params)
}

# simulate from environmental model
# set taus to arbitary but the same value
simulate_data.env1 <- function(params, N, error=0){
  ##mendelian.probs <- mendelianProb(epsilon=error)
  ##mendelian.probs <- gMendelian(tau.one=0.1, tau.two=0.1, tau.three=0.1, err=0)
  mendelian.probs <- gMendelian(params)
  ##stats <- statistics_1000g(region)
  dat <- simulate_data(params, N, mendelian.probs)
  dat$cn <- dat$cn - 1
  colnames(dat$cn) <- c("m", "f", "o")
  colnames(dat$response) <- gsub("y_", "", colnames(dat$response))
  names(dat) <- c("y", "cn")
  dat$theta <- setNames(params[, "theta"], 0:2)
  dat$sigma <- setNames(params[, "sigma"], 0:2)
  dat$p <- setNames(params[, "p"], 0:2)
  dat$logll <- compute_loglik(dat)
  dat
}

simulate_data.env2 <- function(params, N, error=0){
  ##mendelian.probs <- mendelianProb(epsilon=error)
  ##mendelian.probs <- gMendelian(tau.one=0.8, tau.two=0.8, tau.three=0.8, err=0)
  mendelian.probs <- gMendelian(tau.one=0.8, tau.two=0.8, tau.three=0.8, err=0)
  ##stats <- statistics_1000g(region)
  dat <- simulate_data(params, N, mendelian.probs)
  dat$cn <- dat$cn - 1
  colnames(dat$cn) <- c("m", "f", "o")
  colnames(dat$response) <- gsub("y_", "", colnames(dat$response))
  names(dat) <- c("y", "cn")
  dat$theta <- setNames(params[, "theta"], 0:2)
  dat$sigma <- setNames(params[, "sigma"], 0:2)
  dat$p <- setNames(params[, "p"], 0:2)
  dat$logll <- compute_loglik(dat)
  dat
}

mendelianProb <- function(epsilon=0){
  ## Mendelian probability array for child
  ## for ease of referencing of row/ column names, let (0,0) = BB, (0,1)= AB and (1,1) = AA
  ## here BB, AB, AA represent from mother dim for array creation
  ## the probability array is labelled with .c, .f and .m representing child, father, mother respectively
  ##  BB <- matrix(c(1, 0, 0,
  ##                 0.5, 0.5, 0,
  ##                 0, 1, 0) ,nrow=3, ncol=3,
  ##               dimnames=list(c("BB","AB","AA"),c("BB","AB","AA")))
  ##  AB <- matrix(c(0.5, 0.5, 0,
  ##                 0.25, 0.5, 0.25,
  ##                 0, 0.5, 0.5) ,nrow=3, ncol=3,
  ##               dimnames=list(c("BB","AB","AA"),c("BB","AB","AA")))
  ##  AA <- matrix(c(0, 1, 0,
  ##                 0, 0.5, 0.5,
  ##                 0, 0, 1) ,nrow=3, ncol=3,
  ##               dimnames=list(c("BB","AB","AA"),c("BB","AB","AA")))
  ##  mendelian.probs <- abind(BB, AB, AA, along=3,
  ##                           new.names=list(c("BB.c","AB.c","AA.c"),
  ##                                          c("BB.f","AB.f","AA.f"),
  ##                                          c("BB.m","AB.m","AA.m")))
  .Deprecated("use gMendelian")
  mendelian.probs <- array(dim=c(3, 3, 3))
  mendelian.probs[, 1, 1] <- c(1 - epsilon, epsilon/2, epsilon/2)
  mendelian.probs[, 2, 1] <- c(.5 - epsilon/2, .5 - epsilon/2, epsilon)
  mendelian.probs[, 3, 1] <- c(epsilon/2, 1 - epsilon, epsilon/2)
  mendelian.probs[, 1, 2] <- c(.5 - epsilon/2 , .5 - epsilon/2, epsilon)
  mendelian.probs[, 2, 2] <- c(.25, .5, .25)
  mendelian.probs[, 3, 2] <- c(epsilon, .5 - epsilon/2, .5 - epsilon/2)
  mendelian.probs[, 1, 3] <- c(epsilon/2, 1 - epsilon, epsilon/2)
  mendelian.probs[, 2, 3] <- c(epsilon, .5 - epsilon/2, .5 - epsilon/2)
  mendelian.probs[, 3, 3] <- c(epsilon/2, epsilon/2, 1 - epsilon)
  mendelian.probs
}

gg_cnp <- function(dat){
  response.df <- melt(dat$y)
  cn.df <- melt(dat$cn)
  df <- data.frame(logr=response.df$value,
                   cn=cn.df$value)
  df$cn <- as.factor(df$cn)
  p <- ggplot(df, aes(logr, ..count.., fill = cn)) + 
    geom_density(alpha = .5) + xlab("LRR")
  p
}

mcmcParams <- function(iter=1000,
                       burnin=100,
                       thin=10,
                       nstarts=50,
                       max_burnin=50000){
  list(iter=iter, burnin=burnin, thin=thin, nstarts=nstarts,
       max_burnin=max_burnin)
}


## log_ratio ~ normal(theta_z, sigma_z)
## theta ~ normal(mu, xi)
## sigma_z ~ IG(nu/2, 1/2*nu*sigma2.0)
geneticParams <- function(K=5,
                          states=0:4,
                          tau=c(1, 0.5, 0),
                          xi=c(1.5, 1, 1, 1, 1), 
                          mu=c(-3, -0.5, 0, 1, 2), ## theoretical
                          nu=1,
                          sigma2.0=0.001,
                          a=rep(1, K),
                          eta=c(0.5, 0.5),
                          error=1e-4,
                          ncp=30,
                          model="Genetic"){
  ##m.probs <- gMendelian(tau)
  list(K=K,
       states=states,
       tau=tau,
       mu=mu,
       xi=xi,
       nu=nu,
       sigma2.0=sigma2.0,
       a=a,
       eta=eta,
       error=error,
       ncp=ncp,
       ##m.probs=m.probs,
       model=model)
}

gg_chains <- function(model, expected){
  chains <- model$chains
  L <- nrow(chains$theta)
  K <- ncol(chains$theta)
  tbl <- tibble(theta0=chains$theta[, 1],
               theta1=chains$theta[, 2],
               theta2=chains$theta[, 3],
               sigma0=chains$sigma[, 1],
               sigma1=chains$sigma[, 2],
               sigma2=chains$sigma[, 3],
               pi0=chains$p[, 1],
               pi1=chains$p[, 2],
               pi2=chains$p[, 3],
               # pi.child0=gibbs.cnv$post.pi.child[, 1],
               # pi.child1=gibbs.cnv$post.pi.child[, 2],
               # pi.child2=gibbs.cnv$post.pi.child[, 3],
               loglik=chains$logll,
               iter=seq_len(L))
  tbl2 <- gather(tbl, key=parameter, value=monte_carlo, -iter)
  if(!missing(expected)){
    eparams <- expected$params
    expect <- tibble(parameter=c(paste0("theta", 0:2),
                                 paste0("sigma", 0:2),
                                 paste0("pi", 0:2),
                                 "loglik"),
                     truth=c(eparams$theta,
                             eparams$sigma,
                             eparams$p,
                             expected$loglik))
  }
  p <- ggplot(tbl2, aes(iter, monte_carlo)) +
    geom_line(color="gray") +
    facet_wrap(~parameter, scales="free_y")
  if(!missing(expected)){
    p <- p + geom_hline(data=expect, color="steelblue",
                        aes(yintercept=truth))
  }
  p
}


#' Return maximum a posterior copy number from a model
#'
#' @param model a fitted model
#' @return a vector of the maximum a posterior copy number
map_cn2 <- function(model){
  freq <- model$chains$z
  iter <- model$mp$iter + 1
  prob <- freq/iter
  max.col(prob) - 1
  #list(prob = prob,
   #    cn = cn)
}


dnorm_quantiles <- function(y, mean, sd){
  x <- list()
  for(i in seq_along(mean)){
    x[[i]] <-   qnorm(c(0.001, 0.999), mean=mean[i], sd=sd[i])
  }
  xx <- unlist(x)
  limits <- c(min(y, min(xx)),
              max(y, max(xx)))
  seq(limits[1], limits[2], length.out=500)
}


.dnorm_poly <- function(x, p, mean, sd){
  y <- p*dnorm(x, mean=mean, sd=sd)
  yy <- c(y, rep(0, length(y)))
  xx <- c(x, rev(x))
  tmp <- data.frame(y=yy, x=xx)
}

dnorm_model <- function(qtiles, p, mean, sd){
  df.list <- list()
  for(i in seq_along(mean)){
    dat <- .dnorm_poly(qtiles, p[i], mean[i], sd[i])
    if(i == 1){
      overall <- dat$y
    } else{
      overall <- overall + dat$y
    }
    df.list[[i]] <- dat
  }
  df <- do.call(rbind, df.list)
  L <- sapply(df.list, nrow)
  df$component <- factor(rep(seq_along(mean), L))
  df.overall <- data.frame(y=overall, x=dat$x)
  df.overall$component <- "marginal"
  df <- rbind(df, df.overall)
  df$component <- factor(df$component, levels=c("marginal", seq_along(mean)))
  if(length(mean) == 1){
    df <- df[df$component != "overall", ]
    df$component <- factor(df$component)
  }
  df
}

dnorm_poly <- function(model){
  current <- model$current
  tbl <- current$data
  lr <- tbl$log_ratio
  mixprob <- current$p
  means <- current$theta
  sds <- current$sigma
  qtiles <- dnorm_quantiles(lr, means, sds)
  df <- dnorm_model(qtiles, mixprob, means, sds)
  df
}

gg_model <- function(model, bins){
  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
              "#D55E00", "#CC79A7",  "#009E73")
  df <- model$current$data
  if(missing(bins))
    bins <- nrow(df)/2
  dat <- dnorm_poly(model) %>% as.tibble
  component <- x <- y <- ..density.. <- NULL
  ggplot(dat, aes(x, y, group=component)) +
    geom_histogram(data=df, aes(log_ratio, ..density..),
                   bins=bins,
                   inherit.aes=FALSE) +
    geom_polygon(aes(fill=component, color=component), alpha=0.4) +
    xlab("quantiles") + ylab("density") +
    scale_color_manual(values=colors) +
    scale_y_sqrt() +
    scale_fill_manual(values=colors) +
    guides(fill=guide_legend(""), color=guide_legend(""))
}

gg_truth <- function(truth, bins){
  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
              "#D55E00", "#CC79A7",  "#009E73")
  df <- truth$data
  if(missing(bins))
    bins <- nrow(df)/2
  mixprob <- truth$params$p
  means <- truth$params$theta
  sds <- truth$params$sigma
  qtiles <- dnorm_quantiles(df$log_ratio,
                            means,
                            sds)
  dat <- dnorm_model(qtiles, mixprob, means, sds)
  component <- x <- y <- ..density.. <- NULL
  ggplot(dat, aes(x, y, group=component)) +
    geom_histogram(data=df, aes(log_ratio, ..density..),
                   bins=bins,
                   inherit.aes=FALSE) +
    geom_polygon(aes(fill=component, color=component), alpha=0.4) +
    xlab("quantiles") + ylab("density") +
    scale_color_manual(values=colors) +
    scale_y_sqrt() +
    scale_fill_manual(values=colors) +
    guides(fill=guide_legend(""), color=guide_legend(""))
}

startAtTrueValues <- function(truth, model){
  current <- model$current
  current$theta <- truth$params$theta
  current$sigma <- truth$params$sigma
  current$p <- truth$params$p
  current$data$copy_number <- truth$data$copy_number
  current$data$z <- truth$data$copy_number + 1
  model$current <- current
  model
}

longFormat <- function(logr, cn){
  logr <- gather(logr, key="family_member",
                 value="logr", -id)
  cn <- gather(cn, key="family_member",
               value="copy_number", -id) 
  tbl <- left_join(logr, cn, by=c("id", "family_member")) %>%
    mutate(copy_number=factor(copy_number))
  tbl
}

posterior_summary <- function(model){
  tbl <- model$current$data %>%
    mutate(copy_number=map_cn2(model)) %>%
    mutate(z=copy_number + 1)
  ch <- model$chains
  K <- model$gp$K
  params <- tibble(p=colMeans(ch$p),
                   theta=colMeans(ch$theta),
                   sigma=colMeans(ch$sigma))
  ##
  ## convergence diagnostics
  ##
  eff.p <- effectiveSize(ch$p)
  eff.theta <- effectiveSize(ch$theta)
  eff.sigma <- effectiveSize(ch$sigma)
  convergence <- tibble(param=c("p", "theta", "sigma"),
                        effective_size=c(min(eff.p),
                                         min(eff.theta),
                                         min(eff.sigma)))
  list(data=tbl, posterior_means=params, convergence=convergence)
}

current_summary <- function(model){
  current <- model$current
  tbl <- current$data
  params <- tibble(p=as.numeric(current$p),
                   theta=current$theta,
                   sigma=current$sigma)
  list(data=tbl, posterior_means=params)
}

posterior_difference <- function(model.summary, truth){
  cn.model <- model.summary$data
  cn.truth <- truth$data %>% select(c(id, family_member, copy_number)) %>%
    mutate(true_cn=copy_number) %>%
    select(id, family_member, true_cn)
  cn.delta <- left_join(cn.model, cn.truth) %>%
    filter(copy_number != true_cn)
  cn.delta$log_ratio <- round(cn.delta$log_ratio, 2)
  p <- tibble(model=model.summary$posterior_means$p,
              truth=truth$params$p)
  theta <- tibble(model=model.summary$posterior_means$theta,
                  truth=truth$params$theta)
  sigma <- tibble(model=model.summary$posterior_means$sigma,
                  truth=truth$params$sigma)
  K <- nrow(truth$params)
  params <- tibble(name=c(paste0("p_", seq_len(K)),
                          paste0("theta_", seq_len(K)),
                          paste0("sigma_", seq_len(K))),
                   model=round(c(p$model, theta$model, sigma$model), 2),
                   truth=round(c(p$truth, theta$truth, sigma$truth), 2)) %>%
    mutate(diff=abs(model-truth))
  list(cn.diff=cn.delta, params=params)
}

multipleStarts <- function(dat, nstarts, top=10, burnin=50, iter=0, thin=1){
  fitModel <- function(dat, burnin, iter, thin){
    gmodel(dat,
           mp=mcmcParams(burnin=burnin, iter=iter, thin=thin)) %>%
      gibbs_genetic
  }
  mlist <- replicate(nstarts, dat %>% fitModel(burnin, iter, thin), simplify=FALSE)
  getLL <- function(x) x$current$logll
  ll <- mlist %>% map_dbl(getLL)
  ix <- order(ll, decreasing=TRUE)
  ix <- head(ix, top)
  mlist[ix]
}

setMcmcParams <- function(model, mp){
  model$mp <- mp
  gp <- model$gp
  dat <- model$current$data
  chains <- initialize_chains(states=gp$states,
                              N=nrow(dat),
                              K=gp$K,
                              S=mp$iter + 1L)
  model$chains <- chains
  model
}

set_param_names <- function(x, nm){
  K <- seq_len(ncol(x)) - 1
  set_colnames(x, paste0(nm, K))
}

mcmcList <- function(model.list){
  ch.list <- map(model.list, "chains")
  theta.list <- map(ch.list, "theta")
  theta.list <- map(theta.list, set_param_names, "theta")
  sigma.list <- map(ch.list, "sigma")
  sigma.list <- map(sigma.list, set_param_names, "sigma")
  p.list <- map(ch.list, "p")
  p.list <- map(p.list, set_param_names, "p")
  half <- floor(nrow(theta.list[[1]])/2)
  first_half <- function(x, half){
    x[seq_len(half), ]
  }
  last_half <- function(x, half){
    i <- (half + 1):(half*2)
    x[i, ]
  }
  theta.list <- c(map(theta.list, first_half, half),
                  map(theta.list, last_half, half))
  sigma.list <- c(map(sigma.list, first_half, half),
                  map(sigma.list, last_half, half))
  p.list <- c(map(p.list, first_half, half),
              map(p.list, last_half, half))
  vars.list <- vector("list", length(p.list))
  for(i in seq_along(vars.list)){
    vars.list[[i]] <- cbind(theta.list[[i]], sigma.list[[i]], p.list[[i]])
  }
  vars.list <- map(vars.list, mcmc)
  mlist <- mcmc.list(vars.list)
  mlist
}

label_switch <- function(model){
  th <- model$current$theta
  ix <- order(th)
  if(!identical(ix, seq_along(ix))){
    model$current$theta <- th[ix]
    model$current$sigma <- model$current$sigma[ix]
    model$current$p <- model$current$p[ix]
    ch <- model$chains
    ch$theta <- ch$theta[, ix]
    ch$sigma <- ch$sigma[, ix]
    ch$p <- ch$p[, ix]
    ch$z <- ch$z[, ix]
    model$chains <- ch
  }
  model
}

gibbs_genetic2 <- function(starts, chains, dat, gp, mp){
  starts$data <- dat
  model <- list(gp=gp,
                mp=mp,
                current=starts,
                chains=chains)
  model <- gibbs_genetic(model)
  model <- label_switch(model)
  model
}

gelman_rubin <- function(mcmc_list, gp){
  #anyNA <- function(x){
  #  any(is.na(x))
 # }
 # any_nas <- map_lgl(mcmc_list, anyNA)
 # mcmc_list <- mcmc_list[ !any_nas ]
  if(length(mcmc_list) < 2 ) stop("Need at least two MCMC chains")
  r <- tryCatch(gelman.diag(mcmc_list, autoburnin=FALSE), error=function(e) NULL)
  if(is.null(r)){
    ## gelman rubin can fail if p is not positive definite
    pcolumn <- match(paste0("p", gp$K-1), colnames(mcmc_list[[1]]))
    f <- function(x, pcolumn){
      x[, -pcolumn]
    }
    mcmc_list <- map(mcmc_list, f, pcolumn) %>%
      as.mcmc.list
    r <- gelman.diag(mcmc_list, autoburnin=FALSE)
  }
  r
}

gibbs <- function(mp, gp, dat){
  chains <- initialize_chains(states=gp$states,
                              N=nrow(dat),
                              K=gp$K,
                              S=mp$iter+1L)
  max_burnin <- mp$max_burnin
  while(mp$burnin < max_burnin){
    message("burnin: ", mp$burnin)
    start.list <- replicate(mp$nstarts, .init2(gp), simplify=FALSE)
    model.list <- map(start.list, gibbs_genetic2, chains, dat, gp, mp)
    keep <- selectModels(model.list)
    model.list2 <- model.list[ keep ]
    mlist <- mcmcList(model.list2)
    neff <- effectiveSize(mlist)
    
    # still throwing an error when removing last column, switch to removing as necessary as similar in CNPBayes
    ### last column of p-matrix is not needed since p's constrained to sum to 1
    #last_col <- ncol(mlist[[1]])
    #mlist <- mlist[, -last_col]
    r <- gelman_rubin(mlist, gp)
    # r <- gelman.diag(mlist)
    message("   r: ", r$mpsrf)
    if(r$mpsrf < 2 && all(neff > 500)) break()
    mp$burnin <- mp$burnin*2
  }
  model.list
}

selectModels <- function(model.list){
  ## cluster models in two groups by mean of the log likelihood
  ## discard models that cluster in a group of low log likelihoods
  ch.list <- map(model.list, "chains")
  ll <- map(ch.list, "logll")
  mns <- map_dbl(ll, mean)
  cl <- kmeans(mns, centers=2)$cluster
  mean.ll <- sort(map_dbl(split(mns, cl), mean))
  keep <- cl == names(mean.ll)[2]
  if(sum(keep) < 3){
    ## keep all models
    keep <- rep(TRUE, length(model.list))
  }
  keep
}

diagnostics <- function(model.list){
  mlist <- mcmcList(model.list)
  neff <- effectiveSize(mlist)
  r <- gelman_rubin(mlist, gp)  
  list(neff=neff, r=r)
}

unlistModels <- function(model.list){
  ch.list <- map(model.list, "chains")
  z <- map(ch.list, "z") %>%  Reduce("+", .)
  theta <- map(ch.list, "theta") %>% do.call(rbind, .)
  sigma <- map(ch.list, "sigma") %>% do.call(rbind, .)
  logll <- map(ch.list, "logll") %>% unlist
  tau <- map(ch.list, "tau") %>% do.call(rbind, .)
  p <- map(ch.list, "p") %>% do.call(rbind, .)
  chains <- list(z=z,
                 theta=theta,
                 sigma=sigma,
                 logll=logll,
                 tau=tau,
                 p=p)
  mp <- model.list[[1]]$mp
  mp$iter <- nrow(theta)
  model <- list(gp=model.list[[1]]$gp,
                mp=mp,
                current=model.list[[1]]$current,
                chains=chains)
  model
}


# function to check if the real parameter is captured in the credible interval of the posterior chains
credible.interval <- function (gibbs.model, gibbs.result, real.params) {
  
  ci.mat <- matrix(0, ncol=nrow(real.params)*(ncol(real.params)+1), nrow=1)
  mixture.ci.parents <- data.frame(HPDinterval(mcmc(gibbs.model$post.pi), 0.95))
  mixture.ci.offspring <- data.frame(HPDinterval(mcmc(gibbs.model$post.pi.child), 0.95))
  theta.ci <- data.frame(HPDinterval(mcmc(gibbs.model$post.thetas), 0.95))
  sigma.ci <- data.frame(HPDinterval(mcmc(gibbs.model$post.sigmas), 0.95))
  
  #w <- real.params[,1]
  #x <- real.params[,1]
  #y <- real.params[,2]
  #z <- real.params[,3]
  
  w <- gibbs.result$mixture.prob.tb2$empirical.parents
  x <- gibbs.result$mixture.prob.tb2$empirical.offspring
  y <- gibbs.result$mixture.prob.tb2$empirical.thetas
  z <- (gibbs.result$mixture.prob.tb2$empirical.sigmas)#^0.5
  
  mix.parents.hit <- mixture.ci.parents %>%
    filter(lower<=w, w<=upper)
  mix.par <- mixture.ci.parents[,1] %in% mix.parents.hit[,1]
  
  mix.offspring.hit <- mixture.ci.offspring %>%
    filter(lower<=x, x<=upper)
  mix.off <- mixture.ci.offspring[,1] %in% mix.offspring.hit[,1]
  
  theta.hit <- theta.ci %>%
    filter(lower<=y, y<=upper)
  theta <- theta.ci[,1] %in% theta.hit[,1]
  
  sigma.hit <- sigma.ci %>%
    filter(lower<=z, z<=upper)
  sigma <- sigma.ci[,1] %in% sigma.hit[,1]
  
  col <- nrow(real.params)  
  ci.mat[1,1:col] <- mix.par
  ci.mat[1,(col+1) : (2*col)] <- mix.off
  ci.mat[1,(2*col+1) : (3*col)] <- theta
  ci.mat[1,(3*col+1) : (4*col)] <- sigma
  
  return (ci.mat)
}

credible.interval.quantile <- function (gibbs.model, gibbs.result, real.params) {
  
  ci.mat <- matrix(0, ncol=nrow(real.params)*(ncol(real.params)+1), nrow=1)
  mixture.ci.parents <- apply(gibbs.model$post.pi, 2, quantile, probs = c(0.025, 0.975))
  mixture.ci.offspring <- apply(gibbs.model$post.pi.child, 2, quantile, probs = c(0.025, 0.975))
  theta.ci <- apply(gibbs.model$post.thetas, 2, quantile, probs = c(0.025, 0.975))
  sigma.ci <- apply(gibbs.model$post.sigmas, 2, quantile, probs = c(0.025, 0.975))
  
  # reshape
  mixture.ci.parents <- data.frame(t(mixture.ci.parents))
  colnames(mixture.ci.parents) <- c("lower", "upper")
  mixture.ci.offspring <- data.frame(t(mixture.ci.offspring))
  colnames(mixture.ci.offspring) <- c("lower", "upper")
  theta.ci <- data.frame(t(theta.ci))
  colnames(theta.ci) <- c("lower", "upper")
  sigma.ci <- data.frame(t(sigma.ci))
  colnames(sigma.ci) <- c("lower", "upper")
  
  w <- gibbs.result$mixture.prob.tb2$empirical.parents
  x <- gibbs.result$mixture.prob.tb2$empirical.offspring
  y <- gibbs.result$mixture.prob.tb2$empirical.thetas
  z <- (gibbs.result$mixture.prob.tb2$empirical.sigmas)#^0.5
  
  mix.parents.hit <- mixture.ci.parents %>%
    filter(lower<=w, w<=upper)
  mix.par <- mixture.ci.parents[,1] %in% mix.parents.hit[,1]
  
  mix.offspring.hit <- mixture.ci.offspring %>%
    filter(lower<=x, x<=upper)
  mix.off <- mixture.ci.offspring[,1] %in% mix.offspring.hit[,1]
  
  theta.hit <- theta.ci %>%
    filter(lower<=y, y<=upper)
  theta <- theta.ci[,1] %in% theta.hit[,1]
  
  sigma.hit <- sigma.ci %>%
    filter(lower<=z, z<=upper)
  sigma <- sigma.ci[,1] %in% sigma.hit[,1]
  
  col <- nrow(real.params)  
  ci.mat[1,1:col] <- mix.par
  ci.mat[1,(col+1) : (2*col)] <- mix.off
  ci.mat[1,(2*col+1) : (3*col)] <- theta
  ci.mat[1,(3*col+1) : (4*col)] <- sigma
  
  return (ci.mat)
}

init.pi.child <- function(dat.y, cn.mf, params, theta){
  N.half <- params$N / 2
  yy <- as.numeric(dat.y[, c("o")])
  cn.mf.mat <- as.matrix(cn.mf)
  cn.mf.mat <- cn.mf.mat[1:N.half,]
  cn <- as.numeric(cn.mf.mat[, 1:2])
  tab <- tapply(yy, cn, length)
  freq <- setNames(rep(0L, length(theta)),
                   attributes(tab)$dimnames[[1]])
  freq[names(tab)] <- tab
  freq <- freq / params$N
}

ess.compile <- function(gibbs.result, real.params){
  ess.mat <- matrix(0, ncol=nrow(real.params)*(ncol(real.params)+1), nrow=1)
  
  col <- nrow(real.params)  
  ess.mat[1,1:col] <- gibbs.result$effective.size.pi
  ess.mat[1,(col+1) : (2*col)] <- gibbs.result$effective.size.pi.child
  ess.mat[1,(2*col+1) : (3*col)] <- gibbs.result$effective.size.thetas
  ess.mat[1,(3*col+1) : (4*col)] <- gibbs.result$effective.size.sigmas
  
  return(ess.mat)
}
githubmpc/marimba2 documentation built on May 17, 2019, 9:11 a.m.