R/fstats.R

Defines functions indpairs_to_f2blocks xmats_to_pairarrs mats_to_fstarr mats_to_ctarr mats_to_aparr mats_to_f2arr make_chunks afs_to_f2_blocks

Documented in afs_to_f2_blocks

#' Compute all pairwise f2 statistics
#'
#' This function takes a allele frequency data and computes blocked f2 statistics for all population pairs,
#' which are written to `outdir`. \eqn{f2} for each SNP is computed as
#' \eqn{(p1 - p2)^2 - p1 (1 - p1)/(n1 - 1) - p2 (1 - p2)/(n2 - 1)}, where \eqn{p1} and \eqn{p2} are
#' allele frequencies in populations \eqn{1} and \eqn{2}, and \eqn{n1} and \eqn{n2} is the number of
#' non-missing haplotypes in populations \eqn{1} and \eqn{2}. See \code{details}
#' @param afdat A list with three items with the same SNP in each row (generated by \code{\link{packedancestrymap_to_afs}} or \code{\link{plink_to_afs}})
#' \itemize{
#' \item `afs` A matrix of allele frequencies for all populations (columns) and SNPs (rows)
#' \item `counts` A matrix of allele counts for all populations (columns) and SNPs (rows)
#' \item `snpdat` A data frame with SNP metadata
#' }
#' @param maxmem split up allele frequency data into blocks, if memory requirements exceed `maxmem` MB.
#' @param blgsize SNP block size in Morgan. Default is 0.05 (5 cM). If `blgsize` is 100 or greater, if will be interpreted as base pair distance rather than centimorgan distance.
#' @param poly_only Exclude sites with identical allele frequencies in all populations. Can be different for f2-statistics, allele frequency products, and fst. Should be a character vector of length three, with some subset of `c("f2", "ap", "fst")`
#' @param pop1 `pops1` and `pops2` can be specified if only a subset of pairs should be computed.
#' @param pop2 `pops1` and `pops2` can be specified if only a subset of pairs should be computed.
#' @param outpop If specified, f2-statistics will be weighted by heterozygosity in this population
#' @param outdir Directory into which to write f2 data (if `NULL`, data is returned instead)
#' @param overwrite Should existing files be overwritten? Only relevant if `outdir` is not `NULL`
#' @param verbose Print progress updates
#' @details For each population pair, each of the \eqn{i = 1, \ldots, n} resutling values
#' (\eqn{n} is around 700 in practice) is the mean \eqn{f2} estimate across all SNPs except the ones in block \eqn{i}.
#'
#' \eqn{- p1 (1 - p1)/(2 n1 - 1) - p2 (1 - p2)/(2 n2 - 1)} is a correction term which makes the estimates
#' unbiased at low sample sizes.
#' @examples
#' \dontrun{
#' afdat = plink_to_afs('/my/geno/prefix')
#' afs_to_f2_blocks(afdat, outdir = '/my/f2/data/')
#' }
afs_to_f2_blocks = function(afdat, maxmem = 8000, blgsize = 0.05,
                            pops1 = NULL, pops2 = NULL, outpop = NULL, outdir = NULL,
                            overwrite = FALSE, afprod = TRUE, fst = TRUE, poly_only = c('f2'),
                            apply_corr = apply_corr, n_cores = 1, verbose = TRUE) {

  # splits afmat into blocks by column, computes snp blocks on each pair of population blocks,
  #   and combines into 3d array

  old = `[`; `[` = function(...) old(..., drop=F)
  afmat = afdat$afs
  countmat = afdat$counts
  mem = lobstr::obj_size(afmat)
  pops = colnames(afmat)
  if(is.null(pops1) || is.null(pops2) || isTRUE(all.equal(pops, pops1)) && isTRUE(all.equal(pops1, pops2))) {
    pops1 = pops2 = pops
    square = TRUE
  } else square = FALSE

  if(verbose) alert_info(paste0('Allele frequency matrix for ', nrow(afmat), ' SNPs and ',
                         length(pops), ' populations is ', round(mem/1e6), ' MB\n'))

  chunks = make_chunks(pops, mem, maxmem/n_cores, pops1, pops2, verbose = verbose)
  popvecs1 = chunks$popvecs1
  popvecs2 = chunks$popvecs2

  sp = afdat$snpfile$poly
  sn = 1:nrow(afdat$snpfile)

  block_lengths_p = get_block_lengths(afdat$snpfile[sp,], blgsize = blgsize)
  block_lengths_n = get_block_lengths(afdat$snpfile[sn,], blgsize = blgsize)

  if(!is.null(outpop)) {
    p = afmat[,outpop]
    snpwt = 1/(p*(1-p))
  } else snpwt = NULL

  if(is.null(outdir)) {
    dim_p = c(length(pops1), length(pops2), length(block_lengths_p))
    dim_n = c(length(pops1), length(pops2), length(block_lengths_n))
    nam_p = list(pops1, pops2, paste0('l', block_lengths_p))
    nam_n = list(pops1, pops2, paste0('l', block_lengths_n))
    f2_blocks = if('f2' %in% poly_only) array(NA, dim_p, nam_p) else array(NA, dim_n, nam_n)
    ap_blocks = fst_blocks = NULL
    if(afprod) ap_blocks = if('ap' %in% poly_only) array(NA, dim_p, nam_p) else array(NA, dim_n, nam_n)
    if(fst) fst_blocks = if('fst' %in% poly_only) array(NA, dim_p, nam_p) else array(NA, dim_n, nam_n)
  }

  if('f2' %in% poly_only) {
    sf2 = sp
    block_lengths_f2 = block_lengths_p
  } else {
    sf2 = sn
    block_lengths_f2 = block_lengths_n
  }
  if('ap' %in% poly_only) {
    sap = sp
    block_lengths_ap = block_lengths_p
  } else {
    sap = sn
    block_lengths_ap = block_lengths_n
  }
  if('fst' %in% poly_only) {
    sfst = sp
    block_lengths_fst = block_lengths_p
  } else {
    sfst = sn
    block_lengths_fst = block_lengths_n
  }

  if(n_cores > 1) {
    doParallel::registerDoParallel(n_cores)
    `%do%` = foreach::`%dopar%`
  } else {
    `%do%` = foreach::`%do%`
  }
  #for(i in 1:length(popvecs1)) {
  foreach::foreach(i=1:length(popvecs1)) %do% {
    if(length(popvecs1) > 1 & verbose) cat(paste0('\rpop pair block ', i, ' out of ', length(popvecs1)))
    s1 = popvecs1[[i]]
    s2 = popvecs2[[i]]
    am1 = afmat[, s1]
    am2 = afmat[, s2]
    cm1 = countmat[, s1]
    cm2 = countmat[, s2]

    f2 = mats_to_f2arr(am1[sf2,], am2[sf2,], cm1[sf2,], cm2[sf2,], block_lengths_f2, snpwt, apply_corr = apply_corr)
    counts = mats_to_ctarr(am1[sf2,], am2[sf2,], cm1[sf2,], cm2[sf2,], block_lengths_f2)
    if(isTRUE(all.equal(s1, s2))) for(j in 1:dim(f2)[1]) f2[j, j, ] = 0

    if(afprod) {
      aparr = mats_to_aparr(am1[sap,], am2[sap,], cm1[sap,], cm2[sap,], block_lengths_ap)
      if(isTRUE(all.equal(sap, sf2))) {
        countsap = counts
      } else countsap = mats_to_ctarr(am1[sap,], am2[sap,], cm1[sap,], cm2[sap,], block_lengths_ap)
    }
    if(fst) {
      fstarr = mats_to_fstarr(am1[sfst,], am2[sfst,], cm1[sfst,], cm2[sfst,], block_lengths_fst, snpwt)
      if(isTRUE(all.equal(sfst, sf2))) {
        countsfst = counts
      } else if(isTRUE(all.equal(sfst, sap)) && afprod) {
        countsfst = countsap
      } else countsfst = mats_to_ctarr(am1[sfst,], am2[sfst,], cm1[sfst,], cm2[sfst,], block_lengths_fst)
      if(isTRUE(all.equal(s1, s2))) for(j in 1:dim(f2)[1]) fstarr[j, j, ] = 0
    }
    rm(am1, am2, cm1, cm2); gc()


    if(!is.null(outdir)) {
      if(!isTRUE(all.equal(dim(f2), dim(counts)))) browser()
      write_f2(f2, counts, outdir = outdir, id = 'f2', overwrite = overwrite)
      bl = paste0(outdir, '/block_lengths_f2.rds')
      if(!file.exists(bl) || overwrite) saveRDS(block_lengths_f2, file = bl)

      if(afprod) {
        write_f2(aparr, countsap, outdir = outdir, id = 'ap', overwrite = overwrite)
        bl = paste0(outdir, '/block_lengths_ap.rds')
        if(!file.exists(bl) || overwrite) saveRDS(block_lengths_ap, file = bl)
      }
      if(fst) {
        write_f2(fstarr, countsfst, outdir = outdir, id = 'fst', overwrite = overwrite)
        bl = paste0(outdir, '/block_lengths_fst.rds')
        if(!file.exists(bl) || overwrite) saveRDS(block_lengths_fst, file = bl)
      }

    } else {
      f2_blocks[s1, s2, ] = f2
      if(afprod) ap_blocks[s1, s2, ] = aparr
      if(fst) fst_blocks[s1, s2, ] = fstarr
      if(square && !isTRUE(all.equal(s1, s2))) {
        f2_blocks[s2, s1, ] = aperm(f2, c(2,1,3))
        if(afprod) ap_blocks[s2, s1, ] = aperm(aparr, c(2,1,3))
        if(fst) fst_blocks[s2, s1, ] = aperm(fstarr, c(2,1,3))
      }
    }
    rm(counts, f2); gc()
  }
  if(length(popvecs1) > 1 & verbose) cat('\n')
  if(is.null(outdir)) namedList(f2_blocks, ap_blocks, fst_blocks)
}


make_chunks = function(pops, mem, maxmem, pops1 = pops, pops2 = pops, verbose = TRUE) {
  # determines start and end positions of each chunk

  if(isTRUE(all.equal(pops1, pops2))) {
    stopifnot(isTRUE(all.equal(pops, pops1)))
    npops = length(pops)
    mem2 = mem*npops*2
    numsplits = ceiling(mem2/1e6/maxmem)
    width = ceiling(npops/numsplits)
    starts = seq(1, npops, width)
    numsplits2 = length(starts)
    ends = c(lead(starts)[-numsplits2]-1, npops)
    cmb = combn(0:numsplits2, 2)+(1:0)

    popvecs1 = map(1:ncol(cmb), ~pops[starts[cmb[1,.]]:ends[cmb[1,.]]])
    popvecs2 = map(1:ncol(cmb), ~pops[starts[cmb[2,.]]:ends[cmb[2,.]]])

    npair = choose(numsplits2+1,2)

  } else {

    stopifnot(all(pops1 %in% pops) && all(pops2 %in% pops))
    ntot = length(pops)
    n1 = length(pops1)
    n2 = length(pops2)
    mem2 = mem*n1*n2*2/ntot
    numsplits = ceiling(mem2/1e6/maxmem)
    popind2 = match(pops2, pops)

    width = ceiling(n2/numsplits)
    starts2 = seq(1, n2, width)
    numsplits2 = npair = length(starts2)
    ends2 = c(lead(starts2)[-numsplits2]-1, n2)
    cmb = rbind(1, seq_len(numsplits2))

    popvecs1 = rerun(numsplits2, pops1)
    popvecs2 = map(1:numsplits2, ~pops2[starts2[.]:ends2[.]])
  }
  if(verbose) {
    reqmem = round(mem2/1e6)
    alert_info(paste0('Computing pairwise f2 for all SNPs and population pairs requires ',
                      reqmem, ' MB RAM without splitting\n'))
    if(numsplits2 > 1) alert_info(paste0('Splitting into ', numsplits2, ' chunks of ',
                                         width, ' populations and up to ', maxmem, ' MB (',
                                         npair, ' chunk pairs)\n'))
    else alert_info(paste0('Computing without splitting since ', reqmem, ' < ', maxmem, ' (maxmem)...\n'))
  }
  namedList(popvecs1, popvecs2)
}


mats_to_f2arr = function(afmat1, afmat2, countmat1, countmat2, block_lengths,
                         snpwt = NULL, cpp = TRUE, apply_corr = TRUE) {

  nc1 = ncol(countmat1)
  nc2 = ncol(countmat2)

  stopifnot(all.equal(nrow(afmat1), nrow(afmat2), nrow(countmat1), nrow(countmat2)))
  stopifnot(all.equal(ncol(afmat1), nc1))
  stopifnot(all.equal(ncol(afmat2), nc2))

  if(cpp) {
    out = cpp_mats_to_f2_arr(afmat1, afmat2, countmat1, countmat2, apply_corr)
  } else {
    denom1 = matrix(pmax(1, countmat1-1), nrow(countmat1))
    denom2 = matrix(pmax(1, countmat2-1), nrow(countmat2))
    #denom1 = countmat1-1
    #denom2 = countmat2-1
    out = outer_array(afmat1, afmat2, `-`)^2
    if(apply_corr) {
      corr1 = afmat1*(1-afmat1)/denom1
      corr2 = afmat2*(1-afmat2)/denom2
      out = (out - outer_array(corr1, corr2, `+`))
    }
  }
  if(!is.null(snpwt)) {
    stopifnot(length(snpwt) == nrow(afmat1))
    out = out * rep(snpwt, each = nc1*nc2)
  }
  out %<>% block_arr_mean(block_lengths)
  dimnames(out) = list(colnames(afmat1), colnames(afmat2), paste0('l', block_lengths))
  out
}

mats_to_aparr = function(afmat1, afmat2, countmat1, countmat2, block_lengths,
                         snpwt = NULL, cpp = TRUE, apply_corr = TRUE) {

  stopifnot(all.equal(nrow(afmat1), nrow(afmat2), nrow(countmat1), nrow(countmat2)))
  stopifnot(all.equal(ncol(afmat1), ncol(countmat1)))
  stopifnot(all.equal(ncol(afmat2), ncol(countmat2)))

  if(cpp) {
    out = (cpp_outer_array_mul(afmat1, afmat2) + cpp_outer_array_mul(1-afmat1, 1-afmat2))/2
  } else {
    out = (outer_array(afmat1, afmat2) + outer_array(1-afmat1, 1-afmat2))/2
  }
  out %<>% block_arr_mean(block_lengths)

  dimnames(out) = list(colnames(afmat1), colnames(afmat2), paste0('l', block_lengths))
  out
}

mats_to_ctarr = function(afmat1, afmat2, countmat1, countmat2, block_lengths, cpp = TRUE) {

  stopifnot(all.equal(nrow(afmat1), nrow(afmat2), nrow(countmat1), nrow(countmat2)))
  stopifnot(all.equal(ncol(afmat1), ncol(countmat1)))
  stopifnot(all.equal(ncol(afmat2), ncol(countmat2)))

  if(cpp) {
    out = cpp_outer_array_mul(!is.na(afmat1), !is.na(afmat2))
  } else {
    out = outer_array(!is.na(afmat1), !is.na(afmat2))
  }
  out %<>% block_arr_mean(block_lengths)

  dimnames(out) = list(colnames(afmat1), colnames(afmat2), paste0('l', block_lengths))
  out
}

mats_to_fstarr = function(afmat1, afmat2, countmat1, countmat2, block_lengths,
                          snpwt = NULL, cpp = FALSE, apply_corr = TRUE) {

  nc1 = ncol(countmat1)
  nc2 = ncol(countmat2)
  nr = nrow(afmat1)

  stopifnot(all.equal(nrow(afmat1), nrow(afmat2), nrow(countmat1), nrow(countmat2)))
  stopifnot(all.equal(ncol(afmat1), nc1))
  stopifnot(all.equal(ncol(afmat2), nc2))

  if(cpp) {
    # todo
  } else {
    denom1 = countmat1-1
    denom2 = countmat2-1
    het1b = afmat1*(1-afmat1)
    het2b = afmat2*(1-afmat2)
    corr1 = het1b/denom1
    corr2 = het2b/denom2
    het1ub = het1b * countmat1 / denom1
    het2ub = het2b * countmat2 / denom2
    num = (outer_array(afmat1, afmat2, `-`)^2 - outer_array(corr1, corr2, `+`))
    fstdenom = num + outer_array(het1ub, het2ub, `+`)
    num %<>% block_arr_mean(block_lengths)
    fstdenom %<>% block_arr_mean(block_lengths)
    out = num/fstdenom
  }
  dimnames(out) = list(colnames(afmat1), colnames(afmat2), paste0('l', block_lengths))
  out
}


xmats_to_pairarrs = function(xmat1, xmat2) {
  # Compute aa and nn for all SNPs and all population pairs

  nr = nrow(xmat1)
  stopifnot(nr == nrow(xmat2))
  xmat1 %<>% fix_ploidy
  xmat2 %<>% fix_ploidy
  ploidy1 = attr(xmat1, 'ploidy')
  ploidy2 = attr(xmat2, 'ploidy')
  n1 = (!is.na(xmat1)) * rep(ploidy1, each = nr)
  n2 = (!is.na(xmat2)) * rep(ploidy2, each = nr)
  xmat1 %<>% replace_na(0)
  xmat2 %<>% replace_na(0)
  aa = prodarray(xmat1, xmat2)
  nn = prodarray(n1, n2)
  dimnames(aa)[1:2] = dimnames(nn)[1:2] = list(colnames(xmat1), colnames(xmat2))
  namedList(aa, nn)
}


indpairs_to_f2blocks = function(indivs, pairs, poplist, block_lengths,
                                afprod = FALSE, return_array = TRUE, apply_corr = TRUE) {
  # creates f2_blocks from per individual data
  # make fast version in Rcpp and use tibbles here for readability
  # indivs: data frame with columns ind, bl, a, n
  # pairs: data frame with columns ind1, ind2, bl, aa, nn
  # poplist: data frame with columns ind, pop
  # block_lengths
  # a is number of alt alleles, n number of ref + alt alleles.
  stopifnot('ind' %in% names(poplist) && 'pop' %in% names(poplist))

  # the following line is a shortcut which only makes sense as long as I only want to return a df
  # when I don't care about blocks
  if(!return_array) {indivs$bl = 1; pairs$bl = 1}

  indsums = indivs %>% left_join(poplist, by='ind') %>%
    group_by(pop, bl) %>% summarize(p = mean(a[n>0]/n[n>0]), a = sum(a), n = sum(n)) %>% ungroup

  pairs %<>% bind_rows(rename(., ind1 = ind2, ind2 = ind1)) %>% filter(!duplicated(.))
  pairsums = pairs %>%
    left_join(poplist %>% transmute(ind1 = ind, pop1 = pop), by = 'ind1') %>%
    left_join(poplist %>% transmute(ind2 = ind, pop2 = pop), by = 'ind2') %>%
    group_by(pop1, pop2, bl) %>%
    #mutate(aa2 = ifelse(ind1 == ind2, mean(aa[ind1 != ind2], na.rm=T), aa),
    #       nn2 = ifelse(ind1 == ind2, mean(nn[ind1 != ind2], na.rm=T), nn)) %>%
    summarize(pp = mean(aa[nn>0]/nn[nn>0]), aa = sum(aa), nn = sum(nn)) %>% ungroup
    #summarize(pp = weighted.mean(aa[nn>0]/nn[nn>0], nn2), aa = sum(aa), nn = sum(nn)) %>% ungroup

  pairsums_samepop = pairsums %>% filter(pop1 == pop2) %>% transmute(pop = pop1, bl, aa, nn, pp)

  main = pairsums %>%
    left_join(pairsums_samepop %>% transmute(pop1 = pop, bl, pp1 = pp), by = c('pop1', 'bl')) %>%
    left_join(pairsums_samepop %>% transmute(pop2 = pop, bl, pp2 = pp), by = c('pop2', 'bl')) %>%
    mutate(f2uncorr = pp1 + pp2 - 2*pp)

  # corr = pairsums_samepop %>% left_join(indsums, by=c('bl', 'pop')) %>%
  #   mutate(den1 = pmax(nn - n, n),
  #          n3unfix = n*nn, n3fix = nn/n^2, n3 = n3unfix * n3fix,
  #          den2 = pmax(n3 - nn, nn),
  #          # used to be den2 = pmax(n3 - nn, 1),
  #          corr = a/den1 - aa/den2,
  #          corr = pmax(0, corr))

  corr = pairsums_samepop %>%
    left_join(indsums, by=c('bl', 'pop')) %>%
    mutate(corr = pmax(0, (p-pp))/pmax(1, n-1))

  if(!apply_corr) corr$corr = 0

  f2dat = main %>%
    left_join(corr %>% transmute(pop1 = pop, bl, corr1 = corr), by = c('bl', 'pop1')) %>%
    left_join(corr %>% transmute(pop2 = pop, bl, corr2 = corr), by = c('bl', 'pop2')) %>%
    mutate(f2 = f2uncorr - corr1 - corr2, f2 = ifelse(pop1 == pop2, 0, f2),
           pop1pop2 = paste(pmin(pop1, pop2), pmax(pop1, pop2))) %>%
    #group_by(pop1pop2, bl) %>% mutate(cnt = n()) %>% ungroup %>%
    #bind_rows(filter(., pop1 != pop2 & cnt == 1) %>%
    #            rename(pop1 = pop2, pop2 = pop1, pp1 = pp2,
    #                   pp2 = pp1, corr1 = corr2, corr2 = corr1)) %>%
    arrange(bl, pop2, pop1)

  if(!return_array) return(f2dat)

  popnames = unique(poplist$pop)
  popnames2 = unique(pairsums_samepop$pop)
  npop = length(popnames)
  nblock = length(block_lengths)

  if(afprod) col = 'pp' else col = 'f2'
  array(f2dat[[col]], dim = c(npop, npop, nblock),
        dimnames = list(pop1 = popnames2,
                        pop2 = popnames2,
                        bl = paste0('l', block_lengths))) %>%
    `[`(popnames, popnames, ) %>%
    ifelse(is.nan(.), NA, .)
}


fix_ploidy = function(xmat) {
  # divides pseudohaploid columns by 2
  # returns list with updated xmat, with named 1/2 ploidy vector as attribute

  ploidy = apply(xmat, 2, function(x) length(unique(na.omit(x))))-1
  maxgt = apply(xmat, 2, max, na.rm = TRUE)
  fixcols = ploidy == 1 & maxgt == 2
  xmat[, fixcols] = xmat[, fixcols]/2
  xmat %<>% structure(ploidy = ploidy)
  xmat
}


#' Turns f2_data into f2_blocks
#'
#' @param f2_data f2 data as genotype file prefix, f2 directory, or f2 blocks
#' @param pops Populations for which to extract f2-stats (defaults to all)
#' @param pops2 Optional second vector of populations. Can result in non-square array
#' @param afprod Get allele frequency products
#' @param verbose Print progress updates
#' @param ... Additional arguments passed to \code{\link{f2_from_precomp}} or \code{\link{f2_from_geno}}
#' @return A 3d array with f2-statistics
get_f2 = function(f2_data, pops = NULL, pops2 = NULL, afprod = FALSE, verbose = TRUE, ...) {

  stopifnot(!is.character(f2_data) || dir.exists(f2_data) || is_geno_prefix(f2_data))
  argnam = c(names(formals(f2_from_geno)), names(formals(f2_from_precomp)), names(formals(qpfstats)))
  if(!all(...names() %in% argnam)) {
    notused = setdiff(...names(), argnam)
    stop(paste0("The following arguments are not recognized: '", paste0(notused, collapse = "', '"), "'"))
  }
  if(!is.null(pops) && is.null(pops2)) pops2 = pops
  if(!is.character(f2_data)) {
    f2_blocks = f2_data
  } else if(dir.exists(f2_data)) {
    f2_blocks = f2_from_precomp(f2_data, pops = pops, pops2 = pops2, afprod = afprod, verbose = verbose, ...)
  } else {
    f2_blocks = f2_from_geno(f2_data, pops = union(pops, pops2), afprod = afprod, verbose = verbose, ...)
  }
  if(is.null(pops)) pops = dimnames(f2_blocks)[[1]]
  if(is.null(pops2)) pops2 = dimnames(f2_blocks)[[2]]
  blockpops = union(dimnames(f2_blocks)[[1]], dimnames(f2_blocks)[[2]])
  allpops = union(pops, pops2)
  if(!all(allpops %in% blockpops)) {
    stop(paste0('Requested, but not in f2_data: ', paste(setdiff(allpops, blockpops), collapse = ', ')))
  }
  f2_blocks = f2_blocks[pops, pops2, , drop = FALSE]
  f2_blocks
}



test_structutured_missingness = function(mat) {
  # tests if missingness in one column is correlated with values in another column
  # for all column pairs

  # change so that it tests afdiff, not af

  namat = is.na(mat)
  cormat = cor(mat, namat, use = 'p')
  nvec = rep(colSums(!namat), ncol(mat))
  semat = sqrt((1-cormat^2) / (nvec - 2))
  pmat = pt(-abs(cormat/semat), nvec-2)*2
  pmat
}



average_f4blockdat = function(f4blockdat, checkcomplete = FALSE) {
  # takes a data frame generated from all popcombs by f4blockdat_from_geno and adds columns est_avg, n_avg
  # est_avg = weighted.mean(est * sqrt(n)) for all est with two or more pops in common
  # f4(A, B; C, D) = f4(A, X; C, Y) + f4(A, X; Y, D) + f4(X, B; C, Y) + f4(X, B, Y, D)
  # * choose(npop-2, 2) / choose(npop-1, 2)
  # assumes all combinations are present

  t1 = table(f4blockdat$pop1)
  t2 = table(f4blockdat$pop2)
  t3 = table(f4blockdat$pop3)
  t4 = table(f4blockdat$pop4)
  nblocks = length(unique(f4blockdat$block))

  incomplete = FALSE
  if(nrow(f4blockdat) == 4*choose(length(t1),2)*choose(length(t3),2)*nblocks) {
    denom = 2
    if(t1 != t2 || t3 != t4) incomplete = TRUE
  } else if(nrow(f4blockdat) == choose(length(t1), 4)*factorial(4)*nblocks) {
    if(length(unique(c(t1, t2, t3, t4))) != 1) incomplete = TRUE
    npop = f4blockdat$pop1 %>% unique %>% length
    denom = choose(npop-1, 2) / choose(npop-2, 2)
  } else {
    denom = 2
    incomplete = TRUE
  }
  if(checkcomplete && incomplete) stop("Data doesn't appear to have all combinations!")

  f4blockdat %>%
    group_by(block, pop1, pop3) %>% mutate(e13 = weighted.mean(est, sqrt(n), na.rm=T), n13 = mean(n)) %>%
    group_by(block, pop2, pop4) %>% mutate(e24 = weighted.mean(est, sqrt(n), na.rm=T), n24 = mean(n)) %>%
    group_by(block, pop2, pop3) %>% mutate(e23 = weighted.mean(est, sqrt(n), na.rm=T), n23 = mean(n)) %>%
    group_by(block, pop1, pop4) %>% mutate(e14 = weighted.mean(est, sqrt(n), na.rm=T), n14 = mean(n)) %>%
    # group_by(block, pop1, pop3) %>% mutate(e13 = mean(est, na.rm=T), n13 = mean(n)) %>%
    # group_by(block, pop2, pop4) %>% mutate(e24 = mean(est, na.rm=T), n24 = mean(n)) %>%
    # group_by(block, pop2, pop3) %>% mutate(e23 = mean(est, na.rm=T), n23 = mean(n)) %>%
    # group_by(block, pop1, pop4) %>% mutate(e14 = mean(est, na.rm=T), n14 = mean(n)) %>%
    ungroup %>%
    mutate(est_avg = (e13 + e24 + e23 + e14) / denom,
           n_avg =   (n13 + n24 + n23 + n14) / 4) %>%
    select(-e13:-n14)
}


#' Estimate joint allele frequency spectrum
#'
#' @export
#' @param afs A matrix or data frame of allele frequencies
#' @return A data frame with columns `pattern` and `proportion`
#' @examples
#' \dontrun{
#' dat = plink_to_afs('/my/plink/file', pops = c('pop1', 'pop2', 'pop3', 'pop4', 'pop5'))
#'
#' # Spectrum across all SNPs
#' joint_spectrum(dat$afs)
#'
#' # Stratify by allele frequency in one population
#' dat$afs %>% as_tibble %>% select(1:4) %>%
#'   group_by(grp = cut(pop1, 10)) %>%
#'   group_modify(joint_spectrum) %>%
#'   ungroup
#'
#' # Stratify by mutation class
#' dat$afs %>% as_tibble %>% select(1:4) %>%
#'   mutate(mut = paste(dat$snpfile$A1, dat$snpfile$A2)) %>%
#'   group_by(mut) %>%
#'   group_modify(joint_spectrum) %>%
#'   ungroup
#' }
joint_spectrum = function(afs) {

  afs %<>% as.matrix
  if(class(afs[,1]) != 'numeric') stop("'afs' should only have numeric columns!")
  npop = ncol(afs)
  ps = power_set(seq_len(npop))
  allanc = rep('0', npop)
  names = ps %>% map(~{x = allanc; x[.] = '1'; x}) %>%
    prepend(list(allanc)) %>%
    map_chr(~paste(., collapse=''))
  snpcounts = ps %>%
    map(~row_prods(afs[,.,drop=F]) * row_prods(1-afs[,-.,drop=F])) %>%
    prepend(list(row_prods(1-afs))) %>%
    bind_cols(.name_repair = ~names)
  obs = snpcounts %>% as.matrix %>% is.na %>% `!` %>% colSums %>% enframe('pattern', 'total')
  snpcounts %>%
    colMeans(na.rm = TRUE) %>%
    enframe('pattern', 'proportion') %>%
    left_join(obs, by = 'pattern') %>%
    arrange(pattern)
}



#' Joint site frequency spectrum
#'
#' This function computes the joint site frequency spectrum from genotype files or allele frequency and count matrices.
#' The joint site frequency spectrum lists how often each combination of haplotypes is observed.
#' The number of combinations is equal to the product of one plus the number of haplotypes in each population.
#' For example, five populations with a single diploid individual each have 3^5 possible combinations.
#' @export
#' @param afs A named list of length two where the first element (`afs`) contains the allele frequency matrix,
#' and the second element (`counts`) contains the allele count matrix.
#' @param pref Instead of `afs`, the prefix of genotype files can be provided.
#' @return A data frame with the number of times each possible combination of allele counts is observed.
#' @examples
#' \dontrun{
#' dat = plink_to_afs('/my/plink/file', pops = c('pop1', 'pop2', 'pop3', 'pop4', 'pop5'))
#' joint_sfs(dat)
#' }
joint_sfs = function(afs, pref = NULL) {

  if(!is.null(pref)) {
    if(!is.null(afs)) stop("'afs' and 'pref' can't be provided at the same time!")
    afs = geno_to_afs(pref)
  }
  popcounts = apply(afs$counts, 2, max)
  obs = (afs$afs * afs$counts) %>% as_tibble() %>% group_by_all %>% count %>% ungroup
  expand_grid(!!!popcounts %>% map(~0:.)) %>% left_join(obs) %>% mutate(n = replace_na(n, 0)) %>% suppressMessages()
}

sfs_to_f2 = function(sfs) {

  sfs2 = sfs %>% mutate(across(all_of('n'), ~./sum(.)), across(!all_of('n'), ~./max(.)))
  pops = colnames(sfs)[-ncol(sfs)]
  nvec = apply(sfs[-ncol(sfs)], 2, max, na.rm=TRUE)
  cmb = t(combn(1:length(pops), 2))
  map(1:nrow(cmb), ~{
    x1 = cmb[.,1]
    x2 = cmb[.,2]
    n1 = nvec[x1]
    n2 = nvec[x2]
    sfs2 %>% select(x1, x2, n) %>% set_colnames(c('p1', 'p2', 'n')) %>%
      mutate(f2 = (p1-p2)^2 - p1*(1-p1)/max(1,n1-1) - p2*(1-p2)/max(1,n2-1),
             denom = p1 + p2 - 2*p1*p2) %>%
      summarize(f2 = sum(n*f2), fst = sum(n*f2)/sum(n*denom)) %>%
      mutate(pop1 = pops[x1], pop2 = pops[x2])
  }) %>% bind_rows() %>% select(pop1, pop2, f2, fst)
}

# f2_from_sfs = function(sfs, nblocks = 100) {
#
#   pops = setdiff(colnames(sfs), 'n')
#   nsnps = sum(sfs$n)
#   n2 = map(sfs$n, ~if(.==0) rep(0, nblocks) else diff(sort(c(0, sample(seq_len(.), nblocks-1, replace=T), .)))) %>%
#     do.call(rbind, .)
#   x = map(seq_len(nblocks), ~mutate(sfs, n = n2[,.])) %>% map(sfs_to_f2) %>% bind_rows(.id = 'block') %>%
#     transmute(block = as.numeric(block), pop1, pop2, f2) %>% bind_rows(rename(., pop1 = pop2, pop2 = pop1)) %>%
#     bind_rows(expand_grid(block = seq_len(nblocks), pop1 = pops, f2 = 0) %>% mutate(pop2 = pop1)) %>%
#     arrange(block, pop1, pop2)
#   array(x$f2, c(length(pops), length(pops), nblocks), list(pops, pops, rep(paste0('l', round(nsnps/nblocks)), nblocks)))
# }


#' Turn f2 data to f4 data
#'
#' @export
#' @param f2dat A data frame of f2-statistics with columns `pop1`, `pop2`, `f2`
#' @return A data frame with f4-statistics
f2dat_f4dat = function(f2dat, popcomb = NULL) {

  pops = unique(c(f2dat$pop1, f2dat$pop2))
  if(is.null(popcomb)) {
    popcomb = expand_grid(pop1 = pops, pop2 = pops, pop3 = pops, pop4 = pops) %>%
      #filter(pop1 < pop2, pop1 < pop3, pop1 < pop4, pop3 < pop4, pop2 != pop3, pop2 != pop4) %>%
      filter(pop1 != pop2, pop3 != pop4)
  }
  f2dat %<>% bind_rows(rename(., pop1=pop2, pop2=pop1)) %>%
    bind_rows(tibble(pop1 = pops, pop2 = pops, f2 = 0)) %>% rename(p1 = pop1, p2 = pop2) %>% distinct
  popcomb %>%
    left_join(f2dat %>% transmute(pop1 = p1, pop3 = p2, f13 = f2)) %>%
    left_join(f2dat %>% transmute(pop2 = p1, pop4 = p2, f24 = f2)) %>%
    left_join(f2dat %>% transmute(pop1 = p1, pop4 = p2, f14 = f2)) %>%
    left_join(f2dat %>% transmute(pop2 = p1, pop3 = p2, f23 = f2)) %>%
    mutate(f4 = (f14 + f23 - f13 - f24)/2) %>% ungroup %>% select(pop1:pop4, f4) %>% suppressMessages
}

#' Count SNPs in an f2-statistics array
#'
#' This function adds up all block lengths (number of SNPs in each SNP blocks), which are stored in the names along the third dimension of the array. When the f2-statistics were computed while setting `maxmiss` to values greater than 0, it is possible that not all f2-statistics are based on the same number of SNPs. In that case, the value returned by this function is merely an upper bound on the number of SNPs used.
#' @export
#' @param f2_blocks A 3d array of per block f2-statistics
#' @return The total number of SNPs across all blocks
#' @seealso \code{\link{f2_from_geno}}, \code{\link{f2_from_precomp}}
count_snps = function(f2_blocks) {
  sum(parse_number(dimnames(f2_blocks)[[3]]))
}


complete_f4dat = function(dat) {
  # dat has columns pop1:pop4, est

  dat %>% bind_rows(rename(., pop1 = pop2, pop2 = pop1) %>%
                      mutate(across(any_of(c('est', 'z', 'Z', 'f4')), ~-.))) %>%
    bind_rows(rename(., pop1 = pop2, pop2 = pop1, pop3 = pop4, pop4 = pop3)) %>%
    bind_rows(rename(., pop1 = pop3, pop3 = pop1, pop2 = pop4, pop4 = pop2)) %>%
    distinct
}
uqrmaie1/admixtools documentation built on March 20, 2024, 8:24 a.m.