R/ComBat.R

#' Adjust for batch effects using an empirical Bayesian or parameter frameworks
#'
#' ComBat allows users to adjust for batch effects in datasets where the batch covariate is known, using methodology
#' described in Johnson et al. 2007. It uses either parametric or non-parametric empirical Bayes frameworks for adjusting data for
#' batch effects.  Users are returned an expression matrix that has been corrected for batch effects. The input
#' data are assumed to be cleaned and normalized before batch effect removal.[[Cpp Rewrite and we finally get 100x speedup]]
#' (For 485512 x 6 expression array matrix = 9622.983 secs = 2 hrs 40 mins) and origin source code of pure R version is from minif package
#'
#' @param dat Genomic measure matrix (dimensions probe x sample) - for example, expression matrix
#' @param batch {Batch covariate (only one batch allowed)}
#' @param mod Model matrix for outcome of interest and other covariates besides batch
#' @param par.prior (Optional) TRUE indicates parametric adjustments will be used, FALSE indicates non-parametric adjustments will be used
#' @param prior.plots (Optional)TRUE give prior plots with black as a kernel estimate of the empirical batch effect density and red as the parametric
#' @importFrom genefilter rowVars
#' @return data :
#'         A probe x sample genomic measure matrix, adjusted for batch effects.
#' @export
#' @author Xin Zhou \url{xinchoubiology@@gmail.com}

ComBat <- function(dat, batch, mod=NULL, par.prior=TRUE,prior.plots=FALSE) {
  # make batch a factor and make a set of indicators for batch
  if(length(dim(batch))>1){stop("This version of ComBat only allows one batch variable")}  ## to be updated soon!
  batch <- as.factor(batch)
  batchmod <- model.matrix(~ -1 + batch)
  cat("Found",nlevels(batch),'batches\n')

  # A few other characteristics on the batches
  n.batch <- nlevels(batch)
  batches <- list()
  for (i in 1:n.batch){batches[[i]] <- which(batch == levels(batch)[i])} # list of samples in each batch
  n.batches <- sapply(batches, length)
  n.array <- sum(n.batches)

  #combine batch variable and covariates
  design <- cbind(batchmod,mod)

  # check for intercept in covariates, and drop if present
  check <- apply(design, 2, function(x) all(x == 1))
  design <- as.matrix(design[,!check])

  # Number of covariates or covariate levels
  cat("Adjusting for",ncol(design)-ncol(batchmod),'covariate(s) or covariate level(s)\n')

  # Check if the design is confounded
  if(qr(design)$rank<ncol(design)){
    #if(ncol(design)<=(n.batch)){stop("Batch variables are redundant! Remove one or more of the batch variables so they are no longer confounded")}
    if(ncol(design)==(n.batch+1)){stop("The covariate is confounded with batch! Remove the covariate and rerun ComBat")}
    if(ncol(design)>(n.batch+1)){
      if((qr(design[,-c(1:n.batch)])$rank<ncol(design[,-c(1:n.batch)]))){stop('The covariates are confounded! Please remove one or more of the covariates so the design is not confounded')
      }else{stop("At least one covariate is confounded with batch! Please remove confounded covariates and rerun ComBat")}}
  }

  ## Check for missing values
  NAs = any(is.na(dat))
  if(NAs){cat(c('Found',sum(is.na(dat)),'Missing Data Values\n'),sep=' ')}
  #print(dat[1:2,])
  ##Standardize Data across genes
  cat('Standardizing Data across genes\n')
  if (!NAs){B.hat <- solve(t(design)%*%design)%*%t(design)%*%t(as.matrix(dat))}else{B.hat=apply(dat,1,Beta.NA,design)} #Standarization Model
  grand.mean <- t(n.batches/n.array)%*%B.hat[1:n.batch,]
  if (!NAs){var.pooled <- ((dat-t(design%*%B.hat))^2)%*%rep(1/n.array,n.array)}else{var.pooled <- apply(dat-t(design%*%B.hat),1,var,na.rm=T)}

  stand.mean <- t(grand.mean)%*%t(rep(1,n.array))
  if(!is.null(design)){tmp <- design;tmp[,c(1:n.batch)] <- 0;stand.mean <- stand.mean+t(tmp%*%B.hat)}
  s.data <- (dat-stand.mean)/(sqrt(var.pooled)%*%t(rep(1,n.array)))

  ##Get regression batch effect parameters
  cat("Fitting L/S model and finding priors\n")
  batch.design <- design[,1:n.batch]
  if (!NAs){
    gamma.hat <- solve(t(batch.design)%*%batch.design)%*%t(batch.design)%*%t(as.matrix(s.data))
  } else{
    gamma.hat=apply(s.data,1,Beta.NA,batch.design)

  }
  delta.hat <- NULL
  for (i in batches){
      delta.hat <- rbind(delta.hat, na.omit(rowVars(s.data[,i])))
  }

  ##Find Priors
  gamma.bar <- rowMeans(gamma.hat)
  t2 <- rowVars(gamma.hat)
  a.prior <- apply(delta.hat, 1, aprior)
  b.prior <- apply(delta.hat, 1, bprior)

  ##Plot empirical and parametric priors

  if (prior.plots & par.prior){
    par(mfrow=c(2,2))
    tmp <- density(gamma.hat[1,])
    plot(tmp,  type='l', main="Density Plot")
    xx <- seq(min(tmp$x), max(tmp$x), length=100)
    lines(xx,dnorm(xx,gamma.bar[1],sqrt(t2[1])), col=2)
    qqnorm(gamma.hat[1,])
    qqline(gamma.hat[1,], col=2)

    tmp <- density(delta.hat[1,])
    invgam <- 1/rgamma(ncol(delta.hat),a.prior[1],b.prior[1])
    tmp1 <- density(invgam)
    plot(tmp,  typ='l', main="Density Plot", ylim=c(0,max(tmp$y,tmp1$y)))
    lines(tmp1, col=2)
    qqplot(delta.hat[1,], invgam, xlab="Sample Quantiles", ylab='Theoretical Quantiles')
    lines(c(0,max(invgam)),c(0,max(invgam)),col=2)
    title('Q-Q Plot')
  }

  ##Find EB batch adjustments

  gamma.star <- delta.star <- NULL
  if(par.prior){
    cat("Finding parametric adjustments\n")
    for (i in 1:n.batch){
        temp <- it.sol(s.data[,batches[[i]]],gamma.hat[i,],
                       delta.hat[i,],gamma.bar[i],t2[i],a.prior[i],b.prior[i])
        gamma.star <- rbind(gamma.star,temp[1,])
        delta.star <- rbind(delta.star,temp[2,])
    }
  } else{
    cat("Finding nonparametric adjustments\n")
    for (i in 1:n.batch){
        temp <- int.eprior(as.matrix(s.data[,batches[[i]]]),gamma.hat[i,],delta.hat[i,])
        gamma.star <- rbind(gamma.star,temp[1,])
        delta.star <- rbind(delta.star,temp[2,])
    }
  }


  ### Normalize the Data ###
  cat("Adjusting the Data\n")

  bayesdata <- s.data
  j <- 1
  for (i in batches){
    bayesdata[,i] <- (bayesdata[,i]-t(batch.design[i,]%*%gamma.star))/(sqrt(delta.star[j,])%*%t(rep(1,n.batches[j])))
    j <- j + 1 
  }

  bayesdata <- (bayesdata*(sqrt(var.pooled)%*%t(rep(1,n.array))))+stand.mean
  
  if(any(is.na(bayesdata))){
    NAs <- Index.NA(bayesdata, by = "row")
    bayesdata[NAs, ] <- dat[NAs, ]
  }

  return(bayesdata)
}
xinchoubiology/Rcppsva documentation built on May 4, 2019, 1:06 p.m.