R/sparse_linear.R

Defines functions .sparse_linear

.sparse_linear = function(mat, wins_quant, method, soft, alpha_grid,
                          thresh_len, n_cv, thresh_hard, max_p) {
    # Thresholding
    mat_thresh = function(mat, th, soft){
        mat_sign = sign(mat)
        mat_th = mat
        mat_th[abs(mat) <= th] = 0
        if (soft) {
            mat_th[abs(mat) > th] = abs(mat_th[abs(mat) > th]) - th
            mat_th = mat_th * mat_sign
        }
        return(mat_th)
    }

    # Threshold loss function
    thresh_loss = function(mat1, mat2, method, soft, th, alpha) {
        corr1 = cor(mat1, method = method, use = "pairwise.complete.obs")
        corr2 = cor(mat2, method = method, use = "pairwise.complete.obs")
        corr1_th = mat_thresh(corr1, th, soft)
        corr_diff = corr1_th - corr2
        corr_diff[is.na(corr_diff)] = 0

        # Compute Frobenius norm
        fro_norm = norm(corr_diff, type = "F")

        # Compute sparsity penalty if alpha > 0
        if (alpha > 0) {
            epsilon = 1e-10  # Small constant to prevent division by zero
            weight = 1/abs(corr2 + epsilon)
            weight[corr2 == 0] = 0
            diag(weight) = 0
            sparsity_penalty = alpha * sum(abs(weight * corr1_th))
        } else {
            sparsity_penalty = 0
        }

        # Total loss
        loss = 0.5 * fro_norm + sparsity_penalty
        return(loss)
    }

    # Filtering based on p-values
    p_filter = function(mat, mat_p, max_p){
        ind_p = mat_p
        ind_p[mat_p > max_p] = 0
        ind_p[mat_p <= max_p] = 1

        mat_filter = mat * ind_p
        return(mat_filter)
    }

    # Sort taxa
    sort_taxa = sort(colnames(mat))
    mat = mat[, sort_taxa]

    # Winsorization
    mat = apply(mat, 2, function(x)
        DescTools::Winsorize(x, val = stats::quantile(x, probs = wins_quant, na.rm = TRUE)))

    # Co-occurrence matrix
    mat_occur = mat
    mat_occur[mat_occur != 0] = 1
    mat_occur[mat_occur == 0] = 0
    mat_occur[is.na(mat_occur)] = 0

    df_occur = as.data.frame(mat_occur)
    df_occur$sample_id = rownames(df_occur)
    df_occur_long = stats::reshape(df_occur,
                                   direction = "long",
                                   varying = list(colnames(mat_occur)),
                                   v.names = "occur",
                                   idvar = "sample_id",
                                   times = colnames(mat_occur),
                                   new.row.names = seq_len(nrow(df_occur)*ncol(df_occur)))
    names(df_occur_long)[names(df_occur_long) == "time"] = "taxon"
    df_occur_long = df_occur_long[df_occur_long$occur == 1, ]

    mat_cooccur = matrix(0, nrow = ncol(mat_occur), ncol = ncol(mat_occur))
    rownames(mat_cooccur) = colnames(mat_occur)
    colnames(mat_cooccur) = colnames(mat_occur)

    mat_cooccur_comp = crossprod(table(df_occur_long[, seq_len(2)]))
    idx = base::match(colnames(mat_cooccur_comp), colnames(mat_cooccur))
    mat_cooccur[idx, idx] = mat_cooccur_comp
    diag(mat_cooccur) = colSums(mat_occur)

    if (any(mat_cooccur < 10)) {
        warn_txt = sprintf(paste("There are some pairs of taxa that have insufficient (< 10) overlapping samples",
                                 "Proceed with caution since the point estimates for these pairs are unstable",
                                 "For pairs of taxa with no overlapping samples, the point estimates will be replaced with 0s,",
                                 "and the corresponding p-values will be replaced with 1s",
                                 "Please check `mat_cooccur` for details about the co-occurrence pattern",
                                 sep = "\n"))
        warning(warn_txt)
    }

    # Regularization of the covariance matrix
    if(method == "spearman"){
        # Convert to rank
        mat = apply(mat,2,function(x) {
            r = rank(x, na.last = NA)
            x[!is.na(x)] = r
            return(x)
        }
        )
    }
    # Covariance matrix
    cov_mat = stats::cov(mat, use = "pairwise.complete.obs")
    cov_mat[is.na(cov_mat)] = 0

    # Regularize the covariance matrix
    cov_mat_pos = .regularize_eigenvalues(cov_mat)
    cov_mat_pos[mat_cooccur < 2] = 0
    cov_mat_pos[is.infinite(cov_mat_pos)] = 0

    # Check if it is positive semi-definite, if not repeat the regularization process
    while(!.is_psd(cov_mat_pos)) {
        cov_mat_pos = .regularize_eigenvalues(cov_mat_pos)
        cov_mat_pos[mat_cooccur < 2] = 0
        cov_mat_pos[is.infinite(cov_mat_pos)] = 0
    }
    # Convert to correlation coefficient
    corr_reg = cov2cor(cov_mat_pos)

    # Sample size for training and test sets
    n = dim(mat)[1]
    n1 = n - floor(n/log(n))
    n2 = n - n1
    d = dim(mat)[2]

    # Correlation matrix
    corr_list = suppressWarnings(Hmisc::rcorr(x = mat, type = method))
    corr = corr_list$r
    corr[mat_cooccur < 2] = 0
    corr[is.infinite(corr)] = 0

    # Cross-Validation
    max_thresh = max(abs(corr[corr != 1]), na.rm = TRUE)
    thresh_grid = seq(from = 0, to = max_thresh, length.out = thresh_len)
    if (is.null(alpha_grid)) alpha_grid = 0
    param_grid = expand.grid(thresh = thresh_grid, alpha = alpha_grid)

    loss_mat = foreach(i = seq_len(n_cv), .combine = rbind) %dorng% {
        # Create training and validation splits
        index = sample(seq_len(n), size = n1, replace = FALSE)
        mat1 = mat[index,]
        mat2 = mat[-index,]

        # Calculate loss for each parameter combination
        loss = apply(param_grid, 1, function(params) {
            thresh_loss(
                mat1 = mat1,
                mat2 = mat2,
                method = method,
                soft = soft,
                th = params["thresh"],
                alpha = params["alpha"]
            )
        })
    }

    # Calculate mean loss across CV folds
    mean_losses = colMeans(loss_mat)

    # Find optimal parameters
    opt_index = which.min(mean_losses)
    thresh_opt = param_grid$thresh[opt_index]
    alpha_opt = param_grid$alpha[opt_index]

    # Apply optimal thresholding
    corr = cor(mat, method = method, use = "pairwise.complete.obs")
    corr_th = mat_thresh(mat = corr, th = thresh_opt, soft = soft)
    corr_th = mat_thresh(mat = corr_th, th = thresh_hard, soft = FALSE)

    # Correlation matrix after filtering
    corr_p = corr_list$P
    diag(corr_p) = 0
    corr_p[mat_cooccur < 2] = 1
    corr_p[is.na(corr_p)] = 1
    corr_p[is.infinite(corr_p)] = 1
    corr_fl = p_filter(mat = corr, mat_p = corr_p, max_p = max_p)
    corr_fl = mat_thresh(mat = corr_fl, th = thresh_hard, soft = FALSE)

    # Output
    result = list(cv_error = mean_losses,
                  thresh_grid = thresh_grid,
                  thresh_opt = thresh_opt,
                  alpha_opt = alpha_opt,
                  mat_cooccur = mat_cooccur,
                  corr = corr,
                  corr_p = corr_p,
                  corr_th = corr_th,
                  corr_fl = corr_fl,
                  corr_reg = corr_reg)
    return(result)
}
FrederickHuangLin/ANCOMBC documentation built on June 11, 2025, 6:22 p.m.