R/ld_prune_pval_format3.R

#'@title Rank SNPs by pval, prune for LD
#'@description Selects a set of top SNP based on p-value that are in minimal LD.
#'Performs the same function as top_snps_pval_ld but with different input format.
#'@param ld_files List of file prefixes for LD files. One for each chromosome
#'@param snp_info_rds List of RDS files with snp info. One for each chromosome.
#'@param pvals Vector of pvalues (optional)
#'@param snps List of snps
#'@param r2_thresh r^2 threshold for pruning
#'@param pval_thresh Maximum pvalue
#'@param cores Number of cores to use
#'@export
ld_prune_pval_format3 <- function(ld_files, snp_info_files,
                              snps, pval_cols=NA, data_snp_col = "SNP",
                              snp_info_snp_col = "SNP", rowsnp_col = "rowsnp",
                              colsnp_col = "colsnp",
                              pval_thresh = Inf, r2_thresh=0.01,
                              cores=parallel::detectCores()-1){

  stopifnot(length(pval_thresh) == length(pval_cols))
  if(any(is.finite(pval_thresh) & is.na(pval_cols))){
    stop("pval_thresh is given but pvals is missing.\n")
  }

  nfiles <- length(snp_info_files)
  stopifnot(length(ld_files)==nfiles)

  p1 <- nrow(snps)
  stopifnot(data_snp_col %in% names(snps))

  if(any(is.na(pval_cols))){
    na_col_names <- paste0("pNA_", seq(sum(is.na(pval_cols))))
    pval_thresh[is.na(pval_cols)] <- Inf
    pval_cols[is.na(pval_cols)] <- na_col_names
    for(n in na_col_names){
      snps[[n]] <- sample(seq(p1), size=p1, replace=FALSE)/p1
    }
  }
  cat("You have provided information for ", p1, " variants.\n")

  s_in_d <- map(snp_info_files, function(x){
				     dat <- readRDS(x)
				     return(dat[[snp_info_snp_col]])})
  snps_in_ld_data <- unlist(s_in_d)

  cat("There are ", length(snps_in_ld_data), " variants in the LD data.\n")
  snps <- snps %>%
          filter(get(data_snp_col) %in% snps_in_ld_data)
  cat("Of the provided variants, ", nrow(snps), " are in the LD data.\n")

  cl <- makeCluster(cores, type="PSOCK")
  clusterExport(cl, varlist=c("snps", "data_snp_col", "rowsnp_col",
			      "colsnp_col", "pval_cols", "ld_files",
			      "s_in_d", "r2_thresh", "pval_thresh"), envir=environment())
  clusterEvalQ(cl, library(dplyr))
  keep <- parLapply(cl, seq_along(ld_files), fun=function(i){
                     	my_snps <- filter(snps, get(data_snp_col) %in% s_in_d[[i]] )
                     	if(nrow(my_snps) == 0)return(c())
                     	my_snp_ids <- my_snps[[data_snp_col]]

                     	ld <- readRDS(ld_files[i]) %>%
                              filter(get(colsnp_col) %in% my_snp_ids & get(rowsnp_col) %in% my_snp_ids) %>%
                     	        filter(r2 >= r2_thresh)
                     	o_c <- lapply(seq_along(pval_cols), function(i){
				                            n <- pval_cols[i]
				                            ct <- sum(my_snps[[n]]  < pval_thresh[i])
				                            order(my_snps[[n]], decreasing=FALSE)[seq(ct)]
			                })
                     	k_c <- lapply(seq_along(pval_cols), function(x){c()})
                     	for(j in seq_along(pval_cols)){
                     	  while(length(o_c[[j]]) > 0){
                       		snp <- my_snp_ids[o_c[[j]][1]]
                        	k_c[[j]] <- c(k_c[[j]], snp)
                        	if(length(o_c[[j]]) == 1){
                          		o_c[[j]] <- numeric()
                          		next
                        	}
                        	myld <- filter(ld, get(rowsnp_col)==snp | get(colsnp_col)==snp)
                        	if(nrow(myld) > 0){
                          		remove_snps <- with(myld, unique( c(get(rowsnp_col),get(colsnp_col))))
                          		remove_ix <- which(my_snp_ids %in% remove_snps)
                        	}else{
                          		remove_ix <- o_c[[j]][1]
                        	}
                        	o_c[[j]] <- o_c[[j]][!o_c[[j]] %in% remove_ix]
				                  cat(length(o_c), " ", length(o_c[[j]]), "\n")
                        }
                    }
                    return(k_c)
                  })
  stopCluster(cl)
  keep <- map(seq(length(pval_cols)), function(x){map(keep, x) %>% flatten()})
  return(keep)
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.